diff --git a/part_1/solution.py b/part_1/solution.py index d9e77d9..6c9b6d7 100644 --- a/part_1/solution.py +++ b/part_1/solution.py @@ -1615,7 +1615,7 @@ def composite_nuc_mem( ) -> np.ndarray: c_nuc = rescale_clip(image[0]) * nuc_color c_mem = rescale_clip(image[1]) * mem_color - return c_nuc + c_mem + return rescale_intensity(c_nuc + c_mem, out_range=(0, 1)) def clip_p(image: np.ndarray) -> np.ndarray: return rescale_intensity(image.clip(*np.percentile(image, [1, 99]))) @@ -1645,6 +1645,297 @@ def clip_p(image: np.ndarray) -> np.ndarray: a.axis("off") plt.tight_layout() +#%% [markdown] tags=[] +#
+# +# ### Range of validity +# - +# +#
+#%% + +# %% +from monai.transforms import GaussianSmooth +from monai.networks.layers import GaussianFilter +def clip_highlight(image: np.ndarray) -> np.ndarray: + return rescale_intensity(image.clip(0, np.percentile(image, 99.5))) + + +# %% +YX_PATCH_SIZE = (256*2,256*2) +# Re-load the dataloader +phase2fluor_2D_data = HCSDataModule( + data_path, + architecture="UNeXt2_2D", + source_channel=source_channel, + target_channel=target_channel, + z_window_size=1, + split_ratio=0.8, + batch_size=1, + num_workers=8, + yx_patch_size=YX_PATCH_SIZE, + augmentations=[], + normalizations=normalizations, + # ground_truth_masks='' +) +phase2fluor_2D_data.setup("test") + + +# %% [markdown] tags=["task"] +#
+# +# ### Task 3.2: Select a sample batch to test the range of validty of the model +# +# - Select a test batch from the `test_dataloader` by changing the `batch_number`
+# - Examine the plot of the source and target images of the batch
+# +# Note the 2D images have different focus
+#
+#%% tags=["task"] +# ########## TODO ############## +batch_number= 3 +# ####################### +# Iterate through the test dataloader to get the desired batch +i = 0 +for batch in phase2fluor_2D_data.test_dataloader(): + # break if we reach the desired batch + if i == batch_number-1: + break + i += 1 + +# Plot the batch source and target images +f, ax = plt.subplots(1, 2, figsize=(8, 12)) +target_composite = composite_nuc_mem(batch["target"][0].cpu().numpy(), GREEN, MAGENTA) +ax[0].imshow( + batch["source"][0, 0, 0,y_slice,x_slice].cpu().numpy(), cmap="gray", vmin=-15, vmax=15 +) +ax[1].imshow(clip_highlight(target_composite[0,y_slice,x_slice])) +for a in ax.ravel(): + a.axis("off") +f.tight_layout() +plt.show() + +# %% [markdown] tags=[] +#
+# +# ### Task 3.3: Using the selected batch to test the model's range of validity +# +# - Given the selected batch use `monai.networks.layers.GaussianFilter` to blur the images with different sigmas. +# Check the documentation here
+# - Plot the source and predicted images comparing the source, target and added perturbations
+# - How is the model's predictions given the perturbations?
+#
+#%% tags=["task"] +# ########## TODO ############## +# Try out different multiples of 256 to visualize larger/smaller crops +n = 3 +# ############################## +# Center cropping the image +y_slice = slice(Y//2-256*n//2, Y//2+256*n//2) +x_slice = slice(X//2-256*n//2, X//2+256*n//2) + +f, ax = plt.subplots(3, 2, figsize=(8, 12)) + +target_composite = composite_nuc_mem(batch["target"][0].cpu().numpy(), GREEN, MAGENTA) +ax[0, 0].imshow( + batch["source"][0, 0, 0,y_slice,x_slice].cpu().numpy(), cmap="gray", vmin=-15, vmax=15 +) +ax[0, 1].imshow(clip_highlight(target_composite[0,y_slice,x_slice])) +ax[0,0].set_title('Source and target') + +# no perturbation +with torch.inference_mode(): + phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice] + pred = model(phase).cpu().numpy() +pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE) +ax[1, 0].imshow(phase[0,0,0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15) +ax[1, 1].imshow(pred_composite[0]) +ax[1,0].set_title('No perturbation') + + +# Select a sigma for the Gaussian filtering +# ########## TODO ############## +# Tensor dimensions (B,C,D,H,W). +# Hint: Use the GaussianFilter layer to blur the phase image. Provide the num spatial dimensions and sigmas +# Hint: Spatial (D,H,W) +gaussian_blur = GaussianFilter(....) +# ############################# +with torch.inference_mode(): + phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice] + phase = gaussian_blur(phase) + pred = model(phase).cpu().numpy() +pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE) +ax[2, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15) +ax[2, 1].imshow(pred_composite[0]) + +#%% tags=["solution"] +# ########## SOLUTION ############## +# Try out different multiples of 256 to visualize larger/smaller crops +n = 3 +# ############################## +# Center cropping the image +y_slice = slice(Y//2-256*n//2, Y//2+256*n//2) +x_slice = slice(X//2-256*n//2, X//2+256*n//2) + +f, ax = plt.subplots(3, 2, figsize=(8, 12)) + +target_composite = composite_nuc_mem(batch["target"][0].cpu().numpy(), GREEN, MAGENTA) +ax[0, 0].imshow( + batch["source"][0, 0, 0,y_slice,x_slice].cpu().numpy(), cmap="gray", vmin=-15, vmax=15 +) +ax[0, 1].imshow(clip_highlight(target_composite[0,y_slice,x_slice])) +ax[0,0].set_title('Source and target') + +# no perturbation +with torch.inference_mode(): + phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice] + pred = model(phase).cpu().numpy() +pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE) +ax[1, 0].imshow(phase[0,0,0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15) +ax[1, 1].imshow(pred_composite[0]) +ax[1,0].set_title('No perturbation') + + +# Select a sigma for the Gaussian filtering +# ########## SOLUTION ############## +# Tensor dimensions (B,C,D,H,W). +# Hint: Use the GaussianFilter layer to blur the phase image. Provide the num spatial dimensions and sigma +# Hint: Spatial (D,H,W). Apply the same sigma to H,W +gaussian_blur = GaussianFilter(spatial_dims=3, sigma=(0,2,2)) +# ############################# +with torch.inference_mode(): + phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice] + phase = gaussian_blur(phase) + pred = model(phase).cpu().numpy() +pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE) +ax[2, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15) +ax[2, 1].imshow(pred_composite[0]) +# %% [markdown] tags=[] +#
+# +# ### Task 3.3: Using the selected batch to test the model's range of validity +# +# - Scale the pixel values up/down of the phase image
+# - Plot the source and predicted images comparing the source, target and added perturbations
+# - How is the model's predictions given the perturbations?
+#
+ +#%% tags=["task"] +n = 3 +y_slice = slice(Y//2, Y//2+256*n) +x_slice = slice(X//2, X//2+256*n) +f, ax = plt.subplots(3, 2, figsize=(8, 12)) + +target_composite = composite_nuc_mem(batch["target"][0].cpu().numpy(), GREEN, MAGENTA) +ax[0, 0].imshow( + batch["source"][0, 0, 0,y_slice,x_slice].cpu().numpy(), cmap="gray", vmin=-15, vmax=15 +) +ax[0, 1].imshow(clip_highlight(target_composite[0,y_slice,x_slice])) +ax[0,0].set_title('Source and target') + +# no perturbation +with torch.inference_mode(): + phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice] + pred = model(phase).cpu().numpy() +pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE) +ax[1, 0].imshow(phase[0,0,0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15) +ax[1, 1].imshow(pred_composite[0]) +ax[1,0].set_title('No perturbation') + + +# 2-sigma gaussian blur +with torch.inference_mode(): + phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice] + # ########## TODO ############## + # Hint: Scale the phase intensity + phase = phase * ...... + # ####################### + pred = model(phase).cpu().numpy() +pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE) +ax[2, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15) +ax[2, 1].imshow(pred_composite[0]) +ax[2,0].set_title('Gaussian Blur Sigma=2') + +#%% [markdown] +# ########## TODO ############## +# - How is the model's predictions given the blurring and scaling perturbations?
+ + +#%% tags=["solution"] + +# ########## SOLUTIONS FOR ALL POSSIBLE PLOTTINGS ############## +# This plots all perturbations + +n = 3 +y_slice = slice(Y//2, Y//2+256*n) +x_slice = slice(X//2, X//2+256*n) +f, ax = plt.subplots(6, 2, figsize=(8, 12)) + +target_composite = composite_nuc_mem(batch["target"][0].cpu().numpy(), GREEN, MAGENTA) +ax[0, 0].imshow( + batch["source"][0, 0, 0,y_slice,x_slice].cpu().numpy(), cmap="gray", vmin=-15, vmax=15 +) +ax[0, 1].imshow(clip_highlight(target_composite[0,y_slice,x_slice])) +ax[0,0].set_title('Source and target') + +# no perturbation +with torch.inference_mode(): + phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice] + pred = model(phase).cpu().numpy() +pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE) +ax[1, 0].imshow(phase[0,0,0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15) +ax[1, 1].imshow(pred_composite[0]) +ax[1,0].set_title('No perturbation') + + +# 2-sigma gaussian blur +gaussian_blur = GaussianFilter(spatial_dims=3, sigma=(0,2,2)) +with torch.inference_mode(): + phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice] + phase = gaussian_blur(phase) + pred = model(phase).cpu().numpy() +pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE) +ax[2, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15) +ax[2, 1].imshow(pred_composite[0]) +ax[2,0].set_title('Gaussian Blur Sigma=2') + + +# 5-sigma gaussian blur +gaussian_blur = GaussianFilter(spatial_dims=3, sigma=(0,5,5)) +with torch.inference_mode(): + phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice] + phase = gaussian_blur(phase) + pred = model(phase).cpu().numpy() +pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE) +ax[3, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15) +ax[3, 1].imshow(pred_composite[0]) +ax[3,0].set_title('Gaussian Blur Sigma=5') + + +# 0.1x scaling +with torch.inference_mode(): + phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice] + phase = phase*0.1 + pred = model(phase).cpu().numpy() +pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE) +ax[4, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15) +ax[4, 1].imshow(pred_composite[0]) +ax[4,0].set_title('0.1x scaling') + +# 10x scaling +with torch.inference_mode(): + phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice] + phase = phase* 10 + pred = model(phase).cpu().numpy() +pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE) +ax[5, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15) +ax[5, 1].imshow(pred_composite[0]) +ax[5,0].set_title('10x scaling') + +for a in ax.ravel(): + a.axis("off") + +f.tight_layout() # %% [markdown] tags=[] #
@@ -1654,4 +1945,5 @@ def clip_p(image: np.ndarray) -> np.ndarray: # Congratulations! You have trained an image translation model, evaluated its performance, and explored what the network has learned. -#
\ No newline at end of file +# +# %%