Skip to content

Commit

Permalink
Cellular infection phenotyping using annotated viral sensor data & la…
Browse files Browse the repository at this point in the history
…bel-free images (#70)

* refactor data loading into its own module

* update type annotations

* move the logging module out

* move old logging into utils

* rename tests to match module name

* bump torch

* draft fcmae encoder

* add stem to the encoder

* wip: masked stem layernorm

* wip: patchify masked features for linear

* use mlp from timm

* hack: POC training script for FCMAE

* fix mask for fitting

* remove training script

* default architecture

* fine-tuning options

* fix cli for finetuning

* draft combined data module

* fix import

* manual validation loss reduction

* update linting
new black version has different rules

* update development guide

* update type hints

* bump iohub

* draft ctmc v1 dataset

* update tests

* move test_data

* remove path conversion

* configurable normalizations (#68)

* inital commit adding the normalization.

* adding dataset_statistics to each fov to facilitate the configurable augmentations

* fix indentation

* ruff

* test preprocessing

* remove redundant field

* cleanup

---------

Co-authored-by: Ziwen Liu <[email protected]>

* fix ctmc dataloading

* add example ctmc v1 loading script

* changing the normalization and augmentations default from None to empty list.

* invert intensity transform

* concatenated data module

* subsample videos

* livecell dataset

* all sample fields are optional

* fix multi-dataloader validation

* lint

* fixing preprocessing for varying array shapes (i.e aics dataset)

* update loading scripts

* fix CombineMode

* added model and annotation code draft

* chnaged to simple unet model

* start with lesser augmentations

* added readme file

* added tensorboard logging

* added validation step

* chnaged to viscy 2d unet

* used crossentropyloss with one-hot encoding

* added sample image logging

* attempt to build magicgui annotation

* renamed infection annotation tool

* added normalization and augmentations

* added model testing code

* removed annotation refiner

* corrected conversion of class to int

* corrected prediction module

* cleaned up the code and comments for the LightningUNet

* removed confusion matrix code, finding runtime error with model

* moved scripts to viscy.scripts.infection_phenotyping module to enable imports across scripts

* combine the lightning modules for training and prediction, fix the DDP exception

* all the stubs for computing and logging confusion matrix per cell

* separated training and test scripts

* lightning module

* corrected test cm compute

* corrected test module

* separated test and prediction scripts

* changed confusion matrix compute

* fix merge error

* split 2D and 2.5D model scripts

* added covnext script

* fix model input parameter

* update input file

* add augmentations

* refactor infection_classification code to viscy/applications

* changes made for BJ5 classification

* format code

* add explicit packaging list

* rename testing script

* update readme

* move function to preprocessing

* format code

* formatting

* histogram with dask

* fix index and test

* fix import

* black

* fix float comp

* clean up headers

* clean up import

* add argument to change number of classes

---------

Co-authored-by: Ziwen Liu <[email protected]>
Co-authored-by: Eduardo Hirata-Miyasaki <[email protected]>
Co-authored-by: Shalin Mehta <[email protected]>
Co-authored-by: Ziwen Liu <[email protected]>
  • Loading branch information
5 people authored Jul 10, 2024
1 parent dbf4ddc commit 4401f33
Show file tree
Hide file tree
Showing 12 changed files with 1,588 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# %%
import lightning.pytorch as pl
import torch
import torch.nn as nn
from applications.infection_classification.classify_infection_25D import (
SemanticSegUNet25D,
)
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from viscy.data.hcs import HCSDataModule
from viscy.preprocessing.pixel_ratio import sematic_class_weights
from viscy.transforms import NormalizeSampled, RandWeightedCropd

# %% Create a dataloader and visualize the batches.

# Set the path to the dataset
dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_trainVal.zarr"

# %% create data module

# Create an instance of HCSDataModule
data_module = HCSDataModule(
dataset_path,
source_channel=["Phase", "HSP90"],
target_channel=["Inf_mask"],
yx_patch_size=[512, 512],
split_ratio=0.8,
z_window_size=5,
architecture="2.5D",
num_workers=3,
batch_size=32,
normalizations=[
NormalizeSampled(
keys=["Phase", "HSP90"],
level="fov_statistics",
subtrahend="median",
divisor="iqr",
)
],
augmentations=[
RandWeightedCropd(
num_samples=4,
spatial_size=[-1, 512, 512],
keys=["Phase", "HSP90"],
w_key="Inf_mask",
)
],
)

pixel_ratio = sematic_class_weights(dataset_path, target_channel="Inf_mask")

# Prepare the data
data_module.prepare_data()

# Setup the data
data_module.setup(stage="fit")

# Create a dataloader
train_dm = data_module.train_dataloader()

val_dm = data_module.val_dataloader()


# %% Define the logger
logger = TensorBoardLogger(
"/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/",
name="logs",
)

# Pass the logger to the Trainer
trainer = pl.Trainer(
logger=logger,
max_epochs=200,
default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/",
log_every_n_steps=1,
devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs
)

# Define the checkpoint callback
checkpoint_callback = ModelCheckpoint(
dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/",
filename="checkpoint_{epoch:02d}",
save_top_k=-1,
verbose=True,
monitor="loss/validate",
mode="min",
)

# Add the checkpoint callback to the trainer
trainer.callbacks.append(checkpoint_callback)

# Fit the model
model = SemanticSegUNet25D(
in_channels=2,
out_channels=3,
loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)),
)

print(model)

# %% Run training.

trainer.fit(model, data_module)

# %%
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# %%
import lightning.pytorch as pl
import torch
import torch.nn as nn
from applications.infection_classification.classify_infection_2D import (
SemanticSegUNet2D,
)
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from viscy.data.hcs import HCSDataModule
from viscy.preprocessing.pixel_ratio import sematic_class_weights
from viscy.transforms import (
NormalizeSampled,
RandGaussianSmoothd,
RandScaleIntensityd,
RandWeightedCropd,
)

# %% calculate the ratio of background, uninfected and infected pixels in the input dataset

# Set the path to the dataset
dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/4-human_annotation/train_data.zarr"

# %% Create an instance of HCSDataModule

data_module = HCSDataModule(
dataset_path,
source_channel=["TXR_Density3D", "Phase3D"],
target_channel=["Inf_mask"],
yx_patch_size=[128, 128],
split_ratio=0.7,
z_window_size=1,
architecture="2D",
num_workers=1,
batch_size=256,
normalizations=[
NormalizeSampled(
keys=["Phase3D", "TXR_Density3D"],
level="fov_statistics",
subtrahend="median",
divisor="iqr",
)
],
augmentations=[
RandWeightedCropd(
num_samples=16,
spatial_size=[-1, 128, 128],
keys=["TXR_Density3D", "Phase3D", "Inf_mask"],
w_key="Inf_mask",
),
RandScaleIntensityd(
keys=["TXR_Density3D", "Phase3D"],
factors=[0.5, 0.5],
prob=0.5,
),
RandGaussianSmoothd(
keys=["TXR_Density3D", "Phase3D"],
prob=0.5,
sigma_x=[0.5, 1.0],
sigma_y=[0.5, 1.0],
sigma_z=[0.5, 1.0],
),
],
)
pixel_ratio = sematic_class_weights(dataset_path, target_channel="Inf_mask")

# Prepare the data
data_module.prepare_data()

# Setup the data
data_module.setup(stage="fit")

# Create a dataloader
train_dm = data_module.train_dataloader()

val_dm = data_module.val_dataloader()

# %% Set up for training

# define the logger
logger = TensorBoardLogger(
"/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/",
name="logs",
)

# Pass the logger to the Trainer
trainer = pl.Trainer(
logger=logger,
max_epochs=500,
default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/logs/",
log_every_n_steps=1,
devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs
)

# Define the checkpoint callback
checkpoint_callback = ModelCheckpoint(
dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/logs/",
filename="checkpoint_{epoch:02d}",
save_top_k=-1,
verbose=True,
monitor="loss/validate",
mode="min",
)

# Add the checkpoint callback to the trainer
trainer.callbacks.append(checkpoint_callback)

# Fit the model
model = SemanticSegUNet2D(
in_channels=2,
out_channels=3,
loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)),
)

# visualize the model
print(model)

# %% Run training.

trainer.fit(model, data_module)

# %%
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# %%
# import sys
# sys.path.append("/hpc/mydata/soorya.pradeep/viscy_infection_phenotyping/Viscy/")
import lightning.pytorch as pl
import torch
import torch.nn as nn
from applications.infection_classification.classify_infection_covnext import (
SemanticSegUNet22D,
)
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from viscy.data.hcs import HCSDataModule
from viscy.preprocessing.pixel_ratio import sematic_class_weights
from viscy.transforms import NormalizeSampled, RandWeightedCropd

# %% Create a dataloader and visualize the batches.

# Set the path to the dataset
dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_all_curated_train.zarr"

# %% craete data module

# Create an instance of HCSDataModule
data_module = HCSDataModule(
dataset_path,
source_channel=["Phase", "HSP90", "phase_nucl_iqr", "hsp90_skew"],
target_channel=["Inf_mask"],
yx_patch_size=[256, 256],
split_ratio=0.8,
z_window_size=5,
architecture="2.2D",
num_workers=3,
batch_size=16,
normalizations=[
NormalizeSampled(
keys=["Phase", "HSP90", "phase_nucl_iqr", "hsp90_skew"],
level="fov_statistics",
subtrahend="median",
divisor="iqr",
)
],
augmentations=[
RandWeightedCropd(
num_samples=4,
spatial_size=[-1, 256, 256],
keys=["Phase", "HSP90", "phase_nucl_iqr", "hsp90_skew"],
w_key="Inf_mask",
)
],
)
pixel_ratio = sematic_class_weights(dataset_path, target_channel="Inf_mask")

# Prepare the data
data_module.prepare_data()

# Setup the data
data_module.setup(stage="fit")

# Create a dataloader
train_dm = data_module.train_dataloader()

val_dm = data_module.val_dataloader()


# %% Define the logger
logger = TensorBoardLogger(
"/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/",
name="logs",
)

# Pass the logger to the Trainer
trainer = pl.Trainer(
logger=logger,
max_epochs=200,
default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/",
log_every_n_steps=1,
devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs
)

# Define the checkpoint callback
checkpoint_callback = ModelCheckpoint(
dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/",
filename="checkpoint_{epoch:02d}",
save_top_k=-1,
verbose=True,
monitor="loss/validate",
mode="min",
)

# Add the checkpoint callback to the trainer``
trainer.callbacks.append(checkpoint_callback)

# Fit the model
model = SemanticSegUNet22D(
in_channels=4,
out_channels=3,
loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)),
)

print(model)

# %% Run training.

trainer.fit(model, data_module)

# %%
Loading

0 comments on commit 4401f33

Please sign in to comment.