Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Single-cell representation learning #153

Merged
merged 37 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
37b07a1
Merging code related to figures (#146)
mattersoflight Aug 28, 2024
8ebe86c
produce a report of useful visualizations to assess the dimensionalit…
mattersoflight Aug 28, 2024
6e7d61f
Remove obsolete scripts for contrastive phenotyping (#150)
ziw-liu Aug 30, 2024
1f269c7
SSL: fix MLP head and remove L2 normalization (#145)
ziw-liu Aug 31, 2024
4bfbf8b
created and updated classify_feb_embeddings.py
alishbaimran Sep 8, 2024
634b955
Module and scripts for evaluating representations (#156)
mattersoflight Sep 10, 2024
9639961
delete duplicate file
mattersoflight Sep 10, 2024
2f85eec
Merge branch 'main' into representation
ziw-liu Sep 10, 2024
083897c
lint
ziw-liu Sep 10, 2024
4521afc
fix import paths
ziw-liu Sep 10, 2024
19c4559
rename translation tests
ziw-liu Sep 10, 2024
63d9f5a
rename translation metrics
ziw-liu Sep 10, 2024
ee826b5
Sample positive and negative samples with a time offset for the tripl…
ziw-liu Sep 11, 2024
e2175b4
add fig for mitosis
Soorya19Pradeep Sep 17, 2024
2a6cd20
add script to save image patches
Soorya19Pradeep Sep 17, 2024
767b12c
add save patches as npy
Soorya19Pradeep Sep 18, 2024
2759584
save figure at 300dpi
Soorya19Pradeep Sep 18, 2024
74fa3d7
Linear probing (#160)
ziw-liu Sep 20, 2024
10219d3
Tweak attribution visualization (#170)
ziw-liu Sep 26, 2024
42a0cb5
UMAP line plot to assess temporal smoothness in features space (#176)
ziw-liu Sep 27, 2024
d5017ab
fixed import error
Soorya19Pradeep Sep 25, 2024
a58ab83
formatted with black
Soorya19Pradeep Sep 25, 2024
df9a5cd
reduce to single arrow on plot
Soorya19Pradeep Sep 27, 2024
17a2e48
remove reduntant script
Soorya19Pradeep Sep 27, 2024
ad74176
Fixes on correlation of PCA and UMAP components to computed_feature s…
Soorya19Pradeep Sep 27, 2024
582952d
updated eval module & cosine sim figures (#168)
alishbaimran Sep 27, 2024
2960633
Fixup representation (#180)
ziw-liu Oct 9, 2024
b9d159b
Unified CLI entry point (#182)
ziw-liu Oct 16, 2024
25b10e1
Remove outdated comment
ziw-liu Oct 16, 2024
9c86140
updating the dlmbl notebooks
edyoshikun Oct 17, 2024
43dee17
updating dependendencies to allow viscy>0.2 in examples
edyoshikun Oct 17, 2024
3b9063c
updating phase contrast demo notebook.
edyoshikun Oct 17, 2024
ef82427
updating references to main
edyoshikun Oct 17, 2024
beb1c49
Store UMAP embeddings in SSL predictions (#184)
ziw-liu Oct 17, 2024
8b0c6e7
Add representation section to readme (#186)
mattersoflight Oct 17, 2024
131f996
Merge branch 'main' into representation
ziw-liu Oct 17, 2024
15b9386
fix link syntax in readme
ziw-liu Oct 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,36 @@ The robust virtual staining models (i.e *VSCyto2D*, *VSCyto3D*, *VSNeuromast*),

A full illustration of the virtual staining pipeline can be found [here](https://github.com/mehta-lab/VisCy/blob/dde3e27482e58a30f7c202e56d89378031180c75/docs/virtual_staining.md).

## Image representation learning

We are currently developing self-supervised representation learning to map cell state dynamics in response to perturbations,
with focus on cell and organelle remodeling due to viral infection.

See our recent work on temporally regularized contrastive sampling method
for representation learning on [arXiv](https://arxiv.org/abs/2410.11281).

<details>
<summary> Pradeep, Imran, Liu et al., 2024 </summary>

<pre><code>
@misc{pradeep_contrastive_2024,
title={Contrastive learning of cell state dynamics in response to perturbations},
author={Soorya Pradeep and Alishba Imran and Ziwen Liu and Taylla Milena Theodoro and Eduardo Hirata-Miyasaki and Ivan Ivanov and Madhura Bhave and Sudip Khadka and Hunter Woosley and Carolina Arias and Shalin B. Mehta},
year={2024},
eprint={2410.11281},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2410.11281},
}
</code></pre>
</details>

### Workflow demo

[Exploration of learned embeddings with napari-iohub](https://drive.google.com/file/d/16WSoTvXJ-siLb7iyOueOag_cKn9Iwckc/view?usp=drive_link)

![DynaCLR](https://github.com/mehta-lab/VisCy/blob/9eaab7eca50d684d8a473ad9da089aeab0e8f6a0/docs/figures/dynaCLR_schematic.png?raw=true)

## Installation

1. We recommend using a new Conda/virtual environment.
Expand Down Expand Up @@ -148,4 +178,3 @@ for reading and writing data in [OME-Zarr](https://www.nature.com/articles/s4159
The full functionality is tested on Linux `x86_64` with NVIDIA Ampere GPUs (CUDA 12.4).
Some features (e.g. mixed precision and distributed training) may not be available with other setups,
see [PyTorch documentation](https://pytorch.org) for details.

115 changes: 0 additions & 115 deletions applications/contrastive_phenotyping/contrastive_cli/fit.yml

This file was deleted.

117 changes: 117 additions & 0 deletions applications/contrastive_phenotyping/contrastive_cli/fit_ctc_mps.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# See help here on how to configure hyper-parameters with config files:
# https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_advanced.html
seed_everything: 42
trainer:
accelerator: gpu
strategy: auto
devices: 1
num_nodes: 1
precision: 32-true
logger:
class_path: lightning.pytorch.loggers.TensorBoardLogger
# Nesting the logger config like this is equivalent to
# supplying the following argument to `lightning.pytorch.Trainer`:
# logger=TensorBoardLogger(
# "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations",
# log_graph=True,
# version="vanilla",
# )
init_args:
save_dir: /Users/ziwen.liu/Projects/test-time
# this is the name of the experiment.
# The logs will be saved in `save_dir/lightning_logs/version`
version: time_interval_1
log_graph: True
callbacks:
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: step
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: loss/val
every_n_epochs: 1
save_top_k: 4
save_last: true
fast_dev_run: false
max_epochs: 100
log_every_n_steps: 10
enable_checkpointing: true
inference_mode: true
use_distributed_sampler: true
# synchronize batchnorm parameters across multiple GPUs.
# important for contrastive learning to normalize the tensors across the whole batch.
sync_batchnorm: true
model:
class_path: viscy.representation.engine.ContrastiveModule
init_args:
encoder:
class_path: viscy.representation.contrastive.ContrastiveEncoder
init_args:
backbone: convnext_tiny
in_channels: 1
in_stack_depth: 1
stem_kernel_size: [1, 4, 4]
stem_stride: [1, 4, 4]
embedding_dim: 768
projection_dim: 32
drop_path_rate: 0.0
loss_function:
class_path: torch.nn.TripletMarginLoss
init_args:
margin: 0.5
lr: 0.0002
log_batches_per_epoch: 3
log_samples_per_batch: 2
example_input_array_shape: [1, 1, 1, 128, 128]
data:
class_path: viscy.data.triplet.TripletDataModule
init_args:
data_path: /Users/ziwen.liu/Downloads/Hela_CTC.zarr
tracks_path: /Users/ziwen.liu/Downloads/Hela_CTC.zarr
source_channel:
- DIC
z_range: [0, 1]
batch_size: 16
num_workers: 4
initial_yx_patch_size: [256, 256]
final_yx_patch_size: [128, 128]
time_interval: 1
normalizations:
- class_path: viscy.transforms.NormalizeSampled
init_args:
keys: [DIC]
level: fov_statistics
subtrahend: mean
divisor: std
augmentations:
- class_path: viscy.transforms.RandAffined
init_args:
keys: [DIC]
prob: 0.8
scale_range: [0, 0.2, 0.2]
rotate_range: [3.14, 0.0, 0.0]
shear_range: [0.0, 0.01, 0.01]
padding_mode: zeros
- class_path: viscy.transforms.RandAdjustContrastd
init_args:
keys: [DIC]
prob: 0.5
gamma: [0.8, 1.2]
- class_path: viscy.transforms.RandScaleIntensityd
init_args:
keys: [DIC]
prob: 0.5
factors: 0.5
- class_path: viscy.transforms.RandGaussianSmoothd
init_args:
keys: [DIC]
prob: 0.5
sigma_x: [0.25, 0.75]
sigma_y: [0.25, 0.75]
sigma_z: [0.0, 0.0]
- class_path: viscy.transforms.RandGaussianNoised
init_args:
keys: [DIC]
prob: 0.5
mean: 0.0
std: 0.2
50 changes: 0 additions & 50 deletions applications/contrastive_phenotyping/contrastive_cli/predict.yml

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
seed_everything: 42
trainer:
accelerator: gpu
strategy: auto
devices: auto
num_nodes: 1
precision: 32-true
callbacks:
- class_path: viscy.representation.embedding_writer.EmbeddingWriter
init_args:
output_path: /Users/ziwen.liu/Projects/test-time/predict/time_interval_1.zarr
inference_mode: true
model:
class_path: viscy.representation.engine.ContrastiveModule
init_args:
encoder:
class_path: viscy.representation.contrastive.ContrastiveEncoder
init_args:
backbone: convnext_tiny
in_channels: 1
in_stack_depth: 1
stem_kernel_size: [1, 4, 4]
stem_stride: [1, 4, 4]
embedding_dim: 768
projection_dim: 32
drop_path_rate: 0.0
example_input_array_shape: [1, 1, 1, 128, 128]
data:
class_path: viscy.data.triplet.TripletDataModule
init_args:
data_path: /Users/ziwen.liu/Downloads/Hela_CTC.zarr
tracks_path: /Users/ziwen.liu/Downloads/Hela_CTC.zarr
source_channel: DIC
z_range: [0, 1]
batch_size: 16
num_workers: 4
initial_yx_patch_size: [128, 128]
final_yx_patch_size: [128, 128]
time_interval: 1
normalizations:
- class_path: viscy.transforms.NormalizeSampled
init_args:
keys: [DIC]
level: fov_statistics
subtrahend: mean
divisor: std
return_predictions: false
ckpt_path: /Users/ziwen.liu/Projects/test-time/lightning_logs/time_interval_1/checkpoints/last.ckpt
Loading
Loading