From 6021e972ec766ed301a0b73efa54796ba15127e7 Mon Sep 17 00:00:00 2001 From: Talon Chandler Date: Wed, 18 Dec 2024 15:49:13 -0800 Subject: [PATCH] Transfer function visuals (#178) * Bump torch to unpin numpy (#176) * bump torch to unpin numpy * add SPEC-0 conformant numpy requirement * Bump torch to unpin numpy (#176) * bump torch to unpin numpy * add SPEC-0 conformant numpy requirement * first-pass scripts * cleanup greens * clean transfer function support * fix naming issue * Wrap-safe transfer functions (#175) * helper functions * fluorescence wrap safety * 3d phase wrap safety * fix axial nyquist bug * 2d phase wrap safety * fix interaction between padding and wrap safety * green's tensor surfaces * dark theme default --------- Co-authored-by: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> --- pyproject.toml | 4 +- .../models/isotropic_fluorescent_thick_3d.py | 5 +- waveorder/models/isotropic_thin_3d.py | 1 - waveorder/scripts/visualize-greens.py | 139 ++++++++++ waveorder/scripts/visualize-support.py | 241 ++++++++++++++++++ 5 files changed, 386 insertions(+), 4 deletions(-) create mode 100644 waveorder/scripts/visualize-greens.py create mode 100644 waveorder/scripts/visualize-support.py diff --git a/pyproject.toml b/pyproject.toml index 30dcade..803ec99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,12 +45,12 @@ classifiers = [ "Operating System :: MacOS", ] dependencies = [ - "numpy>=1.21, <2", + "numpy>=1.24", "matplotlib>=3.1.1", "scipy>=1.3.0", "pywavelets>=1.1.1", "ipywidgets>=7.5.1", - "torch>=2.2.1", + "torch>=2.4.1", ] dynamic = ["version"] diff --git a/waveorder/models/isotropic_fluorescent_thick_3d.py b/waveorder/models/isotropic_fluorescent_thick_3d.py index 6c78f3d..88c171f 100644 --- a/waveorder/models/isotropic_fluorescent_thick_3d.py +++ b/waveorder/models/isotropic_fluorescent_thick_3d.py @@ -110,7 +110,10 @@ def _calculate_wrap_unsafe_transfer_function( def visualize_transfer_function(viewer, optical_transfer_function, zyx_scale): add_transfer_function_to_viewer( - viewer, torch.real(optical_transfer_function), zyx_scale, clim_factor=0.05 + viewer, + torch.real(optical_transfer_function), + zyx_scale, + clim_factor=0.05, ) diff --git a/waveorder/models/isotropic_thin_3d.py b/waveorder/models/isotropic_thin_3d.py index e96a499..5ba9d22 100644 --- a/waveorder/models/isotropic_thin_3d.py +++ b/waveorder/models/isotropic_thin_3d.py @@ -5,7 +5,6 @@ from torch import Tensor from waveorder import optics, sampling, util -from waveorder.visuals.napari_visuals import add_transfer_function_to_viewer def generate_test_phantom( diff --git a/waveorder/scripts/visualize-greens.py b/waveorder/scripts/visualize-greens.py new file mode 100644 index 0000000..4ed59a2 --- /dev/null +++ b/waveorder/scripts/visualize-greens.py @@ -0,0 +1,139 @@ +from skimage import measure +import napari +from napari.experimental import link_layers +import numpy as np +import torch +import os +from waveorder import util, optics +from scipy.ndimage import gaussian_filter + +# Parameters +# all lengths must use consistent units e.g. um +output_dirpath = "./greens_plots" +os.makedirs(output_dirpath, exist_ok=True) +grid_size = 300 +blur_width = grid_size // 35 # blurring to smooth sharp corners +zyx_shape = 3 * (grid_size,) +yx_pixel_size = 6.5 / 63 +z_pixel_size = 6.5 / 63 +wavelength_illumination = 0.532 +index_of_refraction_media = 1.3 +threshold = 0.5 # for marching cubes + + +# Calculate coordinate grids +zyx_pixel_size = (z_pixel_size, yx_pixel_size, yx_pixel_size) +y_frequencies, x_frequencies = util.generate_frequencies( + zyx_shape[1:], yx_pixel_size +) +radial_frequencies = torch.sqrt(x_frequencies**2 + y_frequencies**2) +z_position_list = torch.fft.ifftshift( + (torch.arange(zyx_shape[0]) - zyx_shape[0] // 2) * z_pixel_size +) +z_frequencies = torch.fft.fftfreq(zyx_shape[0], d=z_pixel_size) + +freq_shape = z_position_list.shape + x_frequencies.shape +z_broadcast = torch.broadcast_to(z_frequencies[:, None, None], freq_shape) +y_broadcast = torch.broadcast_to(y_frequencies[None, :, :], freq_shape) +x_broadcast = torch.broadcast_to(x_frequencies[None, :, :], freq_shape) +nu_rr = torch.sqrt(z_broadcast**2 + y_broadcast**2 + x_broadcast**2) + +freq_voxel_size = [1 / (d * n) for d, n in zip(zyx_pixel_size, zyx_shape)] + +# Calculate Greens tensor spectrum +G_3D = optics.generate_greens_tensor_spectrum( + zyx_shape=zyx_shape, + zyx_pixel_size=zyx_pixel_size, + wavelength=wavelength_illumination / index_of_refraction_media, +) + +# Mask to zero outside of a spherical shell +wavelength = wavelength_illumination / index_of_refraction_media +nu_max = (33 / 32) / (wavelength) +nu_min = (31 / 32) / (wavelength) +mask = torch.logical_and(nu_rr < nu_max, nu_rr > nu_min) +G_3D *= mask + +# Split into positve and negative imaginary parts +G3D_imag = torch.imag(torch.fft.fftshift(G_3D, dim=(-3, -2, -1))) +G_pos = G3D_imag * (G3D_imag > 0) +G_neg = G3D_imag * (G3D_imag < 0) + +# Blur to reduce aliasing +sigma = ( + 0, + 0, +) + 3 * (blur_width,) +G_pos = gaussian_filter(np.array(G_pos), sigma=sigma) +G_neg = gaussian_filter(np.array(G_neg), sigma=sigma) + +# Add to napari +viewer = napari.Viewer() + +viewer.theme = "light" +viewer.dims.ndisplay = 3 +viewer.camera.zoom = 100 + +for i in range(3): + for j in range(3): + name = f"{i}_{j}" + + volume = G_pos[i, j] + verts, faces, normals, _ = measure.marching_cubes( + volume, level=threshold * np.max(volume) + ) + viewer.add_surface( + (verts, faces), + name=name + "-positive-surface", + colormap="greens", + scale=freq_voxel_size, + shading="smooth", + ) + + volume = -G_neg[i, j] + if i != j: + verts, faces, normals, _ = measure.marching_cubes( + volume, level=threshold * np.max(volume) + ) + viewer.add_surface( + (verts, faces), + opacity=1.0, + name=name + "-negative-surface", + colormap="I Purple", + scale=freq_voxel_size, + blending="translucent", + shading="smooth", + ) + else: + viewer.add_surface( + ( + np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0]]), + np.array([[0, 1, 2]]), + ), + opacity=1.0, + name=name + "-dummy-surface", + colormap="gray", + scale=freq_voxel_size, + blending="translucent", + shading="smooth", + ) + link_layers(viewer.layers[-2:]) + + print(f"Screenshotting {i}_{j}") + viewer.camera.set_view_direction( + view_direction=[-1, -1, -1], up_direction=[0, 0, 1] + ) + viewer.screenshot( + os.path.join(output_dirpath, f"{i}_{j}.png"), scale=2 + ) + viewer.layers[-1].visible = False + viewer.layers[-2].visible = False + +# Show in complete grid +for layer in viewer.layers: + layer.visible = True +viewer.grid.enabled = True +viewer.grid.stride = 2 +viewer.grid.shape = (-1, 3) +viewer.theme = "dark" +napari.run() diff --git a/waveorder/scripts/visualize-support.py b/waveorder/scripts/visualize-support.py new file mode 100644 index 0000000..6bf9094 --- /dev/null +++ b/waveorder/scripts/visualize-support.py @@ -0,0 +1,241 @@ +import napari +import numpy as np +import os +import matplotlib.pyplot as plt + + +def plot_otf_support( + ill_na, + det_na, + N_theta=100, + N_phi=50, + top_cmap="green", + bottom_cmap="purple", + top_azimuth_vals=None, + bottom_azimuth_vals=None, +): + # check azimuth values + if top_azimuth_vals is None: + top_azimuth_vals = np.linspace(0, 2, N_phi) % 1.0 + else: + assert len(top_azimuth_vals) == N_phi + if bottom_azimuth_vals is None: + bottom_azimuth_vals = np.linspace(0, 2, N_phi) % 1.0 + else: + assert len(bottom_azimuth_vals) == N_phi + + # key points (transverse, axial) coordinates + points = np.array( + [ + [0, 0], + [det_na - ill_na, (1 - ill_na**2) ** 0.5 - (1 - det_na**2) ** 0.5], + [det_na + ill_na, (1 - ill_na**2) ** 0.5 - (1 - det_na**2) ** 0.5], + [2 * ill_na, 0], + ] + ) + + # arc centers + centers = np.array( + [ + [-ill_na, (1 - ill_na**2) ** 0.5], + [det_na, -((1 - det_na**2) ** 0.5)], + [ill_na, (1 - ill_na**2) ** 0.5], + ] + ) + + # angles of arcs + thetas = [] + for j, center in enumerate(centers): + start_point = points[j] + end_point = points[j + 1] + + theta_start = np.arctan2( + start_point[1] - center[1], start_point[0] - center[0] + ) + theta_end = np.arctan2( + end_point[1] - center[1], end_point[0] - center[0] + ) + + thetas.append((theta_start, theta_end)) + + # compute final points + arc_lengths = [np.abs(theta[1] - theta[0]) for theta in thetas] + total_arc_length = np.sum(arc_lengths) + + theta_coords = [ + np.linspace( + theta[0], + theta[1], + np.int8(np.floor(N_theta * arc_length / total_arc_length)), + ) + for theta, arc_length in zip(thetas, arc_lengths) + ] + + xz_points = [] + for j, center in enumerate(centers): + for theta_coord in theta_coords[j]: + x = center[0] + np.cos(theta_coord) + y = center[1] + np.sin(theta_coord) + xz_points.append([x, y]) + xz_points = np.array(xz_points) + + phi = np.linspace(0, 2 * np.pi, N_phi, endpoint=False) + + # Compute 3D points + points_3d = np.zeros((N_phi, xz_points.shape[0], 3)) + faces = [] + for i, xz_point in enumerate(xz_points): + for j, angle in enumerate(phi): + points_3d[j, i, 0] = xz_point[1] + points_3d[j, i, 1] = xz_point[0] * np.sin(angle) + points_3d[j, i, 2] = xz_point[0] * np.cos(angle) + + next_i = i + 1 + next_j = (j + 1) % N_phi + + faces.append([(j, i), (next_j, i), (j, next_i)]) + faces.append([(next_j, i), (next_j, next_i), (j, next_i)]) + + # handle indexing + mesh = [] + for face in faces: + try: + ravel_face = [ + np.ravel_multi_index(vertex, (N_phi, N_theta - 1)) + for vertex in face + ] + except: + continue # print(face) + mesh.append(ravel_face) + mesh = np.array(mesh) + + top_values = np.tile(top_azimuth_vals, (N_theta - 1, 1)).T + bottom_values = np.tile(bottom_azimuth_vals, (N_theta - 1, 1)).T + + points_3d = points_3d.reshape(-1, 3) + top_values = top_values.reshape(-1) + bottom_values = bottom_values.reshape(-1) + + # Add negative surface first + points_3d_copy = points_3d.copy() + points_3d_copy[:, 0] *= -1 # flip z + viewer.add_surface( + (points_3d_copy, mesh, bottom_values), + opacity=0.75, + colormap=bottom_cmap, + blending="translucent", + shading="smooth", + ) + + viewer.add_surface( + (points_3d, mesh, top_values), + opacity=0.75, + colormap=top_cmap, + blending="translucent", + shading="smooth", + ) + + # import pdb; pdb.set_trace() + # viewer.screenshot(filename) + + +viewer = napari.Viewer() + +# Main loops +output_dir = "./output" +if not os.path.exists(output_dir): + os.makedirs(output_dir) + +N_phi = 50 +N_theta = 100 +det_na = 0.75 +ill_na = 0.5 + +my_colors = [ + [["green", "purple"], [np.ones((N_phi,)), np.ones((N_phi,))]], + [ + ["hsv", "hsv"], + [ + (np.linspace(0, 2, N_phi) + 0) % 1.0, + -(np.linspace(0, 2, N_phi) + 0) % 1.0, + ], + ], + [ + ["hsv", "hsv"], + [ + (np.linspace(0, 2, N_phi) + 0.5) % 1.0, + -(np.linspace(0, 2, N_phi) + 0.5) % 1.0, + ], + ], + [ + ["hsv", "hsv"], + [ + (np.linspace(0, 2, N_phi) + 0) % 1.0, + -(np.linspace(0, 2, N_phi) + 0) % 1.0, + ], + ], + [["red", "red"], [np.ones((N_phi,)), np.ones((N_phi,))]], + [["green", "purple"], [np.ones((N_phi,)), np.ones((N_phi,))]], + [ + ["hsv", "hsv"], + [ + (np.linspace(0, 2, N_phi) + 0.5) % 1.0, + -(np.linspace(0, 2, N_phi) + 0.5) % 1.0, + ], + ], + [["green", "purple"], [np.ones((N_phi,)), np.ones((N_phi,))]], + [["cyan", "cyan"], [np.ones((N_phi,)), np.ones((N_phi,))]], + [["green", "purple"], [np.ones((N_phi,)), np.ones((N_phi,))]], + [ + ["hsv", "hsv"], + [ + (np.linspace(0, 2, N_phi) + 0) % 1.0, + -(np.linspace(0, 2, N_phi) + 0) % 1.0, + ], + ], + [ + ["hsv", "hsv"], + [ + (np.linspace(0, 2, N_phi) + 0.5) % 1.0, + -(np.linspace(0, 2, N_phi) + 0.5) % 1.0, + ], + ], +] + +for my_color in my_colors: + + plot_otf_support( + ill_na, + det_na, + N_theta=N_theta, + N_phi=N_phi, + top_cmap=my_color[0][0], + bottom_cmap=my_color[0][1], + top_azimuth_vals=my_color[1][0], + bottom_azimuth_vals=my_color[1][1], + ) + +viewer.theme = "dark" +viewer.dims.ndisplay = 3 +viewer.camera.set_view_direction( + view_direction=[-0.1, -1, -1], up_direction=[1, 0, 0] +) +viewer.camera.zoom = 250 +viewer.grid.enabled = True +viewer.grid.stride = 2 +viewer.grid.shape = (-1, 3) + + +input("Press Enter to close...") +plot_otf_support( + det_na * 0.98, + det_na, + N_theta=N_theta, + N_phi=N_phi, + top_cmap="red", + bottom_cmap="red", + top_azimuth_vals=np.ones((N_phi,)), + bottom_azimuth_vals=np.ones((N_phi,)), +) + +input("Press Enter to close...")