Skip to content

Commit

Permalink
Single-cell phenotyping with contrastive learning (#113)
Browse files Browse the repository at this point in the history
* first draft of contrastive learning model

* fixed stem and projection head, drafted lightning module

* Contrastive_dataloader (#99)

* initial dataloader.py

* Update dataloader_test.py

* Update dataloader_test.py

* Update dataloader_test.py

* Update dataloader_test.py

* rename training script

* move contrastive network to viscy.representation module

* Update hcs.py

* refactored class names

* correct imports

* cleaner names for model arch and module

* new imports

* Fixed epoch loss logging and WandB integration in ContrastiveModule

* updated training_script.py

* Update hcs.py

* contrastive.py

* engine.py

* script to test data i/o speed from different filesystems

* moved applications folder to viscy.applications so that pip install -e . works.

* add resnet50 to ContrastiveEncoder

* rename training_script.py to training_script_resnet.py

* test dataloader on lustre and vast

* move training_script_resnet to viscy.applications so that `pip install -e .` works

* refined the tests for contrastive dataloader

* sbatch script for dataloader

* delete redundant module

* nits: updated the model construction of contrastive resnet encoder.

* Updated training script, HCS data handling, engine, and contrastive representation

* Fix normalization, visualization issues, logging and multi-channel prediction

* updated training and prediction

* update training and prediction script

* formatting

* combine the application directories

* lint

* replace notebook with script

* format script

* rename scripts conflicting with pytest

* lint application scripts

* do not filter all warnings

* log instead of print

* split data modules by task

* clean up imports

* update typing

* use pathlib

* remove redundant file

* updated predict.py

* better typing

* wip: triplet dataset

* avoid forward ref
this might increase code analysis time a tiny bit
but should not have any effect at runtime

* check that z range is valid
and fix indexing

* clean up and explain random sampling

* sample dict instead of tuple and include track index

* take out generic HCS methods for reuse

* implement TripletDataModule

* use new batch type in engine

* better typing

* read normalization metadata

* docstring for data module

* drop normalization metadata after transformation

* remove unused import

* fix initial crop size

* Infection state (#118)

* updated prediction code

* updated predict code

* updated code

* fixed the stem and forward pass (#115)

* fixed the stem and forward pass

* update forward calls to encoder

* self.encoder -> self.model

* nits

* l2 normalize projections

* black compliance

* black compliance

* WIP: Save progress before merging

* updated contrastive.py

* stem update

* updated predict code

* Delete viscy/applications/contrastive_phenotyping/PCA.ipynb

* pushing dataloader test updated

* pca deleted

* training and dataloader test

* updated structure

* deleted files

* updated training merged files

* removed commented code

* removed uneeded code

* removed uneeded code

* removed comments

* snake_case

* fixed CI issues

* removed num_fovs

---------

Co-authored-by: Shalin Mehta <[email protected]>

---------

Co-authored-by: Shalin Mehta <[email protected]>
Co-authored-by: Alishba Imran <[email protected]>
Co-authored-by: Alishba Imran <[email protected]>
Co-authored-by: Alishba Imran <[email protected]>
Co-authored-by: Duo Peng <[email protected]>
  • Loading branch information
6 people authored and edyoshikun committed Aug 7, 2024
1 parent f1f0288 commit 3184b76
Show file tree
Hide file tree
Showing 13 changed files with 2,049 additions and 63 deletions.
106 changes: 106 additions & 0 deletions applications/contrastive_phenotyping/graphs_ConvNeXt_ResNet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# %% Imports and paths.
import timm
import torch
import torchview

from viscy.light.engine import ContrastiveModule
from viscy.representation.contrastive import ContrastiveEncoder, UNeXt2Stem

# %load_ext autoreload
# %autoreload 2
# %% Initialize the model and log the graph.
contra_model = ContrastiveEncoder(
backbone="convnext_tiny"
) # other options: convnext_tiny resnet50
print(contra_model)
model_graph = torchview.draw_graph(
contra_model,
torch.randn(1, 2, 15, 224, 224),
depth=3, # adjust depth to zoom in.
device="cpu",
)
# Print the image of the model.
model_graph.resize_graph(scale=2.5)
model_graph.visual_graph

# %% Initialize a resent50 model and log the graph.
contra_model = ContrastiveEncoder(
backbone="resnet50", in_stack_depth=16, stem_kernel_size=(4, 3, 3)
) # note that the resnet first layer takes 64 channels (so we can't have multiples of 3)
print(contra_model)
model_graph = torchview.draw_graph(
contra_model,
torch.randn(1, 2, 16, 224, 224),
depth=3, # adjust depth to zoom in.
device="cpu",
)
# Print the image of the model.
model_graph.resize_graph(scale=2.5)
model_graph.visual_graph


# %% Initiatlize the lightning module and view the model.
contrastive_module = ContrastiveModule()
print(contrastive_module.encoder)

# %%
model_graph = torchview.draw_graph(
contrastive_module.encoder,
torch.randn(1, 2, 15, 200, 200),
depth=3, # adjust depth to zoom in.
device="cpu",
)
# Print the image of the model.
model_graph.visual_graph

# %% Playground

available_models = timm.list_models(pretrained=True)

stem = UNeXt2Stem(
in_channels=2, out_channels=96, kernel_size=(5, 2, 2), in_stack_depth=15
)
print(stem)
stem_graph = torchview.draw_graph(
stem,
torch.randn(1, 2, 15, 256, 256),
depth=2, # adjust depth to zoom in.
device="cpu",
)
# Print the image of the model.
stem_graph.visual_graph
# %%
encoder = timm.create_model(
"convnext_tiny",
pretrained=True,
features_only=False,
num_classes=200,
)

print(encoder)

# %%

encoder.stem = stem

model_graph = torchview.draw_graph(
encoder,
torch.randn(1, 2, 15, 256, 256),
depth=2, # adjust depth to zoom in.
device="cpu",
)
# Print the image of the model.
model_graph.visual_graph
# %%
encoder.stem = torch.nn.Identity()

encoder_graph = torchview.draw_graph(
encoder,
torch.randn(1, 96, 128, 128),
depth=2, # adjust depth to zoom in.
device="cpu",
)
# Print the image of the model.
encoder_graph.visual_graph

# %%
Loading

0 comments on commit 3184b76

Please sign in to comment.