diff --git a/part_1/solution.py b/part_1/solution.py
index 15d297a..d9e77d9 100644
--- a/part_1/solution.py
+++ b/part_1/solution.py
@@ -1470,15 +1470,188 @@ def min_max_scale(image:ArrayLike)->ArrayLike:
plt.tight_layout()
plt.show()
+
+
+# %%[markdown] tags=[]
+#
+#
+#
Task 3.1: Let's look at what the model is learning
+#
+# - Here we will visualize the encoder feature maps of the trained model.
+# - We will use PCA to visualize the feature maps by mapping the first 3 principal components to a colormap
+#
+
+#%%
+"""
+Script to visualize the encoder feature maps of a trained model.
+Using PCA to visualize feature maps is inspired by
+https://doi.org/10.48550/arXiv.2304.07193 (Oquab et al., 2023).
+"""
+from matplotlib.patches import Rectangle
+from skimage.exposure import rescale_intensity
+from skimage.transform import downscale_local_mean
+from sklearn.decomposition import PCA
+from sklearn.manifold import TSNE
+from typing import NamedTuple
+
+def feature_map_pca(feature_map: np.array, n_components: int = 8) -> PCA:
+ """
+ Compute PCA on a feature map.
+ :param np.array feature_map: (C, H, W) feature map
+ :param int n_components: number of components to keep
+ :return: PCA: fit sklearn PCA object
+ """
+ # (C, H, W) -> (C, H*W)
+ feat = feature_map.reshape(feature_map.shape[0], -1)
+ pca = PCA(n_components=n_components)
+ pca.fit(feat)
+ return pca
+
+def pcs_to_rgb(feat: np.ndarray, n_components: int = 8) -> np.ndarray:
+ pca = feature_map_pca(feat[0], n_components=n_components)
+ pc_first_3 = pca.components_[:3].reshape(3, *feat.shape[-2:])
+ return np.stack(
+ [rescale_intensity(pc, out_range=(0, 1)) for pc in pc_first_3], axis=-1
+ )
+#%%
+# Load the test dataset
+test_data_path = top_dir / "06_image_translation/part1/test/a549_hoechst_cellmask_test.zarr"
+test_dataset = open_ome_zarr(test_data_path)
+
+# Looking at the test dataset
+print('Test dataset:')
+test_dataset.print_tree()
+
+#%% [markdown] tags=[]
+#
+#
+# - Change the `fov` and `crop` size to visualize the feature maps of the encoder and decoder
+# Note: the crop should be a multiple of 384
+#
+#%%
+# Load one position
+row = 0
+col = 0
+center_index = 2
+n = 1
+crop = 384 * n
+fov = 10
+
+# normalize phase
+norm_meta = test_dataset.zattrs["normalization"]["Phase3D"]["dataset_statistics"]
+
+# Get the OME-Zarr metadata
+Y,X = test_dataset[f"0/0/{fov}"].data.shape[-2:]
+test_dataset.channel_names
+phase_idx= test_dataset.channel_names.index('Phase3D')
+assert crop//2 < Y and crop//2 < Y , "Crop size larger than the image. Check the image shape"
+
+phase_img = test_dataset[f"0/0/{fov}/0"][:, phase_idx:phase_idx+1,0:1, Y//2 - crop // 2 : Y//2 + crop // 2, X//2 - crop // 2 : X//2 + crop // 2]
+fluo = test_dataset[f"0/0/{fov}/0"][0, 1:3, 0, Y//2 - crop // 2 : Y//2 + crop // 2, X//2 - crop // 2 : X//2 + crop // 2]
+
+phase_img = (phase_img - norm_meta["median"]) / norm_meta["iqr"]
+plt.imshow(phase_img[0,0,0], cmap="gray")
+
+# %%
+# TODO: modify the path to your model checkpoint
+# phase2fluor_model_ckpt = natsorted(glob(
+# str(top_dir/"06_image_translation/backup/phase2fluor/version_3/checkpoints/*.ckpt")
+# ))[-1]
+
+# TODO: rerun with pretrained
+pretrained_model_ckpt = (
+ top_dir / "06_image_translation/part1/pretrained_models/VSCyto2D/epoch=399-step=23200.ckpt"
+)
+
+# load model
+model = VSUNet.load_from_checkpoint(
+ pretrained_model_ckpt,
+ architecture="UNeXt2_2D",
+ model_config=phase2fluor_config.copy(),
+ accelerator="gpu",
+)
+
+# %%
+# extract features
+with torch.inference_mode():
+ # encoder
+ encoder_features = model.model.encoder(torch.from_numpy(phase_img.astype(np.float32)).to(model.device))[0]
+ encoder_features_np = [f.detach().cpu().numpy() for f in encoder_features]
+
+ # Print the encoder features shapes
+ for f in encoder_features_np:
+ print(f.shape)
+
+ # decoder
+ features = encoder_features.copy()
+ features.reverse()
+ feat = features[0]
+ features.append(None)
+ decoder_features_np = []
+ for skip, stage in zip(features[1:], model.model.decoder.decoder_stages):
+ feat = stage(feat, skip)
+ decoder_features_np.append(feat.detach().cpu().numpy())
+ for f in decoder_features_np:
+ print(f.shape)
+ prediction = model.model.head(feat).detach().cpu().numpy()
+
+# Defining the colors for plotting
+class Color(NamedTuple):
+ r: float
+ g: float
+ b: float
+
+BOP_ORANGE = Color(0.972549, 0.6784314, 0.1254902)
+BOP_BLUE = Color(BOP_ORANGE.b, BOP_ORANGE.g, BOP_ORANGE.r)
+GREEN = Color(0.0, 1.0, 0.0)
+MAGENTA = Color(1.0, 0.0, 1.0)
+
+# Defining the functions to rescale the image and composite the nuclear and membrane images
+def rescale_clip(image: torch.Tensor) -> np.ndarray:
+ return rescale_intensity(image, out_range=(0, 1))[..., None].repeat(3, axis=-1)
+
+def composite_nuc_mem(
+ image: torch.Tensor, nuc_color: Color, mem_color: Color
+) -> np.ndarray:
+ c_nuc = rescale_clip(image[0]) * nuc_color
+ c_mem = rescale_clip(image[1]) * mem_color
+ return c_nuc + c_mem
+
+def clip_p(image: np.ndarray) -> np.ndarray:
+ return rescale_intensity(image.clip(*np.percentile(image, [1, 99])))
+
+# Plot the PCA to RGB of the feature maps
+
+f, ax = plt.subplots(10, 1, figsize=(5, 25))
+n_components = 4
+ax[0].imshow(phase_img[0, 0, 0], cmap="gray")
+ax[0].set_title(f"Phase {phase_img.shape[1:]}")
+ax[-1].imshow(clip_p(composite_nuc_mem(fluo, GREEN, MAGENTA)))
+ax[-1].set_title("Fluorescence")
+
+for level, feat in enumerate(encoder_features_np):
+ ax[level + 1].imshow(pcs_to_rgb(feat, n_components=n_components))
+ ax[level + 1].set_title(f"Encoder stage {level+1} {feat.shape[1:]}")
+
+for level, feat in enumerate(decoder_features_np):
+ ax[5 + level].imshow(pcs_to_rgb(feat, n_components=n_components))
+ ax[5 + level].set_title(f"Decoder stage {level+1} {feat.shape[1:]}")
+
+pred_comp = composite_nuc_mem(prediction[0, :, 0], BOP_BLUE, BOP_ORANGE)
+ax[-2].imshow(clip_p(pred_comp))
+ax[-2].set_title(f"Prediction {prediction.shape[1:]}")
+
+for a in ax.ravel():
+ a.axis("off")
+plt.tight_layout()
+
# %% [markdown] tags=[]
#
#
# 🎉 The end of the notebook 🎉
-# Continue to Part 2: Image translation with generative models.
#
-# Congratulations! You have trained an image translation model and evaluated its performance.
-#
+# Congratulations! You have trained an image translation model, evaluated its performance, and explored what the network has learned.
-# %%
+#
\ No newline at end of file