Notebook 6: Ray tracing and interactive data visualization on the GPU

Here we illustrate how one can build custom widgets for interactive data visualization. We use the Paicos CUDA GPU implementation of ray tracing to achieve the necessary speed.

Required python packages

This notebook requires that you have the GPU requirements installed and available on your system and that you have modified your Paicos user settings to load GPU functionality on startup. Please see the details here: https://paicos.readthedocs.io/en/latest/installation.html#gpu-cuda-requirements

This notebook also requires that you have a working version of ipywidgets, which might sometimes be a bit cumbersome to get working (I have this working in a Jupyter notebook but there is a risk that you will have trouble if you are using Jupyter Lab). You can simply try

pip install ipywidgets

and if that does not work then you can try with the detailed instructions found here: https://ipywidgets.readthedocs.io/en/stable/user_install.html

You can check that ipywidgets is working by uncommenting and executing the following example:

[3]:
# from ipywidgets import interact

# @interact
# def greet(name="World", count=5):
#     for _ in range(count):
#         print(f"Hello, {name}!")

# greet()
[1]:
%matplotlib widget
[2]:
import cupy as cp
from numba import cuda

Do a manual pre-selection

We start by loading a snapshot and selecting just a part of it. We limit ourselves because the GPU has limited memory and the GPU ray tracer class builds a binary tree spanning the entire snapshot.

[3]:
import paicos as pa
import numpy as np

pa.use_units(True)

snapnum = 247
try:
    snap = pa.Snapshot(pa.data_dir + 'highres', snapnum)
except FileNotFoundError as e:
    print(e)
    err_msg = ('This example is much more fun with a large data set.\nPlease see: '
               + 'https://github.com/tberlok/paicos/tree/main/data/highres/README.md'
               + ' for download instructions.\n'
              + 'For now we simply load the low resolution data set.')
    print(err_msg)
    snap = pa.Snapshot(pa.data_dir, snapnum)
center = snap.Cat.Group['GroupPos'][0]
R200c = snap.Cat.Group['Group_R_Crit200'][0]
widths = np.array([10000, 10000, 10000]) * R200c.uq

# Create subset of snapshot
index = pa.util.get_index_of_radial_range(snap['0_Coordinates'], center, 0, np.max(widths)*np.cbrt(3))
snap = snap.select(index, parttype=0)

# Pixel dimensions of image
nx = ny = 1024

Initialize the GPU projector

Here we use a Paicos orientation class to initialize the view such that the width of the image is along the \(x\)-coordinate of the simulation and the height of the image is along the \(y\)-coordinate. The depth of the image is in the \(z\)-direction.

The orientation class has methods for rotating the view around \(x\), \(y\), and \(z\) or around the axes of its local coordinate system. When an orientation instance has been passed to an ImageCreator (such as the projector below), then calling these methods will result in a rotation around the center of the image.

[4]:
orientation = pa.Orientation(normal_vector=[0, 0, 1], perp_vector1=[1, 0, 0])
projector = pa.GpuRayProjector(snap, center, widths, orientation, npix=nx, threadsperblock=8, do_pre_selection=False)
Attempting to get derived variable: 0_Volume... [DONE]

[5]:
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import ipywidgets as widgets

Some code for sorting FoF and subfind catalogues

This is just for plotting the 20 most massive FoF or subhalos that are inside the projection cube. Mainly for the widget so no need to read this.

[6]:
def get_group_and_sub_indices():
    info = {}
    if hasattr(projector.snap.Cat, 'Sub'):
        sub_in_region_bool = pa.util.get_index_of_rotated_cubic_region_plus_thin_layer(projector.snap.Cat.Sub['SubhaloPos'],
                                                        projector.center, projector.widths,
                                                        projector.snap.Cat.Sub['SubhaloHalfmassRad'],
                                                        projector.snap.box, projector.orientation)

        info['Subhalo_ids'] = np.arange(sub_in_region_bool.shape[0])[sub_in_region_bool]
        info['SubhaloPos'] = projector.snap.Cat.Sub['SubhaloPos'][sub_in_region_bool]
        info['SubhaloHalfmassRad'] = projector.snap.Cat.Sub['SubhaloHalfmassRad'][sub_in_region_bool]
        info['SubhaloMass'] = projector.snap.Cat.Sub['SubhaloMass'][sub_in_region_bool]
        # Sort according to mass
        sort_index = np.argsort(info['SubhaloMass'])[::-1]
        info['Subhalo_ids'] =  info['Subhalo_ids'][sort_index]
        info['SubhaloPos'] =  info['SubhaloPos'][sort_index]
        info['SubhaloHalfmassRad'] = info['SubhaloHalfmassRad'][sort_index]
        info['SubhaloMass'] =  info['SubhaloMass'][sort_index]

    if hasattr(projector.snap.Cat, 'Group'):
        group_in_region_bool = pa.util.get_index_of_rotated_cubic_region_plus_thin_layer(projector.snap.Cat.Group['GroupPos'],
                                                projector.center, projector.widths,
                                                projector.snap.Cat.Group['Group_R_Crit200'],
                                                projector.snap.box, projector.orientation)
        info['Group_ids'] = np.arange(group_in_region_bool.shape[0])[group_in_region_bool]
        info['GroupPos'] = projector.snap.Cat.Group['GroupPos'][group_in_region_bool]
        info['Group_R_Crit200'] = projector.snap.Cat.Group['Group_R_Crit200'][group_in_region_bool]
        info['Group_M_Crit200'] = projector.snap.Cat.Group['Group_M_Crit200'][group_in_region_bool]
        # Sort according to mass
        sort_index = np.argsort(info['Group_M_Crit200'])[::-1]
        info['Group_ids'] = info['Group_ids'][sort_index]
        info['GroupPos']  = info['GroupPos'] [sort_index]
        info['Group_R_Crit200'] = info['Group_R_Crit200'][sort_index]
        info['Group_M_Crit200'] = info['Group_M_Crit200'][sort_index]
    return info

The interactive widget

The rather long code below defines an interactive ipython widget with a number of hopefully mostly self-explanatory buttons.

These can zoom in/out, rotate the image, change width, height and depth etc.

Pressing the ‘Recording’ tick mark will output a png/hdf5 every time the wiev changes (if those boxes are ticked). These will be saved in the directory entered in the box just to the right of png tick box. The recording also saves a .log file with a series of commands that can be used to reproduce an interactive session. This allows for using an interactive session as a starting point for creating an animation of a simulation snapshot.

We have left all the code for the widget visible instead of saving it somewhere else and importing it. We hope that it will in this way be easier for someone to extend/modify the code to their own needs.

The data included with Paicos is very low resolution and this notebook does not really showcase how well the GPU code works at higher resolution. We have tried with a \(12^3\) times better mass resolution simulation (equivalent to the resolution in the TNG300 simulation) and find that an A100 GPU is fast enough to give a smooth user experience. Download instructions for this data set can be found here.

[7]:

def update(): proj = projector.project_variable(var_str.value) extent = projector.centered_extent if to_physical.value: proj = proj.to_physical extent = extent.to_physical if to_cgs.value: proj = proj.cgs extent = extent.cgs if to_astro_units.value: proj = proj.astro extent = extent.astro fig = plt.figure(1) plt.clf() # Deal with color limits and do plot if fix_climits.value: vmin.disabled = False vmax.disabled = False if vmin.value > 0 and vmax.value > 0: pass else: vmin.value = proj.value.min() vmax.value = proj.value.max() im = plt.imshow(proj.value, extent=extent.value, origin='lower', norm=LogNorm(vmin.value, vmax.value), cmap=cmap_str.value) else: vmin.value = proj.value.min() vmax.value = proj.value.max() vmin.disabled = True vmax.disabled = True im = plt.imshow(proj.value, extent=extent.value, origin='lower', norm=LogNorm(), cmap=cmap_str.value) # Labels plt.xlabel(extent.label()) plt.ylabel(extent.label()) # Colorbar cb = plt.colorbar() cb.set_label(proj.label('\\mathrm{' + var_str.value.replace('_', '\_') + '}\,')) # Title title_str = f'Snapnum: {snapnum}, Age: {snap.age:1.2f}, Redshift: {snap.z:1.2f}' plt.title(title_str) # Add subs/groups # TODO: Get rid of mostly duplicate code for groups/subhalos select_center.disabled = True if hasattr(projector.snap, 'Cat'): info = get_group_and_sub_indices() if 'Group_ids' in info and show_groups.value: select_center.disabled = False orientation = projector.orientation points = info['GroupPos'] - projector.center points = np.matmul(orientation.inverse_rotation_matrix, points.T).T if to_physical.value: points = points.to_physical info['Group_R_Crit200'] = info['Group_R_Crit200'].to_physical if to_cgs.value: points = points.cgs info['Group_R_Crit200'] = info['Group_R_Crit200'].cgs if to_astro_units.value: points = points.astro info['Group_R_Crit200'] = info['Group_R_Crit200'].astro ax = plt.gca() options = [] for ii in range(points.shape[0]): if ii >= 20 or info['Group_M_Crit200'][ii].value == 0: break circ = plt.Circle((points[ii, 0].value, points[ii, 1].value), info['Group_R_Crit200'][ii].value, color='k', fill=False) ax.add_patch(circ) plt.text(points[ii, 0].value, points[ii, 1].value, f"G{info['Group_ids'][ii]}", fontsize=6) options.append(f"G{info['Group_ids'][ii]}") select_center.options = list(options) if 'Subhalo_ids' in info and show_subs.value: select_center.disabled = False orientation = projector.orientation points = info['SubhaloPos'] - projector.center points = np.matmul(orientation.inverse_rotation_matrix, points.T).T if to_physical.value: points = points.to_physical info['SubhaloHalfmassRad'] = info['SubhaloHalfmassRad'].to_physical if to_cgs.value: points = points.cgs info['SubhaloHalfmassRad'] = info['SubhaloHalfmassRad'].cgs if to_astro_units.value: points = points.astro info['SubhaloHalfmassRad'] = info['SubhaloHalfmassRad'].astro ax = plt.gca() options = [] for ii in range(points.shape[0]): if ii >= 20 or info['SubhaloMass'][ii].value == 0: break circ = plt.Circle((points[ii, 0].value, points[ii, 1].value), info['SubhaloHalfmassRad'][ii].value, color='k', fill=False) ax.add_patch(circ) plt.text(points[ii, 0].value, points[ii, 1].value, f"S{info['Subhalo_ids'][ii]}", fontsize=6) options.append(f"S{info['Subhalo_ids'][ii]}") select_center.options = list(options) if recording.value: if hdf5.value: image_file = pa.ArepoImage(projector, basedir=outfolder.value, basename=f'{basename.value}_{var_str.value}_frame_{frame_counter.value}') image_file.save_image(var_str.value, proj) # Move from temporary filename to final filename image_file.finalize() if frame_counter.value == 0: mylogger(image_file.filename) mylogger(outfolder.value) mylogger(basename.value) mylogger(var_str.value) org_center = projector.center - projector._diff_center mylogger(f"org_center,{org_center[0].value},{org_center[1].value},{org_center[2].value}") if png.value: plt.savefig(f'{outfolder.value}/{basename.value}_{var_str.value}_frame_{frame_counter.value}_{projector.snap.snapnum}.png', dpi=700) frame_counter.value += 1 plt.show() def mylogger(string, line_ending='\n', mode='a'): with open(f'{outfolder.value}/{basename.value}.log', mode) as f: f.write(string + line_ending) width = widgets.FloatSlider(value=projector.width.value, min=1, max=2*widths[0].value, step=1, description="width", continuous_update=False) height = widgets.FloatSlider(value=projector.height.value, min=1, max=2*widths[1].value, step=1, description="height", continuous_update=False) depth = widgets.FloatSlider(value=projector.depth.value, min=1, max=2*widths[2].value, step=1, description="depth", continuous_update=False) zoom_slider = widgets.FloatSlider(value=1, min=0.1, max=10, step=0.1, description="Zoom factor") center_horizontal = widgets.BoundedFloatText( value=0, min=-widths[0].value, max=widths[0].value, step=1, description='Horizontal', disabled=False ) center_vertical = widgets.BoundedFloatText( value=0, min=-widths[0].value, max=widths[0].value, step=1, description='Vertical', disabled=False ) center_depth = widgets.BoundedFloatText( value=0, min=-widths[0].value, max=widths[0].value, step=1, description='Depth', disabled=False ) def click_step_button(b): if center_horizontal.value != 0: projector.move_center_along_perp_vector1(center_horizontal.value * projector.width.uq) if center_vertical.value != 0: projector.move_center_along_perp_vector2(center_vertical.value * projector.width.uq) if center_depth.value != 0: projector.move_center_along_normal_vector(center_depth.value * projector.width.uq) if recording.value: mylogger(f'move_center,{center_horizontal.value},{center_vertical.value},{center_depth.value}') update() def press_reset_center_button(b): diff = projector._diff_center org_center = projector.center - diff if recording.value: mylogger(f"move_center_sim_coordinates,{diff[0].value},{diff[1].value},{diff[2].value}", line_ending=',#,') mylogger(f"reset_center_to_org_center,{org_center[0].value},{org_center[1].value},{org_center[2].value}") projector.center = org_center projector._diff_center[:] = 0 update() step_button = widgets.Button(description="Move center") recenter_button = widgets.Button(description="Reset center") step_button.on_click(click_step_button) recenter_button.on_click(press_reset_center_button) ## Create dropdown for parttype 0 only (TODO: Remove vectors and tensors from list, or add box for selecting component) avail_list = [] for key in snap._auto_list: if key[0] == '0': avail_list.append(key) var_str = widgets.Dropdown(options=avail_list, value='0_Density', description='Field:', ) select_center = widgets.Dropdown(options=[], value=None, description='Center on:', disabled=True ) # Cmap dropdown cmap_str = widgets.Dropdown(options=plt.colormaps(), value='viridis', description='Cmap:', ) # vmin and vmax vmin = widgets.FloatText( description='vmin:', disabled=True ) vmax = widgets.FloatText( description='vmax:', disabled=True ) fix_climits = widgets.Checkbox( value=False, description='Fix climits', disabled=False, indent=False ) # button_left = widgets.Button(description="Pan left") button_right = widgets.Button(description="Pan right") button_up = widgets.Button(description="Pan up") button_down = widgets.Button(description="Pan down") button_clock_wise = widgets.Button(description="Clockwise") button_anti_clock_wise = widgets.Button(description="Anti-clockwise") step_size_in_degrees = widgets.FloatSlider(value=15, min=0, max=90, step=0.5, description="Step (Degrees)") button_update = widgets.Button(description="Update") def call_update(b): update() button_update.on_click(call_update) def call_double_resolution(b): projector.double_resolution if recording.value: mylogger('double_resolution') update() def call_half_resolution(b): projector.half_resolution if recording.value: mylogger('half_resolution') update() button_double = widgets.Button(description="Double res") button_double.on_click(call_double_resolution) button_half = widgets.Button(description="Half res") button_half.on_click(call_half_resolution) def call_zoom(b): zoom_button_was_pressed.value = True projector.zoom(zoom_slider.value) if recording.value: mylogger(f'zoom,{zoom_slider.value}') # Check that new widths are not completely unreasonable! # Do check width.value = projector.width.value height.value = projector.height.value update() zoom_button_was_pressed.value = False button_zoom = widgets.Button(description="Zoom in/out") button_zoom.on_click(call_zoom) zoom_button_was_pressed = widgets.Checkbox( value=False, description='internal boolean for avoiding calling call_update twice when zooming', ) to_physical = widgets.Checkbox( value=False, description='physical units', disabled=False, indent=False ) to_cgs = widgets.Checkbox( value=False, description='cgs units', disabled=False, indent=False ) to_astro_units = widgets.Checkbox( value=False, description="'astro' units", disabled=False, indent=False ) to_physical.observe(call_update, names=['value']) to_cgs.observe(call_update, names=['value']) to_astro_units.observe(call_update, names=['value']) def change_width(change): projector.width = width.value * projector.width.uq if not zoom_button_was_pressed.value: if recording.value: mylogger(f'width,{width.value}') update() def change_height(change): projector.height = height.value * projector.height.uq if not zoom_button_was_pressed.value: if recording.value: mylogger(f'height,{height.value}') update() def change_depth(change): projector.depth = depth.value * projector.depth.uq if recording.value: mylogger(f'depth,{depth.value}') update() def change_var_str(change): if recording.value: mylogger(var_str.value) update() def pan_left(b): projector.orientation.rotate_around_perp_vector2(degrees=step_size_in_degrees.value) if recording.value: mylogger(f'rotate_around_perp_vector2,{step_size_in_degrees.value}') update() def pan_right(b): projector.orientation.rotate_around_perp_vector2(degrees=-step_size_in_degrees.value) if recording.value: mylogger(f'rotate_around_perp_vector2,{-step_size_in_degrees.value}') update() def pan_up(b): projector.orientation.rotate_around_perp_vector1(degrees=step_size_in_degrees.value) if recording.value: mylogger(f'rotate_around_perp_vector1,{step_size_in_degrees.value}') update() def pan_down(b): projector.orientation.rotate_around_perp_vector1(degrees=-step_size_in_degrees.value) if recording.value: mylogger(f'rotate_around_perp_vector1,{-step_size_in_degrees.value}') update() def clock_wise(b): projector.orientation.rotate_around_normal_vector(degrees=step_size_in_degrees.value) if recording.value: mylogger(f'rotate_around_normal_vector,{step_size_in_degrees.value}') update() def anti_clock_wise(b): projector.orientation.rotate_around_normal_vector(degrees=-step_size_in_degrees.value) if recording.value: mylogger(f'rotate_around_normal_vector,{-step_size_in_degrees.value}') update() button_left.on_click(pan_left) button_right.on_click(pan_right) button_up.on_click(pan_up) button_down.on_click(pan_down) button_clock_wise.on_click(clock_wise) button_anti_clock_wise.on_click(anti_clock_wise) var_str.observe(call_update) # Subhalo, groups show_groups = widgets.Checkbox( value=False, description='Show FoF groups', disabled=False, indent=False ) show_subs = widgets.Checkbox( value=False, description='Show subhalos', disabled=False, indent=False ) show_groups.observe(call_update, names=['value']) show_subs.observe(call_update, names=['value']) # Save hdf5/save frame_counter = widgets.IntSlider(value=0) recording = widgets.Checkbox( value=False, description='Recording', disabled=False, indent=False ) def reset_counter(b): """ Reset counter if recording is stopped, calculate first image if recording is started """ if not recording.value: frame_counter.value = 0 basename.disabled = False outfolder.disabled = False else: basename.disabled = True outfolder.disabled = True mylogger('', mode='w') update() recording.observe(reset_counter, names=['value']) hdf5 = widgets.Checkbox( value=False, description='hdf5', disabled=False, indent=False ) png = widgets.Checkbox( value=False, description='png', disabled=False, indent=False ) out = widgets.interactive_output(update, {},) def change_center_using_cat(change): info = get_group_and_sub_indices() if select_center.value[0] == 'G': gr_id = int(select_center.value[1:]) new_center = projector.snap.Cat.Group['GroupPos'][gr_id].T if recording.value: diff = new_center - projector._center mylogger(f"move_center_sim_coordinates,{diff[0].value},{diff[1].value},{diff[2].value}", line_ending=',#,') mylogger(f"center_on_group,{gr_id},{new_center[0].value},{new_center[1].value},{new_center[2].value}") projector._diff_center += new_center - projector._center projector.center = new_center.copy show_groups.value = False select_center.options = [] select_center.value = None if select_center.value[0] == 'S': sub_id = int(select_center.value[1:]) new_center = projector.snap.Cat.Sub['SubhaloPos'][sub_id].T if recording.value: diff = new_center - projector._center mylogger(f"move_center_sim_coordinates,{diff[0].value},{diff[1].value},{diff[2].value}", line_ending=',#,') mylogger(f"center_on_sub,{sub_id},{new_center[0].value},{new_center[1].value},{new_center[2].value}") projector._diff_center += new_center - projector._center projector.center = new_center.copy show_subs.value = False select_center.options = [] select_center.value = None # update() width.observe(change_width, names=['value']) height.observe(change_height, names=['value']) depth.observe(change_depth, names=['value']) var_str.observe(change_var_str, names=['value']) select_center.observe(change_center_using_cat, names=['value']) basename = widgets.Text(value='image') outfolder = widgets.Text(value='./') display(out, widgets.HBox([button_update, var_str, button_zoom, zoom_slider]), widgets.HBox([button_left, button_right, button_up, button_down, button_clock_wise, button_anti_clock_wise]), widgets.HBox([width, height, depth]), widgets.HBox([step_button, center_horizontal, center_vertical, center_depth, recenter_button]), widgets.HBox([fix_climits, vmin, vmax]), widgets.HBox([step_size_in_degrees, to_physical, to_cgs, to_astro_units]), widgets.HBox([cmap_str, button_double, button_half]), widgets.HBox([recording, hdf5, png, outfolder, basename]), widgets.HBox([show_groups, show_subs, select_center]))
[ ]: