Skip to content

Commit

Permalink
Commit from GitHub Actions (Build Notebooks)
Browse files Browse the repository at this point in the history
  • Loading branch information
edyoshikun committed Aug 24, 2024
1 parent 2ff67f8 commit 71840d6
Show file tree
Hide file tree
Showing 2 changed files with 434 additions and 17 deletions.
232 changes: 223 additions & 9 deletions part_1/exercise.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1750,24 +1750,90 @@
" ax.set_yticks([])\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
"plt.show()\n",
"\n",
"\n",
"\n",
"# %%[markdown] tags=[]\n",
"# <div class=\"alert alert-info\">\n",
"#\n",
"# <h3> Task 3.1: Let's look at what the model is learning </h3>\n",
"# \n",
"# - Here we will visualize the encoder feature maps of the trained model. <br>\n",
"# - We will use PCA to visualize the feature maps by mapping the first 3 principal components to a colormap <br>\n",
"# </div>"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0e6ace44",
"metadata": {
"lines_to_next_cell": 0
},
"outputs": [],
"source": [
"\"\"\"\n",
"Script to visualize the encoder feature maps of a trained model.\n",
"Using PCA to visualize feature maps is inspired by\n",
"https://doi.org/10.48550/arXiv.2304.07193 (Oquab et al., 2023).\n",
"\"\"\"\n",
"from matplotlib.patches import Rectangle\n",
"from skimage.exposure import rescale_intensity\n",
"from skimage.transform import downscale_local_mean\n",
"from sklearn.decomposition import PCA\n",
"from sklearn.manifold import TSNE\n",
"from typing import NamedTuple\n",
"\n",
"def feature_map_pca(feature_map: np.array, n_components: int = 8) -> PCA:\n",
" \"\"\"\n",
" Compute PCA on a feature map.\n",
" :param np.array feature_map: (C, H, W) feature map\n",
" :param int n_components: number of components to keep\n",
" :return: PCA: fit sklearn PCA object\n",
" \"\"\"\n",
" # (C, H, W) -> (C, H*W)\n",
" feat = feature_map.reshape(feature_map.shape[0], -1)\n",
" pca = PCA(n_components=n_components)\n",
" pca.fit(feat)\n",
" return pca\n",
"\n",
"def pcs_to_rgb(feat: np.ndarray, n_components: int = 8) -> np.ndarray:\n",
" pca = feature_map_pca(feat[0], n_components=n_components)\n",
" pc_first_3 = pca.components_[:3].reshape(3, *feat.shape[-2:])\n",
" return np.stack(\n",
" [rescale_intensity(pc, out_range=(0, 1)) for pc in pc_first_3], axis=-1\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "febc728e",
"metadata": {},
"outputs": [],
"source": [
"# Load the test dataset\n",
"test_data_path = top_dir / \"06_image_translation/part1/test/a549_hoechst_cellmask_test.zarr\"\n",
"test_dataset = open_ome_zarr(test_data_path)\n",
"\n",
"# Looking at the test dataset\n",
"print('Test dataset:')\n",
"test_dataset.print_tree()"
]
},
{
"cell_type": "markdown",
"id": "24c959e2",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
},
"source": [
"<div class=\"alert alert-success\">\n",
"\n",
"<h2>\n",
"🎉 The end of the notebook 🎉\n",
"Continue to Part 2: Image translation with generative models.\n",
"</h2>\n",
"<div class=\"alert alert-info\">\n",
"\n",
"Congratulations! You have trained an image translation model and evaluated its performance.\n",
"- Change the `fov` and `crop` size to visualize the feature maps of the encoder and decoder <br>\n",
"Note: the crop should be a multiple of 384\n",
"</div>"
]
},
Expand All @@ -1777,7 +1843,155 @@
"id": "2d5426ea",
"metadata": {},
"outputs": [],
"source": []
"source": [
"# Load one position\n",
"row = 0\n",
"col = 0\n",
"center_index = 2\n",
"n = 1\n",
"crop = 384 * n\n",
"fov = 10\n",
"\n",
"# normalize phase\n",
"norm_meta = test_dataset.zattrs[\"normalization\"][\"Phase3D\"][\"dataset_statistics\"]\n",
"\n",
"# Get the OME-Zarr metadata\n",
"Y,X = test_dataset[f\"0/0/{fov}\"].data.shape[-2:]\n",
"test_dataset.channel_names\n",
"phase_idx= test_dataset.channel_names.index('Phase3D')\n",
"assert crop//2 < Y and crop//2 < Y , \"Crop size larger than the image. Check the image shape\"\n",
"\n",
"phase_img = test_dataset[f\"0/0/{fov}/0\"][:, phase_idx:phase_idx+1,0:1, Y//2 - crop // 2 : Y//2 + crop // 2, X//2 - crop // 2 : X//2 + crop // 2]\n",
"fluo = test_dataset[f\"0/0/{fov}/0\"][0, 1:3, 0, Y//2 - crop // 2 : Y//2 + crop // 2, X//2 - crop // 2 : X//2 + crop // 2]\n",
"\n",
"phase_img = (phase_img - norm_meta[\"median\"]) / norm_meta[\"iqr\"]\n",
"plt.imshow(phase_img[0,0,0], cmap=\"gray\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c885db9c",
"metadata": {},
"outputs": [],
"source": [
"# TODO: modify the path to your model checkpoint\n",
"# phase2fluor_model_ckpt = natsorted(glob(\n",
"# str(top_dir/\"06_image_translation/backup/phase2fluor/version_3/checkpoints/*.ckpt\")\n",
"# ))[-1]\n",
"\n",
"# TODO: rerun with pretrained\n",
"pretrained_model_ckpt = (\n",
" top_dir / \"06_image_translation/part1/pretrained_models/VSCyto2D/epoch=399-step=23200.ckpt\"\n",
")\n",
"\n",
"# load model\n",
"model = VSUNet.load_from_checkpoint(\n",
" pretrained_model_ckpt,\n",
" architecture=\"UNeXt2_2D\",\n",
" model_config=phase2fluor_config.copy(),\n",
" accelerator=\"gpu\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eb461c25",
"metadata": {},
"outputs": [],
"source": [
"# extract features\n",
"with torch.inference_mode():\n",
" # encoder\n",
" encoder_features = model.model.encoder(torch.from_numpy(phase_img.astype(np.float32)).to(model.device))[0]\n",
" encoder_features_np = [f.detach().cpu().numpy() for f in encoder_features]\n",
" \n",
" # Print the encoder features shapes\n",
" for f in encoder_features_np:\n",
" print(f.shape)\n",
"\n",
" # decoder\n",
" features = encoder_features.copy()\n",
" features.reverse()\n",
" feat = features[0]\n",
" features.append(None)\n",
" decoder_features_np = []\n",
" for skip, stage in zip(features[1:], model.model.decoder.decoder_stages):\n",
" feat = stage(feat, skip)\n",
" decoder_features_np.append(feat.detach().cpu().numpy())\n",
" for f in decoder_features_np:\n",
" print(f.shape)\n",
" prediction = model.model.head(feat).detach().cpu().numpy()\n",
" \n",
"# Defining the colors for plotting\n",
"class Color(NamedTuple):\n",
" r: float\n",
" g: float\n",
" b: float\n",
"\n",
"BOP_ORANGE = Color(0.972549, 0.6784314, 0.1254902)\n",
"BOP_BLUE = Color(BOP_ORANGE.b, BOP_ORANGE.g, BOP_ORANGE.r)\n",
"GREEN = Color(0.0, 1.0, 0.0)\n",
"MAGENTA = Color(1.0, 0.0, 1.0)\n",
"\n",
"# Defining the functions to rescale the image and composite the nuclear and membrane images\n",
"def rescale_clip(image: torch.Tensor) -> np.ndarray:\n",
" return rescale_intensity(image, out_range=(0, 1))[..., None].repeat(3, axis=-1)\n",
"\n",
"def composite_nuc_mem(\n",
" image: torch.Tensor, nuc_color: Color, mem_color: Color\n",
") -> np.ndarray:\n",
" c_nuc = rescale_clip(image[0]) * nuc_color\n",
" c_mem = rescale_clip(image[1]) * mem_color\n",
" return c_nuc + c_mem\n",
"\n",
"def clip_p(image: np.ndarray) -> np.ndarray:\n",
" return rescale_intensity(image.clip(*np.percentile(image, [1, 99])))\n",
"\n",
"# Plot the PCA to RGB of the feature maps\n",
"\n",
"f, ax = plt.subplots(10, 1, figsize=(5, 25))\n",
"n_components = 4\n",
"ax[0].imshow(phase_img[0, 0, 0], cmap=\"gray\")\n",
"ax[0].set_title(f\"Phase {phase_img.shape[1:]}\")\n",
"ax[-1].imshow(clip_p(composite_nuc_mem(fluo, GREEN, MAGENTA)))\n",
"ax[-1].set_title(\"Fluorescence\")\n",
"\n",
"for level, feat in enumerate(encoder_features_np):\n",
" ax[level + 1].imshow(pcs_to_rgb(feat, n_components=n_components))\n",
" ax[level + 1].set_title(f\"Encoder stage {level+1} {feat.shape[1:]}\")\n",
"\n",
"for level, feat in enumerate(decoder_features_np):\n",
" ax[5 + level].imshow(pcs_to_rgb(feat, n_components=n_components))\n",
" ax[5 + level].set_title(f\"Decoder stage {level+1} {feat.shape[1:]}\")\n",
"\n",
"pred_comp = composite_nuc_mem(prediction[0, :, 0], BOP_BLUE, BOP_ORANGE)\n",
"ax[-2].imshow(clip_p(pred_comp))\n",
"ax[-2].set_title(f\"Prediction {prediction.shape[1:]}\")\n",
"\n",
"for a in ax.ravel():\n",
" a.axis(\"off\")\n",
"plt.tight_layout()"
]
},
{
"cell_type": "markdown",
"id": "12d96459",
"metadata": {
"tags": []
},
"source": [
"<div class=\"alert alert-success\">\n",
"\n",
"<h2>\n",
"🎉 The end of the notebook 🎉\n",
"</h2>\n",
"\n",
"Congratulations! You have trained an image translation model, evaluated its performance, and explored what the network has learned. \n",
"\n",
"</div>"
]
}
],
"metadata": {
Expand Down
Loading

0 comments on commit 71840d6

Please sign in to comment.