-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Cellular infection phenotyping using annotated viral sensor data & la…
…bel-free images (#70) * refactor data loading into its own module * update type annotations * move the logging module out * move old logging into utils * rename tests to match module name * bump torch * draft fcmae encoder * add stem to the encoder * wip: masked stem layernorm * wip: patchify masked features for linear * use mlp from timm * hack: POC training script for FCMAE * fix mask for fitting * remove training script * default architecture * fine-tuning options * fix cli for finetuning * draft combined data module * fix import * manual validation loss reduction * update linting new black version has different rules * update development guide * update type hints * bump iohub * draft ctmc v1 dataset * update tests * move test_data * remove path conversion * configurable normalizations (#68) * inital commit adding the normalization. * adding dataset_statistics to each fov to facilitate the configurable augmentations * fix indentation * ruff * test preprocessing * remove redundant field * cleanup --------- Co-authored-by: Ziwen Liu <[email protected]> * fix ctmc dataloading * add example ctmc v1 loading script * changing the normalization and augmentations default from None to empty list. * invert intensity transform * concatenated data module * subsample videos * livecell dataset * all sample fields are optional * fix multi-dataloader validation * lint * fixing preprocessing for varying array shapes (i.e aics dataset) * update loading scripts * fix CombineMode * added model and annotation code draft * chnaged to simple unet model * start with lesser augmentations * added readme file * added tensorboard logging * added validation step * chnaged to viscy 2d unet * used crossentropyloss with one-hot encoding * added sample image logging * attempt to build magicgui annotation * renamed infection annotation tool * added normalization and augmentations * added model testing code * removed annotation refiner * corrected conversion of class to int * corrected prediction module * cleaned up the code and comments for the LightningUNet * removed confusion matrix code, finding runtime error with model * moved scripts to viscy.scripts.infection_phenotyping module to enable imports across scripts * combine the lightning modules for training and prediction, fix the DDP exception * all the stubs for computing and logging confusion matrix per cell * separated training and test scripts * lightning module * corrected test cm compute * corrected test module * separated test and prediction scripts * changed confusion matrix compute * fix merge error * split 2D and 2.5D model scripts * added covnext script * fix model input parameter * update input file * add augmentations * refactor infection_classification code to viscy/applications * changes made for BJ5 classification * format code * add explicit packaging list * rename testing script * update readme * move function to preprocessing * format code * formatting * histogram with dask * fix index and test * fix import * black * fix float comp * clean up headers * clean up import * add argument to change number of classes --------- Co-authored-by: Ziwen Liu <[email protected]> Co-authored-by: Eduardo Hirata-Miyasaki <[email protected]> Co-authored-by: Shalin Mehta <[email protected]> Co-authored-by: Ziwen Liu <[email protected]>
- Loading branch information
1 parent
dbf4ddc
commit 4401f33
Showing
12 changed files
with
1,588 additions
and
0 deletions.
There are no files selected for viewing
106 changes: 106 additions & 0 deletions
106
applications/infection_classification/Infection_classification_25DModel.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# %% | ||
import lightning.pytorch as pl | ||
import torch | ||
import torch.nn as nn | ||
from applications.infection_classification.classify_infection_25D import ( | ||
SemanticSegUNet25D, | ||
) | ||
from pytorch_lightning.callbacks import ModelCheckpoint | ||
from pytorch_lightning.loggers import TensorBoardLogger | ||
|
||
from viscy.data.hcs import HCSDataModule | ||
from viscy.preprocessing.pixel_ratio import sematic_class_weights | ||
from viscy.transforms import NormalizeSampled, RandWeightedCropd | ||
|
||
# %% Create a dataloader and visualize the batches. | ||
|
||
# Set the path to the dataset | ||
dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_trainVal.zarr" | ||
|
||
# %% create data module | ||
|
||
# Create an instance of HCSDataModule | ||
data_module = HCSDataModule( | ||
dataset_path, | ||
source_channel=["Phase", "HSP90"], | ||
target_channel=["Inf_mask"], | ||
yx_patch_size=[512, 512], | ||
split_ratio=0.8, | ||
z_window_size=5, | ||
architecture="2.5D", | ||
num_workers=3, | ||
batch_size=32, | ||
normalizations=[ | ||
NormalizeSampled( | ||
keys=["Phase", "HSP90"], | ||
level="fov_statistics", | ||
subtrahend="median", | ||
divisor="iqr", | ||
) | ||
], | ||
augmentations=[ | ||
RandWeightedCropd( | ||
num_samples=4, | ||
spatial_size=[-1, 512, 512], | ||
keys=["Phase", "HSP90"], | ||
w_key="Inf_mask", | ||
) | ||
], | ||
) | ||
|
||
pixel_ratio = sematic_class_weights(dataset_path, target_channel="Inf_mask") | ||
|
||
# Prepare the data | ||
data_module.prepare_data() | ||
|
||
# Setup the data | ||
data_module.setup(stage="fit") | ||
|
||
# Create a dataloader | ||
train_dm = data_module.train_dataloader() | ||
|
||
val_dm = data_module.val_dataloader() | ||
|
||
|
||
# %% Define the logger | ||
logger = TensorBoardLogger( | ||
"/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/", | ||
name="logs", | ||
) | ||
|
||
# Pass the logger to the Trainer | ||
trainer = pl.Trainer( | ||
logger=logger, | ||
max_epochs=200, | ||
default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", | ||
log_every_n_steps=1, | ||
devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs | ||
) | ||
|
||
# Define the checkpoint callback | ||
checkpoint_callback = ModelCheckpoint( | ||
dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", | ||
filename="checkpoint_{epoch:02d}", | ||
save_top_k=-1, | ||
verbose=True, | ||
monitor="loss/validate", | ||
mode="min", | ||
) | ||
|
||
# Add the checkpoint callback to the trainer | ||
trainer.callbacks.append(checkpoint_callback) | ||
|
||
# Fit the model | ||
model = SemanticSegUNet25D( | ||
in_channels=2, | ||
out_channels=3, | ||
loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)), | ||
) | ||
|
||
print(model) | ||
|
||
# %% Run training. | ||
|
||
trainer.fit(model, data_module) | ||
|
||
# %% |
123 changes: 123 additions & 0 deletions
123
applications/infection_classification/Infection_classification_2Dmodel.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# %% | ||
import lightning.pytorch as pl | ||
import torch | ||
import torch.nn as nn | ||
from applications.infection_classification.classify_infection_2D import ( | ||
SemanticSegUNet2D, | ||
) | ||
from pytorch_lightning.callbacks import ModelCheckpoint | ||
from pytorch_lightning.loggers import TensorBoardLogger | ||
|
||
from viscy.data.hcs import HCSDataModule | ||
from viscy.preprocessing.pixel_ratio import sematic_class_weights | ||
from viscy.transforms import ( | ||
NormalizeSampled, | ||
RandGaussianSmoothd, | ||
RandScaleIntensityd, | ||
RandWeightedCropd, | ||
) | ||
|
||
# %% calculate the ratio of background, uninfected and infected pixels in the input dataset | ||
|
||
# Set the path to the dataset | ||
dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/4-human_annotation/train_data.zarr" | ||
|
||
# %% Create an instance of HCSDataModule | ||
|
||
data_module = HCSDataModule( | ||
dataset_path, | ||
source_channel=["TXR_Density3D", "Phase3D"], | ||
target_channel=["Inf_mask"], | ||
yx_patch_size=[128, 128], | ||
split_ratio=0.7, | ||
z_window_size=1, | ||
architecture="2D", | ||
num_workers=1, | ||
batch_size=256, | ||
normalizations=[ | ||
NormalizeSampled( | ||
keys=["Phase3D", "TXR_Density3D"], | ||
level="fov_statistics", | ||
subtrahend="median", | ||
divisor="iqr", | ||
) | ||
], | ||
augmentations=[ | ||
RandWeightedCropd( | ||
num_samples=16, | ||
spatial_size=[-1, 128, 128], | ||
keys=["TXR_Density3D", "Phase3D", "Inf_mask"], | ||
w_key="Inf_mask", | ||
), | ||
RandScaleIntensityd( | ||
keys=["TXR_Density3D", "Phase3D"], | ||
factors=[0.5, 0.5], | ||
prob=0.5, | ||
), | ||
RandGaussianSmoothd( | ||
keys=["TXR_Density3D", "Phase3D"], | ||
prob=0.5, | ||
sigma_x=[0.5, 1.0], | ||
sigma_y=[0.5, 1.0], | ||
sigma_z=[0.5, 1.0], | ||
), | ||
], | ||
) | ||
pixel_ratio = sematic_class_weights(dataset_path, target_channel="Inf_mask") | ||
|
||
# Prepare the data | ||
data_module.prepare_data() | ||
|
||
# Setup the data | ||
data_module.setup(stage="fit") | ||
|
||
# Create a dataloader | ||
train_dm = data_module.train_dataloader() | ||
|
||
val_dm = data_module.val_dataloader() | ||
|
||
# %% Set up for training | ||
|
||
# define the logger | ||
logger = TensorBoardLogger( | ||
"/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/", | ||
name="logs", | ||
) | ||
|
||
# Pass the logger to the Trainer | ||
trainer = pl.Trainer( | ||
logger=logger, | ||
max_epochs=500, | ||
default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/logs/", | ||
log_every_n_steps=1, | ||
devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs | ||
) | ||
|
||
# Define the checkpoint callback | ||
checkpoint_callback = ModelCheckpoint( | ||
dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/logs/", | ||
filename="checkpoint_{epoch:02d}", | ||
save_top_k=-1, | ||
verbose=True, | ||
monitor="loss/validate", | ||
mode="min", | ||
) | ||
|
||
# Add the checkpoint callback to the trainer | ||
trainer.callbacks.append(checkpoint_callback) | ||
|
||
# Fit the model | ||
model = SemanticSegUNet2D( | ||
in_channels=2, | ||
out_channels=3, | ||
loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)), | ||
) | ||
|
||
# visualize the model | ||
print(model) | ||
|
||
# %% Run training. | ||
|
||
trainer.fit(model, data_module) | ||
|
||
# %% |
107 changes: 107 additions & 0 deletions
107
applications/infection_classification/Infection_classification_covnextModel.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# %% | ||
# import sys | ||
# sys.path.append("/hpc/mydata/soorya.pradeep/viscy_infection_phenotyping/Viscy/") | ||
import lightning.pytorch as pl | ||
import torch | ||
import torch.nn as nn | ||
from applications.infection_classification.classify_infection_covnext import ( | ||
SemanticSegUNet22D, | ||
) | ||
from pytorch_lightning.callbacks import ModelCheckpoint | ||
from pytorch_lightning.loggers import TensorBoardLogger | ||
|
||
from viscy.data.hcs import HCSDataModule | ||
from viscy.preprocessing.pixel_ratio import sematic_class_weights | ||
from viscy.transforms import NormalizeSampled, RandWeightedCropd | ||
|
||
# %% Create a dataloader and visualize the batches. | ||
|
||
# Set the path to the dataset | ||
dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_all_curated_train.zarr" | ||
|
||
# %% craete data module | ||
|
||
# Create an instance of HCSDataModule | ||
data_module = HCSDataModule( | ||
dataset_path, | ||
source_channel=["Phase", "HSP90", "phase_nucl_iqr", "hsp90_skew"], | ||
target_channel=["Inf_mask"], | ||
yx_patch_size=[256, 256], | ||
split_ratio=0.8, | ||
z_window_size=5, | ||
architecture="2.2D", | ||
num_workers=3, | ||
batch_size=16, | ||
normalizations=[ | ||
NormalizeSampled( | ||
keys=["Phase", "HSP90", "phase_nucl_iqr", "hsp90_skew"], | ||
level="fov_statistics", | ||
subtrahend="median", | ||
divisor="iqr", | ||
) | ||
], | ||
augmentations=[ | ||
RandWeightedCropd( | ||
num_samples=4, | ||
spatial_size=[-1, 256, 256], | ||
keys=["Phase", "HSP90", "phase_nucl_iqr", "hsp90_skew"], | ||
w_key="Inf_mask", | ||
) | ||
], | ||
) | ||
pixel_ratio = sematic_class_weights(dataset_path, target_channel="Inf_mask") | ||
|
||
# Prepare the data | ||
data_module.prepare_data() | ||
|
||
# Setup the data | ||
data_module.setup(stage="fit") | ||
|
||
# Create a dataloader | ||
train_dm = data_module.train_dataloader() | ||
|
||
val_dm = data_module.val_dataloader() | ||
|
||
|
||
# %% Define the logger | ||
logger = TensorBoardLogger( | ||
"/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/", | ||
name="logs", | ||
) | ||
|
||
# Pass the logger to the Trainer | ||
trainer = pl.Trainer( | ||
logger=logger, | ||
max_epochs=200, | ||
default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", | ||
log_every_n_steps=1, | ||
devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs | ||
) | ||
|
||
# Define the checkpoint callback | ||
checkpoint_callback = ModelCheckpoint( | ||
dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", | ||
filename="checkpoint_{epoch:02d}", | ||
save_top_k=-1, | ||
verbose=True, | ||
monitor="loss/validate", | ||
mode="min", | ||
) | ||
|
||
# Add the checkpoint callback to the trainer`` | ||
trainer.callbacks.append(checkpoint_callback) | ||
|
||
# Fit the model | ||
model = SemanticSegUNet22D( | ||
in_channels=4, | ||
out_channels=3, | ||
loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)), | ||
) | ||
|
||
print(model) | ||
|
||
# %% Run training. | ||
|
||
trainer.fit(model, data_module) | ||
|
||
# %% |
Oops, something went wrong.