From 2ff67f821f780cdbdcc7531d6c5f57095e9c2124 Mon Sep 17 00:00:00 2001 From: edyoshikun Date: Sat, 24 Aug 2024 01:30:09 +0000 Subject: [PATCH] adding the feature maps PCA to RGB. --- part_1/solution.py | 181 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 177 insertions(+), 4 deletions(-) diff --git a/part_1/solution.py b/part_1/solution.py index 15d297a..d9e77d9 100644 --- a/part_1/solution.py +++ b/part_1/solution.py @@ -1470,15 +1470,188 @@ def min_max_scale(image:ArrayLike)->ArrayLike: plt.tight_layout() plt.show() + + +# %%[markdown] tags=[] +#
+# +#

Task 3.1: Let's look at what the model is learning

+# +# - Here we will visualize the encoder feature maps of the trained model.
+# - We will use PCA to visualize the feature maps by mapping the first 3 principal components to a colormap
+#
+ +#%% +""" +Script to visualize the encoder feature maps of a trained model. +Using PCA to visualize feature maps is inspired by +https://doi.org/10.48550/arXiv.2304.07193 (Oquab et al., 2023). +""" +from matplotlib.patches import Rectangle +from skimage.exposure import rescale_intensity +from skimage.transform import downscale_local_mean +from sklearn.decomposition import PCA +from sklearn.manifold import TSNE +from typing import NamedTuple + +def feature_map_pca(feature_map: np.array, n_components: int = 8) -> PCA: + """ + Compute PCA on a feature map. + :param np.array feature_map: (C, H, W) feature map + :param int n_components: number of components to keep + :return: PCA: fit sklearn PCA object + """ + # (C, H, W) -> (C, H*W) + feat = feature_map.reshape(feature_map.shape[0], -1) + pca = PCA(n_components=n_components) + pca.fit(feat) + return pca + +def pcs_to_rgb(feat: np.ndarray, n_components: int = 8) -> np.ndarray: + pca = feature_map_pca(feat[0], n_components=n_components) + pc_first_3 = pca.components_[:3].reshape(3, *feat.shape[-2:]) + return np.stack( + [rescale_intensity(pc, out_range=(0, 1)) for pc in pc_first_3], axis=-1 + ) +#%% +# Load the test dataset +test_data_path = top_dir / "06_image_translation/part1/test/a549_hoechst_cellmask_test.zarr" +test_dataset = open_ome_zarr(test_data_path) + +# Looking at the test dataset +print('Test dataset:') +test_dataset.print_tree() + +#%% [markdown] tags=[] +#
+# +# - Change the `fov` and `crop` size to visualize the feature maps of the encoder and decoder
+# Note: the crop should be a multiple of 384 +#
+#%% +# Load one position +row = 0 +col = 0 +center_index = 2 +n = 1 +crop = 384 * n +fov = 10 + +# normalize phase +norm_meta = test_dataset.zattrs["normalization"]["Phase3D"]["dataset_statistics"] + +# Get the OME-Zarr metadata +Y,X = test_dataset[f"0/0/{fov}"].data.shape[-2:] +test_dataset.channel_names +phase_idx= test_dataset.channel_names.index('Phase3D') +assert crop//2 < Y and crop//2 < Y , "Crop size larger than the image. Check the image shape" + +phase_img = test_dataset[f"0/0/{fov}/0"][:, phase_idx:phase_idx+1,0:1, Y//2 - crop // 2 : Y//2 + crop // 2, X//2 - crop // 2 : X//2 + crop // 2] +fluo = test_dataset[f"0/0/{fov}/0"][0, 1:3, 0, Y//2 - crop // 2 : Y//2 + crop // 2, X//2 - crop // 2 : X//2 + crop // 2] + +phase_img = (phase_img - norm_meta["median"]) / norm_meta["iqr"] +plt.imshow(phase_img[0,0,0], cmap="gray") + +# %% +# TODO: modify the path to your model checkpoint +# phase2fluor_model_ckpt = natsorted(glob( +# str(top_dir/"06_image_translation/backup/phase2fluor/version_3/checkpoints/*.ckpt") +# ))[-1] + +# TODO: rerun with pretrained +pretrained_model_ckpt = ( + top_dir / "06_image_translation/part1/pretrained_models/VSCyto2D/epoch=399-step=23200.ckpt" +) + +# load model +model = VSUNet.load_from_checkpoint( + pretrained_model_ckpt, + architecture="UNeXt2_2D", + model_config=phase2fluor_config.copy(), + accelerator="gpu", +) + +# %% +# extract features +with torch.inference_mode(): + # encoder + encoder_features = model.model.encoder(torch.from_numpy(phase_img.astype(np.float32)).to(model.device))[0] + encoder_features_np = [f.detach().cpu().numpy() for f in encoder_features] + + # Print the encoder features shapes + for f in encoder_features_np: + print(f.shape) + + # decoder + features = encoder_features.copy() + features.reverse() + feat = features[0] + features.append(None) + decoder_features_np = [] + for skip, stage in zip(features[1:], model.model.decoder.decoder_stages): + feat = stage(feat, skip) + decoder_features_np.append(feat.detach().cpu().numpy()) + for f in decoder_features_np: + print(f.shape) + prediction = model.model.head(feat).detach().cpu().numpy() + +# Defining the colors for plotting +class Color(NamedTuple): + r: float + g: float + b: float + +BOP_ORANGE = Color(0.972549, 0.6784314, 0.1254902) +BOP_BLUE = Color(BOP_ORANGE.b, BOP_ORANGE.g, BOP_ORANGE.r) +GREEN = Color(0.0, 1.0, 0.0) +MAGENTA = Color(1.0, 0.0, 1.0) + +# Defining the functions to rescale the image and composite the nuclear and membrane images +def rescale_clip(image: torch.Tensor) -> np.ndarray: + return rescale_intensity(image, out_range=(0, 1))[..., None].repeat(3, axis=-1) + +def composite_nuc_mem( + image: torch.Tensor, nuc_color: Color, mem_color: Color +) -> np.ndarray: + c_nuc = rescale_clip(image[0]) * nuc_color + c_mem = rescale_clip(image[1]) * mem_color + return c_nuc + c_mem + +def clip_p(image: np.ndarray) -> np.ndarray: + return rescale_intensity(image.clip(*np.percentile(image, [1, 99]))) + +# Plot the PCA to RGB of the feature maps + +f, ax = plt.subplots(10, 1, figsize=(5, 25)) +n_components = 4 +ax[0].imshow(phase_img[0, 0, 0], cmap="gray") +ax[0].set_title(f"Phase {phase_img.shape[1:]}") +ax[-1].imshow(clip_p(composite_nuc_mem(fluo, GREEN, MAGENTA))) +ax[-1].set_title("Fluorescence") + +for level, feat in enumerate(encoder_features_np): + ax[level + 1].imshow(pcs_to_rgb(feat, n_components=n_components)) + ax[level + 1].set_title(f"Encoder stage {level+1} {feat.shape[1:]}") + +for level, feat in enumerate(decoder_features_np): + ax[5 + level].imshow(pcs_to_rgb(feat, n_components=n_components)) + ax[5 + level].set_title(f"Decoder stage {level+1} {feat.shape[1:]}") + +pred_comp = composite_nuc_mem(prediction[0, :, 0], BOP_BLUE, BOP_ORANGE) +ax[-2].imshow(clip_p(pred_comp)) +ax[-2].set_title(f"Prediction {prediction.shape[1:]}") + +for a in ax.ravel(): + a.axis("off") +plt.tight_layout() + # %% [markdown] tags=[] #
#

# 🎉 The end of the notebook 🎉 -# Continue to Part 2: Image translation with generative models. #

-# Congratulations! You have trained an image translation model and evaluated its performance. -#
+# Congratulations! You have trained an image translation model, evaluated its performance, and explored what the network has learned. -# %% +# \ No newline at end of file