From 4401f3315f5ff3e6018ec8570780809862a68e26 Mon Sep 17 00:00:00 2001 From: Soorya19Pradeep <101817974+Soorya19Pradeep@users.noreply.github.com> Date: Wed, 10 Jul 2024 15:22:38 -0700 Subject: [PATCH] Cellular infection phenotyping using annotated viral sensor data & label-free images (#70) * refactor data loading into its own module * update type annotations * move the logging module out * move old logging into utils * rename tests to match module name * bump torch * draft fcmae encoder * add stem to the encoder * wip: masked stem layernorm * wip: patchify masked features for linear * use mlp from timm * hack: POC training script for FCMAE * fix mask for fitting * remove training script * default architecture * fine-tuning options * fix cli for finetuning * draft combined data module * fix import * manual validation loss reduction * update linting new black version has different rules * update development guide * update type hints * bump iohub * draft ctmc v1 dataset * update tests * move test_data * remove path conversion * configurable normalizations (#68) * inital commit adding the normalization. * adding dataset_statistics to each fov to facilitate the configurable augmentations * fix indentation * ruff * test preprocessing * remove redundant field * cleanup --------- Co-authored-by: Ziwen Liu * fix ctmc dataloading * add example ctmc v1 loading script * changing the normalization and augmentations default from None to empty list. * invert intensity transform * concatenated data module * subsample videos * livecell dataset * all sample fields are optional * fix multi-dataloader validation * lint * fixing preprocessing for varying array shapes (i.e aics dataset) * update loading scripts * fix CombineMode * added model and annotation code draft * chnaged to simple unet model * start with lesser augmentations * added readme file * added tensorboard logging * added validation step * chnaged to viscy 2d unet * used crossentropyloss with one-hot encoding * added sample image logging * attempt to build magicgui annotation * renamed infection annotation tool * added normalization and augmentations * added model testing code * removed annotation refiner * corrected conversion of class to int * corrected prediction module * cleaned up the code and comments for the LightningUNet * removed confusion matrix code, finding runtime error with model * moved scripts to viscy.scripts.infection_phenotyping module to enable imports across scripts * combine the lightning modules for training and prediction, fix the DDP exception * all the stubs for computing and logging confusion matrix per cell * separated training and test scripts * lightning module * corrected test cm compute * corrected test module * separated test and prediction scripts * changed confusion matrix compute * fix merge error * split 2D and 2.5D model scripts * added covnext script * fix model input parameter * update input file * add augmentations * refactor infection_classification code to viscy/applications * changes made for BJ5 classification * format code * add explicit packaging list * rename testing script * update readme * move function to preprocessing * format code * formatting * histogram with dask * fix index and test * fix import * black * fix float comp * clean up headers * clean up import * add argument to change number of classes --------- Co-authored-by: Ziwen Liu Co-authored-by: Eduardo Hirata-Miyasaki Co-authored-by: Shalin Mehta Co-authored-by: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> --- .../Infection_classification_25DModel.py | 106 +++++ .../Infection_classification_2Dmodel.py | 123 ++++++ .../Infection_classification_covnextModel.py | 107 ++++++ .../classify_infection_25D.py | 356 +++++++++++++++++ .../classify_infection_2D.py | 362 +++++++++++++++++ .../classify_infection_covnext.py | 363 ++++++++++++++++++ .../infection_classifier_testing.py | 59 +++ .../predict_infection_classifier.py | 59 +++ .../infection_classification/readme.md | 13 + pyproject.toml | 3 + tests/preprocessing/test_pixel_ratio.py | 15 + viscy/preprocessing/pixel_ratio.py | 22 ++ 12 files changed, 1588 insertions(+) create mode 100644 applications/infection_classification/Infection_classification_25DModel.py create mode 100644 applications/infection_classification/Infection_classification_2Dmodel.py create mode 100644 applications/infection_classification/Infection_classification_covnextModel.py create mode 100644 applications/infection_classification/classify_infection_25D.py create mode 100644 applications/infection_classification/classify_infection_2D.py create mode 100644 applications/infection_classification/classify_infection_covnext.py create mode 100644 applications/infection_classification/infection_classifier_testing.py create mode 100644 applications/infection_classification/predict_infection_classifier.py create mode 100644 applications/infection_classification/readme.md create mode 100644 tests/preprocessing/test_pixel_ratio.py create mode 100644 viscy/preprocessing/pixel_ratio.py diff --git a/applications/infection_classification/Infection_classification_25DModel.py b/applications/infection_classification/Infection_classification_25DModel.py new file mode 100644 index 00000000..a4e712f5 --- /dev/null +++ b/applications/infection_classification/Infection_classification_25DModel.py @@ -0,0 +1,106 @@ +# %% +import lightning.pytorch as pl +import torch +import torch.nn as nn +from applications.infection_classification.classify_infection_25D import ( + SemanticSegUNet25D, +) +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger + +from viscy.data.hcs import HCSDataModule +from viscy.preprocessing.pixel_ratio import sematic_class_weights +from viscy.transforms import NormalizeSampled, RandWeightedCropd + +# %% Create a dataloader and visualize the batches. + +# Set the path to the dataset +dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_trainVal.zarr" + +# %% create data module + +# Create an instance of HCSDataModule +data_module = HCSDataModule( + dataset_path, + source_channel=["Phase", "HSP90"], + target_channel=["Inf_mask"], + yx_patch_size=[512, 512], + split_ratio=0.8, + z_window_size=5, + architecture="2.5D", + num_workers=3, + batch_size=32, + normalizations=[ + NormalizeSampled( + keys=["Phase", "HSP90"], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], + augmentations=[ + RandWeightedCropd( + num_samples=4, + spatial_size=[-1, 512, 512], + keys=["Phase", "HSP90"], + w_key="Inf_mask", + ) + ], +) + +pixel_ratio = sematic_class_weights(dataset_path, target_channel="Inf_mask") + +# Prepare the data +data_module.prepare_data() + +# Setup the data +data_module.setup(stage="fit") + +# Create a dataloader +train_dm = data_module.train_dataloader() + +val_dm = data_module.val_dataloader() + + +# %% Define the logger +logger = TensorBoardLogger( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/", + name="logs", +) + +# Pass the logger to the Trainer +trainer = pl.Trainer( + logger=logger, + max_epochs=200, + default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", + log_every_n_steps=1, + devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs +) + +# Define the checkpoint callback +checkpoint_callback = ModelCheckpoint( + dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", + filename="checkpoint_{epoch:02d}", + save_top_k=-1, + verbose=True, + monitor="loss/validate", + mode="min", +) + +# Add the checkpoint callback to the trainer +trainer.callbacks.append(checkpoint_callback) + +# Fit the model +model = SemanticSegUNet25D( + in_channels=2, + out_channels=3, + loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)), +) + +print(model) + +# %% Run training. + +trainer.fit(model, data_module) + +# %% diff --git a/applications/infection_classification/Infection_classification_2Dmodel.py b/applications/infection_classification/Infection_classification_2Dmodel.py new file mode 100644 index 00000000..333718aa --- /dev/null +++ b/applications/infection_classification/Infection_classification_2Dmodel.py @@ -0,0 +1,123 @@ +# %% +import lightning.pytorch as pl +import torch +import torch.nn as nn +from applications.infection_classification.classify_infection_2D import ( + SemanticSegUNet2D, +) +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger + +from viscy.data.hcs import HCSDataModule +from viscy.preprocessing.pixel_ratio import sematic_class_weights +from viscy.transforms import ( + NormalizeSampled, + RandGaussianSmoothd, + RandScaleIntensityd, + RandWeightedCropd, +) + +# %% calculate the ratio of background, uninfected and infected pixels in the input dataset + +# Set the path to the dataset +dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/4-human_annotation/train_data.zarr" + +# %% Create an instance of HCSDataModule + +data_module = HCSDataModule( + dataset_path, + source_channel=["TXR_Density3D", "Phase3D"], + target_channel=["Inf_mask"], + yx_patch_size=[128, 128], + split_ratio=0.7, + z_window_size=1, + architecture="2D", + num_workers=1, + batch_size=256, + normalizations=[ + NormalizeSampled( + keys=["Phase3D", "TXR_Density3D"], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], + augmentations=[ + RandWeightedCropd( + num_samples=16, + spatial_size=[-1, 128, 128], + keys=["TXR_Density3D", "Phase3D", "Inf_mask"], + w_key="Inf_mask", + ), + RandScaleIntensityd( + keys=["TXR_Density3D", "Phase3D"], + factors=[0.5, 0.5], + prob=0.5, + ), + RandGaussianSmoothd( + keys=["TXR_Density3D", "Phase3D"], + prob=0.5, + sigma_x=[0.5, 1.0], + sigma_y=[0.5, 1.0], + sigma_z=[0.5, 1.0], + ), + ], +) +pixel_ratio = sematic_class_weights(dataset_path, target_channel="Inf_mask") + +# Prepare the data +data_module.prepare_data() + +# Setup the data +data_module.setup(stage="fit") + +# Create a dataloader +train_dm = data_module.train_dataloader() + +val_dm = data_module.val_dataloader() + +# %% Set up for training + +# define the logger +logger = TensorBoardLogger( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/", + name="logs", +) + +# Pass the logger to the Trainer +trainer = pl.Trainer( + logger=logger, + max_epochs=500, + default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/logs/", + log_every_n_steps=1, + devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs +) + +# Define the checkpoint callback +checkpoint_callback = ModelCheckpoint( + dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/logs/", + filename="checkpoint_{epoch:02d}", + save_top_k=-1, + verbose=True, + monitor="loss/validate", + mode="min", +) + +# Add the checkpoint callback to the trainer +trainer.callbacks.append(checkpoint_callback) + +# Fit the model +model = SemanticSegUNet2D( + in_channels=2, + out_channels=3, + loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)), +) + +# visualize the model +print(model) + +# %% Run training. + +trainer.fit(model, data_module) + +# %% diff --git a/applications/infection_classification/Infection_classification_covnextModel.py b/applications/infection_classification/Infection_classification_covnextModel.py new file mode 100644 index 00000000..bfe20362 --- /dev/null +++ b/applications/infection_classification/Infection_classification_covnextModel.py @@ -0,0 +1,107 @@ +# %% +# import sys +# sys.path.append("/hpc/mydata/soorya.pradeep/viscy_infection_phenotyping/Viscy/") +import lightning.pytorch as pl +import torch +import torch.nn as nn +from applications.infection_classification.classify_infection_covnext import ( + SemanticSegUNet22D, +) +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger + +from viscy.data.hcs import HCSDataModule +from viscy.preprocessing.pixel_ratio import sematic_class_weights +from viscy.transforms import NormalizeSampled, RandWeightedCropd + +# %% Create a dataloader and visualize the batches. + +# Set the path to the dataset +dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_all_curated_train.zarr" + +# %% craete data module + +# Create an instance of HCSDataModule +data_module = HCSDataModule( + dataset_path, + source_channel=["Phase", "HSP90", "phase_nucl_iqr", "hsp90_skew"], + target_channel=["Inf_mask"], + yx_patch_size=[256, 256], + split_ratio=0.8, + z_window_size=5, + architecture="2.2D", + num_workers=3, + batch_size=16, + normalizations=[ + NormalizeSampled( + keys=["Phase", "HSP90", "phase_nucl_iqr", "hsp90_skew"], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], + augmentations=[ + RandWeightedCropd( + num_samples=4, + spatial_size=[-1, 256, 256], + keys=["Phase", "HSP90", "phase_nucl_iqr", "hsp90_skew"], + w_key="Inf_mask", + ) + ], +) +pixel_ratio = sematic_class_weights(dataset_path, target_channel="Inf_mask") + +# Prepare the data +data_module.prepare_data() + +# Setup the data +data_module.setup(stage="fit") + +# Create a dataloader +train_dm = data_module.train_dataloader() + +val_dm = data_module.val_dataloader() + + +# %% Define the logger +logger = TensorBoardLogger( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/", + name="logs", +) + +# Pass the logger to the Trainer +trainer = pl.Trainer( + logger=logger, + max_epochs=200, + default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", + log_every_n_steps=1, + devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs +) + +# Define the checkpoint callback +checkpoint_callback = ModelCheckpoint( + dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", + filename="checkpoint_{epoch:02d}", + save_top_k=-1, + verbose=True, + monitor="loss/validate", + mode="min", +) + +# Add the checkpoint callback to the trainer`` +trainer.callbacks.append(checkpoint_callback) + +# Fit the model +model = SemanticSegUNet22D( + in_channels=4, + out_channels=3, + loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)), +) + +print(model) + +# %% Run training. + +trainer.fit(model, data_module) + +# %% diff --git a/applications/infection_classification/classify_infection_25D.py b/applications/infection_classification/classify_infection_25D.py new file mode 100644 index 00000000..e16f56f4 --- /dev/null +++ b/applications/infection_classification/classify_infection_25D.py @@ -0,0 +1,356 @@ +# import torchview +from typing import Literal, Sequence + +import cv2 +import lightning.pytorch as pl +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from matplotlib.cm import get_cmap +from monai.transforms import DivisiblePad +from skimage.exposure import rescale_intensity +from skimage.measure import label, regionprops +from torch import Tensor + +from viscy.data.hcs import Sample +from viscy.unet.networks.Unet25D import Unet25d + +# %% Methods to compute confusion matrix per cell using torchmetrics + + +# The confusion matrix at the single-cell resolution. +def confusion_matrix_per_cell( + y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int +): + """Compute confusion matrix per cell. + + Args: + y_true (torch.Tensor): Ground truth label image (BXHXW). + y_pred (torch.Tensor): Predicted label image (BXHXW). + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Confusion matrix per cell (BXCXC). + """ + # Convert the image class to the nuclei class + confusion_matrix_per_cell = compute_confusion_matrix(y_true, y_pred, num_classes) + confusion_matrix_per_cell = torch.tensor(confusion_matrix_per_cell) + return confusion_matrix_per_cell + + +# These images can be logged with prediction. +def compute_confusion_matrix( + y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int +): + """Convert the class of the image to the class of the nuclei. + + Args: + label_image (torch.Tensor): Label image (BXHXW). Values of tensor are integers that represent semantic segmentation. + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Label images with a consensus class at the centroid of nuclei. + """ + + batch_size = y_true.size(0) + # find centroids of nuclei from y_true + conf_mat = np.zeros((num_classes, num_classes)) + for i in range(batch_size): + y_true_cpu = y_true[i].cpu().numpy() + y_pred_cpu = y_pred[i].cpu().numpy() + y_true_reshaped = y_true_cpu.reshape(y_true_cpu.shape[-2:]) + y_pred_reshaped = y_pred_cpu.reshape(y_pred_cpu.shape[-2:]) + y_pred_resized = cv2.resize( + y_pred_reshaped, + dsize=y_true_reshaped.shape[::-1], + interpolation=cv2.INTER_NEAREST, + ) + y_pred_resized = np.where(y_true_reshaped > 0, y_pred_resized, 0) + + # find objects in every image + label_img = label(y_true_reshaped) + regions = regionprops(label_img) + + # Find centroids, pixel coordinates from the ground truth. + for region in regions: + if region.area > 0: + row, col = region.centroid + pred_id = y_pred_resized[int(row), int(col)] + test_id = y_true_reshaped[int(row), int(col)] + + if pred_id == 1 and test_id == 1: + conf_mat[1, 1] += 1 + if pred_id == 1 and test_id == 2: + conf_mat[0, 1] += 1 + if pred_id == 2 and test_id == 1: + conf_mat[1, 0] += 1 + if pred_id == 2 and test_id == 2: + conf_mat[0, 0] += 1 + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + return conf_mat + + +def plot_confusion_matrix(confusion_matrix, index_to_label_dict): + # Create a figure and axis to plot the confusion matrix + fig, ax = plt.subplots() + + # Create a color heatmap for the confusion matrix + cax = ax.matshow(confusion_matrix, cmap="viridis") + + # Create a colorbar and set the label + index_to_label_dict = dict( + enumerate(index_to_label_dict) + ) # Convert list to dictionary + fig.colorbar(cax, label="Frequency") + + # Set labels for the classes + ax.set_xticks(np.arange(len(index_to_label_dict))) + ax.set_yticks(np.arange(len(index_to_label_dict))) + ax.set_xticklabels(index_to_label_dict.values(), rotation=45) + ax.set_yticklabels(index_to_label_dict.values()) + + # Set labels for the axes + ax.set_xlabel("Predicted") + ax.set_ylabel("True") + + # Add text annotations to the confusion matrix + for i in range(len(index_to_label_dict)): + for j in range(len(index_to_label_dict)): + ax.text( + j, + i, + str(int(confusion_matrix[i, j])), + ha="center", + va="center", + color="white", + ) + + # plt.show(fig) # Show the figure + return fig + + +# Define a 25d unet model for infection classification as a lightning module. + + +class SemanticSegUNet25D(pl.LightningModule): + # Model for semantic segmentation. + def __init__( + self, + in_channels: int, # Number of input channels + out_channels: int, # Number of output channels + lr: float = 1e-3, # Learning rate + loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function + schedule: Literal[ + "WarmupCosine", "Constant" + ] = "Constant", # Learning rate schedule + log_batches_per_epoch: int = 2, # Number of batches to log per epoch + log_samples_per_batch: int = 2, # Number of samples to log per batch + ckpt_path: str = None, # Path to the checkpoint + ): + super(SemanticSegUNet25D, self).__init__() # Call the superclass initializer + # Initialize the UNet model + self.unet_model = Unet25d( + in_channels=in_channels, + out_channels=out_channels, + num_blocks=4, + num_block_layers=4, + ) + if ckpt_path is not None: + state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[ + "state_dict" + ] + state_dict.pop("loss_function.weight", None) # Remove the unexpected key + self.load_state_dict(state_dict) # loading only weights + self.lr = lr # Set the learning rate + # Set the loss function to CrossEntropyLoss if none is provided + self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss() + self.schedule = schedule # Set the learning rate schedule + self.log_batches_per_epoch = ( + log_batches_per_epoch # Set the number of batches to log per epoch + ) + self.log_samples_per_batch = ( + log_samples_per_batch # Set the number of samples to log per batch + ) + self.training_step_outputs = [] # Initialize the list of training step outputs + self.validation_step_outputs = ( + [] + ) # Initialize the list of validation step outputs + + self.pred_cm = None # Initialize the confusion matrix + self.index_to_label_dict = ["Infected", "Uninfected"] + + # Define the forward pass + def forward(self, x): + return self.unet_model(x) # Pass the input through the UNet model + + # Define the optimizer + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.parameters(), lr=self.lr + ) # Use the Adam optimizer + return optimizer + + # Define the training step + def training_step(self, batch: Sample, batch_idx: int): + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert the target to float type + train_loss = self.loss_function(pred, target_one_hot) # Calculate the loss + # Log the training step outputs if the batch index is less than the number of batches to log per epoch + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) + # Log the training loss + self.log( + "loss/train", + train_loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + return train_loss # Return the training loss + + def validation_step(self, batch: Sample, batch_idx: int): + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert the target to float type + loss = self.loss_function(pred, target_one_hot) # Calculate the loss + # Log the validation step outputs if the batch index is less than the number of batches to log per epoch + if batch_idx < self.log_batches_per_epoch: + self.validation_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) + # Log the validation loss + self.log( + "loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True + ) + return loss # Return the validation loss + + def on_predict_start(self): + """Pad the input shape to be divisible by the downsampling factor. + The inverse of this transform crops the prediction to original shape. + """ + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + + # Define the prediction step + def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): + source = self._predict_pad(batch["source"]) # Pad the source + logits = self._predict_pad.inverse( + self.forward(source) + ) # Predict and remove padding. + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + # Go from probabilities/one-hot encoded data to class labels. + labels_pred = torch.argmax( + prob_pred, dim=1, keepdim=True + ) # Calculate the predicted labels + # prob_chan = prob_pred[:, 2, :, :] + # prob_chan = prob_chan.unsqueeze(1) + return labels_pred # log the class predicted image + # return prob_chan # log the probability predicted image + + def on_test_start(self): + self.pred_cm = torch.zeros((2, 2)) + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + + def test_step(self, batch: Sample): + source = self._predict_pad(batch["source"]) # Pad the source + logits = self._predict_pad.inverse(self.forward(source)) + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + labels_pred = torch.argmax( + prob_pred, dim=1, keepdim=True + ) # Calculate the predicted labels + + target = self._predict_pad(batch["target"]) # Extract the target from the batch + pred_cm = confusion_matrix_per_cell( + target, labels_pred, num_classes=2 + ) # Calculate the confusion matrix per cell + self.pred_cm += pred_cm # Append the confusion matrix to pred_cm + + self.logger.experiment.add_figure( + "Confusion Matrix per Cell", + plot_confusion_matrix(pred_cm, self.index_to_label_dict), + self.current_epoch, + ) + + # Accumulate the confusion matrix at the end of test epoch and log. + def on_test_end(self): + confusion_matrix_sum = self.pred_cm + self.logger.experiment.add_figure( + "Confusion Matrix", + plot_confusion_matrix(confusion_matrix_sum, self.index_to_label_dict), + self.current_epoch, + ) + + # Define what happens at the end of a training epoch + def on_train_epoch_end(self): + self._log_samples( + "train_samples", self.training_step_outputs + ) # Log the training samples + self.training_step_outputs = [] # Reset the list of training step outputs + + # Define what happens at the end of a validation epoch + def on_validation_epoch_end(self): + self._log_samples( + "val_samples", self.validation_step_outputs + ) # Log the validation samples + self.validation_step_outputs = [] # Reset the list of validation step outputs + + # Define a method to detach a sample + def _detach_sample(self, imgs: Sequence[Tensor]): + # Detach the images and convert them to numpy arrays + num_samples = 3 + return [ + [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] + for i in range(num_samples) + ] + + # Define a method to log samples + def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): + images_grid = [] # Initialize the list of image grids + for sample_images in imgs: # For each sample image + images_row = [] # Initialize the list of image rows + for i, image in enumerate( + sample_images + ): # For each image in the sample images + cm_name = "gray" if i == 0 else "inferno" # Set the colormap name + if image.ndim == 2: # If the image is 2D + image = image[np.newaxis] # Add a new axis + for channel in image: # For each channel in the image + channel = rescale_intensity( + channel, out_range=(0, 1) + ) # Rescale the intensity of the channel + render = get_cmap(cm_name)(channel, bytes=True)[ + ..., :3 + ] # Render the channel + images_row.append( + render + ) # Append the render to the list of image rows + images_grid.append( + np.concatenate(images_row, axis=1) + ) # Append the concatenated image rows to the list of image grids + grid = np.concatenate(images_grid, axis=0) # Concatenate the image grids + # Log the image grid + self.logger.experiment.add_image( + key, grid, self.current_epoch, dataformats="HWC" + ) + + +# %% diff --git a/applications/infection_classification/classify_infection_2D.py b/applications/infection_classification/classify_infection_2D.py new file mode 100644 index 00000000..afd97ab7 --- /dev/null +++ b/applications/infection_classification/classify_infection_2D.py @@ -0,0 +1,362 @@ +# %% lightning moules for infection classification using the viscy library + +# import torchview +from typing import Literal, Sequence + +import cv2 +import lightning.pytorch as pl +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from matplotlib.cm import get_cmap +from monai.transforms import DivisiblePad +from skimage.exposure import rescale_intensity +from skimage.measure import label, regionprops +from torch import Tensor + +# from viscy.unet.networks.Unet25D import Unet25d +from viscy.data.hcs import Sample +from viscy.unet.networks.Unet2D import Unet2d + +# +# %% Methods to compute confusion matrix per cell using torchmetrics + + +# The confusion matrix at the single-cell resolution. +def confusion_matrix_per_cell( + y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int +): + """Compute confusion matrix per cell. + + Args: + y_true (torch.Tensor): Ground truth label image (BXHXW). + y_pred (torch.Tensor): Predicted label image (BXHXW). + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Confusion matrix per cell (BXCXC). + """ + # Convert the image class to the nuclei class + confusion_matrix_per_cell = compute_confusion_matrix(y_true, y_pred, num_classes) + confusion_matrix_per_cell = torch.tensor(confusion_matrix_per_cell) + return confusion_matrix_per_cell + + +# confusion matrix computation +def compute_confusion_matrix( + y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int +): + """Convert the class of the image to the class of the nuclei. + + Args: + label_image (torch.Tensor): Label image (BXHXW). Values of tensor are integers that represent semantic segmentation. + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Label images with a consensus class at the centroid of nuclei. + """ + + batch_size = y_true.size(0) + # find centroids of nuclei from y_true + conf_mat = np.zeros((num_classes, num_classes)) + for i in range(batch_size): + y_true_cpu = y_true[i].cpu().numpy() + y_pred_cpu = y_pred[i].cpu().numpy() + y_true_reshaped = y_true_cpu.reshape(y_true_cpu.shape[-2:]) + y_pred_reshaped = y_pred_cpu.reshape(y_pred_cpu.shape[-2:]) + + y_pred_resized = cv2.resize( + y_pred_reshaped, + dsize=y_true_reshaped.shape[::-1], + interpolation=cv2.INTER_NEAREST, + ) + y_pred_resized = np.where(y_true_reshaped > 0, y_pred_resized, 0) + + # find objects in every image + label_img = label(y_true_reshaped) + regions = regionprops(label_img) + + # Find centroids, pixel coordinates from the ground truth. + for region in regions: + if region.area > 0: + row, col = region.centroid + pred_id = y_pred_resized[int(row), int(col)] + test_id = y_true_reshaped[int(row), int(col)] + + if pred_id == 1 and test_id == 1: + conf_mat[1, 1] += 1 + if pred_id == 1 and test_id == 2: + conf_mat[0, 1] += 1 + if pred_id == 2 and test_id == 1: + conf_mat[1, 0] += 1 + if pred_id == 2 and test_id == 2: + conf_mat[0, 0] += 1 + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + return conf_mat + + +# plot the computed confusion matrix +def plot_confusion_matrix(confusion_matrix, index_to_label_dict): + # Create a figure and axis to plot the confusion matrix + fig, ax = plt.subplots() + + # Create a color heatmap for the confusion matrix + cax = ax.matshow(confusion_matrix, cmap="viridis") + + # Create a colorbar and set the label + index_to_label_dict = dict( + enumerate(index_to_label_dict) + ) # Convert list to dictionary + fig.colorbar(cax, label="Frequency") + + # Set labels for the classes + ax.set_xticks(np.arange(len(index_to_label_dict))) + ax.set_yticks(np.arange(len(index_to_label_dict))) + ax.set_xticklabels(index_to_label_dict.values(), rotation=45) + ax.set_yticklabels(index_to_label_dict.values()) + + # Set labels for the axes + ax.set_xlabel("Predicted") + ax.set_ylabel("True") + + # Add text annotations to the confusion matrix + for i in range(len(index_to_label_dict)): + for j in range(len(index_to_label_dict)): + ax.text( + j, + i, + str(int(confusion_matrix[i, j])), + ha="center", + va="center", + color="white", + ) + + return fig + + +class SemanticSegUNet2D(pl.LightningModule): + + # Model for semantic segmentation. + + def __init__( + self, + in_channels: int, # Number of input channels + out_channels: int, # Number of output channels + lr: float = 1e-4, # Learning rate + loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function + schedule: Literal[ + "WarmupCosine", "Constant" + ] = "Constant", # Learning rate schedule + log_batches_per_epoch: int = 2, # Number of batches to log per epoch + log_samples_per_batch: int = 2, # Number of samples to log per batch + ckpt_path: str = None, # Path to the checkpoint + ): + super(SemanticSegUNet2D, self).__init__() # Call the superclass initializer + # Initialize the UNet model + self.unet_model = Unet2d(in_channels=in_channels, out_channels=out_channels) + if ckpt_path is not None: + state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[ + "state_dict" + ] + state_dict.pop("loss_function.weight", None) # Remove the unexpected key + self.load_state_dict(state_dict) # loading only weights + self.lr = lr # Set the learning rate + # Set the loss function to CrossEntropyLoss if none is provided + self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss() + self.schedule = schedule # Set the learning rate schedule + self.log_batches_per_epoch = ( + log_batches_per_epoch # Set the number of batches to log per epoch + ) + self.log_samples_per_batch = ( + log_samples_per_batch # Set the number of samples to log per batch + ) + self.training_step_outputs = [] # Initialize the list of training step outputs + self.validation_step_outputs = ( + [] + ) # Initialize the list of validation step outputs + + self.pred_cm = None # Initialize the confusion matrix + self.index_to_label_dict = ["Infected", "Uninfected"] + + # Define the forward pass + def forward(self, x): + return self.unet_model(x) # Pass the input through the UNet model + + # Define the optimizer + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.parameters(), lr=self.lr + ) # Use the Adam optimizer + return optimizer + + # Define the training step + def training_step(self, batch: Sample, batch_idx: int): + """ + The training step for the model. + This method is called for every batch during the training process. + """ + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert the target to float type + train_loss = self.loss_function(pred, target_one_hot) # Calculate the loss + # Log the training step outputs if the batch index is less than the number of batches to log per epoch + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) + # Log the training loss + self.log( + "loss/train", + train_loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + return train_loss # Return the training loss + + def validation_step(self, batch: Sample, batch_idx: int): + """ + The validation step for the model. + """ + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert the target to float type + loss = self.loss_function(pred, target_one_hot) # Calculate the loss + # Log the validation step outputs if the batch index is less than the number of batches to log per epoch + if batch_idx < self.log_batches_per_epoch: + self.validation_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) + # Log the validation loss + self.log( + "loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True + ) + return loss # Return the validation loss + + def on_predict_start(self): + """Pad the input shape to be divisible by the downsampling factor. + The inverse of this transform crops the prediction to original shape. + """ + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + + # Define the prediction step + def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): + + source = self._predict_pad(batch["source"]) # Pad the source + logits = self._predict_pad.inverse( + self.forward(source) + ) # Predict and remove padding. + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + # Go from probabilities/one-hot encoded data to class labels. + labels_pred = torch.argmax( + prob_pred, dim=1, keepdim=True + ) # Calculate the predicted labels + + return labels_pred # log the class predicted image + + def on_test_start(self): + self.pred_cm = torch.zeros((2, 2)) + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + + def test_step(self, batch: Sample): + source = self._predict_pad(batch["source"]) # Pad the source + # predict_writer(batch["source"], f"test_source_{self.i_num}.npy") + logits = self._predict_pad.inverse(self.forward(source)) + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + labels_pred = torch.argmax( + prob_pred, dim=1, keepdim=True + ) # Calculate the predicted labels + + target = self._predict_pad(batch["target"]) # Extract the target from the batch + pred_cm = confusion_matrix_per_cell( + target, labels_pred, num_classes=2 + ) # Calculate the confusion matrix per cell + self.pred_cm += pred_cm # Append the confusion matrix to pred_cm + + self.logger.experiment.add_figure( + "Confusion Matrix per Cell", + plot_confusion_matrix(pred_cm, self.index_to_label_dict), + self.current_epoch, + ) + + # Accumulate the confusion matrix at the end of test epoch and log. + def on_test_end(self): + confusion_matrix_sum = self.pred_cm + self.logger.experiment.add_figure( + "Confusion Matrix", + plot_confusion_matrix(confusion_matrix_sum, self.index_to_label_dict), + self.current_epoch, + ) + + # Define what happens at the end of a training epoch + def on_train_epoch_end(self): + self._log_samples( + "train_samples", self.training_step_outputs + ) # Log the training samples + self.training_step_outputs = [] # Reset the list of training step outputs + + # Define what happens at the end of a validation epoch + def on_validation_epoch_end(self): + self._log_samples( + "val_samples", self.validation_step_outputs + ) # Log the validation samples + self.validation_step_outputs = [] # Reset the list of validation step outputs + + # Define a method to detach a sample + def _detach_sample(self, imgs: Sequence[Tensor]): + # Detach the images and convert them to numpy arrays + num_samples = 3 + return [ + [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] + for i in range(num_samples) + ] + + # Define a method to log samples + def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): + images_grid = [] # Initialize the list of image grids + for sample_images in imgs: # For each sample image + images_row = [] # Initialize the list of image rows + for i, image in enumerate( + sample_images + ): # For each image in the sample images + cm_name = "gray" if i == 0 else "inferno" # Set the colormap name + if image.ndim == 2: # If the image is 2D + image = image[np.newaxis] # Add a new axis + for channel in image: # For each channel in the image + channel = rescale_intensity( + channel, out_range=(0, 1) + ) # Rescale the intensity of the channel + render = get_cmap(cm_name)(channel, bytes=True)[ + ..., :3 + ] # Render the channel + images_row.append( + render + ) # Append the render to the list of image rows + images_grid.append( + np.concatenate(images_row, axis=1) + ) # Append the concatenated image rows to the list of image grids + grid = np.concatenate(images_grid, axis=0) # Concatenate the image grids + # Log the image grid + self.logger.experiment.add_image( + key, grid, self.current_epoch, dataformats="HWC" + ) + + +# %% diff --git a/applications/infection_classification/classify_infection_covnext.py b/applications/infection_classification/classify_infection_covnext.py new file mode 100644 index 00000000..5eddb236 --- /dev/null +++ b/applications/infection_classification/classify_infection_covnext.py @@ -0,0 +1,363 @@ +# import torchview +from typing import Literal, Sequence + +import cv2 +import lightning.pytorch as pl +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from matplotlib.cm import get_cmap +from monai.transforms import DivisiblePad +from skimage.exposure import rescale_intensity +from skimage.measure import label, regionprops +from torch import Tensor + +from viscy.data.hcs import Sample +from viscy.light.engine import VSUNet + +# +# %% Methods to compute confusion matrix per cell using torchmetrics + + +# The confusion matrix at the single-cell resolution. +def confusion_matrix_per_cell( + y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int +): + """Compute confusion matrix per cell. + + Args: + y_true (torch.Tensor): Ground truth label image (BXHXW). + y_pred (torch.Tensor): Predicted label image (BXHXW). + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Confusion matrix per cell (BXCXC). + """ + # Convert the image class to the nuclei class + confusion_matrix_per_cell = compute_confusion_matrix(y_true, y_pred, num_classes) + confusion_matrix_per_cell = torch.tensor(confusion_matrix_per_cell) + return confusion_matrix_per_cell + + +# These images can be logged with prediction. +def compute_confusion_matrix( + y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int +): + """Convert the class of the image to the class of the nuclei. + + Args: + label_image (torch.Tensor): Label image (BXHXW). Values of tensor are integers that represent semantic segmentation. + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Label images with a consensus class at the centroid of nuclei. + """ + + batch_size = y_true.size(0) + # find centroids of nuclei from y_true + conf_mat = np.zeros((num_classes, num_classes)) + for i in range(batch_size): + y_true_cpu = y_true[i].cpu().numpy() + y_pred_cpu = y_pred[i].cpu().numpy() + y_true_reshaped = y_true_cpu.reshape(y_true_cpu.shape[-2:]) + y_pred_reshaped = y_pred_cpu.reshape(y_pred_cpu.shape[-2:]) + y_pred_resized = cv2.resize( + y_pred_reshaped, + dsize=y_true_reshaped.shape[::-1], + interpolation=cv2.INTER_NEAREST, + ) + y_pred_resized = np.where(y_true_reshaped > 0, y_pred_resized, 0) + + # find objects in every image + label_img = label(y_true_reshaped) + regions = regionprops(label_img) + + # Find centroids, pixel coordinates from the ground truth. + for region in regions: + if region.area > 0: + row, col = region.centroid + pred_id = y_pred_resized[int(row), int(col)] + test_id = y_true_reshaped[int(row), int(col)] + + if pred_id == 1 and test_id == 1: + conf_mat[1, 1] += 1 + if pred_id == 1 and test_id == 2: + conf_mat[0, 1] += 1 + if pred_id == 2 and test_id == 1: + conf_mat[1, 0] += 1 + if pred_id == 2 and test_id == 2: + conf_mat[0, 0] += 1 + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + return conf_mat + + +def plot_confusion_matrix(confusion_matrix, index_to_label_dict): + # Create a figure and axis to plot the confusion matrix + fig, ax = plt.subplots() + + # Create a color heatmap for the confusion matrix + cax = ax.matshow(confusion_matrix, cmap="viridis") + + # Create a colorbar and set the label + index_to_label_dict = dict( + enumerate(index_to_label_dict) + ) # Convert list to dictionary + fig.colorbar(cax, label="Frequency") + + # Set labels for the classes + ax.set_xticks(np.arange(len(index_to_label_dict))) + ax.set_yticks(np.arange(len(index_to_label_dict))) + ax.set_xticklabels(index_to_label_dict.values(), rotation=45) + ax.set_yticklabels(index_to_label_dict.values()) + + # Set labels for the axes + ax.set_xlabel("Predicted") + ax.set_ylabel("True") + + # Add text annotations to the confusion matrix + for i in range(len(index_to_label_dict)): + for j in range(len(index_to_label_dict)): + ax.text( + j, + i, + str(int(confusion_matrix[i, j])), + ha="center", + va="center", + color="white", + ) + + # plt.show(fig) # Show the figure + return fig + + +# Define a 25d unet model for infection classification as a lightning module. + + +class SemanticSegUNet22D(pl.LightningModule): + # Model for semantic segmentation. + def __init__( + self, + in_channels: int, # Number of input channels + out_channels: int, # Number of output channels + lr: float = 1e-3, # Learning rate + loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function + schedule: Literal[ + "WarmupCosine", "Constant" + ] = "Constant", # Learning rate schedule + log_batches_per_epoch: int = 2, # Number of batches to log per epoch + log_samples_per_batch: int = 2, # Number of samples to log per batch + ckpt_path: str = None, # Path to the checkpoint + ): + super(SemanticSegUNet22D, self).__init__() # Call the superclass initializer + # Initialize the UNet model + self.unet_model = VSUNet( + architecture="2.2D", + model_config={ + "in_channels": in_channels, + "out_channels": out_channels, + "in_stack_depth": 5, + "backbone": "convnextv2_tiny", + "stem_kernel_size": (5, 4, 4), + "decoder_mode": "pixelshuffle", + "head_expansion_ratio": 4, + }, + ) + if ckpt_path is not None: + state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[ + "state_dict" + ] + state_dict.pop("loss_function.weight", None) # Remove the unexpected key + self.load_state_dict(state_dict) # loading only weights + self.lr = lr # Set the learning rate + # Set the loss function to CrossEntropyLoss if none is provided + self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss() + self.schedule = schedule # Set the learning rate schedule + self.log_batches_per_epoch = ( + log_batches_per_epoch # Set the number of batches to log per epoch + ) + self.log_samples_per_batch = ( + log_samples_per_batch # Set the number of samples to log per batch + ) + self.training_step_outputs = [] # Initialize the list of training step outputs + self.validation_step_outputs = ( + [] + ) # Initialize the list of validation step outputs + + self.pred_cm = None # Initialize the confusion matrix + self.index_to_label_dict = ["Infected", "Uninfected"] + + # Define the forward pass + def forward(self, x): + return self.unet_model(x) # Pass the input through the UNet model + + # Define the optimizer + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.parameters(), lr=self.lr + ) # Use the Adam optimizer + return optimizer + + # Define the training step + def training_step(self, batch: Sample, batch_idx: int): + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert the target to float type + train_loss = self.loss_function(pred, target_one_hot) # Calculate the loss + # Log the training step outputs if the batch index is less than the number of batches to log per epoch + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) + # Log the training loss + self.log( + "loss/train", + train_loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + return train_loss # Return the training loss + + def validation_step(self, batch: Sample, batch_idx: int): + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert the target to float type + loss = self.loss_function(pred, target_one_hot) # Calculate the loss + # Log the validation step outputs if the batch index is less than the number of batches to log per epoch + if batch_idx < self.log_batches_per_epoch: + self.validation_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) + # Log the validation loss + self.log( + "loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True + ) + return loss # Return the validation loss + + def on_predict_start(self): + """Pad the input shape to be divisible by the downsampling factor. + The inverse of this transform crops the prediction to original shape. + """ + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + + # Define the prediction step + def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): + source = self._predict_pad(batch["source"]) # Pad the source + logits = self._predict_pad.inverse( + self.forward(source) + ) # Predict and remove padding. + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + # Go from probabilities/one-hot encoded data to class labels. + labels_pred = torch.argmax( + prob_pred, dim=1, keepdim=True + ) # Calculate the predicted labels + # prob_chan = prob_pred[:, 2, :, :] + # prob_chan = prob_chan.unsqueeze(1) + return labels_pred # log the class predicted image + # return prob_chan # log the probability predicted image + + def on_test_start(self): + self.pred_cm = torch.zeros((2, 2)) + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + + def test_step(self, batch: Sample): + source = self._predict_pad(batch["source"]) # Pad the source + logits = self._predict_pad.inverse(self.forward(source)) + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + labels_pred = torch.argmax( + prob_pred, dim=1, keepdim=True + ) # Calculate the predicted labels + + target = self._predict_pad(batch["target"]) # Extract the target from the batch + pred_cm = confusion_matrix_per_cell( + target, labels_pred, num_classes=2 + ) # Calculate the confusion matrix per cell + self.pred_cm += pred_cm # Append the confusion matrix to pred_cm + + self.logger.experiment.add_figure( + "Confusion Matrix per Cell", + plot_confusion_matrix(pred_cm, self.index_to_label_dict), + self.current_epoch, + ) + + # Accumulate the confusion matrix at the end of test epoch and log. + def on_test_end(self): + confusion_matrix_sum = self.pred_cm + self.logger.experiment.add_figure( + "Confusion Matrix", + plot_confusion_matrix(confusion_matrix_sum, self.index_to_label_dict), + self.current_epoch, + ) + + # Define what happens at the end of a training epoch + def on_train_epoch_end(self): + self._log_samples( + "train_samples", self.training_step_outputs + ) # Log the training samples + self.training_step_outputs = [] # Reset the list of training step outputs + + # Define what happens at the end of a validation epoch + def on_validation_epoch_end(self): + self._log_samples( + "val_samples", self.validation_step_outputs + ) # Log the validation samples + self.validation_step_outputs = [] # Reset the list of validation step outputs + + # Define a method to detach a sample + def _detach_sample(self, imgs: Sequence[Tensor]): + # Detach the images and convert them to numpy arrays + num_samples = 3 + return [ + [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] + for i in range(num_samples) + ] + + # Define a method to log samples + def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): + images_grid = [] # Initialize the list of image grids + for sample_images in imgs: # For each sample image + images_row = [] # Initialize the list of image rows + for i, image in enumerate( + sample_images + ): # For each image in the sample images + cm_name = "gray" if i == 0 else "inferno" # Set the colormap name + if image.ndim == 2: # If the image is 2D + image = image[np.newaxis] # Add a new axis + for channel in image: # For each channel in the image + channel = rescale_intensity( + channel, out_range=(0, 1) + ) # Rescale the intensity of the channel + render = get_cmap(cm_name)(channel, bytes=True)[ + ..., :3 + ] # Render the channel + images_row.append( + render + ) # Append the render to the list of image rows + images_grid.append( + np.concatenate(images_row, axis=1) + ) # Append the concatenated image rows to the list of image grids + grid = np.concatenate(images_grid, axis=0) # Concatenate the image grids + # Log the image grid + self.logger.experiment.add_image( + key, grid, self.current_epoch, dataformats="HWC" + ) + + +# %% diff --git a/applications/infection_classification/infection_classifier_testing.py b/applications/infection_classification/infection_classifier_testing.py new file mode 100644 index 00000000..fea8326d --- /dev/null +++ b/applications/infection_classification/infection_classifier_testing.py @@ -0,0 +1,59 @@ +# %% +import lightning.pytorch as pl +from applications.infection_classification.classify_infection_2D import ( + SemanticSegUNet2D, +) +from pytorch_lightning.loggers import TensorBoardLogger + +from viscy.data.hcs import HCSDataModule +from viscy.transforms import NormalizeSampled + +# %% test the model on the test set +test_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/4-human_annotation/test_data.zarr" + +data_module = HCSDataModule( + data_path=test_datapath, + source_channel=["TXR_Density3D", "Phase3D"], + target_channel=["Inf_mask"], + split_ratio=0.7, + z_window_size=1, + architecture="2D", + num_workers=1, + batch_size=1, + normalizations=[ + NormalizeSampled( + keys=["TXR_Density3D", "Phase3D"], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], +) + +data_module.setup(stage="test") + +# %% create trainer and input + +logger = TensorBoardLogger( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/", + name="logs", +) + +trainer = pl.Trainer( + logger=logger, + default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/logs", + log_every_n_steps=1, + devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs +) + +model = SemanticSegUNet2D( + in_channels=2, + out_channels=3, + ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/logs/checkpoint_epoch=206.ckpt", +) + +# %% test the model + +trainer.test(model=model, datamodule=data_module) + +# %% diff --git a/applications/infection_classification/predict_infection_classifier.py b/applications/infection_classification/predict_infection_classifier.py new file mode 100644 index 00000000..458fc670 --- /dev/null +++ b/applications/infection_classification/predict_infection_classifier.py @@ -0,0 +1,59 @@ +# %% + +import lightning.pytorch as pl +from applications.infection_classification.classify_infection_2D import ( + SemanticSegUNet2D, +) + +from viscy.data.hcs import HCSDataModule +from viscy.light.predict_writer import HCSPredictionWriter +from viscy.transforms import NormalizeSampled + +# %% # %% write the predictions to a zarr file + +pred_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/A549_63X/2024_02_04_A549_DENV_ZIKV_timelapse/0-train_test_data/2024_02_04_A549_DENV_ZIKV_timelapse_test_2D.zarr" + +data_module = HCSDataModule( + data_path=pred_datapath, + source_channel=["RFP", "Phase3D"], + target_channel=["Inf_mask"], + split_ratio=0.7, + z_window_size=1, + architecture="2D", + num_workers=1, + batch_size=1, + normalizations=[ + NormalizeSampled( + keys=["RFP", "Phase3D"], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], +) + +data_module.setup(stage="predict") + +model = SemanticSegUNet2D( + in_channels=2, + out_channels=3, + ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/A549_63X/2024_02_04_A549_DENV_ZIKV_timelapse/1-model_train/logs/version_0/checkpoints/epoch=199-step=800.ckpt", +) + +# %% perform prediction + +output_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/A549_63X/2024_02_04_A549_DENV_ZIKV_timelapse/2-predict_infection/2024_02_04_A549_DENV_ZIKV_timelapse_pred_2D_new.zarr" + +trainer = pl.Trainer( + default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/A549_63X/2024_02_04_A549_DENV_ZIKV_timelapse/1-model_train/logs", + callbacks=[HCSPredictionWriter(output_path, write_input=False)], + devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs +) + +trainer.predict( + model=model, + datamodule=data_module, + return_predictions=False, +) + +# %% diff --git a/applications/infection_classification/readme.md b/applications/infection_classification/readme.md new file mode 100644 index 00000000..a5d317b7 --- /dev/null +++ b/applications/infection_classification/readme.md @@ -0,0 +1,13 @@ +# Infection Classification Model + +This repository contains the code for developing the infection classification model used in the infection phenotyping project. Infection classification models can be trained on human annotated ground truth with fluorescence sensor channel and phase channel to predict the state of infection of single cells. The pixels are predicted to be background (class 0), uninfected (class 1) or infected (class 2) by the model. + +## Overview + +The following scripts are available: + +Training: `infection_classification_*model.py` file implements a machine learning model for classifying infections based on various features. The model is trained on a labeled dataset, with fluorescence and label-free images. + +Testing: `infection_classifier_testing.py` file tests the 2D infection classification model trained on a 2D dataset. + +Prediction: `predict_classifier_testing.py` is an example script to perform prediction using 2D data and 2D model. It can be used to predict the infection type for new samples. diff --git a/pyproject.toml b/pyproject.toml index 7f48ea7e..ecb0c3d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,9 @@ dev = [ [project.scripts] viscy = "viscy.cli.cli:main" +[tool.setuptools] +packages = ["viscy"] + [tool.setuptools_scm] write_to = "viscy/_version.py" diff --git a/tests/preprocessing/test_pixel_ratio.py b/tests/preprocessing/test_pixel_ratio.py new file mode 100644 index 00000000..2dce7afe --- /dev/null +++ b/tests/preprocessing/test_pixel_ratio.py @@ -0,0 +1,15 @@ +from numpy.testing import assert_allclose + +from viscy.preprocessing.pixel_ratio import sematic_class_weights + + +def test_sematic_class_weights(small_hcs_dataset): + weights = sematic_class_weights(small_hcs_dataset, "GFP") + assert weights.shape == (3,) + assert_allclose(weights[0], 1.0) + # infinity + assert weights[1] > 1.0 + assert weights[2] > 1.0 + assert sematic_class_weights( + small_hcs_dataset, "GFP", num_classes=2 + ).shape == (2,) diff --git a/viscy/preprocessing/pixel_ratio.py b/viscy/preprocessing/pixel_ratio.py new file mode 100644 index 00000000..29c2ed41 --- /dev/null +++ b/viscy/preprocessing/pixel_ratio.py @@ -0,0 +1,22 @@ +import dask.array as da +from iohub.ngff import open_ome_zarr +from numpy.typing import NDArray + + +def sematic_class_weights( + dataset_path: str, target_channel: str, num_classes: int = 3 +) -> NDArray: + """Computes class balancing weights for semantic segmentation. + The weights can be used for cross-entropy loss. + + :param str dataset_path: HCS OME-Zarr dataset path + :param str target_channel: target channel name + :param int num_classes: number of classes + :return NDArray: inverted ratio of background, uninfected and infected pixels + """ + dataset = open_ome_zarr(dataset_path) + arrays = [da.from_zarr(pos["0"]) for _, pos in dataset.positions()] + imgs = da.stack(arrays, axis=0)[:, :, dataset.get_channel_index(target_channel)] + ratio, _ = da.histogram(imgs, bins=range(num_classes + 1), density=True) + weights = 1 / ratio + return weights.compute()