diff --git a/_sources/notebooks/Tutorial6_Registration/Tutorial6_Registration_student.ipynb b/_sources/notebooks/Tutorial6_Registration/Tutorial6_Registration_student.ipynb index 383fa65..931eced 100644 --- a/_sources/notebooks/Tutorial6_Registration/Tutorial6_Registration_student.ipynb +++ b/_sources/notebooks/Tutorial6_Registration/Tutorial6_Registration_student.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": "bbf79dc9", "metadata": {"user_expressions": []}, "source": ["# Tutorial 6\n", "## June 18, 2024\n", "In this tutorial you will develop, train, and evaluate a CNN that learns to perform deformable image registration in chest X-ray images. "]}, {"cell_type": "markdown", "id": "0131100c", "metadata": {"user_expressions": []}, "source": ["First, let's take care of the necessities:\n", "- If you're using Google Colab, make sure to select a GPU Runtime.\n", "- Connect to Weights & Biases using the code below.\n", "- Install a few libraries that we will use in this tutorial."]}, {"cell_type": "code", "execution_count": null, "id": "19dc1ed4", "metadata": {}, "outputs": [], "source": ["import os\n", "import wandb\n", "\n", "os.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\"\n", "wandb.login()"]}, {"cell_type": "code", "execution_count": null, "id": "4f42a830", "metadata": {}, "outputs": [], "source": ["!pip install dival\n", "!pip install kornia\n", "!pip install monai"]}, {"cell_type": "markdown", "id": "8d93b495", "metadata": {"user_expressions": []}, "source": ["## Part 1 - Registration"]}, {"cell_type": "code", "execution_count": null, "id": "a9968f5a", "metadata": {}, "outputs": [], "source": ["import monai\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import torch\n", "import wandb"]}, {"cell_type": "markdown", "id": "7865deb7", "metadata": {"user_expressions": []}, "source": ["We will register chest X-ray images. We will reuse the data of **Tutorial 3**. As always, we first set the paths. This should be the path ending in 'ribs'. If you don't have the data set anymore, you can download it using the lines below:"]}, {"cell_type": "code", "execution_count": null, "id": "3eee66c0", "metadata": {}, "outputs": [], "source": ["!wget https://surfdrive.surf.nl/files/index.php/s/Y4psc2pQnfkJuoT/download -O Tutorial_3.zip\n", "!unzip -qo Tutorial_3.zip\n", "data_path = \"ribs\""]}, {"cell_type": "code", "execution_count": null, "id": "469e6d1f", "metadata": {}, "outputs": [], "source": ["# ONLY IF YOU USE JUPYTER: ADD PATH \u2328\ufe0f\n", "data_path = r'/Users/jmwolterink/Downloads/ribs'# WHEREDIDYOUPUTTHEDATA?"]}, {"cell_type": "code", "execution_count": null, "id": "dd90ec3d", "metadata": {}, "outputs": [], "source": ["# ONLY IF YOU USE COLAB: ADD PATH \u2328\ufe0f\n", "from google.colab import drive\n", "\n", "drive.mount('/content/drive')\n", "data_path = r'/content/drive/My Drive/Tutorial3'"]}, {"cell_type": "code", "execution_count": null, "id": "5d513239", "metadata": {}, "outputs": [], "source": ["# check if data_path exists:\n", "import os\n", "\n", "if not os.path.exists(data_path):\n", " print(\"Please update your data path to an existing folder.\")\n", "elif not set([\"train\", \"val\", \"test\"]).issubset(set(os.listdir(data_path))):\n", " print(\"Please update your data path to the correct folder (should contain train, val and test folders).\")\n", "else:\n", " print(\"Congrats! You selected the correct folder :)\")"]}, {"cell_type": "markdown", "id": "c75ef28b", "metadata": {}, "source": ["### Data management\n", "\n", "In this part we prepare all the tools needed to load and visualize our samples. One thing we *could* do is perform **inter**-patient registration, i.e., register two chest X-ray images of different patients. However, this is a very challenging problem. Instead, to make our life a bit easier, we will perform **intra**-patient registration: register two images of the same patient. For each patient, we make a synthetic moving image by applying some random elastic deformations. To build this data set, we we used the [Rand2DElasticd](https://docs.monai.io/en/stable/transforms.html#rand2delastic) transform on both the image and the mask. We will use a neural network to learn the deformation field between the fixed image and the moving image.\n", ""]}, {"cell_type": "markdown", "id": "a441968c", "metadata": {"user_expressions": []}, "source": ["Similarly as in **Tutorial 3**, make a dictionary of the image file names."]}, {"cell_type": "code", "execution_count": null, "id": "ed3a37b4", "metadata": {}, "outputs": [], "source": ["import os\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import glob\n", "import monai\n", "from PIL import Image\n", "import torch\n", "\n", "def build_dict_ribs(data_path, mode='train'):\n", " \"\"\"\n", " This function returns a list of dictionaries, each dictionary containing the keys 'img' and 'mask' \n", " that returns the path to the corresponding image.\n", " \n", " Args:\n", " data_path (str): path to the root folder of the data set.\n", " mode (str): subset used. Must correspond to 'train', 'val' or 'test'.\n", " \n", " Returns:\n", " (List[Dict[str, str]]) list of the dictionnaries containing the paths of X-ray images and masks.\n", " \"\"\"\n", " # test if mode is correct\n", " if mode not in [\"train\", \"val\", \"test\"]:\n", " raise ValueError(f\"Please choose a mode in ['train', 'val', 'test']. Current mode is {mode}.\")\n", " \n", " # define empty dictionary\n", " dicts = []\n", " # list all .png files in directory, including the path\n", " paths_xray = glob.glob(os.path.join(data_path, mode, 'img', '*.png'))\n", " # make a corresponding list for all the mask files\n", " for xray_path in paths_xray:\n", " if mode == 'test':\n", " suffix = 'val'\n", " else:\n", " suffix = mode\n", " # find the binary mask that belongs to the original image, based on indexing in the filename\n", " image_index = os.path.split(xray_path)[1].split('_')[-1].split('.')[0]\n", " # define path to mask file based on this index and add to list of mask paths\n", " mask_path = os.path.join(data_path, mode, 'mask', f'VinDr_RibCXR_{suffix}_{image_index}.png')\n", " if os.path.exists(mask_path):\n", " dicts.append({'fixed': xray_path, 'moving': xray_path, 'fixed_mask': mask_path, 'moving_mask': mask_path})\n", " return dicts\n", "\n", "class LoadRibData(monai.transforms.Transform):\n", " \"\"\"\n", " This custom Monai transform loads the data from the rib segmentation dataset.\n", " Defining a custom transform is simple; just overwrite the __init__ function and __call__ function.\n", " \"\"\"\n", " def __init__(self, keys=None):\n", " pass\n", "\n", " def __call__(self, sample):\n", " fixed = Image.open(sample['fixed']).convert('L') # import as grayscale image\n", " fixed = np.array(fixed, dtype=np.uint8)\n", " moving = Image.open(sample['moving']).convert('L') # import as grayscale image\n", " moving = np.array(moving, dtype=np.uint8) \n", " fixed_mask = Image.open(sample['fixed_mask']).convert('L') # import as grayscale image\n", " fixed_mask = np.array(fixed_mask, dtype=np.uint8)\n", " moving_mask = Image.open(sample['moving_mask']).convert('L') # import as grayscale image\n", " moving_mask = np.array(moving_mask, dtype=np.uint8) \n", " # mask has value 255 on rib pixels. Convert to binary array\n", " fixed_mask[np.where(fixed_mask==255)] = 1\n", " moving_mask[np.where(moving_mask==255)] = 1 \n", " return {'fixed': fixed, 'moving': moving, 'fixed_mask': fixed_mask, 'moving_mask': moving_mask, 'img_meta_dict': {'affine': np.eye(2)}, \n", " 'mask_meta_dict': {'affine': np.eye(2)}}"]}, {"cell_type": "markdown", "id": "db485cff", "metadata": {"user_expressions": []}, "source": ["Then we make a training dataset like before. The `Rand2DElasticd` transform here determines how much deformation is in the 'moving' image. "]}, {"cell_type": "code", "execution_count": null, "id": "db413cd7", "metadata": {}, "outputs": [], "source": ["train_dict_list = build_dict_ribs(data_path, mode='train')\n", "\n", "# constructDataset from list of paths + transform\n", "transform = monai.transforms.Compose(\n", "[\n", " LoadRibData(),\n", " monai.transforms.AddChanneld(keys=['fixed', 'moving', 'fixed_mask', 'moving_mask']),\n", " monai.transforms.Resized(keys=['fixed', 'moving', 'fixed_mask', 'moving_mask'], spatial_size=(256, 256), mode=['bilinear', 'bilinear', 'nearest', 'nearest']),\n", " monai.transforms.HistogramNormalized(keys=['fixed', 'moving']),\n", " monai.transforms.ScaleIntensityd(keys=['fixed', 'moving'], minv=0.0, maxv=1.0),\n", " monai.transforms.Rand2DElasticd(keys=['moving', 'moving_mask'], spacing=(64, 64), \n", " magnitude_range=(-8, 8), prob=1, mode=['bilinear', 'nearest']), \n", "])\n", "train_dataset = monai.data.Dataset(train_dict_list, transform=transform)"]}, {"cell_type": "markdown", "id": "6475f780", "metadata": {"user_expressions": []}, "source": [":::{admonition} Exercise\n", ":class: tip\n", "Visualize fixed and moving training images associated to their comparison image with the `visualize_fmc_sample` function below.\n", "\n", "Try different methods to create the comparison image. How well do these different methods allow you to qualitatively assess the quality of the registration?\n", "\n", "More information on this method is available in [the scikit-image documentation](https://scikit-image.org/docs/stable/api/skimage.util.html#skimage.util.compare_images).\n", "\n", ":::"]}, {"cell_type": "code", "execution_count": null, "id": "51ef974b", "metadata": {"lines_to_next_cell": 2}, "outputs": [], "source": ["def visualize_fmc_sample(sample, method=\"checkerboard\"):\n", " \"\"\"\n", " Plot three images: fixed, moving and comparison.\n", " \n", " Args:\n", " sample (dict): sample of dataset created with `build_dataset`.\n", " method (str): method used by `skimage.util.compare_image`.\n", " \"\"\"\n", " import skimage.util as skut \n", " \n", " skut_methods = [\"diff\", \"blend\", \"checkerboard\"]\n", " if method not in skut_methods:\n", " raise ValueError(f\"Method must be chosen in {skut_methods}.\\n\"\n", " f\"Current value is {method}.\")\n", " \n", " \n", " fixed = np.squeeze(sample['fixed'])\n", " moving = np.squeeze(sample['moving'])\n", " comp_checker = skut.compare_images(fixed, moving, method=method)\n", " axs = plt.figure(constrained_layout=True, figsize=(15, 5)).subplot_mosaic(\"FMC\")\n", " axs['F'].imshow(fixed, cmap='gray')\n", " axs['F'].set_title('Fixed')\n", " axs['M'].imshow(moving, cmap='gray')\n", " axs['M'].set_title('Moving')\n", " axs['C'].imshow(comp_checker, cmap='gray')\n", " axs['C'].set_title('Comparison')\n", " plt.show()"]}, {"cell_type": "code", "execution_count": null, "id": "cb60c63e", "metadata": {}, "outputs": [], "source": ["sample = train_dataset[0]\n", "for method in [\"diff\", \"blend\", \"checkerboard\"]:\n", " print(f\"Method {method}\")\n", " visualize_fmc_sample(sample, method=method)"]}, {"cell_type": "markdown", "id": "129c46e6", "metadata": {"user_expressions": []}, "source": ["Now we apply a little trick. Because applying the random deformation in each training iteration will be very costly, we only apply the deformation once and we make a new dataset based on the deformed images. Running the cell below may take a few minutes."]}, {"cell_type": "code", "execution_count": null, "id": "f3ec6b68", "metadata": {}, "outputs": [], "source": ["import tqdm\n", "\n", "train_loader = monai.data.DataLoader(train_dataset, batch_size=1, shuffle=False)\n", "\n", "samples = []\n", "for train_batch in tqdm.tqdm(train_loader):\n", " samples.append(train_batch)\n", "\n", "# Make a new dataset and dataloader using the transformed images\n", "train_dataset = monai.data.Dataset(samples, transform=monai.transforms.SqueezeDimd(keys=['fixed', 'moving', 'fixed_mask', 'moving_mask']))\n", "train_loader = monai.data.DataLoader(train_dataset, batch_size=16, shuffle=False)"]}, {"cell_type": "markdown", "id": "879f3ee3", "metadata": {"user_expressions": []}, "source": [":::{admonition} Exercise\n", ":class: tip\n", "Create `val_dataset` and `val_loader`, corresponding to the `DataSet` and `DataLoader` for your validation set. The transforms can be the same as in the training set.\n", ":::"]}, {"cell_type": "code", "execution_count": null, "id": "1090bfc3", "metadata": {"tags": ["student"]}, "outputs": [], "source": ["# Your code goes here"]}, {"cell_type": "markdown", "id": "eaebbc90", "metadata": {"user_expressions": []}, "source": ["### Model\n", "\n", "As model, we'll use a U-Net. The input/output structure is quite different from what we've seen before:\n", "- the network takes as input two images: the *moving* and *fixed* images.\n", "- it outputs one tensor representing the *deformation field*.\n", "\n", "\n", "\n", "\n", "This *deformation field* can be applied to the *moving* image with the `monai.networks.blocks.Warp` block of Monai.\n", "\n", "\n", "\n", "\n", "This deformed moving image is then compared to the *fixed* image: if they are similar, the deformation field is correctly registering the moving image on the fixed image. Keep in mind that this is done on **training** data, and we want the U-Net to learn to predict a proper deformation field given two new and unseen images. So we're not optimizing for a pair of images as would be done in conventional iterative registration, but training a model that can generalize.\n", "\n", "\n"]}, {"cell_type": "markdown", "id": "b51070ce", "metadata": {"user_expressions": []}, "source": ["Before starting, let's check that you can work on a GPU by runnning the following cell:\n", "- if the device is \"cuda\" you are working on a GPU,\n", "- if the device is \"cpu\" call a teacher."]}, {"cell_type": "code", "execution_count": null, "id": "06e3d24f", "metadata": {"lines_to_next_cell": 2}, "outputs": [], "source": ["if torch.cuda.is_available():\n", " device = torch.device(\"cuda\")\n", "elif torch.backends.mps.is_available():\n", " device = torch.device(\"mps\")\n", " os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"]=\"1\"\n", "else:\n", " device = \"cpu\"\n", "print(f'The used device is {device}')"]}, {"cell_type": "markdown", "id": "b47a520d", "metadata": {"user_expressions": []}, "source": [":::{admonition} Exercise\n", ":class: tip\n", "Construct a U-Net with suitable settings and name it `model`. Check that you can correctly apply its output to the input moving image with the `warp_layer`!\n", ":::"]}, {"cell_type": "code", "execution_count": null, "id": "803e5cce", "metadata": {"tags": ["student"]}, "outputs": [], "source": ["model = # FILL IN\n", "\n", "warp_layer = monai.networks.blocks.Warp().to(device)"]}, {"cell_type": "markdown", "id": "28352f01", "metadata": {}, "source": ["### Objective function\n", "\n", "We evaluate the similarity between the fixed image and the deformed moving image with the `MSELoss()`. The L1 or SSIM losses seen in the previous section could also be used. Furthermore, the deformation field is regularized with `BendingEnergyLoss`. This is a penalty that takes the smoothness of the deformation field into account: if it's not smooth enough, the bending energy is high. Thus, our model will favor smooth deformation fields.\n", "\n", "Finally, we pick an optimizer, in this case again an Adam optimizer."]}, {"cell_type": "code", "execution_count": null, "id": "45d79ea1", "metadata": {}, "outputs": [], "source": ["image_loss = torch.nn.MSELoss()\n", "regularization = monai.losses.BendingEnergyLoss()\n", "optimizer = torch.optim.Adam(model.parameters(), 1e-3)"]}, {"cell_type": "markdown", "id": "8e4f8204", "metadata": {"user_expressions": []}, "source": [":::{admonition} Exercise\n", ":class: tip\n", "Add a learning rate scheduler that lowers the learning rate by a factor ten every 100 epochs.\n", ":::"]}, {"cell_type": "code", "execution_count": null, "id": "2035ad19", "metadata": {"tags": ["student"]}, "outputs": [], "source": ["# Your code goes here"]}, {"cell_type": "markdown", "id": "56f1b410", "metadata": {"user_expressions": []}, "source": ["To warp the moving image using the predicted deformation field and *then* compute the loss between the deformed image and the fixed image, we define a forward function which does all this. The output of this function is `pred_image`. "]}, {"cell_type": "code", "execution_count": null, "id": "2159eb8a", "metadata": {}, "outputs": [], "source": ["def forward(batch_data, model):\n", " \"\"\"\n", " Applies the model to a batch of data.\n", " \n", " Args:\n", " batch_data (dict): a batch of samples computed by a DataLoader.\n", " model (Module): a model computing the deformation field.\n", " \n", " Returns:\n", " ddf (Tensor): batch of deformation fields.\n", " pred_image (Tensor): batch of deformed moving images.\n", " \n", " \"\"\"\n", " fixed_image = batch_data[\"fixed\"].to(device).float()\n", " moving_image = batch_data[\"moving\"].to(device).float()\n", " \n", " # predict DDF\n", " ddf = model(torch.cat((moving_image, fixed_image), dim=1))\n", "\n", " # warp moving image and label with the predicted ddf\n", " pred_image = warp_layer(moving_image, ddf)\n", "\n", " return ddf, pred_image"]}, {"cell_type": "markdown", "id": "1b0aea1e", "metadata": {"user_expressions": []}, "source": ["You can supervise the training process in W&B, in which at each epoch a batch of validation images are used to compute the comparison images of your choice, based on the parameter `method`."]}, {"cell_type": "code", "execution_count": null, "id": "71b9cc62", "metadata": {}, "outputs": [], "source": ["def log_to_wandb(epoch, train_loss, val_loss, pred_batch, fixed_batch, method=\"checkerboard\"):\n", " \"\"\" Function that logs ongoing training variables to W&B \"\"\"\n", " import skimage.util as skut\n", " \n", " log_imgs = []\n", " for fixed_pt, pred_pt in zip(pred_batch, fixed_batch):\n", " fixed_np = np.squeeze(fixed_pt.cpu().detach())\n", " pred_np = np.squeeze(pred_pt.cpu().detach())\n", " comp_checker = skut.compare_images(fixed_np, pred_np, method=method)\n", " log_imgs.append(wandb.Image(comp_checker))\n", "\n", " # Send epoch, losses and images to W&B\n", " wandb.log({'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss, 'results': log_imgs})"]}, {"cell_type": "markdown", "id": "fb0a53ac", "metadata": {"user_expressions": []}, "source": ["### Training time\n", "\n", "Use the following cells to train your network. You may choose different parameters to improve the performance!"]}, {"cell_type": "code", "execution_count": null, "id": "0aee95b5", "metadata": {}, "outputs": [], "source": ["# Choose your parameters\n", "\n", "max_epochs = 200\n", "reg_weight = 0 # By default 0, but you can investigate what it does"]}, {"cell_type": "code", "execution_count": null, "id": "b915ac88", "metadata": {}, "outputs": [], "source": ["from tqdm import tqdm\n", "\n", "run = wandb.init(\n", " project='tutorial4_registration',\n", " config={\n", " 'lr': optimizer.param_groups[0][\"lr\"],\n", " 'batch_size': train_loader.batch_size,\n", " 'regularization': reg_weight,\n", " 'loss_function': str(image_loss)\n", " }\n", ")\n", "# Do not hesitate to enrich this list of settings to be able to correctly keep track of your experiments!\n", "# For example you should add information on your model...\n", "\n", "run_id = run.id # We remember here the run ID to be able to write the evaluation metrics\n", "\n", "for epoch in tqdm(range(max_epochs)): \n", " model.train()\n", " epoch_loss = 0\n", " for batch_data in train_loader:\n", " optimizer.zero_grad()\n", "\n", " ddf, pred_image = forward(batch_data, model)\n", "\n", " fixed_image = batch_data[\"fixed\"].to(device).float()\n", " reg = regularization(ddf)\n", " loss = image_loss(pred_image, fixed_image) + reg_weight * reg\n", " loss.backward()\n", " optimizer.step()\n", " epoch_loss += loss.item()\n", "\n", " epoch_loss /= len(train_loader)\n", "\n", " model.eval()\n", " val_epoch_loss = 0\n", " for batch_data in val_loader:\n", " ddf, pred_image = forward(batch_data, model)\n", " fixed_image = batch_data[\"fixed\"].to(device).float()\n", " reg = regularization(ddf)\n", " loss = image_loss(pred_image, fixed_image) + reg_weight * reg\n", " val_epoch_loss += loss.item()\n", " val_epoch_loss /= len(val_loader)\n", "\n", " log_to_wandb(epoch, epoch_loss, val_epoch_loss, pred_image, fixed_image)\n", " \n", "run.finish() "]}, {"cell_type": "markdown", "id": "e5641636", "metadata": {"user_expressions": []}, "source": ["### Evaluation of the trained model\n", "\n", "Now that the model has been trained, it's time to evaluate its performance. Use the code below to visualize samples and deformation fields. \n", "\n", ":::{admonition} Exercise\n", ":class: tip\n", "Are you satisfied with these registration results? Do they seem anatomically plausible? Try out different regularization factors (`reg_weight`) and see what they do to the registration.\n", ":::"]}, {"cell_type": "markdown", "id": "4c2f5785", "metadata": {"tags": ["student"]}, "source": ["Answer: "]}, {"cell_type": "code", "execution_count": null, "id": "cf302a04", "metadata": {}, "outputs": [], "source": ["def visualize_prediction(sample, model, method=\"checkerboard\"):\n", " \"\"\"\n", " Plot three images: fixed, moving and comparison.\n", " \n", " Args:\n", " sample (dict): sample of dataset created with `build_dataset`.\n", " model (Module): a model computing the deformation field.\n", " method (str): method used by `skimage.util.compare_image`.\n", " \"\"\"\n", " import skimage.util as skut \n", " \n", " skut_methods = [\"diff\", \"blend\", \"checkerboard\"]\n", " if method not in skut_methods:\n", " raise ValueError(f\"Method must be chosen in {skut_methods}.\\n\"\n", " f\"Current value is {method}.\")\n", " \n", " model.eval()\n", " \n", " # Compute deformation field + deformed image\n", " batch_data = {\n", " \"fixed\": sample[\"fixed\"].unsqueeze(0),\n", " \"moving\": sample[\"moving\"].unsqueeze(0),\n", " }\n", " ddf, pred_image = forward(batch_data, model)\n", " ddf = ddf.detach().cpu().numpy().squeeze()\n", " ddf = np.linalg.norm(ddf, axis=0).squeeze()\n", " \n", " # Squeeze images\n", " fixed = np.squeeze(sample[\"fixed\"])\n", " moving = np.squeeze(sample[\"moving\"]) \n", " deformed = np.squeeze(pred_image.detach().cpu())\n", " \n", " # Generate comparison image\n", " comp_checker = skut.compare_images(fixed, deformed, method=method, n_tiles=(4, 4))\n", " \n", " # Plot everything\n", " fig, axs = plt.subplots(1, 5, figsize=(18, 5)) \n", " axs[0].imshow(fixed, cmap='gray')\n", " axs[0].set_title('Fixed')\n", " axs[1].imshow(moving, cmap='gray')\n", " axs[1].set_title('Moving')\n", " axs[2].imshow(deformed, cmap='gray')\n", " axs[2].set_title('Deformed')\n", " axs[3].imshow(comp_checker, cmap='gray')\n", " axs[3].set_title('Comparison') \n", " dpl = axs[4].imshow(ddf, clim=(0, 10))\n", " fig.colorbar(dpl, ax=axs[4])\n", " plt.show() \n", " plt.show()\n", "for sample in val_dataset:\n", " visualize_prediction(sample, model)"]}, {"cell_type": "markdown", "id": "1e6898fa", "metadata": {"user_expressions": []}, "source": [":::{admonition} Exercise\n", ":class: tip\n", "Compute the Jacobian determinant at each image voxel. How many of these are negative? Can you improve upon this?\n", ":::"]}, {"cell_type": "markdown", "id": "e3954ac2", "metadata": {"user_expressions": []}, "source": ["## Part 2 - Equivariance\n", "In this part, we are going to use some concepts that you've learned in the lecture on geometric deep learning. We are going to look at the equivariance properties of a neural network architecture that you should by now be very familiar with: the U-Net. We will again use the chest X-ray segmentation problem. Because training a network is not the focus here, we have pretrained a network that you can use for these experiments."]}, {"cell_type": "markdown", "id": "32158df7", "metadata": {"user_expressions": []}, "source": ["### Data loading\n", "We will again use the same utility functions as in Tutorial 3 to build a dictionary of files and load rib data."]}, {"cell_type": "code", "execution_count": null, "id": "4ca8332c", "metadata": {}, "outputs": [], "source": ["import os\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import glob\n", "import monai\n", "from PIL import Image\n", "import torch\n", "\n", "def build_dict_ribs(data_path, mode='train'):\n", " \"\"\"\n", " This function returns a list of dictionaries, each dictionary containing the keys 'img' and 'mask' \n", " that returns the path to the corresponding image.\n", " \n", " Args:\n", " data_path (str): path to the root folder of the data set.\n", " mode (str): subset used. Must correspond to 'train', 'val' or 'test'.\n", " \n", " Returns:\n", " (List[Dict[str, str]]) list of the dictionaries containing the paths of X-ray images and masks.\n", " \"\"\"\n", " # test if mode is correct\n", " if mode not in [\"train\", \"val\", \"test\"]:\n", " raise ValueError(f\"Please choose a mode in ['train', 'val', 'test']. Current mode is {mode}.\")\n", " \n", " # define empty dictionary\n", " dicts = []\n", " # list all .png files in directory, including the path\n", " paths_xray = glob.glob(os.path.join(data_path, mode, 'img', '*.png'))\n", " # make a corresponding list for all the mask files\n", " for xray_path in paths_xray:\n", " if mode == 'test':\n", " suffix = 'val'\n", " else:\n", " suffix = mode\n", " # find the binary mask that belongs to the original image, based on indexing in the filename\n", " image_index = os.path.split(xray_path)[1].split('_')[-1].split('.')[0]\n", " # define path to mask file based on this index and add to list of mask paths\n", " mask_path = os.path.join(data_path, mode, 'mask', f'VinDr_RibCXR_{suffix}_{image_index}.png')\n", " if os.path.exists(mask_path):\n", " dicts.append({'img': xray_path, 'mask': mask_path})\n", " return dicts\n", "\n", "class LoadRibData(monai.transforms.Transform):\n", " \"\"\"\n", " This custom Monai transform loads the data from the rib segmentation dataset.\n", " Defining a custom transform is simple; just overwrite the __init__ function and __call__ function.\n", " \"\"\"\n", " def __init__(self, keys=None):\n", " pass\n", "\n", " def __call__(self, sample):\n", " image = Image.open(sample['img']).convert('L') # import as grayscale image\n", " image = np.array(image, dtype=np.uint8)\n", " mask = Image.open(sample['mask']).convert('L') # import as grayscale image\n", " mask = np.array(mask, dtype=np.uint8)\n", " # mask has value 255 on rib pixels. Convert to binary array\n", " mask[np.where(mask==255)] = 1\n", " return {'img': image, 'mask': mask, 'img_meta_dict': {'affine': np.eye(2)}, \n", " 'mask_meta_dict': {'affine': np.eye(2)}}"]}, {"cell_type": "markdown", "id": "781d7268", "metadata": {"user_expressions": []}, "source": ["Use the cell below to make a validation loader with a single image. This is sufficient for the small experiment that you will perform."]}, {"cell_type": "code", "execution_count": null, "id": "f86f9bd6", "metadata": {}, "outputs": [], "source": ["validation_dict_list = build_dict_ribs(data_path, mode='val')\n", "validation_transform = monai.transforms.Compose(\n", " [\n", " LoadRibData(),\n", " monai.transforms.AddChanneld(keys=['img', 'mask']),\n", " monai.transforms.HistogramNormalized(keys=['img']), \n", " monai.transforms.ScaleIntensityd(keys=['img'], minv=0, maxv=1),\n", " monai.transforms.Zoomd(keys=['img', 'mask'], zoom=0.25, mode=['bilinear', 'nearest'], keep_size=False),\n", " # monai.transforms.RandSpatialCropd(keys=['img', 'mask'], roi_size=[384, 384], random_size=False)\n", " monai.transforms.SpatialCropd(keys=['img', 'mask'], roi_center=[300, 300], roi_size=[384 + 64, 384]) \n", " ]\n", ")\n", "validation_data = monai.data.CacheDataset([validation_dict_list[3]], transform=validation_transform)\n", "validation_loader = monai.data.DataLoader(validation_data, batch_size=1, shuffle=False)"]}, {"cell_type": "markdown", "id": "f1f2ab8b", "metadata": {"user_expressions": []}, "source": ["### Loading a pretrained model\n", "We have already trained a model for you, the parameters of which were shared in JupyterLab as well.\n", "**Note**: if you downloaded the data set yourself, the model should be in the same folder as the images.\n", "If you already downloaded the data set but not the model, the model file is available [here](https://surfdrive.surf.nl/files/index.php/s/613zrvr0RDYZDqp)."]}, {"cell_type": "code", "execution_count": null, "id": "aa828ba3", "metadata": {}, "outputs": [], "source": ["pretrained_file = path.join(data_path, \"trainedUNet.pt\")"]}, {"cell_type": "markdown", "id": "ad89a231", "metadata": {"user_expressions": []}, "source": ["Next, we initialize a standard U-Net architecture and load the parameters of the pretrained network using the `load_state_dict` function."]}, {"cell_type": "code", "execution_count": null, "id": "fe582e8a", "metadata": {}, "outputs": [], "source": ["import torch\n", "import monai\n", "\n", "# Check whether we're using a GPU\n", "if torch.cuda.is_available():\n", " n_gpus = torch.cuda.device_count() # Total number of GPUs\n", " gpu_idx = random.randint(0, n_gpus - 1) # Random GPU index\n", " device = torch.device(f'cuda:{gpu_idx}')\n", " print('Using GPU: {}'.format(device))\n", "else:\n", " device = torch.device('cpu')\n", " print('GPU not found. Using CPU.')\n", "\n", "model = monai.networks.nets.UNet(\n", " spatial_dims=2,\n", " in_channels=1,\n", " out_channels=1,\n", " channels = (8, 16, 32, 64, 128),\n", " strides=(2, 2, 2, 2),\n", " num_res_units=2,\n", " dropout=0.5\n", ").to(device)\n", "\n", "model.load_state_dict(torch.load(pretrained_file))\n", "model.eval()"]}, {"cell_type": "markdown", "id": "568f8363", "metadata": {}, "source": ["Let's use the pretrained network to segment (part of) our image. Run the cell below."]}, {"cell_type": "code", "execution_count": null, "id": "98d06063", "metadata": {}, "outputs": [], "source": ["for sample in validation_loader:\n", "\n", " img = sample['img'][:, :, :384, :384] \n", " mask = sample['mask'][:, :, :384, :384]\n", " output_noshift = torch.sigmoid(model(img.to(device))).detach().cpu().numpy().squeeze() \n", " \n", " fig, ax = plt.subplots(1,2, figsize = [12, 10]) \n", " # Plot X-ray image\n", " ax[0].imshow(img.squeeze(), 'gray')\n", " # Plot ground truth\n", " mask = np.squeeze(mask)\n", " overlay_mask = np.ma.masked_where(mask == 0, mask == 1)\n", " ax[0].imshow(overlay_mask, 'Greens', alpha = 0.7, clim=[0,1], interpolation='nearest')\n", " ax[0].set_title('Ground truth')\n", " # Plot output\n", " overlay_output = np.ma.masked_where(output_noshift < 0.1, output_noshift > 0.99)\n", " ax[1].imshow(img.squeeze(), 'gray')\n", " ax[1].imshow(overlay_output.squeeze(), 'Reds', alpha = 0.7, clim=[0,1])\n", " ax[1].set_title('Prediction')\n", " plt.show() "]}, {"cell_type": "markdown", "id": "c2560bea", "metadata": {}, "source": ["As you can see, segmentation isn't perfect, but that's also not the goal of this exercise. What we are going to look into is the translation equivariance (**Lecture 8**) of the U-Net. That is: if you translate the image by $d$ pixels, does the output also simply change by $d$ pixels. Note that this is a nice feature to have for a segmentation network: in principle we'd want our network to give us the same label for a pixel regardless of where the image was cut. The image below visualizes this principle. For segmentation of the pixels in the orange square, it shouldn't matter if we provide the red square or the green square as input to the U-Net.\n", "\n", ""]}, {"cell_type": "markdown", "id": "66a36fea", "metadata": {"user_expressions": []}, "source": [":::{admonition} Exercise\n", ":class: tip\n", "What do you think will happen to the U-Net's prediction if we give it a slightly shifted version of the image as input?\n", ":::"]}, {"cell_type": "markdown", "id": "c004ad47", "metadata": {"user_expressions": []}, "source": ["Now we make a small script that performs the above experiment. First, we obtain the segmentation in the red box and we call this `output_noshift`. Then we shift the green box by an offset and each time obtain a segmentation in this box using the same model. We start small with a shift/offset of just a **single pixel**.\n", "\n", ":::{admonition} Exercise\n", ":class: tip\n", "Run the cell below and observe the outputs. Can you spot differences between the two segmentation masks?\n", ":::"]}, {"cell_type": "code", "execution_count": null, "id": "b54446fd", "metadata": {}, "outputs": [], "source": ["offset = 1\n", "\n", "for sample in validation_loader:\n", "\n", " # Original image\n", " img = sample['img'][:, :, :384, :384] \n", " mask = sample['mask'][:, :, :384, :384]\n", " output_noshift = torch.sigmoid(model(img.to(device))).detach().cpu().numpy().squeeze() \n", "\n", " # Plot X-ray image\n", " fig, ax = plt.subplots(1,2, figsize = [12, 10]) \n", " ax[0].imshow(img.squeeze(), 'gray')\n", " # Plot ground truth\n", " mask = np.squeeze(mask)\n", " overlay_mask = np.ma.masked_where(mask == 0, mask == 1)\n", " ax[0].imshow(overlay_mask, 'Greens', alpha = 0.7, clim=[0,1], interpolation='nearest')\n", " ax[0].set_title('Ground truth')\n", " # Plot output\n", " overlay_output = np.ma.masked_where(output_noshift < 0.1, output_noshift >0.99)\n", " ax[1].imshow(img.squeeze(), 'gray')\n", " ax[1].imshow(overlay_output.squeeze(), 'Reds', alpha = 0.7, clim=[0,1])\n", " ax[1].set_title('Prediction')\n", " plt.show()\n", " \n", " # Shifted image\n", " img = sample['img'][:, :, offset:offset+384, :384]\n", " mask = sample['mask'][:, :, offset:offset+384, :384]\n", " output = torch.sigmoid(model(img.to(device))).detach().cpu().numpy().squeeze()\n", "\n", " # Plot X-ray image\n", " fig, ax = plt.subplots(1,2, figsize = [12, 10])\n", " ax[0].imshow(img.squeeze(), 'gray')\n", " # Plot ground truth\n", " mask = np.squeeze(mask)\n", " overlay_mask = np.ma.masked_where(mask == 0, mask == 1)\n", " ax[0].imshow(overlay_mask, 'Greens', alpha = 0.7, clim=[0,1], interpolation='nearest')\n", " ax[0].set_title('Ground truth shifted')\n", " # Plot output\n", " overlay_output = np.ma.masked_where(output < 0.1, output >0.99)\n", " ax[1].imshow(img.squeeze(), 'gray')\n", " ax[1].imshow(overlay_output.squeeze(), 'Reds', alpha = 0.7, clim=[0,1])\n", " ax[1].set_title('Prediction shifted')\n", " plt.show()"]}, {"cell_type": "markdown", "id": "a308c53b", "metadata": {"user_expressions": []}, "source": ["To highlight the differences between both segmentation masks a bit more, we make a difference image. We correct for the shift applied so that we're not comparing apples and oranges. The next cell shows the difference image between the original image and what we get when we process an image that is shifted by one pixel.\n", "\n", ":::{admonition} Exercise\n", ":class: tip\n", "Given these results, is a U-Net translation equivariant, invariant, or neither?\n", ":::"]}, {"cell_type": "code", "execution_count": null, "id": "c02bc159", "metadata": {}, "outputs": [], "source": ["plt.figure(figsize=(6, 6))\n", "diffout = output_noshift[offset:, :384] - output[:-offset, :384]\n", "plt.imshow(diffout, cmap='seismic', clim=[-1, 1])\n", "plt.title('Offset {}'.format(offset))\n", "plt.colorbar()\n", "plt.show()"]}, {"cell_type": "markdown", "id": "c21a98d5", "metadata": {"user_expressions": []}, "source": ["We can repeat this for larger offsets. Let's take offsets up to 64 pixels, and each time compute the difference between the original and shifted image, in a subimage that should be unaffected by the shift. We store the L1 norm of the difference image in an array `norms` and plot these as a function of offset.\n", "\n", ":::{admonition} Exercise\n", ":class: tip\n", "The resulting plot shows that the U-Net is equivariant for none of the translations. This is due to a combination of border effects and downsampling layers. However, the plot also shows a particular pattern, in which the norm *dips* every 16 pixels of offset. Can you explain this based on the U-Net architecture? \n", ":::"]}, {"cell_type": "code", "execution_count": null, "id": "6ca12518", "metadata": {}, "outputs": [], "source": ["norms = []\n", "offsets = []\n", "plot_differences = False # Set to True to plot difference images for every offset\n", "\n", "img = sample['img'][:, :, :384, :384] \n", "mask = sample['mask'][:, :, :384, :384]\n", "output_noshift = torch.sigmoid(model(img.to(device))).detach().cpu().numpy().squeeze() \n", "\n", "for offset in range(1, 65):\n", " for sample in validation_loader:\n", " img = sample['img'][:, :, offset:offset+384, :384]\n", " mask = sample['mask'][:, :, offset:offset+384, :384]\n", "\n", " output = torch.sigmoid(model(img.to(device))).detach().cpu().numpy().squeeze() \n", "\n", " diffout = (output_noshift[offset:, :384] - output[:-offset, :384])[100:284, 100:284]\n", " offsets.append(offset)\n", " norms.append(np.sum(np.abs(diffout)))\n", " if plot_differences:\n", " plt.figure()\n", " plt.imshow(diffout, cmap='seismic', clim=[-1, 1])\n", " plt.title(f\"Offset {offset}\")\n", " plt.colorbar()\n", " plt.show()\n", "\n", "plt.figure()\n", "plt.plot(offsets, norms)\n", "plt.xlabel('Offset')\n", "plt.ylabel('Difference')\n", "plt.show()"]}], "metadata": {"kernelspec": {"display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3"}}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": "7eed9088", "metadata": {"user_expressions": []}, "source": ["# Tutorial 6\n", "## June 20, 2024\n", "In this tutorial you will develop, train, and evaluate a CNN that learns to perform deformable image registration in chest X-ray images. "]}, {"cell_type": "markdown", "id": "a453a035", "metadata": {"user_expressions": []}, "source": ["First, let's take care of the necessities:\n", "- If you're using Google Colab, make sure to select a GPU Runtime.\n", "- Connect to Weights & Biases using the code below.\n", "- Install a few libraries that we will use in this tutorial."]}, {"cell_type": "code", "execution_count": null, "id": "33583e16", "metadata": {}, "outputs": [], "source": ["import os\n", "import wandb\n", "\n", "os.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\"\n", "wandb.login()"]}, {"cell_type": "code", "execution_count": null, "id": "21ccb1ce", "metadata": {}, "outputs": [], "source": ["!pip install monai"]}, {"cell_type": "markdown", "id": "3ad13220", "metadata": {"user_expressions": []}, "source": ["## Part 1 - Registration"]}, {"cell_type": "code", "execution_count": null, "id": "6f0a8af8", "metadata": {}, "outputs": [], "source": ["import monai\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import torch\n", "import wandb"]}, {"cell_type": "markdown", "id": "f570b141", "metadata": {"user_expressions": []}, "source": ["We will register chest X-ray images. We will reuse the data of **Tutorial 3**. As always, we first set the paths. This should be the path ending in 'ribs'. If you don't have the data set anymore, you can download it using the lines below:"]}, {"cell_type": "code", "execution_count": null, "id": "5b428321", "metadata": {}, "outputs": [], "source": ["!wget https://surfdrive.surf.nl/files/index.php/s/Y4psc2pQnfkJuoT/download -O Tutorial_3.zip\n", "!unzip -qo Tutorial_3.zip\n", "data_path = \"ribs\""]}, {"cell_type": "code", "execution_count": null, "id": "a40867ae", "metadata": {}, "outputs": [], "source": ["# ONLY IF YOU USE JUPYTER: ADD PATH \u2328\ufe0f\n", "data_path = r'ribs'# WHEREDIDYOUPUTTHEDATA?"]}, {"cell_type": "code", "execution_count": null, "id": "d01c02c2", "metadata": {}, "outputs": [], "source": ["# ONLY IF YOU USE COLAB: ADD PATH \u2328\ufe0f\n", "from google.colab import drive\n", "\n", "drive.mount('/content/drive')\n", "data_path = r'/content/drive/My Drive/Tutorial3'"]}, {"cell_type": "code", "execution_count": null, "id": "2e6f09c6", "metadata": {}, "outputs": [], "source": ["# check if data_path exists:\n", "import os\n", "\n", "if not os.path.exists(data_path):\n", " print(\"Please update your data path to an existing folder.\")\n", "elif not set([\"train\", \"val\", \"test\"]).issubset(set(os.listdir(data_path))):\n", " print(\"Please update your data path to the correct folder (should contain train, val and test folders).\")\n", "else:\n", " print(\"Congrats! You selected the correct folder :)\")"]}, {"cell_type": "markdown", "id": "c0511cdc", "metadata": {"user_expressions": []}, "source": ["### Data management\n", "\n", "In this part we prepare all the tools needed to load and visualize our samples. One thing we *could* do is perform **inter**-patient registration, i.e., register two chest X-ray images of different patients. However, this is a very challenging problem. Instead, to make our life a bit easier, we will perform **intra**-patient registration: register two images of the same patient. For each patient, we make a synthetic moving image by applying some random elastic deformations. To build this data set, we we used the [Rand2DElasticd](https://docs.monai.io/en/stable/transforms.html#rand2delastic) transform on both the image and the mask. We will use a neural network to learn the deformation field between the fixed image and the moving image.\n", "
\n", "\n", "
"]}, {"cell_type": "markdown", "id": "ad9d9466", "metadata": {"user_expressions": []}, "source": ["Similarly as in **Tutorial 3**, make a dictionary of the image file names."]}, {"cell_type": "code", "execution_count": null, "id": "69604174", "metadata": {}, "outputs": [], "source": ["import os\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import glob\n", "import monai\n", "from PIL import Image\n", "import torch\n", "\n", "def build_dict_ribs(data_path, mode='train'):\n", " \"\"\"\n", " This function returns a list of dictionaries, each dictionary containing the keys 'img' and 'mask' \n", " that returns the path to the corresponding image.\n", " \n", " Args:\n", " data_path (str): path to the root folder of the data set.\n", " mode (str): subset used. Must correspond to 'train', 'val' or 'test'.\n", " \n", " Returns:\n", " (List[Dict[str, str]]) list of the dictionnaries containing the paths of X-ray images and masks.\n", " \"\"\"\n", " # test if mode is correct\n", " if mode not in [\"train\", \"val\", \"test\"]:\n", " raise ValueError(f\"Please choose a mode in ['train', 'val', 'test']. Current mode is {mode}.\")\n", " \n", " # define empty dictionary\n", " dicts = []\n", " # list all .png files in directory, including the path\n", " paths_xray = glob.glob(os.path.join(data_path, mode, 'img', '*.png'))\n", " # make a corresponding list for all the mask files\n", " for xray_path in paths_xray:\n", " if mode == 'test':\n", " suffix = 'val'\n", " else:\n", " suffix = mode\n", " # find the binary mask that belongs to the original image, based on indexing in the filename\n", " image_index = os.path.split(xray_path)[1].split('_')[-1].split('.')[0]\n", " # define path to mask file based on this index and add to list of mask paths\n", " mask_path = os.path.join(data_path, mode, 'mask', f'VinDr_RibCXR_{suffix}_{image_index}.png')\n", " if os.path.exists(mask_path):\n", " dicts.append({'fixed': xray_path, 'moving': xray_path, 'fixed_mask': mask_path, 'moving_mask': mask_path})\n", " return dicts\n", "\n", "class LoadRibData(monai.transforms.Transform):\n", " \"\"\"\n", " This custom Monai transform loads the data from the rib segmentation dataset.\n", " Defining a custom transform is simple; just overwrite the __init__ function and __call__ function.\n", " \"\"\"\n", " def __init__(self, keys=None):\n", " pass\n", "\n", " def __call__(self, sample):\n", " fixed = Image.open(sample['fixed']).convert('L') # import as grayscale image\n", " fixed = np.array(fixed, dtype=np.uint8)\n", " moving = Image.open(sample['moving']).convert('L') # import as grayscale image\n", " moving = np.array(moving, dtype=np.uint8) \n", " fixed_mask = Image.open(sample['fixed_mask']).convert('L') # import as grayscale image\n", " fixed_mask = np.array(fixed_mask, dtype=np.uint8)\n", " moving_mask = Image.open(sample['moving_mask']).convert('L') # import as grayscale image\n", " moving_mask = np.array(moving_mask, dtype=np.uint8) \n", " # mask has value 255 on rib pixels. Convert to binary array\n", " fixed_mask[np.where(fixed_mask==255)] = 1\n", " moving_mask[np.where(moving_mask==255)] = 1 \n", " return {'fixed': fixed, 'moving': moving, 'fixed_mask': fixed_mask, 'moving_mask': moving_mask, 'img_meta_dict': {'affine': np.eye(2)}, \n", " 'mask_meta_dict': {'affine': np.eye(2)}}"]}, {"cell_type": "markdown", "id": "ac0a2501", "metadata": {"user_expressions": []}, "source": ["Then we make a training dataset like before. The `Rand2DElasticd` transform here determines how much deformation is in the 'moving' image. "]}, {"cell_type": "code", "execution_count": null, "id": "355e3e2c", "metadata": {}, "outputs": [], "source": ["train_dict_list = build_dict_ribs(data_path, mode='train')\n", "\n", "# constructDataset from list of paths + transform\n", "transform = monai.transforms.Compose(\n", "[\n", " LoadRibData(),\n", " monai.transforms.AddChanneld(keys=['fixed', 'moving', 'fixed_mask', 'moving_mask']),\n", " monai.transforms.Resized(keys=['fixed', 'moving', 'fixed_mask', 'moving_mask'], spatial_size=(256, 256), mode=['bilinear', 'bilinear', 'nearest', 'nearest']),\n", " monai.transforms.HistogramNormalized(keys=['fixed', 'moving']),\n", " monai.transforms.ScaleIntensityd(keys=['fixed', 'moving'], minv=0.0, maxv=1.0),\n", " monai.transforms.Rand2DElasticd(keys=['moving', 'moving_mask'], spacing=(64, 64), \n", " magnitude_range=(-8, 8), prob=1, mode=['bilinear', 'nearest']), \n", "])\n", "train_dataset = monai.data.Dataset(train_dict_list, transform=transform)"]}, {"cell_type": "markdown", "id": "3e616f15", "metadata": {"user_expressions": []}, "source": [":::{admonition} Exercise\n", ":class: tip\n", "Visualize fixed and moving training images associated to their comparison image with the `visualize_fmc_sample` function below.\n", "\n", "Try different methods to create the comparison image. How well do these different methods allow you to qualitatively assess the quality of the registration?\n", "\n", "More information on this method is available in [the scikit-image documentation](https://scikit-image.org/docs/stable/api/skimage.util.html#skimage.util.compare_images).\n", "\n", ":::"]}, {"cell_type": "code", "execution_count": null, "id": "e99e50a7", "metadata": {"lines_to_next_cell": 2}, "outputs": [], "source": ["def visualize_fmc_sample(sample, method=\"checkerboard\"):\n", " \"\"\"\n", " Plot three images: fixed, moving and comparison.\n", " \n", " Args:\n", " sample (dict): sample of dataset created with `build_dataset`.\n", " method (str): method used by `skimage.util.compare_image`.\n", " \"\"\"\n", " import skimage.util as skut \n", " \n", " skut_methods = [\"diff\", \"blend\", \"checkerboard\"]\n", " if method not in skut_methods:\n", " raise ValueError(f\"Method must be chosen in {skut_methods}.\\n\"\n", " f\"Current value is {method}.\")\n", " \n", " \n", " fixed = np.squeeze(sample['fixed'])\n", " moving = np.squeeze(sample['moving'])\n", " comp_checker = skut.compare_images(fixed, moving, method=method)\n", " axs = plt.figure(constrained_layout=True, figsize=(15, 5)).subplot_mosaic(\"FMC\")\n", " axs['F'].imshow(fixed, cmap='gray')\n", " axs['F'].set_title('Fixed')\n", " axs['M'].imshow(moving, cmap='gray')\n", " axs['M'].set_title('Moving')\n", " axs['C'].imshow(comp_checker, cmap='gray')\n", " axs['C'].set_title('Comparison')\n", " plt.show()"]}, {"cell_type": "code", "execution_count": null, "id": "99085909", "metadata": {}, "outputs": [], "source": ["sample = train_dataset[0]\n", "for method in [\"diff\", \"blend\", \"checkerboard\"]:\n", " print(f\"Method {method}\")\n", " visualize_fmc_sample(sample, method=method)"]}, {"cell_type": "markdown", "id": "45138071", "metadata": {"user_expressions": []}, "source": ["Now we apply a little trick. Because applying the random deformation in each training iteration will be very costly, we only apply the deformation once and we make a new dataset based on the deformed images. Running the cell below may take a few minutes."]}, {"cell_type": "code", "execution_count": null, "id": "4a45ce6d", "metadata": {}, "outputs": [], "source": ["import tqdm\n", "\n", "train_loader = monai.data.DataLoader(train_dataset, batch_size=1, shuffle=False)\n", "\n", "samples = []\n", "for train_batch in tqdm.tqdm(train_loader):\n", " samples.append(train_batch)\n", "\n", "# Make a new dataset and dataloader using the transformed images\n", "train_dataset = monai.data.Dataset(samples, transform=monai.transforms.SqueezeDimd(keys=['fixed', 'moving', 'fixed_mask', 'moving_mask']))\n", "train_loader = monai.data.DataLoader(train_dataset, batch_size=16, shuffle=False)"]}, {"cell_type": "markdown", "id": "cd6ac0da", "metadata": {"user_expressions": []}, "source": [":::{admonition} Exercise\n", ":class: tip\n", "Create `val_dataset` and `val_loader`, corresponding to the `DataSet` and `DataLoader` for your validation set. The transforms can be the same as in the training set.\n", ":::"]}, {"cell_type": "code", "execution_count": null, "id": "6003311e", "metadata": {"tags": ["student"]}, "outputs": [], "source": ["# Your code goes here"]}, {"cell_type": "markdown", "id": "465afbfc", "metadata": {"user_expressions": []}, "source": ["### Model\n", "\n", "As model, we'll use a U-Net. The input/output structure is quite different from what we've seen before:\n", "- the network takes as input two images: the *moving* and *fixed* images.\n", "- it outputs one tensor representing the *deformation field*.\n", "\n", "\n", "\n", "\n", "This *deformation field* can be applied to the *moving* image with the `monai.networks.blocks.Warp` block of Monai.\n", "\n", "\n", "\n", "\n", "This deformed moving image is then compared to the *fixed* image: if they are similar, the deformation field is correctly registering the moving image on the fixed image. Keep in mind that this is done on **training** data, and we want the U-Net to learn to predict a proper deformation field given two new and unseen images. So we're not optimizing for a pair of images as would be done in conventional iterative registration, but training a model that can generalize.\n", "\n", "\n"]}, {"cell_type": "markdown", "id": "96c640a7", "metadata": {"user_expressions": []}, "source": ["Before starting, let's check that you can work on a GPU by runnning the following cell:\n", "- if the device is \"cuda\" you are working on a GPU,\n", "- if the device is \"cpu\" call a teacher."]}, {"cell_type": "code", "execution_count": null, "id": "4ab7677d", "metadata": {"lines_to_next_cell": 2}, "outputs": [], "source": ["if torch.cuda.is_available():\n", " device = torch.device(\"cuda\")\n", "elif torch.backends.mps.is_available():\n", " device = torch.device(\"mps\")\n", " os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"]=\"1\"\n", "else:\n", " device = \"cpu\"\n", "print(f'The used device is {device}')"]}, {"cell_type": "markdown", "id": "41d64c58", "metadata": {"user_expressions": []}, "source": [":::{admonition} Exercise\n", ":class: tip\n", "Construct a U-Net with suitable settings and name it `model`. Keep in mind that you want to be able to correctly apply its output to the input moving image with the `warp_layer`!\n", ":::"]}, {"cell_type": "code", "execution_count": null, "id": "eecd46dd", "metadata": {"tags": ["student"]}, "outputs": [], "source": ["model = # FILL IN\n", "\n", "warp_layer = monai.networks.blocks.Warp().to(device)"]}, {"cell_type": "markdown", "id": "6398d534", "metadata": {}, "source": ["### Objective function\n", "\n", "We evaluate the similarity between the fixed image and the deformed moving image with the `MSELoss()`. The L1 or SSIM losses seen in the previous section could also be used. Furthermore, the deformation field is regularized with `BendingEnergyLoss`. This is a penalty that takes the smoothness of the deformation field into account: if it's not smooth enough, the bending energy is high. Thus, our model will favor smooth deformation fields.\n", "\n", "Finally, we pick an optimizer, in this case again an Adam optimizer."]}, {"cell_type": "code", "execution_count": null, "id": "6c3cceea", "metadata": {}, "outputs": [], "source": ["image_loss = torch.nn.MSELoss()\n", "regularization = monai.losses.BendingEnergyLoss()\n", "optimizer = torch.optim.Adam(model.parameters(), 1e-3)"]}, {"cell_type": "markdown", "id": "4046caa6", "metadata": {"user_expressions": []}, "source": [":::{admonition} Exercise\n", ":class: tip\n", "Add a learning rate scheduler that lowers the learning rate by a factor ten every 100 epochs.\n", ":::"]}, {"cell_type": "code", "execution_count": null, "id": "383554f7", "metadata": {"tags": ["student"]}, "outputs": [], "source": ["# Your code goes here"]}, {"cell_type": "markdown", "id": "e9de16a0", "metadata": {"user_expressions": []}, "source": ["To warp the moving image using the predicted deformation field and *then* compute the loss between the deformed image and the fixed image, we define a forward function which does all this. The output of this function is `pred_image`. "]}, {"cell_type": "code", "execution_count": null, "id": "2b03f401", "metadata": {}, "outputs": [], "source": ["def forward(batch_data, model):\n", " \"\"\"\n", " Applies the model to a batch of data.\n", " \n", " Args:\n", " batch_data (dict): a batch of samples computed by a DataLoader.\n", " model (Module): a model computing the deformation field.\n", " \n", " Returns:\n", " ddf (Tensor): batch of deformation fields.\n", " pred_image (Tensor): batch of deformed moving images.\n", " \n", " \"\"\"\n", " fixed_image = batch_data[\"fixed\"].to(device).float()\n", " moving_image = batch_data[\"moving\"].to(device).float()\n", " \n", " # predict DDF\n", " ddf = model(torch.cat((moving_image, fixed_image), dim=1))\n", "\n", " # warp moving image and label with the predicted ddf\n", " pred_image = warp_layer(moving_image, ddf)\n", "\n", " return ddf, pred_image"]}, {"cell_type": "markdown", "id": "c06b57a8", "metadata": {"user_expressions": []}, "source": ["You can supervise the training process in W&B, in which at each epoch a batch of validation images are used to compute the comparison images of your choice, based on the parameter `method`."]}, {"cell_type": "code", "execution_count": null, "id": "975c85da", "metadata": {}, "outputs": [], "source": ["def log_to_wandb(epoch, train_loss, val_loss, pred_batch, fixed_batch, method=\"checkerboard\"):\n", " \"\"\" Function that logs ongoing training variables to W&B \"\"\"\n", " import skimage.util as skut\n", " \n", " log_imgs = []\n", " for fixed_pt, pred_pt in zip(pred_batch, fixed_batch):\n", " fixed_np = np.squeeze(fixed_pt.cpu().detach())\n", " pred_np = np.squeeze(pred_pt.cpu().detach())\n", " comp_checker = skut.compare_images(fixed_np, pred_np, method=method)\n", " log_imgs.append(wandb.Image(comp_checker))\n", "\n", " # Send epoch, losses and images to W&B\n", " wandb.log({'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss, 'results': log_imgs})"]}, {"cell_type": "markdown", "id": "b73a0737", "metadata": {"user_expressions": []}, "source": ["### Training time\n", "\n", "Use the following cells to train your network. You may choose different parameters to improve the performance!"]}, {"cell_type": "code", "execution_count": null, "id": "f738f685", "metadata": {}, "outputs": [], "source": ["# Choose your parameters\n", "\n", "max_epochs = 200\n", "reg_weight = 0 # By default 0, but you can investigate what it does"]}, {"cell_type": "code", "execution_count": null, "id": "c58e44c4", "metadata": {}, "outputs": [], "source": ["from tqdm import tqdm\n", "\n", "run = wandb.init(\n", " project='tutorial4_registration',\n", " config={\n", " 'lr': optimizer.param_groups[0][\"lr\"],\n", " 'batch_size': train_loader.batch_size,\n", " 'regularization': reg_weight,\n", " 'loss_function': str(image_loss)\n", " }\n", ")\n", "# Do not hesitate to enrich this list of settings to be able to correctly keep track of your experiments!\n", "# For example you should add information on your model...\n", "\n", "run_id = run.id # We remember here the run ID to be able to write the evaluation metrics\n", "\n", "for epoch in tqdm(range(max_epochs)): \n", " model.train()\n", " epoch_loss = 0\n", " for batch_data in train_loader:\n", " optimizer.zero_grad()\n", "\n", " ddf, pred_image = forward(batch_data, model)\n", "\n", " fixed_image = batch_data[\"fixed\"].to(device).float()\n", " reg = regularization(ddf)\n", " loss = image_loss(pred_image, fixed_image) + reg_weight * reg\n", " loss.backward()\n", " optimizer.step()\n", " epoch_loss += loss.item()\n", "\n", " epoch_loss /= len(train_loader)\n", "\n", " model.eval()\n", " val_epoch_loss = 0\n", " for batch_data in val_loader:\n", " ddf, pred_image = forward(batch_data, model)\n", " fixed_image = batch_data[\"fixed\"].to(device).float()\n", " reg = regularization(ddf)\n", " loss = image_loss(pred_image, fixed_image) + reg_weight * reg\n", " val_epoch_loss += loss.item()\n", " val_epoch_loss /= len(val_loader)\n", "\n", " log_to_wandb(epoch, epoch_loss, val_epoch_loss, pred_image, fixed_image)\n", " \n", "run.finish() "]}, {"cell_type": "markdown", "id": "cf8c343c", "metadata": {"user_expressions": []}, "source": ["### Evaluation of the trained model\n", "\n", "Now that the model has been trained, it's time to evaluate its performance. Use the code below to visualize samples and deformation fields. \n", "\n", ":::{admonition} Exercise\n", ":class: tip\n", "Are you satisfied with these registration results? Do they seem anatomically plausible? Try out different regularization factors (`reg_weight`) and see what they do to the registration.\n", ":::"]}, {"cell_type": "markdown", "id": "c87eacd5", "metadata": {"tags": ["student"]}, "source": ["Answer: "]}, {"cell_type": "code", "execution_count": null, "id": "0eb57e3c", "metadata": {}, "outputs": [], "source": ["def visualize_prediction(sample, model, method=\"checkerboard\"):\n", " \"\"\"\n", " Plot three images: fixed, moving and comparison.\n", " \n", " Args:\n", " sample (dict): sample of dataset created with `build_dataset`.\n", " model (Module): a model computing the deformation field.\n", " method (str): method used by `skimage.util.compare_image`.\n", " \"\"\"\n", " import skimage.util as skut \n", " \n", " skut_methods = [\"diff\", \"blend\", \"checkerboard\"]\n", " if method not in skut_methods:\n", " raise ValueError(f\"Method must be chosen in {skut_methods}.\\n\"\n", " f\"Current value is {method}.\")\n", " \n", " model.eval()\n", " \n", " # Compute deformation field + deformed image\n", " batch_data = {\n", " \"fixed\": sample[\"fixed\"].unsqueeze(0),\n", " \"moving\": sample[\"moving\"].unsqueeze(0),\n", " }\n", " ddf, pred_image = forward(batch_data, model)\n", " ddf = ddf.detach().cpu().numpy().squeeze()\n", " ddf = np.linalg.norm(ddf, axis=0).squeeze()\n", " \n", " # Squeeze images\n", " fixed = np.squeeze(sample[\"fixed\"])\n", " moving = np.squeeze(sample[\"moving\"]) \n", " deformed = np.squeeze(pred_image.detach().cpu())\n", " \n", " # Generate comparison image\n", " comp_checker = skut.compare_images(fixed, deformed, method=method, n_tiles=(4, 4))\n", " \n", " # Plot everything\n", " fig, axs = plt.subplots(1, 5, figsize=(18, 5)) \n", " axs[0].imshow(fixed, cmap='gray')\n", " axs[0].set_title('Fixed')\n", " axs[1].imshow(moving, cmap='gray')\n", " axs[1].set_title('Moving')\n", " axs[2].imshow(deformed, cmap='gray')\n", " axs[2].set_title('Deformed')\n", " axs[3].imshow(comp_checker, cmap='gray')\n", " axs[3].set_title('Comparison') \n", " dpl = axs[4].imshow(ddf, clim=(0, 10))\n", " fig.colorbar(dpl, ax=axs[4])\n", " plt.show() \n", " plt.show()\n", "for sample in val_dataset:\n", " visualize_prediction(sample, model)"]}, {"cell_type": "markdown", "id": "d2e86aae", "metadata": {"user_expressions": []}, "source": [":::{admonition} Exercise\n", ":class: tip\n", "Compute the Jacobian determinant at each image voxel. How many of these are negative? Can you improve upon this?\n", ":::"]}, {"cell_type": "markdown", "id": "fec2efad", "metadata": {"user_expressions": []}, "source": ["## Part 2 - Equivariance\n", "In this part, we are going to use some concepts that you've learned in the lecture on geometric deep learning. We are going to look at the equivariance properties of a neural network architecture that you should by now be very familiar with: the U-Net. We will again use the chest X-ray segmentation problem. Because training a network is not the focus here, we have pretrained a network that you can use for these experiments."]}, {"cell_type": "markdown", "id": "f7e0f55f", "metadata": {"user_expressions": []}, "source": ["### Data loading\n", "We will again use the same utility functions as in Tutorial 3 to build a dictionary of files and load rib data."]}, {"cell_type": "code", "execution_count": null, "id": "e48d7658", "metadata": {}, "outputs": [], "source": ["import os\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import glob\n", "import monai\n", "from PIL import Image\n", "import torch\n", "\n", "def build_dict_ribs(data_path, mode='train'):\n", " \"\"\"\n", " This function returns a list of dictionaries, each dictionary containing the keys 'img' and 'mask' \n", " that returns the path to the corresponding image.\n", " \n", " Args:\n", " data_path (str): path to the root folder of the data set.\n", " mode (str): subset used. Must correspond to 'train', 'val' or 'test'.\n", " \n", " Returns:\n", " (List[Dict[str, str]]) list of the dictionaries containing the paths of X-ray images and masks.\n", " \"\"\"\n", " # test if mode is correct\n", " if mode not in [\"train\", \"val\", \"test\"]:\n", " raise ValueError(f\"Please choose a mode in ['train', 'val', 'test']. Current mode is {mode}.\")\n", " \n", " # define empty dictionary\n", " dicts = []\n", " # list all .png files in directory, including the path\n", " paths_xray = glob.glob(os.path.join(data_path, mode, 'img', '*.png'))\n", " # make a corresponding list for all the mask files\n", " for xray_path in paths_xray:\n", " if mode == 'test':\n", " suffix = 'val'\n", " else:\n", " suffix = mode\n", " # find the binary mask that belongs to the original image, based on indexing in the filename\n", " image_index = os.path.split(xray_path)[1].split('_')[-1].split('.')[0]\n", " # define path to mask file based on this index and add to list of mask paths\n", " mask_path = os.path.join(data_path, mode, 'mask', f'VinDr_RibCXR_{suffix}_{image_index}.png')\n", " if os.path.exists(mask_path):\n", " dicts.append({'img': xray_path, 'mask': mask_path})\n", " return dicts\n", "\n", "class LoadRibData(monai.transforms.Transform):\n", " \"\"\"\n", " This custom Monai transform loads the data from the rib segmentation dataset.\n", " Defining a custom transform is simple; just overwrite the __init__ function and __call__ function.\n", " \"\"\"\n", " def __init__(self, keys=None):\n", " pass\n", "\n", " def __call__(self, sample):\n", " image = Image.open(sample['img']).convert('L') # import as grayscale image\n", " image = np.array(image, dtype=np.uint8)\n", " mask = Image.open(sample['mask']).convert('L') # import as grayscale image\n", " mask = np.array(mask, dtype=np.uint8)\n", " # mask has value 255 on rib pixels. Convert to binary array\n", " mask[np.where(mask==255)] = 1\n", " return {'img': image, 'mask': mask, 'img_meta_dict': {'affine': np.eye(2)}, \n", " 'mask_meta_dict': {'affine': np.eye(2)}}"]}, {"cell_type": "markdown", "id": "39c92a63", "metadata": {"user_expressions": []}, "source": ["Use the cell below to make a validation loader with a single image. This is sufficient for the small experiment that you will perform."]}, {"cell_type": "code", "execution_count": null, "id": "b6479aea", "metadata": {}, "outputs": [], "source": ["validation_dict_list = build_dict_ribs(data_path, mode='val')\n", "validation_transform = monai.transforms.Compose(\n", " [\n", " LoadRibData(),\n", " monai.transforms.AddChanneld(keys=['img', 'mask']),\n", " monai.transforms.HistogramNormalized(keys=['img']), \n", " monai.transforms.ScaleIntensityd(keys=['img'], minv=0, maxv=1),\n", " monai.transforms.Zoomd(keys=['img', 'mask'], zoom=0.25, mode=['bilinear', 'nearest'], keep_size=False),\n", " # monai.transforms.RandSpatialCropd(keys=['img', 'mask'], roi_size=[384, 384], random_size=False)\n", " monai.transforms.SpatialCropd(keys=['img', 'mask'], roi_center=[300, 300], roi_size=[384 + 64, 384]) \n", " ]\n", ")\n", "validation_data = monai.data.CacheDataset([validation_dict_list[3]], transform=validation_transform)\n", "validation_loader = monai.data.DataLoader(validation_data, batch_size=1, shuffle=False)"]}, {"cell_type": "markdown", "id": "d2d19163", "metadata": {"user_expressions": []}, "source": ["### Loading a pretrained model\n", "We have already trained a model for you, the parameters of which were shared in JupyterLab as well.\n", "**Note**: if you downloaded the data set yourself, the model should be in the same folder as the images.\n", "If you already downloaded the data set but not the model, the model file is available [here](https://surfdrive.surf.nl/files/index.php/s/613zrvr0RDYZDqp)."]}, {"cell_type": "code", "execution_count": null, "id": "d4ed33e6", "metadata": {}, "outputs": [], "source": ["!wget -O trainedUNet.pt https://surfdrive.surf.nl/files/index.php/s/613zrvr0RDYZDqp/download"]}, {"cell_type": "code", "execution_count": null, "id": "44eeff73", "metadata": {}, "outputs": [], "source": ["pretrained_file = path.join(data_path, \"trainedUNet.pt\")"]}, {"cell_type": "markdown", "id": "90c87fe9", "metadata": {"user_expressions": []}, "source": ["Next, we initialize a standard U-Net architecture and load the parameters of the pretrained network using the `load_state_dict` function."]}, {"cell_type": "code", "execution_count": null, "id": "79886327", "metadata": {}, "outputs": [], "source": ["import torch\n", "import monai\n", "import random\n", "\n", "# Check whether we're using a GPU\n", "if torch.cuda.is_available():\n", " n_gpus = torch.cuda.device_count() # Total number of GPUs\n", " gpu_idx = random.randint(0, n_gpus - 1) # Random GPU index\n", " device = torch.device(f'cuda:{gpu_idx}')\n", " print('Using GPU: {}'.format(device))\n", "else:\n", " device = torch.device('cpu')\n", " print('GPU not found. Using CPU.')\n", "\n", "model = monai.networks.nets.UNet(\n", " spatial_dims=2,\n", " in_channels=1,\n", " out_channels=1,\n", " channels = (8, 16, 32, 64, 128),\n", " strides=(2, 2, 2, 2),\n", " num_res_units=2,\n", " dropout=0.5\n", ").to(device)\n", "\n", "model.load_state_dict(torch.load(pretrained_file))\n", "model.eval()"]}, {"cell_type": "markdown", "id": "169250c5", "metadata": {}, "source": ["Let's use the pretrained network to segment (part of) our image. Run the cell below."]}, {"cell_type": "code", "execution_count": null, "id": "5fb63192", "metadata": {}, "outputs": [], "source": ["for sample in validation_loader:\n", "\n", " img = sample['img'][:, :, :384, :384] \n", " mask = sample['mask'][:, :, :384, :384]\n", " output_noshift = torch.sigmoid(model(img.to(device))).detach().cpu().numpy().squeeze() \n", " \n", " fig, ax = plt.subplots(1,2, figsize = [12, 10]) \n", " # Plot X-ray image\n", " ax[0].imshow(img.squeeze(), 'gray')\n", " # Plot ground truth\n", " mask = np.squeeze(mask)\n", " overlay_mask = np.ma.masked_where(mask == 0, mask == 1)\n", " ax[0].imshow(overlay_mask, 'Greens', alpha = 0.7, clim=[0,1], interpolation='nearest')\n", " ax[0].set_title('Ground truth')\n", " # Plot output\n", " overlay_output = np.ma.masked_where(output_noshift < 0.1, output_noshift > 0.99)\n", " ax[1].imshow(img.squeeze(), 'gray')\n", " ax[1].imshow(overlay_output.squeeze(), 'Reds', alpha = 0.7, clim=[0,1])\n", " ax[1].set_title('Prediction')\n", " plt.show() "]}, {"cell_type": "markdown", "id": "178a689a", "metadata": {}, "source": ["As you can see, segmentation isn't perfect, but that's also not the goal of this exercise. What we are going to look into is the translation equivariance (**Lecture 8**) of the U-Net. That is: if you translate the image by $d$ pixels, does the output also simply change by $d$ pixels. Note that this is a nice feature to have for a segmentation network: in principle we'd want our network to give us the same label for a pixel regardless of where the image was cut. The image below visualizes this principle. For segmentation of the pixels in the orange square, it shouldn't matter if we provide the red square or the green square as input to the U-Net.\n", "\n", ""]}, {"cell_type": "markdown", "id": "3d3f0270", "metadata": {"user_expressions": []}, "source": [":::{admonition} Exercise\n", ":class: tip\n", "What do you think will happen to the U-Net's prediction if we give it a slightly shifted version of the image as input?\n", ":::"]}, {"cell_type": "markdown", "id": "1c8d0a36", "metadata": {"user_expressions": []}, "source": ["Now we make a small script that performs the above experiment. First, we obtain the segmentation in the red box and we call this `output_noshift`. Then we shift the green box by an offset and each time obtain a segmentation in this box using the same model. We start small with a shift/offset of just a **single pixel**.\n", "\n", ":::{admonition} Exercise\n", ":class: tip\n", "Run the cell below and observe the outputs. Can you spot differences between the two segmentation masks?\n", ":::"]}, {"cell_type": "code", "execution_count": null, "id": "c1007982", "metadata": {}, "outputs": [], "source": ["offset = 1\n", "\n", "for sample in validation_loader:\n", "\n", " # Original image\n", " img = sample['img'][:, :, :384, :384] \n", " mask = sample['mask'][:, :, :384, :384]\n", " output_noshift = torch.sigmoid(model(img.to(device))).detach().cpu().numpy().squeeze() \n", "\n", " # Plot X-ray image\n", " fig, ax = plt.subplots(1,2, figsize = [12, 10]) \n", " ax[0].imshow(img.squeeze(), 'gray')\n", " # Plot ground truth\n", " mask = np.squeeze(mask)\n", " overlay_mask = np.ma.masked_where(mask == 0, mask == 1)\n", " ax[0].imshow(overlay_mask, 'Greens', alpha = 0.7, clim=[0,1], interpolation='nearest')\n", " ax[0].set_title('Ground truth')\n", " # Plot output\n", " overlay_output = np.ma.masked_where(output_noshift < 0.1, output_noshift >0.99)\n", " ax[1].imshow(img.squeeze(), 'gray')\n", " ax[1].imshow(overlay_output.squeeze(), 'Reds', alpha = 0.7, clim=[0,1])\n", " ax[1].set_title('Prediction')\n", " plt.show()\n", " \n", " # Shifted image\n", " img = sample['img'][:, :, offset:offset+384, :384]\n", " mask = sample['mask'][:, :, offset:offset+384, :384]\n", " output = torch.sigmoid(model(img.to(device))).detach().cpu().numpy().squeeze()\n", "\n", " # Plot X-ray image\n", " fig, ax = plt.subplots(1,2, figsize = [12, 10])\n", " ax[0].imshow(img.squeeze(), 'gray')\n", " # Plot ground truth\n", " mask = np.squeeze(mask)\n", " overlay_mask = np.ma.masked_where(mask == 0, mask == 1)\n", " ax[0].imshow(overlay_mask, 'Greens', alpha = 0.7, clim=[0,1], interpolation='nearest')\n", " ax[0].set_title('Ground truth shifted')\n", " # Plot output\n", " overlay_output = np.ma.masked_where(output < 0.1, output >0.99)\n", " ax[1].imshow(img.squeeze(), 'gray')\n", " ax[1].imshow(overlay_output.squeeze(), 'Reds', alpha = 0.7, clim=[0,1])\n", " ax[1].set_title('Prediction shifted')\n", " plt.show()"]}, {"cell_type": "markdown", "id": "f6505ba6", "metadata": {"user_expressions": []}, "source": ["To highlight the differences between both segmentation masks a bit more, we make a difference image. We correct for the shift applied so that we're not comparing apples and oranges. The next cell shows the difference image between the original image and what we get when we process an image that is shifted by one pixel.\n", "\n", ":::{admonition} Exercise\n", ":class: tip\n", "Given these results, is a U-Net translation equivariant, invariant, or neither?\n", ":::"]}, {"cell_type": "code", "execution_count": null, "id": "cebe0524", "metadata": {}, "outputs": [], "source": ["plt.figure(figsize=(6, 6))\n", "diffout = output_noshift[offset:, :384] - output[:-offset, :384]\n", "plt.imshow(diffout, cmap='seismic', clim=[-1, 1])\n", "plt.title('Offset {}'.format(offset))\n", "plt.colorbar()\n", "plt.show()"]}, {"cell_type": "markdown", "id": "7cdd74a7", "metadata": {"user_expressions": []}, "source": ["We can repeat this for larger offsets. Let's take offsets up to 64 pixels, and each time compute the difference between the original and shifted image, in a subimage that should be unaffected by the shift. We store the L1 norm of the difference image in an array `norms` and plot these as a function of offset.\n", "\n", ":::{admonition} Exercise\n", ":class: tip\n", "The resulting plot shows that the U-Net is equivariant for none of the translations. This is due to a combination of border effects and downsampling layers. However, the plot also shows a particular pattern, in which the norm *dips* every 16 pixels of offset. Can you explain this based on the U-Net architecture? \n", ":::"]}, {"cell_type": "code", "execution_count": null, "id": "ddeb7bc3", "metadata": {}, "outputs": [], "source": ["norms = []\n", "offsets = []\n", "plot_differences = False # Set to True to plot difference images for every offset\n", "\n", "img = sample['img'][:, :, :384, :384] \n", "mask = sample['mask'][:, :, :384, :384]\n", "output_noshift = torch.sigmoid(model(img.to(device))).detach().cpu().numpy().squeeze() \n", "\n", "for offset in range(1, 65):\n", " for sample in validation_loader:\n", " img = sample['img'][:, :, offset:offset+384, :384]\n", " mask = sample['mask'][:, :, offset:offset+384, :384]\n", "\n", " output = torch.sigmoid(model(img.to(device))).detach().cpu().numpy().squeeze() \n", "\n", " diffout = (output_noshift[offset:, :384] - output[:-offset, :384])[100:284, 100:284]\n", " offsets.append(offset)\n", " norms.append(np.sum(np.abs(diffout)))\n", " if plot_differences:\n", " plt.figure()\n", " plt.imshow(diffout, cmap='seismic', clim=[-1, 1])\n", " plt.title(f\"Offset {offset}\")\n", " plt.colorbar()\n", " plt.show()\n", "\n", "plt.figure()\n", "plt.plot(offsets, norms)\n", "plt.xlabel('Offset')\n", "plt.ylabel('Difference')\n", "plt.show()"]}], "metadata": {"kernelspec": {"display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3"}}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/notebooks/Tutorial6_Registration/Tutorial6_Registration_student.html b/notebooks/Tutorial6_Registration/Tutorial6_Registration_student.html index eaef384..9367493 100644 --- a/notebooks/Tutorial6_Registration/Tutorial6_Registration_student.html +++ b/notebooks/Tutorial6_Registration/Tutorial6_Registration_student.html @@ -412,7 +412,7 @@

Contents