Skip to content

Commit

Permalink
2.1D upscale decoder (#37)
Browse files Browse the repository at this point in the history
* pixelshuffle decoder

* Allow sampling multiple patches from the same stack (#35)

* sample multiple patches from one stack

* do not use type annotations from future
it breaks jsonargparse

* fix channel stacking for non-training samples

* remove batch size from model
the metrics will be automatically reduced by lightning

* add flop counting script

* 3d ouput head

* add datamodule target dims mode

* remove unused argument and configure drop path

* move architecture argument to model level

* DLMBL 2023 excercise (#36)

* updated intro and paths

* updated figures, tested data loader

* setup.sh fetches correct dataset

* finalized the exercise outline

* semi-final exercise

* parts 1 and 2 tested, part 3 outline ready

* clearer variables, train with larger patch size

* fix typo

* clarify variable names

* trying to log graph

* match example size with training

* reuse globals

* fix reference

* log sample images from the first batch

* wider model

* low LR solution

* fix path

* seed everything

* fix test dataset without masks

* metrics solution
this needs a new test dataset

* fetch test data, compute metrics

* byass cellpose import error due to numpy version conflicts

* final exercise

* moved files

* fixed formatting - ready for review

* viscy -> VisCy (#34) (#39)

Introducing capitalization to highlight vision and single-cell aspects of the pipeline.

* trying to log graph

* log graph

* black

---------

Co-authored-by: Shalin Mehta <[email protected]>
Co-authored-by: Shalin Mehta <[email protected]>

* fix channel dimension size for example input

#40

* fix argument linking

* 3D prediction writer
sliding windows are blended with uniform average

* update network diagram

* upgrade flop counting

* shallow 3D (2.5D) SSIM metric

* ms-ssim

* mixed loss

* fix arguments

* fix inheritance

* fix weight checking

* squeeze metric

* aggregate metrics

* optinal clamp to stabilize gradient of MS-SSIM

* fix calling

* increase epsilon

* disable autocast for loss

* restore relu for clamping

* plot all architectures with network_diagram script

---------
Co-authored-by: Shalin Mehta <[email protected]>
  • Loading branch information
ziw-liu and mattersoflight authored Aug 30, 2023
1 parent 933013c commit b4ec13c
Show file tree
Hide file tree
Showing 19 changed files with 3,131 additions and 410 deletions.
Binary file added docs/figures/phase_to_nuclei_membrane.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
41 changes: 41 additions & 0 deletions examples/demo_dlmbl/convert-solution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import argparse
from traitlets.config import Config
import nbformat as nbf
from nbconvert.preprocessors import TagRemovePreprocessor, ClearOutputPreprocessor
from nbconvert.exporters import NotebookExporter


def get_arg_parser():
parser = argparse.ArgumentParser()

parser.add_argument('input_file')
parser.add_argument('output_file')

return parser


def convert(input_file, output_file):
c = Config()
c.TagRemovePreprocessor.remove_cell_tags = ("solution",)
c.TagRemovePreprocessor.enabled = True
c.ClearOutputPreprocesser.enabled = True
c.NotebookExporter.preprocessors = [
"nbconvert.preprocessors.TagRemovePreprocessor",
"nbconvert.preprocessors.ClearOutputPreprocessor"
]

exporter = NotebookExporter(config=c)
exporter.register_preprocessor(TagRemovePreprocessor(config=c), True)
exporter.register_preprocessor(ClearOutputPreprocessor(), True)

output = NotebookExporter(config=c).from_filename(input_file)
with open(output_file, 'w') as f:
f.write(output[0])


if __name__ == "__main__":
parser = get_arg_parser()
args = parser.parse_args()

convert(args.input_file, args.output_file)
print(f'Converted {args.input_file} to {args.output_file}')
96 changes: 96 additions & 0 deletions examples/demo_dlmbl/debug_log_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@

# %%
# %% Imports and paths

from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torchview
import torchvision
from iohub import open_ome_zarr
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import CSVLogger

# pytorch lightning wrapper for Tensorboard.
from tensorboard import notebook # for viewing tensorboard in notebook
from torch.utils.tensorboard import SummaryWriter # for logging to tensorboard

# HCSDataModule makes it easy to load data during training.
from viscy.light.data import HCSDataModule

# Trainer class and UNet.
from viscy.light.engine import VSTrainer, VSUNet

seed_everything(42, workers=True)

# Paths to data and log directory
data_path = Path(
Path("~/data/04_image_translation/HEK_nuclei_membrane_pyramid.zarr/")
).expanduser()

log_dir = Path("~/data/04_image_translation/logs/").expanduser()

# Create log directory if needed, and launch tensorboard
log_dir.mkdir(parents=True, exist_ok=True)

# fmt: off
%reload_ext tensorboard
%tensorboard --logdir {log_dir} --port 6007 --bind_all
# fmt: on

# %% The entire training loop is contained in this cell.

GPU_ID = 0
BATCH_SIZE = 10
YX_PATCH_SIZE = (512, 512)


# Dictionary that specifies key parameters of the model.
phase2fluor_config = {
"architecture": "2D",
"num_filters": [24, 48, 96, 192, 384],
"in_channels": 1,
"out_channels": 2,
"residual": True,
"dropout": 0.1, # dropout randomly turns off weights to avoid overfitting of the model to data.
"task": "reg", # reg = regression task.
}

phase2fluor_model = VSUNet(
model_config=phase2fluor_config.copy(),
batch_size=BATCH_SIZE,
loss_function=torch.nn.functional.l1_loss,
schedule="WarmupCosine",
log_num_samples=10, # Number of samples from each batch to log to tensorboard.
example_input_yx_shape=YX_PATCH_SIZE,
)

# Reinitialize the data module.
phase2fluor_data = HCSDataModule(
data_path,
source_channel="Phase",
target_channel=["Nuclei", "Membrane"],
z_window_size=1,
split_ratio=0.8,
batch_size=BATCH_SIZE,
num_workers=8,
architecture="2D",
yx_patch_size=YX_PATCH_SIZE,
augment=True,
)
phase2fluor_data.setup("fit")


# Train for 3 epochs to see if you can log graph.
trainer = VSTrainer(accelerator="gpu", devices=[GPU_ID], max_epochs=3, default_root_dir=log_dir)

# trainer class takes the model and the data module as inputs.
trainer.fit(phase2fluor_model, datamodule=phase2fluor_data)

# %% Is exmple_input_array present?
print(f'{phase2fluor_model.example_input_array.shape},{phase2fluor_model.example_input_array.dtype}')
trainer.logger.log_graph(phase2fluor_model, phase2fluor_model.example_input_array)
# %%
Loading

0 comments on commit b4ec13c

Please sign in to comment.