diff --git a/part_1/solution.py b/part_1/solution.py
index d9e77d9..6c9b6d7 100644
--- a/part_1/solution.py
+++ b/part_1/solution.py
@@ -1615,7 +1615,7 @@ def composite_nuc_mem(
) -> np.ndarray:
c_nuc = rescale_clip(image[0]) * nuc_color
c_mem = rescale_clip(image[1]) * mem_color
- return c_nuc + c_mem
+ return rescale_intensity(c_nuc + c_mem, out_range=(0, 1))
def clip_p(image: np.ndarray) -> np.ndarray:
return rescale_intensity(image.clip(*np.percentile(image, [1, 99])))
@@ -1645,6 +1645,297 @@ def clip_p(image: np.ndarray) -> np.ndarray:
a.axis("off")
plt.tight_layout()
+#%% [markdown] tags=[]
+#
+#
+# ### Range of validity
+# -
+#
+#
+#%%
+
+# %%
+from monai.transforms import GaussianSmooth
+from monai.networks.layers import GaussianFilter
+def clip_highlight(image: np.ndarray) -> np.ndarray:
+ return rescale_intensity(image.clip(0, np.percentile(image, 99.5)))
+
+
+# %%
+YX_PATCH_SIZE = (256*2,256*2)
+# Re-load the dataloader
+phase2fluor_2D_data = HCSDataModule(
+ data_path,
+ architecture="UNeXt2_2D",
+ source_channel=source_channel,
+ target_channel=target_channel,
+ z_window_size=1,
+ split_ratio=0.8,
+ batch_size=1,
+ num_workers=8,
+ yx_patch_size=YX_PATCH_SIZE,
+ augmentations=[],
+ normalizations=normalizations,
+ # ground_truth_masks=''
+)
+phase2fluor_2D_data.setup("test")
+
+
+# %% [markdown] tags=["task"]
+#
+#
+# ### Task 3.2: Select a sample batch to test the range of validty of the model
+#
+# - Select a test batch from the `test_dataloader` by changing the `batch_number`
+# - Examine the plot of the source and target images of the batch
+#
+# Note the 2D images have different focus
+#
+#%% tags=["task"]
+# ########## TODO ##############
+batch_number= 3
+# #######################
+# Iterate through the test dataloader to get the desired batch
+i = 0
+for batch in phase2fluor_2D_data.test_dataloader():
+ # break if we reach the desired batch
+ if i == batch_number-1:
+ break
+ i += 1
+
+# Plot the batch source and target images
+f, ax = plt.subplots(1, 2, figsize=(8, 12))
+target_composite = composite_nuc_mem(batch["target"][0].cpu().numpy(), GREEN, MAGENTA)
+ax[0].imshow(
+ batch["source"][0, 0, 0,y_slice,x_slice].cpu().numpy(), cmap="gray", vmin=-15, vmax=15
+)
+ax[1].imshow(clip_highlight(target_composite[0,y_slice,x_slice]))
+for a in ax.ravel():
+ a.axis("off")
+f.tight_layout()
+plt.show()
+
+# %% [markdown] tags=[]
+#
+#
+# ### Task 3.3: Using the selected batch to test the model's range of validity
+#
+# - Given the selected batch use `monai.networks.layers.GaussianFilter` to blur the images with different sigmas.
+# Check the documentation
here
+# - Plot the source and predicted images comparing the source, target and added perturbations
+# - How is the model's predictions given the perturbations?
+#
+#%% tags=["task"]
+# ########## TODO ##############
+# Try out different multiples of 256 to visualize larger/smaller crops
+n = 3
+# ##############################
+# Center cropping the image
+y_slice = slice(Y//2-256*n//2, Y//2+256*n//2)
+x_slice = slice(X//2-256*n//2, X//2+256*n//2)
+
+f, ax = plt.subplots(3, 2, figsize=(8, 12))
+
+target_composite = composite_nuc_mem(batch["target"][0].cpu().numpy(), GREEN, MAGENTA)
+ax[0, 0].imshow(
+ batch["source"][0, 0, 0,y_slice,x_slice].cpu().numpy(), cmap="gray", vmin=-15, vmax=15
+)
+ax[0, 1].imshow(clip_highlight(target_composite[0,y_slice,x_slice]))
+ax[0,0].set_title('Source and target')
+
+# no perturbation
+with torch.inference_mode():
+ phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice]
+ pred = model(phase).cpu().numpy()
+pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
+ax[1, 0].imshow(phase[0,0,0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
+ax[1, 1].imshow(pred_composite[0])
+ax[1,0].set_title('No perturbation')
+
+
+# Select a sigma for the Gaussian filtering
+# ########## TODO ##############
+# Tensor dimensions (B,C,D,H,W).
+# Hint: Use the GaussianFilter layer to blur the phase image. Provide the num spatial dimensions and sigmas
+# Hint: Spatial (D,H,W)
+gaussian_blur = GaussianFilter(....)
+# #############################
+with torch.inference_mode():
+ phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice]
+ phase = gaussian_blur(phase)
+ pred = model(phase).cpu().numpy()
+pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
+ax[2, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
+ax[2, 1].imshow(pred_composite[0])
+
+#%% tags=["solution"]
+# ########## SOLUTION ##############
+# Try out different multiples of 256 to visualize larger/smaller crops
+n = 3
+# ##############################
+# Center cropping the image
+y_slice = slice(Y//2-256*n//2, Y//2+256*n//2)
+x_slice = slice(X//2-256*n//2, X//2+256*n//2)
+
+f, ax = plt.subplots(3, 2, figsize=(8, 12))
+
+target_composite = composite_nuc_mem(batch["target"][0].cpu().numpy(), GREEN, MAGENTA)
+ax[0, 0].imshow(
+ batch["source"][0, 0, 0,y_slice,x_slice].cpu().numpy(), cmap="gray", vmin=-15, vmax=15
+)
+ax[0, 1].imshow(clip_highlight(target_composite[0,y_slice,x_slice]))
+ax[0,0].set_title('Source and target')
+
+# no perturbation
+with torch.inference_mode():
+ phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice]
+ pred = model(phase).cpu().numpy()
+pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
+ax[1, 0].imshow(phase[0,0,0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
+ax[1, 1].imshow(pred_composite[0])
+ax[1,0].set_title('No perturbation')
+
+
+# Select a sigma for the Gaussian filtering
+# ########## SOLUTION ##############
+# Tensor dimensions (B,C,D,H,W).
+# Hint: Use the GaussianFilter layer to blur the phase image. Provide the num spatial dimensions and sigma
+# Hint: Spatial (D,H,W). Apply the same sigma to H,W
+gaussian_blur = GaussianFilter(spatial_dims=3, sigma=(0,2,2))
+# #############################
+with torch.inference_mode():
+ phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice]
+ phase = gaussian_blur(phase)
+ pred = model(phase).cpu().numpy()
+pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
+ax[2, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
+ax[2, 1].imshow(pred_composite[0])
+# %% [markdown] tags=[]
+#
+#
+# ### Task 3.3: Using the selected batch to test the model's range of validity
+#
+# - Scale the pixel values up/down of the phase image
+# - Plot the source and predicted images comparing the source, target and added perturbations
+# - How is the model's predictions given the perturbations?
+#
+
+#%% tags=["task"]
+n = 3
+y_slice = slice(Y//2, Y//2+256*n)
+x_slice = slice(X//2, X//2+256*n)
+f, ax = plt.subplots(3, 2, figsize=(8, 12))
+
+target_composite = composite_nuc_mem(batch["target"][0].cpu().numpy(), GREEN, MAGENTA)
+ax[0, 0].imshow(
+ batch["source"][0, 0, 0,y_slice,x_slice].cpu().numpy(), cmap="gray", vmin=-15, vmax=15
+)
+ax[0, 1].imshow(clip_highlight(target_composite[0,y_slice,x_slice]))
+ax[0,0].set_title('Source and target')
+
+# no perturbation
+with torch.inference_mode():
+ phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice]
+ pred = model(phase).cpu().numpy()
+pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
+ax[1, 0].imshow(phase[0,0,0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
+ax[1, 1].imshow(pred_composite[0])
+ax[1,0].set_title('No perturbation')
+
+
+# 2-sigma gaussian blur
+with torch.inference_mode():
+ phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice]
+ # ########## TODO ##############
+ # Hint: Scale the phase intensity
+ phase = phase * ......
+ # #######################
+ pred = model(phase).cpu().numpy()
+pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
+ax[2, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
+ax[2, 1].imshow(pred_composite[0])
+ax[2,0].set_title('Gaussian Blur Sigma=2')
+
+#%% [markdown]
+# ########## TODO ##############
+# - How is the model's predictions given the blurring and scaling perturbations?
+
+
+#%% tags=["solution"]
+
+# ########## SOLUTIONS FOR ALL POSSIBLE PLOTTINGS ##############
+# This plots all perturbations
+
+n = 3
+y_slice = slice(Y//2, Y//2+256*n)
+x_slice = slice(X//2, X//2+256*n)
+f, ax = plt.subplots(6, 2, figsize=(8, 12))
+
+target_composite = composite_nuc_mem(batch["target"][0].cpu().numpy(), GREEN, MAGENTA)
+ax[0, 0].imshow(
+ batch["source"][0, 0, 0,y_slice,x_slice].cpu().numpy(), cmap="gray", vmin=-15, vmax=15
+)
+ax[0, 1].imshow(clip_highlight(target_composite[0,y_slice,x_slice]))
+ax[0,0].set_title('Source and target')
+
+# no perturbation
+with torch.inference_mode():
+ phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice]
+ pred = model(phase).cpu().numpy()
+pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
+ax[1, 0].imshow(phase[0,0,0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
+ax[1, 1].imshow(pred_composite[0])
+ax[1,0].set_title('No perturbation')
+
+
+# 2-sigma gaussian blur
+gaussian_blur = GaussianFilter(spatial_dims=3, sigma=(0,2,2))
+with torch.inference_mode():
+ phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice]
+ phase = gaussian_blur(phase)
+ pred = model(phase).cpu().numpy()
+pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
+ax[2, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
+ax[2, 1].imshow(pred_composite[0])
+ax[2,0].set_title('Gaussian Blur Sigma=2')
+
+
+# 5-sigma gaussian blur
+gaussian_blur = GaussianFilter(spatial_dims=3, sigma=(0,5,5))
+with torch.inference_mode():
+ phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice]
+ phase = gaussian_blur(phase)
+ pred = model(phase).cpu().numpy()
+pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
+ax[3, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
+ax[3, 1].imshow(pred_composite[0])
+ax[3,0].set_title('Gaussian Blur Sigma=5')
+
+
+# 0.1x scaling
+with torch.inference_mode():
+ phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice]
+ phase = phase*0.1
+ pred = model(phase).cpu().numpy()
+pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
+ax[4, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
+ax[4, 1].imshow(pred_composite[0])
+ax[4,0].set_title('0.1x scaling')
+
+# 10x scaling
+with torch.inference_mode():
+ phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice]
+ phase = phase* 10
+ pred = model(phase).cpu().numpy()
+pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
+ax[5, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
+ax[5, 1].imshow(pred_composite[0])
+ax[5,0].set_title('10x scaling')
+
+for a in ax.ravel():
+ a.axis("off")
+
+f.tight_layout()
# %% [markdown] tags=[]
#
@@ -1654,4 +1945,5 @@ def clip_p(image: np.ndarray) -> np.ndarray:
# Congratulations! You have trained an image translation model, evaluated its performance, and explored what the network has learned.
-#
\ No newline at end of file
+#
+# %%