Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SSL: fix MLP head and remove L2 normalization #145

Merged
merged 18 commits into from
Aug 31, 2024

Conversation

ziw-liu
Copy link
Collaborator

@ziw-liu ziw-liu commented Aug 28, 2024

Fix the sequence of batchnorm and linear in the last MLP layer of the contrastive encoder model.

Remove L2 normalization before computing triplet loss. This works best when also reducing the dimension of projections.

Refactor light module into representation and translation to separate pipelining code for different tasks.

Fix #139, fix #138.

Reopened from #141 to rename base branch.

@ziw-liu ziw-liu requested a review from mattersoflight August 28, 2024 00:45
@ziw-liu ziw-liu marked this pull request as ready for review August 28, 2024 00:45
@ziw-liu ziw-liu added bug Something isn't working enhancement New feature or request breaking Breaking changes labels Aug 28, 2024
@mattersoflight mattersoflight changed the base branch from main to representation August 28, 2024 14:50
Copy link
Member

@mattersoflight mattersoflight left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ziw-liu for this refactor. Reorganization in viscy.data, viscy.representation and viscy.translation makes a lot of sense. Since contrastive models have been trained with this branch, makes sense to merge into representation branch. I am running a training, but haven't evaluated it.

TODO before merging:
@Soorya19Pradeep please test that you can train a model with this branch
@ziw-liu we need following clean-up.

  • Training and prediction scripts in applications.contrastive_phenotyping are obsolete and should be deleted.
  • There are two versions of DLMBL exercise examples/virtual_staining/dlmbl_exercise and examples/virtual_staining/img2img_translation. One of them is old and should be deleted.

After merging, @ziw-liu, please test that refactor has not broken the image translation task. I ran a few scripts (count_flops.py, network_diagram.py), but we need proper testing with a few trainings.

@Soorya19Pradeep
Copy link
Contributor

@ziw-liu , @mattersoflight , I was able to set off training as well. It has continued for a few epochs for me as well.

@mattersoflight
Copy link
Member

@ziw-liu This is ready to merge into representation. After this merge, please continue building the time sampling strategy on the representation branch. Let's discuss the design in #123.

@ziw-liu ziw-liu force-pushed the refine_projection_head branch from 62212f6 to 93dc2f7 Compare August 31, 2024 13:19
@ziw-liu ziw-liu merged commit 1f269c7 into representation Aug 31, 2024
4 checks passed
@ziw-liu ziw-liu deleted the refine_projection_head branch August 31, 2024 13:22
ziw-liu added a commit that referenced this pull request Oct 17, 2024
* Merging code related to figures (#146)

* notes on standard report

* Add code for generating figures

---------

Co-authored-by: Alishba Imran <[email protected]>

* produce a report of useful visualizations to assess the dimensionality and features learned by embeddings (#140)

* notes on standard report

* add lib of computed features

* correlates PCA with computed features

* compute for all timepoints

* compute correlation

* remove cv library usage

* remove edge detection

* convert to dataframe

* for entire well

* add std_dev feature

* fix patch size

---------

Co-authored-by: Soorya Pradeep <[email protected]>

* Remove obsolete scripts for contrastive phenotyping (#150)

* remove obsolete training and prediction scripts

* lint contrastive scripts

* SSL: fix MLP head and remove L2 normalization (#145)

* draft projection head per Update the projection head (normalization and size). #139

* reorganize comments in example fit config

* configurable stem stride and projection dimensions

* update type hint and docstring for ContrastiveEncoder

* clarify embedding_dim

* use the forward method directly for projected

* normalize projections only when fitting
the projected features saved during prediction is now *not* normalized

* remove unused logger

* refactor training code into translation and representation modules

* extract image logging functions

* use AdamW instead of Adam for contrastive learning

* inline single-use argument

* fix normalization

* fix MLP layer order

* fix output dimensions

* remove L2 normalization before computing loss

* compute rank of features and projections

* documentation

---------

Co-authored-by: Shalin Mehta <[email protected]>

* created and updated classify_feb_embeddings.py

* Module and scripts for evaluating representations (#156)

* docstring

* move scripts from contrastive_scripts to viscy/scripts

* organize files in applications/contrastive_phenotyping

* delete unused evaluation code

* more cleanup

* refactor evaluation metrics for translation task

* refactor viscy.evaluation -> viscy.translation.evaluation_metrics and viscy.representation.evaluation

* WIP: representation evaluation module

* WIP: representation eval - docstrings in numpy format

* WIP: more documentation

* refactor: feature_extractor moved to viscy.representation.evaluation

* lint

* bug fix

* refactored common computations and dataset

* add imbalance-learn dependecy to metrics

* refactor classification of embeddings

* organize viscy.representation.evaluation

* ruff

* Soorya's plotting script

* WIP: combine two versions of plot_embeddings.py

* simplify representation.viscy.evaluation - move LCA to its own module

* refactor of viscy.representation.evaluation

* refactored and tested PCA and UMAP plots

---------

Co-authored-by: Soorya Pradeep <[email protected]>

* delete duplicate file

* lint

* fix import paths

* rename translation tests

* rename translation metrics

* Sample positive and negative samples with a time offset for the triplet contrastive task (#154)

* wip: sample positive and negative samples from another time point

* configure time interval in triplet data module

* vectorized anchor filtering

* conditional augmentation for anchor
anchor is augmented if the positive is another time point

* example training script for the CTC dataset
this is optimized to run on MPS

* add example CTC prediction config for MPS

* add fig for mitosis

* add script to save image patches

* add save patches as npy

* save figure at 300dpi

* Linear probing (#160)

* refactor linear probing with lightning

* test convenience function

* always convert to long before onehot

* use onehot only during training

* supply trainer through argument to avoid wrapping

* only log per epoch

* example script for linear probing

* add comment about loss curve

* fix sample filtering order for select tracks

* add script to visualize integrated gradients

* plot integrated gradients over time

* Use sklearn's logistic regression for linear probing (#169)

* use binary logistic regression to initialize the linear layer

* plot integrated gradients from a binary classifier

* add cmap to 'visual' requirements

* move model assembling to lca

* rename init argument

* disable feature scaling

* update test and evaluation scripts to use new API

* add docstrings to LCA

* Tweak attribution visualization (#170)

* add maplotlib style sheet for figure making

* add cell division attribution

* add matplotlib style sheet

* move attribution computation to lca

* tweak contrast limits and text

* add captum to optional dependencies

* move attribution function to a method of the classifier

* add script to show organelle dynamics

* add occlusion attribution

* more generic save path

* add uninfected cell

* tweak subplot spacing

* UMAP line plot to assess temporal smoothness in features space (#176)

* add maplotlib style sheet for figure making

* add cell division attribution

* add matplotlib style sheet

* move attribution computation to lca

* tweak contrast limits and text

* add captum to optional dependencies

* move attribution function to a method of the classifier

* add script to show organelle dynamics

* add occlusion attribution

* more generic save path

* add uninfected cell

* tweak subplot spacing

* lower case titles

* reduce UMAP components to 2 and add indices

* add script to make the bridge gaps figure

* fixed import error

* formatted with black

* reduce to single arrow on plot

* remove reduntant script

* Fixes on correlation of PCA and UMAP components to computed_feature script (#159)

* reduce initial patch size

* add radial profiling

* add function descriptions

* add umap correlation

* add def comments

* change umap for all data

* add script for 1 chan

* add p-value analysis

* add PCA analysis

* remove duplicate script

* Refactor and format code

* Format code

* Removed umap correlation

* note for future refactor

---------

Co-authored-by: Ziwen Liu <[email protected]>

* updated eval module & cosine sim figures (#168)

* updated files

* format fixed for tests

* updated scripts

* umap dist code

* bug fixes and linting

* logistic regression script

* add infection figure script

* Add script for generating infection figure and perform prediction on the June dataset

* Format code

* Black format evaluation module and fix import in figure_cell_infection script

* Refactor scatterplot colors and markers

* Calculate model accuracy

* Add script for appendix video

* formatted code

* updated displacement funcs for full embeddings

* script for displacement computation

* fix style

* fix docstring format

---------

Co-authored-by: Shalin Mehta <[email protected]>
Co-authored-by: Soorya Pradeep <[email protected]>
Co-authored-by: Ziwen Liu <[email protected]>

* Fixup representation (#180)

* fix docstrings and type hint for the ContrastiveEncoder

* refactor the representation evaluation module into submodules

* move shared image logging into utils

* fix line end

* fix import paths in example notebooks

* Unified CLI entry point (#182)

* remove obsolete metrics script for translation

* move cellpose annotation script

* consolidate CLI documentation

* remove old CLI help

* move translation CLI to its own module

* move contrastive CLI to its own module

* remove old CLI module

* remove global entry script

* share trainer class between tasks

* move cli from init to main

* inherit base CLI class for tasks

* improve type hint and docstring

* restore global CLI entry point

* special case subclass mode for preprocessing

* remove separate entry points

* add CLI description message

* make the setup function private

* fix subclass mode detection

* remove unused arguments from custom subcommands

* use generic path in example

* fix docstring style

* update virtual staining example configs

* update CTC SSL example configs

* update infection SSL example configs

* Remove outdated comment

* updating the dlmbl notebooks

* updating dependendencies to allow viscy>0.2 in examples

* updating phase contrast demo notebook.

* updating references to main

* Store UMAP embeddings in SSL predictions (#184)

* extract function for computing umap

* specific return type for predict step

* write umap in prediction

* raise log level for umap computation

* fix key conversion

* Add representation section to readme (#186)

* draft readme

* direct link dynaCLR schematic

* add DynaCLR schemetic figure

* add static schematic and link to video

---------

Co-authored-by: Ziwen Liu <[email protected]>
Co-authored-by: Ziwen Liu <[email protected]>

* fix link syntax in readme

---------

Co-authored-by: Shalin Mehta <[email protected]>
Co-authored-by: Alishba Imran <[email protected]>
Co-authored-by: Soorya Pradeep <[email protected]>
Co-authored-by: Alishba Imran <[email protected]>
Co-authored-by: Soorya19Pradeep <[email protected]>
Co-authored-by: Eduardo Hirata-Miyasaki <[email protected]>
ziw-liu added a commit that referenced this pull request Oct 21, 2024
* Merging code related to figures (#146)

* notes on standard report

* Add code for generating figures

---------

Co-authored-by: Alishba Imran <[email protected]>

* produce a report of useful visualizations to assess the dimensionality and features learned by embeddings (#140)

* notes on standard report

* add lib of computed features

* correlates PCA with computed features

* compute for all timepoints

* compute correlation

* remove cv library usage

* remove edge detection

* convert to dataframe

* for entire well

* add std_dev feature

* fix patch size

---------

Co-authored-by: Soorya Pradeep <[email protected]>

* Remove obsolete scripts for contrastive phenotyping (#150)

* remove obsolete training and prediction scripts

* lint contrastive scripts

* SSL: fix MLP head and remove L2 normalization (#145)

* draft projection head per Update the projection head (normalization and size). #139

* reorganize comments in example fit config

* configurable stem stride and projection dimensions

* update type hint and docstring for ContrastiveEncoder

* clarify embedding_dim

* use the forward method directly for projected

* normalize projections only when fitting
the projected features saved during prediction is now *not* normalized

* remove unused logger

* refactor training code into translation and representation modules

* extract image logging functions

* use AdamW instead of Adam for contrastive learning

* inline single-use argument

* fix normalization

* fix MLP layer order

* fix output dimensions

* remove L2 normalization before computing loss

* compute rank of features and projections

* documentation

---------

Co-authored-by: Shalin Mehta <[email protected]>

* created and updated classify_feb_embeddings.py

* Module and scripts for evaluating representations (#156)

* docstring

* move scripts from contrastive_scripts to viscy/scripts

* organize files in applications/contrastive_phenotyping

* delete unused evaluation code

* more cleanup

* refactor evaluation metrics for translation task

* refactor viscy.evaluation -> viscy.translation.evaluation_metrics and viscy.representation.evaluation

* WIP: representation evaluation module

* WIP: representation eval - docstrings in numpy format

* WIP: more documentation

* refactor: feature_extractor moved to viscy.representation.evaluation

* lint

* bug fix

* refactored common computations and dataset

* add imbalance-learn dependecy to metrics

* refactor classification of embeddings

* organize viscy.representation.evaluation

* ruff

* Soorya's plotting script

* WIP: combine two versions of plot_embeddings.py

* simplify representation.viscy.evaluation - move LCA to its own module

* refactor of viscy.representation.evaluation

* refactored and tested PCA and UMAP plots

---------

Co-authored-by: Soorya Pradeep <[email protected]>

* delete duplicate file

* lint

* fix import paths

* rename translation tests

* rename translation metrics

* Sample positive and negative samples with a time offset for the triplet contrastive task (#154)

* wip: sample positive and negative samples from another time point

* configure time interval in triplet data module

* vectorized anchor filtering

* conditional augmentation for anchor
anchor is augmented if the positive is another time point

* example training script for the CTC dataset
this is optimized to run on MPS

* add example CTC prediction config for MPS

* add fig for mitosis

* add script to save image patches

* add save patches as npy

* save figure at 300dpi

* Linear probing (#160)

* refactor linear probing with lightning

* test convenience function

* always convert to long before onehot

* use onehot only during training

* supply trainer through argument to avoid wrapping

* only log per epoch

* example script for linear probing

* add comment about loss curve

* fix sample filtering order for select tracks

* add script to visualize integrated gradients

* plot integrated gradients over time

* Use sklearn's logistic regression for linear probing (#169)

* use binary logistic regression to initialize the linear layer

* plot integrated gradients from a binary classifier

* add cmap to 'visual' requirements

* move model assembling to lca

* rename init argument

* disable feature scaling

* update test and evaluation scripts to use new API

* add docstrings to LCA

* Tweak attribution visualization (#170)

* add maplotlib style sheet for figure making

* add cell division attribution

* add matplotlib style sheet

* move attribution computation to lca

* tweak contrast limits and text

* add captum to optional dependencies

* move attribution function to a method of the classifier

* add script to show organelle dynamics

* add occlusion attribution

* more generic save path

* add uninfected cell

* tweak subplot spacing

* UMAP line plot to assess temporal smoothness in features space (#176)

* add maplotlib style sheet for figure making

* add cell division attribution

* add matplotlib style sheet

* move attribution computation to lca

* tweak contrast limits and text

* add captum to optional dependencies

* move attribution function to a method of the classifier

* add script to show organelle dynamics

* add occlusion attribution

* more generic save path

* add uninfected cell

* tweak subplot spacing

* lower case titles

* reduce UMAP components to 2 and add indices

* add script to make the bridge gaps figure

* fixed import error

* formatted with black

* reduce to single arrow on plot

* remove reduntant script

* Fixes on correlation of PCA and UMAP components to computed_feature script (#159)

* reduce initial patch size

* add radial profiling

* add function descriptions

* add umap correlation

* add def comments

* change umap for all data

* add script for 1 chan

* add p-value analysis

* add PCA analysis

* remove duplicate script

* Refactor and format code

* Format code

* Removed umap correlation

* note for future refactor

---------

Co-authored-by: Ziwen Liu <[email protected]>

* updated eval module & cosine sim figures (#168)

* updated files

* format fixed for tests

* updated scripts

* umap dist code

* bug fixes and linting

* logistic regression script

* add infection figure script

* Add script for generating infection figure and perform prediction on the June dataset

* Format code

* Black format evaluation module and fix import in figure_cell_infection script

* Refactor scatterplot colors and markers

* Calculate model accuracy

* Add script for appendix video

* formatted code

* updated displacement funcs for full embeddings

* script for displacement computation

* fix style

* fix docstring format

---------

Co-authored-by: Shalin Mehta <[email protected]>
Co-authored-by: Soorya Pradeep <[email protected]>
Co-authored-by: Ziwen Liu <[email protected]>

* Fixup representation (#180)

* fix docstrings and type hint for the ContrastiveEncoder

* refactor the representation evaluation module into submodules

* move shared image logging into utils

* fix line end

* fix import paths in example notebooks

* Unified CLI entry point (#182)

* remove obsolete metrics script for translation

* move cellpose annotation script

* consolidate CLI documentation

* remove old CLI help

* move translation CLI to its own module

* move contrastive CLI to its own module

* remove old CLI module

* remove global entry script

* share trainer class between tasks

* move cli from init to main

* inherit base CLI class for tasks

* improve type hint and docstring

* restore global CLI entry point

* special case subclass mode for preprocessing

* remove separate entry points

* add CLI description message

* make the setup function private

* fix subclass mode detection

* remove unused arguments from custom subcommands

* use generic path in example

* fix docstring style

* update virtual staining example configs

* update CTC SSL example configs

* update infection SSL example configs

* Remove outdated comment

* updating the dlmbl notebooks

* updating dependendencies to allow viscy>0.2 in examples

* updating phase contrast demo notebook.

* updating references to main

* Store UMAP embeddings in SSL predictions (#184)

* extract function for computing umap

* specific return type for predict step

* write umap in prediction

* raise log level for umap computation

* fix key conversion

* Add representation section to readme (#186)

* draft readme

* direct link dynaCLR schematic

* add DynaCLR schemetic figure

* add static schematic and link to video

---------

Co-authored-by: Ziwen Liu <[email protected]>
Co-authored-by: Ziwen Liu <[email protected]>

* fix link syntax in readme

---------

Co-authored-by: Shalin Mehta <[email protected]>
Co-authored-by: Alishba Imran <[email protected]>
Co-authored-by: Soorya Pradeep <[email protected]>
Co-authored-by: Alishba Imran <[email protected]>
Co-authored-by: Soorya19Pradeep <[email protected]>
Co-authored-by: Eduardo Hirata-Miyasaki <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
breaking Breaking changes bug Something isn't working enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Update the projection head (normalization and size). Hard-coded stem stride in ContrastiveEncoder
3 participants