Notebook 6: Ray tracing and interactive data visualization on the GPU
Here we illustrate how one can build custom widgets for interactive data visualization. We use the Paicos CUDA GPU implementation of ray tracing to achieve the necessary speed.
Required python packages
This notebook requires that you have the GPU requirements installed and available on your system and that you have modified your Paicos user settings to load GPU functionality on startup. Please see the details here: https://paicos.readthedocs.io/en/latest/installation.html#gpu-cuda-requirements
This notebook also requires that you have a working version of ipywidgets, which might sometimes be a bit cumbersome to get working (I have this working in a Jupyter notebook but there is a risk that you will have trouble if you are using Jupyter Lab). You can simply try
pip install ipywidgets
and if that does not work then you can try with the detailed instructions found here: https://ipywidgets.readthedocs.io/en/stable/user_install.html
You can check that ipywidgets is working by uncommenting and executing the following example:
[3]:
# from ipywidgets import interact
# @interact
# def greet(name="World", count=5):
# for _ in range(count):
# print(f"Hello, {name}!")
# greet()
[1]:
%matplotlib widget
[2]:
import cupy as cp
from numba import cuda
Do a manual pre-selection
We start by loading a snapshot and selecting just a part of it. We limit ourselves because the GPU has limited memory and the GPU ray tracer class builds a binary tree spanning the entire snapshot.
[3]:
import paicos as pa
import numpy as np
pa.use_units(True)
snapnum = 247
try:
snap = pa.Snapshot(pa.data_dir + 'highres', snapnum)
except FileNotFoundError as e:
print(e)
err_msg = ('This example is much more fun with a large data set.\nPlease see: '
+ 'https://github.com/tberlok/paicos/tree/main/data/highres/README.md'
+ ' for download instructions.\n'
+ 'For now we simply load the low resolution data set.')
print(err_msg)
snap = pa.Snapshot(pa.data_dir, snapnum)
center = snap.Cat.Group['GroupPos'][0]
R200c = snap.Cat.Group['Group_R_Crit200'][0]
widths = np.array([10000, 10000, 10000]) * R200c.uq
# Create subset of snapshot
index = pa.util.get_index_of_radial_range(snap['0_Coordinates'], center, 0, np.max(widths)*np.cbrt(3))
snap = snap.select(index, parttype=0)
# Pixel dimensions of image
nx = ny = 1024
Initialize the GPU projector
Here we use a Paicos orientation class to initialize the view such that the width of the image is along the \(x\)-coordinate of the simulation and the height of the image is along the \(y\)-coordinate. The depth of the image is in the \(z\)-direction.
The orientation class has methods for rotating the view around \(x\), \(y\), and \(z\) or around the axes of its local coordinate system. When an orientation instance has been passed to an ImageCreator (such as the projector below), then calling these methods will result in a rotation around the center of the image.
[4]:
orientation = pa.Orientation(normal_vector=[0, 0, 1], perp_vector1=[1, 0, 0])
projector = pa.GpuRayProjector(snap, center, widths, orientation, npix=nx, threadsperblock=8, do_pre_selection=False)
Attempting to get derived variable: 0_Volume... [DONE]
[5]:
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import ipywidgets as widgets
Some code for sorting FoF and subfind catalogues
This is just for plotting the 20 most massive FoF or subhalos that are inside the projection cube. Mainly for the widget so no need to read this.
[6]:
def get_group_and_sub_indices():
info = {}
if hasattr(projector.snap.Cat, 'Sub'):
sub_in_region_bool = pa.util.get_index_of_rotated_cubic_region_plus_thin_layer(projector.snap.Cat.Sub['SubhaloPos'],
projector.center, projector.widths,
projector.snap.Cat.Sub['SubhaloHalfmassRad'],
projector.snap.box, projector.orientation)
info['Subhalo_ids'] = np.arange(sub_in_region_bool.shape[0])[sub_in_region_bool]
info['SubhaloPos'] = projector.snap.Cat.Sub['SubhaloPos'][sub_in_region_bool]
info['SubhaloHalfmassRad'] = projector.snap.Cat.Sub['SubhaloHalfmassRad'][sub_in_region_bool]
info['SubhaloMass'] = projector.snap.Cat.Sub['SubhaloMass'][sub_in_region_bool]
# Sort according to mass
sort_index = np.argsort(info['SubhaloMass'])[::-1]
info['Subhalo_ids'] = info['Subhalo_ids'][sort_index]
info['SubhaloPos'] = info['SubhaloPos'][sort_index]
info['SubhaloHalfmassRad'] = info['SubhaloHalfmassRad'][sort_index]
info['SubhaloMass'] = info['SubhaloMass'][sort_index]
if hasattr(projector.snap.Cat, 'Group'):
group_in_region_bool = pa.util.get_index_of_rotated_cubic_region_plus_thin_layer(projector.snap.Cat.Group['GroupPos'],
projector.center, projector.widths,
projector.snap.Cat.Group['Group_R_Crit200'],
projector.snap.box, projector.orientation)
info['Group_ids'] = np.arange(group_in_region_bool.shape[0])[group_in_region_bool]
info['GroupPos'] = projector.snap.Cat.Group['GroupPos'][group_in_region_bool]
info['Group_R_Crit200'] = projector.snap.Cat.Group['Group_R_Crit200'][group_in_region_bool]
info['Group_M_Crit200'] = projector.snap.Cat.Group['Group_M_Crit200'][group_in_region_bool]
# Sort according to mass
sort_index = np.argsort(info['Group_M_Crit200'])[::-1]
info['Group_ids'] = info['Group_ids'][sort_index]
info['GroupPos'] = info['GroupPos'] [sort_index]
info['Group_R_Crit200'] = info['Group_R_Crit200'][sort_index]
info['Group_M_Crit200'] = info['Group_M_Crit200'][sort_index]
return info
The interactive widget
The rather long code below defines an interactive ipython widget with a number of hopefully mostly self-explanatory buttons.
These can zoom in/out, rotate the image, change width, height and depth etc.
Pressing the ‘Recording’ tick mark will output a png/hdf5 every time the wiev changes (if those boxes are ticked). These will be saved in the directory entered in the box just to the right of png tick box. The recording also saves a .log file with a series of commands that can be used to reproduce an interactive session. This allows for using an interactive session as a starting point for creating an animation of a simulation snapshot.
We have left all the code for the widget visible instead of saving it somewhere else and importing it. We hope that it will in this way be easier for someone to extend/modify the code to their own needs.
The data included with Paicos is very low resolution and this notebook does not really showcase how well the GPU code works at higher resolution. We have tried with a \(12^3\) times better mass resolution simulation (equivalent to the resolution in the TNG300 simulation) and find that an A100 GPU is fast enough to give a smooth user experience. Download instructions for this data set can be found here.
[7]:
def update():
proj = projector.project_variable(var_str.value)
extent = projector.centered_extent
if to_physical.value:
proj = proj.to_physical
extent = extent.to_physical
if to_cgs.value:
proj = proj.cgs
extent = extent.cgs
if to_astro_units.value:
proj = proj.astro
extent = extent.astro
fig = plt.figure(1)
plt.clf()
# Deal with color limits and do plot
if fix_climits.value:
vmin.disabled = False
vmax.disabled = False
if vmin.value > 0 and vmax.value > 0:
pass
else:
vmin.value = proj.value.min()
vmax.value = proj.value.max()
im = plt.imshow(proj.value, extent=extent.value, origin='lower',
norm=LogNorm(vmin.value, vmax.value), cmap=cmap_str.value)
else:
vmin.value = proj.value.min()
vmax.value = proj.value.max()
vmin.disabled = True
vmax.disabled = True
im = plt.imshow(proj.value, extent=extent.value, origin='lower',
norm=LogNorm(), cmap=cmap_str.value)
# Labels
plt.xlabel(extent.label())
plt.ylabel(extent.label())
# Colorbar
cb = plt.colorbar()
cb.set_label(proj.label('\\mathrm{' + var_str.value.replace('_', '\_') + '}\,'))
# Title
title_str = f'Snapnum: {snapnum}, Age: {snap.age:1.2f}, Redshift: {snap.z:1.2f}'
plt.title(title_str)
# Add subs/groups
# TODO: Get rid of mostly duplicate code for groups/subhalos
select_center.disabled = True
if hasattr(projector.snap, 'Cat'):
info = get_group_and_sub_indices()
if 'Group_ids' in info and show_groups.value:
select_center.disabled = False
orientation = projector.orientation
points = info['GroupPos'] - projector.center
points = np.matmul(orientation.inverse_rotation_matrix, points.T).T
if to_physical.value:
points = points.to_physical
info['Group_R_Crit200'] = info['Group_R_Crit200'].to_physical
if to_cgs.value:
points = points.cgs
info['Group_R_Crit200'] = info['Group_R_Crit200'].cgs
if to_astro_units.value:
points = points.astro
info['Group_R_Crit200'] = info['Group_R_Crit200'].astro
ax = plt.gca()
options = []
for ii in range(points.shape[0]):
if ii >= 20 or info['Group_M_Crit200'][ii].value == 0:
break
circ = plt.Circle((points[ii, 0].value, points[ii, 1].value),
info['Group_R_Crit200'][ii].value, color='k', fill=False)
ax.add_patch(circ)
plt.text(points[ii, 0].value, points[ii, 1].value, f"G{info['Group_ids'][ii]}", fontsize=6)
options.append(f"G{info['Group_ids'][ii]}")
select_center.options = list(options)
if 'Subhalo_ids' in info and show_subs.value:
select_center.disabled = False
orientation = projector.orientation
points = info['SubhaloPos'] - projector.center
points = np.matmul(orientation.inverse_rotation_matrix, points.T).T
if to_physical.value:
points = points.to_physical
info['SubhaloHalfmassRad'] = info['SubhaloHalfmassRad'].to_physical
if to_cgs.value:
points = points.cgs
info['SubhaloHalfmassRad'] = info['SubhaloHalfmassRad'].cgs
if to_astro_units.value:
points = points.astro
info['SubhaloHalfmassRad'] = info['SubhaloHalfmassRad'].astro
ax = plt.gca()
options = []
for ii in range(points.shape[0]):
if ii >= 20 or info['SubhaloMass'][ii].value == 0:
break
circ = plt.Circle((points[ii, 0].value, points[ii, 1].value),
info['SubhaloHalfmassRad'][ii].value, color='k', fill=False)
ax.add_patch(circ)
plt.text(points[ii, 0].value, points[ii, 1].value, f"S{info['Subhalo_ids'][ii]}", fontsize=6)
options.append(f"S{info['Subhalo_ids'][ii]}")
select_center.options = list(options)
if recording.value:
if hdf5.value:
image_file = pa.ArepoImage(projector, basedir=outfolder.value,
basename=f'{basename.value}_{var_str.value}_frame_{frame_counter.value}')
image_file.save_image(var_str.value, proj)
# Move from temporary filename to final filename
image_file.finalize()
if frame_counter.value == 0:
mylogger(image_file.filename)
mylogger(outfolder.value)
mylogger(basename.value)
mylogger(var_str.value)
org_center = projector.center - projector._diff_center
mylogger(f"org_center,{org_center[0].value},{org_center[1].value},{org_center[2].value}")
if png.value:
plt.savefig(f'{outfolder.value}/{basename.value}_{var_str.value}_frame_{frame_counter.value}_{projector.snap.snapnum}.png', dpi=700)
frame_counter.value += 1
plt.show()
def mylogger(string, line_ending='\n', mode='a'):
with open(f'{outfolder.value}/{basename.value}.log', mode) as f:
f.write(string + line_ending)
width = widgets.FloatSlider(value=projector.width.value, min=1, max=2*widths[0].value, step=1, description="width", continuous_update=False)
height = widgets.FloatSlider(value=projector.height.value, min=1, max=2*widths[1].value, step=1, description="height", continuous_update=False)
depth = widgets.FloatSlider(value=projector.depth.value, min=1, max=2*widths[2].value, step=1, description="depth", continuous_update=False)
zoom_slider = widgets.FloatSlider(value=1, min=0.1, max=10, step=0.1, description="Zoom factor")
center_horizontal = widgets.BoundedFloatText(
value=0,
min=-widths[0].value,
max=widths[0].value,
step=1,
description='Horizontal',
disabled=False
)
center_vertical = widgets.BoundedFloatText(
value=0,
min=-widths[0].value,
max=widths[0].value,
step=1,
description='Vertical',
disabled=False
)
center_depth = widgets.BoundedFloatText(
value=0,
min=-widths[0].value,
max=widths[0].value,
step=1,
description='Depth',
disabled=False
)
def click_step_button(b):
if center_horizontal.value != 0:
projector.move_center_along_perp_vector1(center_horizontal.value * projector.width.uq)
if center_vertical.value != 0:
projector.move_center_along_perp_vector2(center_vertical.value * projector.width.uq)
if center_depth.value != 0:
projector.move_center_along_normal_vector(center_depth.value * projector.width.uq)
if recording.value:
mylogger(f'move_center,{center_horizontal.value},{center_vertical.value},{center_depth.value}')
update()
def press_reset_center_button(b):
diff = projector._diff_center
org_center = projector.center - diff
if recording.value:
mylogger(f"move_center_sim_coordinates,{diff[0].value},{diff[1].value},{diff[2].value}", line_ending=',#,')
mylogger(f"reset_center_to_org_center,{org_center[0].value},{org_center[1].value},{org_center[2].value}")
projector.center = org_center
projector._diff_center[:] = 0
update()
step_button = widgets.Button(description="Move center")
recenter_button = widgets.Button(description="Reset center")
step_button.on_click(click_step_button)
recenter_button.on_click(press_reset_center_button)
## Create dropdown for parttype 0 only (TODO: Remove vectors and tensors from list, or add box for selecting component)
avail_list = []
for key in snap._auto_list:
if key[0] == '0':
avail_list.append(key)
var_str = widgets.Dropdown(options=avail_list,
value='0_Density',
description='Field:',
)
select_center = widgets.Dropdown(options=[],
value=None,
description='Center on:',
disabled=True
)
# Cmap dropdown
cmap_str = widgets.Dropdown(options=plt.colormaps(),
value='viridis',
description='Cmap:',
)
# vmin and vmax
vmin = widgets.FloatText(
description='vmin:',
disabled=True
)
vmax = widgets.FloatText(
description='vmax:',
disabled=True
)
fix_climits = widgets.Checkbox(
value=False,
description='Fix climits',
disabled=False,
indent=False
)
#
button_left = widgets.Button(description="Pan left")
button_right = widgets.Button(description="Pan right")
button_up = widgets.Button(description="Pan up")
button_down = widgets.Button(description="Pan down")
button_clock_wise = widgets.Button(description="Clockwise")
button_anti_clock_wise = widgets.Button(description="Anti-clockwise")
step_size_in_degrees = widgets.FloatSlider(value=15, min=0, max=90, step=0.5, description="Step (Degrees)")
button_update = widgets.Button(description="Update")
def call_update(b):
update()
button_update.on_click(call_update)
def call_double_resolution(b):
projector.double_resolution
if recording.value:
mylogger('double_resolution')
update()
def call_half_resolution(b):
projector.half_resolution
if recording.value:
mylogger('half_resolution')
update()
button_double = widgets.Button(description="Double res")
button_double.on_click(call_double_resolution)
button_half = widgets.Button(description="Half res")
button_half.on_click(call_half_resolution)
def call_zoom(b):
zoom_button_was_pressed.value = True
projector.zoom(zoom_slider.value)
if recording.value:
mylogger(f'zoom,{zoom_slider.value}')
# Check that new widths are not completely unreasonable!
# Do check
width.value = projector.width.value
height.value = projector.height.value
update()
zoom_button_was_pressed.value = False
button_zoom = widgets.Button(description="Zoom in/out")
button_zoom.on_click(call_zoom)
zoom_button_was_pressed = widgets.Checkbox(
value=False, description='internal boolean for avoiding calling call_update twice when zooming',
)
to_physical = widgets.Checkbox(
value=False,
description='physical units',
disabled=False,
indent=False
)
to_cgs = widgets.Checkbox(
value=False,
description='cgs units',
disabled=False,
indent=False
)
to_astro_units = widgets.Checkbox(
value=False,
description="'astro' units",
disabled=False,
indent=False
)
to_physical.observe(call_update, names=['value'])
to_cgs.observe(call_update, names=['value'])
to_astro_units.observe(call_update, names=['value'])
def change_width(change):
projector.width = width.value * projector.width.uq
if not zoom_button_was_pressed.value:
if recording.value:
mylogger(f'width,{width.value}')
update()
def change_height(change):
projector.height = height.value * projector.height.uq
if not zoom_button_was_pressed.value:
if recording.value:
mylogger(f'height,{height.value}')
update()
def change_depth(change):
projector.depth = depth.value * projector.depth.uq
if recording.value:
mylogger(f'depth,{depth.value}')
update()
def change_var_str(change):
if recording.value:
mylogger(var_str.value)
update()
def pan_left(b):
projector.orientation.rotate_around_perp_vector2(degrees=step_size_in_degrees.value)
if recording.value:
mylogger(f'rotate_around_perp_vector2,{step_size_in_degrees.value}')
update()
def pan_right(b):
projector.orientation.rotate_around_perp_vector2(degrees=-step_size_in_degrees.value)
if recording.value:
mylogger(f'rotate_around_perp_vector2,{-step_size_in_degrees.value}')
update()
def pan_up(b):
projector.orientation.rotate_around_perp_vector1(degrees=step_size_in_degrees.value)
if recording.value:
mylogger(f'rotate_around_perp_vector1,{step_size_in_degrees.value}')
update()
def pan_down(b):
projector.orientation.rotate_around_perp_vector1(degrees=-step_size_in_degrees.value)
if recording.value:
mylogger(f'rotate_around_perp_vector1,{-step_size_in_degrees.value}')
update()
def clock_wise(b):
projector.orientation.rotate_around_normal_vector(degrees=step_size_in_degrees.value)
if recording.value:
mylogger(f'rotate_around_normal_vector,{step_size_in_degrees.value}')
update()
def anti_clock_wise(b):
projector.orientation.rotate_around_normal_vector(degrees=-step_size_in_degrees.value)
if recording.value:
mylogger(f'rotate_around_normal_vector,{-step_size_in_degrees.value}')
update()
button_left.on_click(pan_left)
button_right.on_click(pan_right)
button_up.on_click(pan_up)
button_down.on_click(pan_down)
button_clock_wise.on_click(clock_wise)
button_anti_clock_wise.on_click(anti_clock_wise)
var_str.observe(call_update)
# Subhalo, groups
show_groups = widgets.Checkbox(
value=False,
description='Show FoF groups',
disabled=False,
indent=False
)
show_subs = widgets.Checkbox(
value=False,
description='Show subhalos',
disabled=False,
indent=False
)
show_groups.observe(call_update, names=['value'])
show_subs.observe(call_update, names=['value'])
# Save hdf5/save
frame_counter = widgets.IntSlider(value=0)
recording = widgets.Checkbox(
value=False,
description='Recording',
disabled=False,
indent=False
)
def reset_counter(b):
"""
Reset counter if recording is stopped,
calculate first image if recording is started
"""
if not recording.value:
frame_counter.value = 0
basename.disabled = False
outfolder.disabled = False
else:
basename.disabled = True
outfolder.disabled = True
mylogger('', mode='w')
update()
recording.observe(reset_counter, names=['value'])
hdf5 = widgets.Checkbox(
value=False,
description='hdf5',
disabled=False,
indent=False
)
png = widgets.Checkbox(
value=False,
description='png',
disabled=False,
indent=False
)
out = widgets.interactive_output(update, {},)
def change_center_using_cat(change):
info = get_group_and_sub_indices()
if select_center.value[0] == 'G':
gr_id = int(select_center.value[1:])
new_center = projector.snap.Cat.Group['GroupPos'][gr_id].T
if recording.value:
diff = new_center - projector._center
mylogger(f"move_center_sim_coordinates,{diff[0].value},{diff[1].value},{diff[2].value}", line_ending=',#,')
mylogger(f"center_on_group,{gr_id},{new_center[0].value},{new_center[1].value},{new_center[2].value}")
projector._diff_center += new_center - projector._center
projector.center = new_center.copy
show_groups.value = False
select_center.options = []
select_center.value = None
if select_center.value[0] == 'S':
sub_id = int(select_center.value[1:])
new_center = projector.snap.Cat.Sub['SubhaloPos'][sub_id].T
if recording.value:
diff = new_center - projector._center
mylogger(f"move_center_sim_coordinates,{diff[0].value},{diff[1].value},{diff[2].value}", line_ending=',#,')
mylogger(f"center_on_sub,{sub_id},{new_center[0].value},{new_center[1].value},{new_center[2].value}")
projector._diff_center += new_center - projector._center
projector.center = new_center.copy
show_subs.value = False
select_center.options = []
select_center.value = None
# update()
width.observe(change_width, names=['value'])
height.observe(change_height, names=['value'])
depth.observe(change_depth, names=['value'])
var_str.observe(change_var_str, names=['value'])
select_center.observe(change_center_using_cat, names=['value'])
basename = widgets.Text(value='image')
outfolder = widgets.Text(value='./')
display(out,
widgets.HBox([button_update, var_str, button_zoom, zoom_slider]),
widgets.HBox([button_left, button_right, button_up, button_down, button_clock_wise, button_anti_clock_wise]),
widgets.HBox([width, height, depth]),
widgets.HBox([step_button, center_horizontal, center_vertical, center_depth, recenter_button]),
widgets.HBox([fix_climits, vmin, vmax]),
widgets.HBox([step_size_in_degrees, to_physical, to_cgs, to_astro_units]),
widgets.HBox([cmap_str, button_double, button_half]),
widgets.HBox([recording, hdf5, png, outfolder, basename]),
widgets.HBox([show_groups, show_subs, select_center]))
[ ]: