Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infection state #118

Merged
merged 27 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
db63713
updated prediction code
alishbaimran Jul 23, 2024
79f7b4e
updated predict code
alishbaimran Jul 23, 2024
36864a2
updated code
alishbaimran Jul 23, 2024
240293c
fixed the stem and forward pass (#115)
mattersoflight Jul 26, 2024
144ab18
WIP: Save progress before merging
alishbaimran Jul 27, 2024
4bccadd
updated predict code
alishbaimran Jul 27, 2024
32d50ff
updated contrastive.py
alishbaimran Jul 27, 2024
6073505
stem update
alishbaimran Jul 27, 2024
92e1f88
updated predict code
alishbaimran Jul 29, 2024
eb3443b
Delete viscy/applications/contrastive_phenotyping/PCA.ipynb
alishbaimran Jul 29, 2024
43a032a
pushing dataloader test updated
alishbaimran Jul 29, 2024
44ea0f6
pca deleted
alishbaimran Jul 29, 2024
d3c6f70
training and dataloader test
alishbaimran Jul 29, 2024
ee50300
Merge branch 'infection_state' of https://github.com/mehta-lab/VisCy …
alishbaimran Jul 29, 2024
4d787f2
updated structure
alishbaimran Jul 29, 2024
a7bda86
deleted files
alishbaimran Jul 29, 2024
6242796
Merge branch 'contrastive_phenotyping' into infection_state
alishbaimran Jul 29, 2024
acd4840
updated training merged files
alishbaimran Jul 31, 2024
445f839
removed commented code
alishbaimran Jul 31, 2024
770ae59
Merge branch 'contrastive_phenotyping' into infection_state
alishbaimran Jul 31, 2024
05f59d7
removed uneeded code
alishbaimran Jul 31, 2024
20388ea
removed uneeded code
alishbaimran Jul 31, 2024
60f5e9c
Merge branch 'infection_state' of https://github.com/mehta-lab/VisCy …
alishbaimran Jul 31, 2024
bd49833
removed comments
alishbaimran Jul 31, 2024
d0be5dc
snake_case
alishbaimran Jul 31, 2024
5ac447b
fixed CI issues
alishbaimran Jul 31, 2024
61792a5
removed num_fovs
alishbaimran Jul 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 93 additions & 59 deletions applications/contrastive_phenotyping/predict.py
Original file line number Diff line number Diff line change
@@ -1,91 +1,125 @@
from argparse import ArgumentParser
from pathlib import Path

import numpy as np
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import TQDMProgressBar
from lightning.pytorch.strategies import DDPStrategy

from viscy.data.hcs import ContrastiveDataModule
from viscy.data.triplet import TripletDataModule, TripletDataset
from viscy.light.engine import ContrastiveModule

import os
from torch.multiprocessing import Manager
from viscy.transforms import (
NormalizeSampled,
RandAdjustContrastd,
RandAffined,
RandGaussianNoised,
RandGaussianSmoothd,
RandScaleIntensityd,
RandWeightedCropd,
)

normalizations = [
# Normalization for Phase3D using mean and std
NormalizeSampled(
keys=["Phase3D"],
level="fov_statistics",
subtrahend="mean",
divisor="std",
),
# Normalization for RFP using median and IQR
NormalizeSampled(
keys=["RFP"],
level="fov_statistics",
subtrahend="median",
divisor="iqr",
),
]

def main(hparams):
# Set paths
# this CSV defines the order in which embeddings should be processed. Currently using num_workers = 1 to keep order
top_dir = Path("/hpc/projects/intracellular_dashboard/viral-sensor/")
timesteps_csv_path = "/hpc/mydata/alishba.imran/VisCy/viscy/applications/contrastive_phenotyping/expanded_transitioning_cells_metadata.csv"
predict_base_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/6-patches/all_annotations_patch.zarr"
checkpoint_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/infection_score/updated_multiple_channels/contrastive_model-test-epoch=97-val_loss=0.00.ckpt"

# Data parameters
channels = 2
x = 200
y = 200
z_range = (28, 43)
batch_size = 12
channel_names = ["RFP", "Phase3D"]
# /hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/6-patches/expanded_final_track_timesteps.csv
# /hpc/mydata/alishba.imran/VisCy/viscy/applications/contrastive_phenotyping/uninfected_cells.csv
# /hpc/mydata/alishba.imran/VisCy/viscy/applications/contrastive_phenotyping/expanded_transitioning_cells_metadata.csv
checkpoint_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/infection_score/contrastive_model-test-epoch=09-val_loss=0.00.ckpt"

# non-rechunked data
data_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/2.1-register/registered.zarr"

# updated tracking data
tracks_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/5-finaltrack/track_labels_final.zarr"

source_channel = ["RFP", "Phase3D"]
z_range = (26, 38)
batch_size = 15 # match the number of fovs being processed such that no data is left
# set to 15 for full, 12 for infected, and 8 for uninfected

# Initialize the data module for prediction
data_module = ContrastiveDataModule(
base_path=str(predict_base_path),
channels=channels,
x=x,
y=y,
timesteps_csv_path=timesteps_csv_path,
channel_names=channel_names,
batch_size=batch_size,
z_range=z_range,
predict_base_path=predict_base_path,
analysis=True, # for self-supervised results
data_module = TripletDataModule(
data_path=data_path,
tracks_path=tracks_path,
source_channel=source_channel,
z_range=z_range,
initial_yx_patch_size=(512, 512),
final_yx_patch_size=(224, 224),
batch_size=batch_size,
num_workers=hparams.num_workers,
normalizations=normalizations,
)

data_module.setup(stage="predict")

print(f"Total prediction dataset size: {len(data_module.predict_dataset)}")

# Load the model from checkpoint
model = ContrastiveModule.load_from_checkpoint(str(checkpoint_path), predict=True)
backbone = "resnet50"
in_stack_depth = 12
stem_kernel_size = (5, 3, 3)
model = ContrastiveModule.load_from_checkpoint(
str(checkpoint_path),
predict=True,
backbone=backbone,
in_channels=len(source_channel),
in_stack_depth=in_stack_depth,
stem_kernel_size=stem_kernel_size,
tracks_path = tracks_path,
)

model.eval()
model.encoder.predict = True

# Initialize the trainer
trainer = Trainer(
accelerator="gpu",
devices=1,
num_nodes=1,
strategy=DDPStrategy(),
strategy=DDPStrategy(find_unused_parameters=False),
callbacks=[TQDMProgressBar(refresh_rate=1)],
)

# Run prediction
predictions = trainer.predict(model, datamodule=data_module)

# Collect features and projections
features_list = []
projections_list = []

for batch_idx, batch in enumerate(predictions):
features, projections = batch
features_list.append(features.cpu().numpy())
projections_list.append(projections.cpu().numpy())

all_features = np.concatenate(features_list, axis=0)
all_projections = np.concatenate(projections_list, axis=0)

# Save features and projections
# Save in sub-folder instead for the specific FOV

# for saving visualizations embeddings
base_dir = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/5-finaltrack/test_visualizations"
features_path = os.path.join(base_dir, 'B', '4', '2', 'before_projected_embeddings', 'test_epoch88_predicted_features.npy')
projections_path = os.path.join(base_dir, 'B', '4', '2', 'projected_embeddings', 'test_epoch88_predicted_projections.npy')

np.save("/hpc/mydata/alishba.imran/VisCy/viscy/applications/contrastive_phenotyping/ss1_epoch97_predicted_features.npy", all_features)
np.save("/hpc/mydata/alishba.imran/VisCy/viscy/applications/contrastive_phenotyping/ss1_epoch97_predicted_projections.npy", all_projections)

trainer.predict(model, datamodule=data_module)

# # Collect features and projections
# features_list = []
# projections_list = []

# for batch_idx, batch in enumerate(predictions):
# features, projections = batch
# features_list.append(features.cpu().numpy())
# projections_list.append(projections.cpu().numpy())
# all_features = np.concatenate(features_list, axis=0)
# all_projections = np.concatenate(projections_list, axis=0)

# # for saving visualizations embeddings
# base_dir = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/5-finaltrack/test_visualizations"
# features_path = os.path.join(base_dir, 'B', '4', '2', 'before_projected_embeddings', 'test_epoch88_predicted_features.npy')
# projections_path = os.path.join(base_dir, 'B', '4', '2', 'projected_embeddings', 'test_epoch88_predicted_projections.npy')

# np.save("/hpc/mydata/alishba.imran/VisCy/viscy/applications/contrastive_phenotyping/embeddings/resnet_uninf_rfp_epoch99_predicted_features.npy", all_features)
# np.save("/hpc/mydata/alishba.imran/VisCy/viscy/applications/contrastive_phenotyping/embeddings/resnet_uninf_rfp_epoch99_predicted_projections.npy", all_projections)

if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--backbone", type=str, default="convnext_tiny")
parser.add_argument("--backbone", type=str, default="resnet50")
parser.add_argument("--margin", type=float, default=0.5)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--schedule", type=str, default="Constant")
Expand All @@ -94,8 +128,8 @@ def main(hparams):
parser.add_argument("--max_epochs", type=int, default=100)
parser.add_argument("--accelerator", type=str, default="gpu")
parser.add_argument("--devices", type=int, default=1)
parser.add_argument("--num_nodes", type=int, default=2)
parser.add_argument("--num_nodes", type=int, default=1)
parser.add_argument("--log_every_n_steps", type=int, default=1)
parser.add_argument("--num_workers", type=int, default=15)
args = parser.parse_args()

main(args)
Loading
Loading