diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 00000000..6694b2a5 --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,14 @@ +coverage: + precision: 2 + round: down + range: "70...100" + + status: + project: yes + patch: no + changes: no + +comment: + layout: "header, reach, diff, flags, files, footer" + behavior: default + require_changes: no diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml new file mode 100644 index 00000000..2900ff96 --- /dev/null +++ b/.github/workflows/pr.yml @@ -0,0 +1,35 @@ +name: Lint and Test + +on: pull_request + +jobs: + lint: + name: Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: psf/black@stable + with: + src: "viscy" + options: "--check --verbose" + - uses: chartboost/ruff-action@v1 + with: + src: "viscy" + + test: + needs: [lint] + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9", "3.10"] + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install ".[metrics,dev]" + - name: Test with pytest + run: pytest -v diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..ac64f851 --- /dev/null +++ b/.gitignore @@ -0,0 +1,40 @@ +.idea +.DS_Store +__pycache__/ +.ipynb_checkpoints/ +.vscode + +# written by setuptools_scm +*/_version.py + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..013f5fdd --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,19 @@ +# Contributing to viscy + +## Development installation + +Clone or fork the repository, +then make an editable installation with optional dependencies: + +```sh +# in project root directory +pip install ".[dev,metrics]" +``` + +## Testing + +Run tests with `pytest`: + +```sh +pytest +``` diff --git a/LICENSE b/LICENSE index 5de97048..4520d7a9 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ BSD 3-Clause License -Copyright (c) 2023, Computational Microscopy Platform (Mehta Lab), CZ Biohub +Copyright (c) 2023, CZ Biohub SF Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/README.md b/README.md new file mode 100644 index 00000000..a0a02b8b --- /dev/null +++ b/README.md @@ -0,0 +1,41 @@ +# viscy + +viscy is a machine learning toolkit to solve computer vision problems +in high-throughput imaging of cells. + +## Installation + +(Optional) create a new virtual/Conda environment. + +Clone this repository and install viscy: + +```sh +git clone https://github.com/mehta-lab/viscy.git +pip install viscy +``` + +Verify installation by accessing the CLI help message: + +```sh +viscy --help +``` + +For development installation, see [the contributing guide](CONTRIBUTING.md). + +Full functionality is only tested on Linux `x86_64` with NVIDIA Ampere GPUs (CUDA 12.0). +Some features (e.g. mixed precision and distributed training) may not work with other setups, +see [PyTorch documentation](https://pytorch.org) for details. + +## Predicting sub-cellular structure + +Training a model for the segmentation of sub-cellular landmarks +such as nuclei and membrane +directly can require laborious manual annotation. +We use fluorescent markers as a proxy of human-annotated masks +and turn this instance segmentation problem into +an image-to-image translation (I2I) problem. + +viscy features an end-to-end pipeline to design, train and evaluate +I2I models in a declarative manner. +It supports 2D, 2.5D (3D encoder, 2D decoder) and 3D U-Nets, +as well as 3D networks with anisotropic filters. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..fd9ec29b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,51 @@ +[build-system] +requires = ["setuptools", "setuptools-scm[toml]"] +build-backend = "setuptools.build_meta" + +[project] +name = "viscy" +description = "Learning vision for cells" +readme = "README.md" +# cannot build on 3.11 due to https://github.com/cheind/py-lapsolver/pull/18 +requires-python = ">=3.9,!=3.11" +license = { file = "LICENSE" } +authors = [{ name = "CZ Biohub SF", email = "compmicro@czbiohub.org" }] +dependencies = [ + "iohub==0.1.0.dev3", + "torch>=2.0.0", + "torchvision>=0.15.1", + "tensorboard>=2.13.0", + "lightning>=2.0.1", + "monai>=1.2.0", + "jsonargparse[signatures]>=4.20.1", + "scikit-image>=0.19.2", + "matplotlib", +] +dynamic = ["version"] + +[project.optional-dependencies] +metrics = [ + "cellpose==2.1.0", + "lapsolver==1.1.0", + "scikit-learn>=1.1.3", + "scipy>=1.8.0", + "torchmetrics[detection]>=1.0.0", +] +dev = ["pytest", "pytest-cov", "hypothesis", "profilehooks", "onnxruntime"] + +[project.scripts] +viscy = "viscy.cli.cli:main" + +[tool.setuptools_scm] +write_to = "viscy/_version.py" + +[tool.black] +src = ["viscy"] +line-length = 88 + +[tool.ruff] +src = ["viscy", "tests"] +extend-select = ["I001"] + +[tool.ruff.isort] +known-first-party = ["viscy"] diff --git a/tests/ conftest.py b/tests/ conftest.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..d15ea6cf --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for viscy""" diff --git a/tests/cli/__init__.py b/tests/cli/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/cli/preprocess_script_tests.py b/tests/cli/preprocess_script_tests.py new file mode 100644 index 00000000..da1edaed --- /dev/null +++ b/tests/cli/preprocess_script_tests.py @@ -0,0 +1,621 @@ +import os +import unittest +from tempfile import TemporaryDirectory + +import cv2 +import numpy as np +from iohub.ngff import open_ome_zarr + +import viscy.cli.preprocess_script as pp +import viscy.utils.aux_utils as aux_utils + + +class TestPreprocessScript(unittest.TestCase): + def setUp(self): + """ + Set up a directory with some images to resample + """ + self.tempdir = TemporaryDirectory() + self.temp_path = self.tempdir.path + self.image_dir = self.temp_path + self.output_dir = os.path.join(self.temp_path, "out_dir") + self.tempdir.makedir(self.output_dir) + self.zarr_dir = os.path.join(self.temp_path, "zarr_dir") + self.tempdir.makedir(self.zarr_dir) + # Start frames meta file + self.meta_name = "frames_meta.csv" + self.frames_meta = aux_utils.make_dataframe() + # Write images + self.time_idx = 0 + self.pos_ids = [7, 8, 10] + self.channel_ids = [0, 1, 2, 3] + self.slice_ids = [0, 1, 2, 3, 4, 5] + self.im = 1500 * np.ones((30, 20), dtype=np.uint16) + self.im[10:20, 5:15] = 3000 + # Create the same data in zarr format + zarr_writer = open_ome_zarr( + self.zarr_dir, + mode="w-", + layout="hcs", + channel_names=["ch0", "ch1", "ch2", "ch3"], + ) + for p in self.pos_ids: + zarr_writer.create_zarr_root("test_name_pos{}".format(p)) + zarr_writer.init_array( + position=p, + data_shape=(1, 4, 6, 30, 20), + chunk_size=(1, 1, 1, 30, 20), + dtype="uint16", + ) + for c in self.channel_ids: + for z in self.slice_ids: + im_name = aux_utils.get_im_name( + channel_idx=c, + slice_idx=z, + time_idx=self.time_idx, + pos_idx=p, + ) + im = self.im + c * 100 + cv2.imwrite(os.path.join(self.temp_path, im_name), im) + # Write zarr + zarr_writer.write(im, p=p, t=0, c=c, z=z) + # Create metadata + meta_row = aux_utils.parse_idx_from_name( + im_name=im_name, + dir_name=self.image_dir, + ) + meta_row["mean"] = np.nanmean(im) + meta_row["std"] = np.nanstd(im) + self.frames_meta = self.frames_meta.append( + meta_row, + ignore_index=True, + ) + # Write metadata + self.frames_meta.to_csv( + os.path.join(self.image_dir, self.meta_name), + sep=",", + ) + # Make input masks + self.input_mask_channel = 111 + self.input_mask_dir = os.path.join(self.temp_path, "input_mask_dir") + self.tempdir.makedir(self.input_mask_dir) + # Must have at least two foreground classes in mask for weight map to work + mask = np.zeros((30, 20), dtype=np.uint16) + mask[5:10, 5:15] = 1 + mask[20:25, 5:10] = 2 + mask_meta = aux_utils.make_dataframe() + for p in self.pos_ids: + for z in self.slice_ids: + im_name = aux_utils.get_im_name( + channel_idx=self.input_mask_channel, + slice_idx=z, + time_idx=self.time_idx, + pos_idx=p, + ) + cv2.imwrite( + os.path.join(self.input_mask_dir, im_name), + mask, + ) + mask_meta = mask_meta.append( + aux_utils.parse_idx_from_name( + im_name=im_name, + dir_name=self.input_mask_dir, + ), + ignore_index=True, + ) + mask_meta.to_csv( + os.path.join(self.input_mask_dir, self.meta_name), + sep=",", + ) + # Create preprocessing config + self.pp_config = { + "output_dir": self.output_dir, + "input_dir": self.image_dir, + "file_format": "tiff", + "channel_ids": [0, 1, 3], + "num_workers": 4, + "masks": { + "channels": [3], + "str_elem_radius": 3, + }, + "tile": { + "tile_size": [10, 10], + "step_size": [10, 10], + "depths": [1, 1, 1], + "mask_depth": 1, + "image_format": "zyx", + "normalize_channels": [True, True, True], + }, + "normalize": { + "normalize_im": "stack", + }, + } + # Create base config, generated party from pp_config in script + self.base_config = { + "input_dir": self.image_dir, + "output_dir": self.output_dir, + "file_format": "tiff", + "slice_ids": -1, + "time_ids": -1, + "pos_ids": -1, + "channel_ids": self.pp_config["channel_ids"], + "uniform_struct": True, + "int2strlen": 3, + "num_workers": 4, + "normalize_channels": [True, True, True], + } + + def tearDown(self): + """ + Tear down temporary folder and file structure + """ + self.tempdir.cleanup() + + def test_pre_process(self): + out_config, runtime = pp.pre_process(self.pp_config) + self.assertIsInstance(runtime, np.float) + self.assertEqual( + self.base_config["input_dir"], + self.image_dir, + ) + self.assertEqual( + self.base_config["channel_ids"], + self.pp_config["channel_ids"], + ) + self.assertEqual( + out_config["masks"]["mask_dir"], + os.path.join(self.output_dir, "mask_channels_3"), + ) + self.assertEqual( + out_config["tile"]["tile_dir"], + os.path.join(self.output_dir, "tiles_10-10_step_10-10"), + ) + # Make sure new mask channel assignment is correct + self.assertEqual(out_config["masks"]["mask_channel"], 4) + # Check that masks are generated + mask_dir = out_config["masks"]["mask_dir"] + mask_meta = aux_utils.read_meta(mask_dir) + mask_names = os.listdir(mask_dir) + mask_names = [mn for mn in mask_names if "overlay" not in mn] + mask_names.pop(mask_names.index("frames_meta.csv")) + # Validate that all masks are there + self.assertEqual( + len(mask_names), + len(self.slice_ids) * len(self.pos_ids), + ) + for p in self.pos_ids: + for z in self.slice_ids: + im_name = aux_utils.get_im_name( + channel_idx=out_config["masks"]["mask_channel"], + slice_idx=z, + time_idx=self.time_idx, + pos_idx=p, + ) + im = cv2.imread( + os.path.join(mask_dir, im_name), + cv2.IMREAD_ANYDEPTH, + ) + self.assertTupleEqual(im.shape, (30, 20)) + self.assertTrue(im.dtype == "uint8") + self.assertTrue(im_name in mask_names) + self.assertTrue(im_name in mask_meta["file_name"].tolist()) + # Check tiles + tile_dir = out_config["tile"]["tile_dir"] + tile_meta = aux_utils.read_meta(tile_dir) + # 4 processed channels (0, 1, 3, 4), 6 tiles per image + expected_rows = 4 * 6 * len(self.slice_ids) * len(self.pos_ids) + self.assertEqual(tile_meta.shape[0], expected_rows) + # Check indices + self.assertListEqual( + tile_meta.channel_idx.unique().tolist(), + [0, 1, 3, 4], + ) + self.assertListEqual( + tile_meta.pos_idx.unique().tolist(), + self.pos_ids, + ) + self.assertListEqual( + tile_meta.slice_idx.unique().tolist(), + self.slice_ids, + ) + self.assertListEqual( + tile_meta.time_idx.unique().tolist(), + [self.time_idx], + ) + self.assertListEqual( + list(tile_meta), + [ + "channel_idx", + "slice_idx", + "time_idx", + "file_name", + "pos_idx", + "row_start", + "col_start", + "dir_name", + ], + ) + self.assertListEqual( + tile_meta.row_start.unique().tolist(), + [0, 10, 20], + ) + self.assertListEqual( + tile_meta.col_start.unique().tolist(), + [0, 10], + ) + # Read one tile and check format + # r = row start/end idx, c = column start/end, sl = slice start/end + # sl0-1 signifies depth of 1 + im = np.load( + os.path.join( + tile_dir, + "im_c001_z000_t000_p007_r10-20_c10-20_sl0-1.npy", + ) + ) + self.assertTupleEqual(im.shape, (1, 10, 10)) + self.assertTrue(im.dtype == np.float64) + + def test_pre_process_zarr(self): + self.pp_config["input_dir"] = self.zarr_dir + self.pp_config["file_format"] = "zarr" + out_config, runtime = pp.pre_process(self.pp_config) + self.assertIsInstance(runtime, np.float) + self.assertEqual( + self.base_config["channel_ids"], + self.pp_config["channel_ids"], + ) + self.assertEqual( + out_config["masks"]["mask_dir"], + os.path.join(self.output_dir, "mask_channels_3"), + ) + self.assertEqual( + out_config["tile"]["tile_dir"], + os.path.join(self.output_dir, "tiles_10-10_step_10-10"), + ) + # Make sure new mask channel assignment is correct + self.assertEqual(out_config["masks"]["mask_channel"], 4) + # Check that masks are generated + mask_dir = out_config["masks"]["mask_dir"] + mask_meta = aux_utils.read_meta(mask_dir) + mask_names = os.listdir(mask_dir) + mask_names = [mn for mn in mask_names if "overlay" not in mn] + mask_names.pop(mask_names.index("frames_meta.csv")) + # Validate that all masks are there + self.assertEqual( + len(mask_names), + len(self.slice_ids) * len(self.pos_ids), + ) + for p in self.pos_ids: + for z in self.slice_ids: + im_name = aux_utils.get_im_name( + channel_idx=out_config["masks"]["mask_channel"], + slice_idx=z, + time_idx=self.time_idx, + pos_idx=p, + ) + im = cv2.imread( + os.path.join(mask_dir, im_name), + cv2.IMREAD_ANYDEPTH, + ) + self.assertTupleEqual(im.shape, (30, 20)) + self.assertTrue(im.dtype == "uint8") + self.assertTrue(im_name in mask_names) + self.assertTrue(im_name in mask_meta["file_name"].tolist()) + # Check tiles + tile_dir = out_config["tile"]["tile_dir"] + tile_meta = aux_utils.read_meta(tile_dir) + # 4 processed channels (0, 1, 3, 4), 6 tiles per image + expected_rows = 4 * 6 * len(self.slice_ids) * len(self.pos_ids) + self.assertEqual(tile_meta.shape[0], expected_rows) + # Check indices + self.assertListEqual( + tile_meta.channel_idx.unique().tolist(), + [0, 1, 3, 4], + ) + self.assertListEqual( + tile_meta.pos_idx.unique().tolist(), + self.pos_ids, + ) + self.assertListEqual( + tile_meta.slice_idx.unique().tolist(), + self.slice_ids, + ) + self.assertListEqual( + tile_meta.time_idx.unique().tolist(), + [self.time_idx], + ) + self.assertListEqual( + list(tile_meta), + [ + "channel_idx", + "slice_idx", + "time_idx", + "file_name", + "pos_idx", + "row_start", + "col_start", + "dir_name", + ], + ) + self.assertListEqual( + tile_meta.row_start.unique().tolist(), + [0, 10, 20], + ) + self.assertListEqual( + tile_meta.col_start.unique().tolist(), + [0, 10], + ) + # Read one tile and check format + # r = row start/end idx, c = column start/end, sl = slice start/end + # sl0-1 signifies depth of 1 + im = np.load( + os.path.join( + tile_dir, + "im_c001_z000_t000_p007_r10-20_c10-20_sl0-1.npy", + ) + ) + self.assertTupleEqual(im.shape, (1, 10, 10)) + self.assertTrue(im.dtype == np.float64) + + def test_pre_process_no_meta(self): + # Remove frames metadata and make sure it regenerates + os.remove(os.path.join(self.image_dir, "frames_meta.csv")) + out_config, runtime = pp.pre_process(self.pp_config) + self.assertIsInstance(runtime, np.float) + self.assertEqual( + self.base_config["input_dir"], + self.image_dir, + ) + frames_meta = aux_utils.read_meta(self.image_dir) + self.assertTupleEqual(frames_meta.shape, (72, 8)) + + def test_pre_process_intensity_meta(self): + cur_config = self.pp_config + # Use preexisiting masks with more than one class, otherwise + # weight map generation doesn't work + cur_config["normalize"] = { + "normalize_im": "volume", + "min_fraction": 0.1, + } + cur_config["metadata"] = { + "block_size": 10, + } + out_config, runtime = pp.pre_process(cur_config) + intensity_meta = aux_utils.read_meta(self.image_dir, "intensity_meta.csv") + expected_rows = [ + "channel_idx", + "pos_idx", + "slice_idx", + "time_idx", + "channel_name", + "dir_name", + "file_name", + "row_idx", + "col_idx", + "intensity", + "fg_frac", + "zscore_median", + "zscore_iqr", + "intensity_norm", + ] + self.assertListEqual(list(intensity_meta), expected_rows) + + def test_pre_process_weight_maps(self): + cur_config = self.pp_config + # Use preexisiting masks with more than one class, otherwise + # weight map generation doesn't work + cur_config["masks"] = { + "mask_dir": self.input_mask_dir, + "mask_channel": self.input_mask_channel, + } + cur_config["make_weight_map"] = True + out_config, runtime = pp.pre_process(cur_config) + + # Check weights dir + self.assertEqual( + out_config["weights"]["weights_dir"], + os.path.join(self.output_dir, "mask_channels_5"), + ) + weights_meta = aux_utils.read_meta(out_config["weights"]["weights_dir"]) + # Check indices + self.assertListEqual( + weights_meta.channel_idx.unique().tolist(), + [5], + ) + self.assertListEqual( + weights_meta.pos_idx.unique().tolist(), + self.pos_ids, + ) + self.assertListEqual( + weights_meta.slice_idx.unique().tolist(), + self.slice_ids, + ) + self.assertListEqual( + weights_meta.time_idx.unique().tolist(), + [self.time_idx], + ) + # Load one weights file and check contents + im = np.load( + os.path.join( + out_config["weights"]["weights_dir"], + "im_c005_z002_t000_p007.npy", + ) + ) + self.assertTupleEqual(im.shape, (30, 20)) + self.assertTrue(im.dtype == np.float64) + # Check tiles + tile_dir = out_config["tile"]["tile_dir"] + tile_meta = aux_utils.read_meta(tile_dir) + # 5 processed channels (0, 1, 3, 111, 112), 6 tiles per image + expected_rows = 5 * 6 * len(self.slice_ids) * len(self.pos_ids) + self.assertEqual(tile_meta.shape[0], expected_rows) + # Check indices + self.assertListEqual( + tile_meta.channel_idx.unique().tolist(), + [0, 1, 3, 4, 5], + ) + self.assertListEqual( + tile_meta.pos_idx.unique().tolist(), + self.pos_ids, + ) + self.assertListEqual( + tile_meta.slice_idx.unique().tolist(), + self.slice_ids, + ) + self.assertListEqual( + tile_meta.time_idx.unique().tolist(), + [self.time_idx], + ) + # Load a weight tile + im = np.load( + os.path.join( + tile_dir, + "im_c005_z002_t000_p008_r0-10_c10-20_sl0-1.npy", + ) + ) + self.assertTupleEqual(im.shape, (1, 10, 10)) + self.assertTrue(im.dtype == np.float) + # Load one tile + im = np.load( + os.path.join( + tile_dir, + "im_c004_z002_t000_p008_r0-10_c10-20_sl0-1.npy", + ) + ) + self.assertTupleEqual(im.shape, (1, 10, 10)) + self.assertTrue(im.dtype == bool) + + def test_pre_process_resize2d(self): + cur_config = self.pp_config + cur_config["resize"] = { + "scale_factor": 2, + "resize_3d": False, + } + cur_config["make_weight_map"] = False + out_config, runtime = pp.pre_process(cur_config) + + self.assertIsInstance(runtime, np.float) + self.assertEqual( + out_config["resize"]["resize_dir"], + os.path.join(self.output_dir, "resized_images"), + ) + resize_dir = out_config["resize"]["resize_dir"] + # Check that all images have been resized + resize_meta = aux_utils.read_meta(resize_dir) + # 3 resized channels + expected_rows = 3 * len(self.slice_ids) * len(self.pos_ids) + self.assertEqual(resize_meta.shape[0], expected_rows) + # Load an image and make sure it's twice as big + im = cv2.imread( + os.path.join(resize_dir, "im_c003_z002_t000_p010.png"), + cv2.IMREAD_ANYDEPTH, + ) + self.assertTupleEqual(im.shape, (60, 40)) + self.assertTrue(im.dtype, np.uint8) + # There should now be 2*2 the amount of tiles, same shape + tile_dir = out_config["tile"]["tile_dir"] + tile_meta = aux_utils.read_meta(tile_dir) + # 4 processed channels (0, 1, 3, 4), 24 tiles per image + expected_rows = 4 * 24 * len(self.slice_ids) * len(self.pos_ids) + self.assertEqual(tile_meta.shape[0], expected_rows) + # Load a tile and assert shape + im = np.load( + os.path.join( + tile_dir, + "im_c001_z000_t000_p007_r40-50_c20-30_sl0-1.npy", + ) + ) + self.assertTupleEqual(im.shape, (1, 10, 10)) + self.assertTrue(im.dtype == np.float64) + + def test_pre_process_resize3d(self): + cur_config = self.pp_config + cur_config["resize"] = { + "scale_factor": [1, 1.5, 1], + "resize_3d": True, + } + cur_config["tile"] = { + "tile_size": [10, 10], + "step_size": [10, 10], + "depths": [1, 1, 1], + "mask_depth": 1, + "image_format": "zyx", + "normalize_channels": [True, True, True], + } + out_config, runtime = pp.pre_process(cur_config) + + self.assertIsInstance(runtime, np.float) + self.assertEqual( + out_config["resize"]["resize_dir"], + os.path.join(self.output_dir, "resized_images"), + ) + # Load a resized image and assert shape + im_path = os.path.join( + out_config["resize"]["resize_dir"], + "im_c000_z000_t000_p007_1.0-1.5-1.0.npy", + ) + im = np.load(im_path) + # shape should be 30, 20*1.5, z=6) + self.assertTupleEqual(im.shape, (30, 30, 6)) + + self.assertEqual( + out_config["masks"]["mask_dir"], + os.path.join(self.output_dir, "mask_channels_3"), + ) + self.assertEqual(out_config["masks"]["mask_channel"], 4) + + self.assertEqual( + out_config["tile"]["tile_dir"], + os.path.join(self.output_dir, "tiles_10-10_step_10-10"), + ) + im_path = os.path.join( + out_config["tile"]["tile_dir"], + "im_c000_z000_t000_p008_r0-10_c0-10_sl0-6.npy", + ) + # A tile channels first should have shape (6, 10, 10) + tile = np.load(im_path) + self.assertTupleEqual(tile.shape, (6, 10, 10)) + self.assertTrue(tile.dtype == np.float64) + + def test_pre_process_nonisotropic(self): + base_config = self.base_config + base_config["uniform_struct"] = False + out_config, runtime = pp.pre_process(self.pp_config) + + self.assertIsInstance(runtime, np.float) + self.assertEqual( + out_config["masks"]["mask_dir"], + os.path.join(self.output_dir, "mask_channels_3"), + ) + self.assertEqual(out_config["masks"]["mask_channel"], 4) + self.assertEqual( + out_config["tile"]["tile_dir"], + os.path.join(self.output_dir, "tiles_10-10_step_10-10"), + ) + + def test_save_config(self): + cur_config = self.pp_config + cur_config["masks"]["mask_dir"] = os.path.join( + self.output_dir, "mask_channels_3" + ) + cur_config["tile"]["tile_dir"] = os.path.join( + self.output_dir, "tiles_10-10_step_10-10" + ) + pp.save_config(cur_config, 11.1) + # Load json back up + saved_info = aux_utils.read_json( + os.path.join(self.output_dir, "preprocessing_info.json"), + ) + self.assertEqual(len(saved_info), 1) + saved_config = saved_info[0]["config"] + self.assertDictEqual(saved_config, cur_config) + # Save one more config + cur_config["input_dir"] = cur_config["tile"]["tile_dir"] + pp.save_config(cur_config, 666.66) + # Load json back up + saved_info = aux_utils.read_json( + os.path.join(self.output_dir, "preprocessing_info.json"), + ) + self.assertEqual(len(saved_info), 2) + saved_config = saved_info[1]["config"] + self.assertDictEqual(saved_config, cur_config) diff --git a/tests/evaluation/__init__.py b/tests/evaluation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/evaluation/test_evaluation_metrics.py b/tests/evaluation/test_evaluation_metrics.py new file mode 100644 index 00000000..30ef15b3 --- /dev/null +++ b/tests/evaluation/test_evaluation_metrics.py @@ -0,0 +1,103 @@ +import numpy as np +import pytest +import torch +from skimage import data, measure + +from viscy.evaluation.evaluation_metrics import ( + POD_metric, + VOI_metric, + labels_to_detection, + labels_to_masks, + mean_average_precision, +) + + +@pytest.fixture(scope="session") +def labels_numpy() -> tuple[np.ndarray]: + return ( + measure.label(image).astype(np.int16) + for image in (data.binary_blobs(), data.horse() == 0) + ) + + +@pytest.fixture(scope="session") +def labels_tensor(labels_numpy) -> tuple[torch.ShortTensor]: + return (torch.from_numpy(labels) for labels in labels_numpy) + + +def _is_within_unit(metric: float): + return 1 > metric > 0 + + +def test_VOI_metric(labels_numpy): + """Variation of information: smaller is better.""" + for labels in labels_numpy: + assert VOI_metric(labels, labels)[0] == 0 + assert VOI_metric(labels, np.zeros_like(labels))[0] > 0.9 + + +def test_POD_metric(labels_numpy): + """Test POD_metric()""" + for labels in labels_numpy: + ( + true_positives, + false_positives, + false_negatives, + precision, + recall, + f1_score, + ) = POD_metric(labels, labels) + assert true_positives == labels.max() + assert precision == recall == f1_score == 1 + assert false_negatives == 0 + assert false_positives == 0 + wrong_labels = np.copy(labels) + wrong_labels[wrong_labels == 1] = 0 + wrong_labels[0, 0] == 1 + ( + true_positives, + false_positives, + false_negatives, + precision, + recall, + f1_score, + ) = POD_metric(labels, wrong_labels) + for i, metric in enumerate((precision, recall, f1_score)): + assert _is_within_unit(metric), i + assert true_positives < labels.max() + assert false_negatives > 0 + assert false_positives > 0 + + +def test_labels_to_masks(labels_tensor: torch.ShortTensor): + for labels in labels_tensor: + masks = labels_to_masks(labels) + assert masks.shape == (int(labels.max()), *labels.shape) + assert masks.dtype == torch.bool + assert torch.equal(masks[0], labels == 1) + with pytest.raises(ValueError): + _ = labels_to_masks(torch.randint(0, 5, (3, 32, 32), dtype=torch.short)) + + +def test_labels_to_masks_more_dims(): + with pytest.raises(ValueError): + _ = labels_to_masks(torch.randint(0, 42, (1, 4, 4), dtype=torch.short)) + + +def test_labels_to_detection(labels_tensor: torch.ShortTensor): + for labels in labels_tensor: + num_boxes = int(labels.max()) + detection = labels_to_detection(labels) + assert set(detection) == {"boxes", "scores", "labels", "masks"} + assert detection["boxes"].shape == (num_boxes, 4) + assert detection["scores"].shape == (num_boxes,) + assert detection["labels"].shape == (num_boxes,) + assert detection["masks"].device == labels.device + + +def test_mean_average_precision(labels_tensor: torch.ShortTensor): + for labels in labels_tensor: + coco_metrics = mean_average_precision(labels, labels) + assert coco_metrics["map"] == 1 + assert _is_within_unit(coco_metrics["mar_1"]) + assert coco_metrics["mar_10"] == 1 diff --git a/tests/preprocessing/__init__.py b/tests/preprocessing/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/preprocessing/generate_masks_tests.py b/tests/preprocessing/generate_masks_tests.py new file mode 100644 index 00000000..45f2e4f4 --- /dev/null +++ b/tests/preprocessing/generate_masks_tests.py @@ -0,0 +1,214 @@ +import os +import unittest +import warnings + +import nose.tools +import numpy as np +import numpy.testing +import pandas as pd +import skimage.io as sk_im_io +from testfixtures import TempDirectory + +from viscy.preprocessing.generate_masks import MaskProcessor +from viscy.utils import aux_utils as aux_utils + + +class TestMaskProcessor(unittest.TestCase): + def setUp(self): + """Set up a directory for mask generation,""" + + self.tempdir = TempDirectory() + self.temp_path = self.tempdir.path + self.meta_fname = "frames_meta.csv" + frames_meta = aux_utils.make_dataframe() + + # create an image with bimodal hist + x = np.linspace(-4, 4, 32) + y = x.copy() + z = np.linspace(-3, 3, 8) + xx, yy, zz = np.meshgrid(x, y, z) + sph = xx**2 + yy**2 + zz**2 + fg = (sph <= 8) * (8 - sph) + fg[fg > 1e-8] = (fg[fg > 1e-8] / np.max(fg)) * 127 + 128 + fg = np.around(fg).astype("uint8") + bg = np.around((sph > 8) * sph).astype("uint8") + object1 = fg + bg + + # create an image with a rect + rec = np.zeros(sph.shape) + rec[3:30, 14:18, 3:6] = 120 + rec[14:18, 3:30, 3:6] = 120 + + self.sph_object = object1 + self.rec_object = rec + + self.channel_ids = [1, 2] + self.time_ids = 0 + self.pos_ids = 1 + self.int2str_len = 3 + + for z in range(sph.shape[2]): + im_name = aux_utils.get_im_name( + time_idx=self.time_ids, + channel_idx=1, + slice_idx=z, + pos_idx=self.pos_ids, + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + sk_im_io.imsave( + os.path.join(self.temp_path, im_name), + object1[:, :, z].astype("uint8"), + ) + frames_meta = frames_meta.append( + aux_utils.parse_idx_from_name(im_name=im_name, dir_name=self.temp_path), + ignore_index=True, + ) + for z in range(rec.shape[2]): + im_name = aux_utils.get_im_name( + time_idx=self.time_ids, + channel_idx=2, + slice_idx=z, + pos_idx=self.pos_ids, + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + sk_im_io.imsave( + os.path.join(self.temp_path, im_name), + rec[:, :, z].astype("uint8"), + ) + frames_meta = frames_meta.append( + aux_utils.parse_idx_from_name(im_name=im_name, dir_name=self.temp_path), + ignore_index=True, + ) + # Write metadata + frames_meta.to_csv(os.path.join(self.temp_path, self.meta_fname), sep=",") + + self.output_dir = os.path.join(self.temp_path, "mask_dir") + self.mask_gen_inst = MaskProcessor( + input_dir=self.temp_path, + output_dir=self.output_dir, + channel_ids=self.channel_ids, + ) + + def tearDown(self): + """Tear down temporary folder and file structure""" + TempDirectory.cleanup_all() + self.assertFalse(os.path.isdir(self.temp_path)) + + def test_init(self): + """Test init""" + self.assertEqual(self.mask_gen_inst.input_dir, self.temp_path) + self.assertEqual(self.mask_gen_inst.output_dir, self.output_dir) + nose.tools.assert_equal(self.mask_gen_inst.mask_channel, 3) + nose.tools.assert_equal( + self.mask_gen_inst.mask_dir, + os.path.join(self.output_dir, "mask_channels_1-2"), + ) + self.assertListEqual(self.channel_ids, self.channel_ids) + nose.tools.assert_equal(self.mask_gen_inst.nested_id_dict, None) + + def test_get_mask_dir(self): + """Test get_mask_dir""" + mask_dir = os.path.join(self.output_dir, "mask_channels_1-2") + nose.tools.assert_equal(self.mask_gen_inst.get_mask_dir(), mask_dir) + + def test_get_mask_channel(self): + """Test get_mask_channel""" + nose.tools.assert_equal(self.mask_gen_inst.get_mask_channel(), 3) + + def test_generate_masks_uni(self): + """Test generate masks""" + self.mask_gen_inst.generate_masks(str_elem_radius=1) + frames_meta = pd.read_csv( + os.path.join(self.mask_gen_inst.get_mask_dir(), "frames_meta.csv"), + index_col=0, + ) + # 8 slices and 3 channels + exp_len = 8 + nose.tools.assert_equal(len(frames_meta), exp_len) + for idx in range(exp_len): + nose.tools.assert_equal( + "im_c003_z00{}_t000_p001.npy".format(idx), + frames_meta.iloc[idx]["file_name"], + ) + + def test_generate_masks_nonuni(self): + """Test generate_masks with non-uniform structure""" + rec = self.rec_object[:, :, 3:6] + channel_ids = 0 + time_ids = 0 + pos_ids = [1, 2] + frames_meta = aux_utils.make_dataframe() + + for z in range(self.sph_object.shape[2]): + im_name = aux_utils.get_im_name( + time_idx=time_ids, + channel_idx=channel_ids, + slice_idx=z, + pos_idx=pos_ids[0], + ) + sk_im_io.imsave( + os.path.join(self.temp_path, im_name), + self.sph_object[:, :, z].astype("uint8"), + ) + frames_meta = frames_meta.append( + aux_utils.parse_idx_from_name(im_name=im_name, dir_name=self.temp_path), + ignore_index=True, + ) + for z in range(rec.shape[2]): + im_name = aux_utils.get_im_name( + time_idx=time_ids, + channel_idx=channel_ids, + slice_idx=z, + pos_idx=pos_ids[1], + ) + sk_im_io.imsave( + os.path.join(self.temp_path, im_name), rec[:, :, z].astype("uint8") + ) + frames_meta = frames_meta.append( + aux_utils.parse_idx_from_name(im_name=im_name, dir_name=self.temp_path), + ignore_index=True, + ) + # Write metadata + frames_meta.to_csv(os.path.join(self.temp_path, self.meta_fname), sep=",") + + self.output_dir = os.path.join(self.temp_path, "mask_dir") + mask_gen_inst = MaskProcessor( + input_dir=self.temp_path, + output_dir=self.output_dir, + channel_ids=channel_ids, + uniform_struct=False, + ) + exp_nested_id_dict = {0: {0: {1: [0, 1, 2, 3, 4, 5, 6, 7], 2: [0, 1, 2]}}} + numpy.testing.assert_array_equal( + mask_gen_inst.nested_id_dict[0][0][1], exp_nested_id_dict[0][0][1] + ) + numpy.testing.assert_array_equal( + mask_gen_inst.nested_id_dict[0][0][2], exp_nested_id_dict[0][0][2] + ) + + mask_gen_inst.generate_masks(str_elem_radius=1) + + frames_meta = pd.read_csv( + os.path.join(mask_gen_inst.get_mask_dir(), "frames_meta.csv"), + index_col=0, + ) + # pos1: 8 slices, pos2: 3 slices + exp_len = 8 + 3 + nose.tools.assert_equal(len(frames_meta), exp_len) + mask_fnames = frames_meta["file_name"].tolist() + exp_mask_fnames = [ + "im_c001_z000_t000_p001.npy", + "im_c001_z000_t000_p002.npy", + "im_c001_z001_t000_p001.npy", + "im_c001_z001_t000_p002.npy", + "im_c001_z002_t000_p001.npy", + "im_c001_z002_t000_p002.npy", + "im_c001_z003_t000_p001.npy", + "im_c001_z004_t000_p001.npy", + "im_c001_z005_t000_p001.npy", + "im_c001_z006_t000_p001.npy", + "im_c001_z007_t000_p001.npy", + ] + nose.tools.assert_list_equal(mask_fnames, exp_mask_fnames) diff --git a/tests/preprocessing/resize_images_tests.py b/tests/preprocessing/resize_images_tests.py new file mode 100644 index 00000000..835f3de4 --- /dev/null +++ b/tests/preprocessing/resize_images_tests.py @@ -0,0 +1,188 @@ +import os +import unittest + +import cv2 +import numpy as np +import pandas as pd +from testfixtures import TempDirectory + +import viscy.preprocessing.resize_images as resize_images +import viscy.utils.aux_utils as aux_utils + + +class TestResizeImages(unittest.TestCase): + def setUp(self): + """ + Set up a directory with some images to resample + """ + self.tempdir = TempDirectory() + self.temp_path = self.tempdir.path + self.output_dir = os.path.join(self.temp_path, "out_dir") + # Start frames meta file + self.meta_name = "frames_meta.csv" + self.frames_meta = aux_utils.make_dataframe() + # Write images + self.time_idx = 5 + self.slice_idx = 6 + self.pos_idx = 7 + self.im = 1500 * np.ones((30, 20), dtype=np.uint16) + + for c in range(4): + for p in range(self.pos_idx, self.pos_idx + 2): + im_name = aux_utils.get_im_name( + channel_idx=c, + slice_idx=self.slice_idx, + time_idx=self.time_idx, + pos_idx=p, + ) + cv2.imwrite(os.path.join(self.temp_path, im_name), self.im + c * 100) + self.frames_meta = self.frames_meta.append( + aux_utils.parse_idx_from_name( + im_name=im_name, dir_name=self.temp_path + ), + ignore_index=True, + ) + # Write metadata + self.frames_meta.to_csv( + os.path.join(self.temp_path, self.meta_name), + sep=",", + ) + + def tearDown(self): + """ + Tear down temporary folder and file structure + """ + TempDirectory.cleanup_all() + nose.tools.assert_equal(os.path.isdir(self.temp_path), False) + + def test_downsample(self): + # Half the image size + scale_factor = 0.5 + resize_inst = resize_images.ImageResizer( + input_dir=self.temp_path, + output_dir=self.output_dir, + scale_factor=scale_factor, + ) + self.assertEqual(resize_inst.time_ids, self.time_idx) + self.assertListEqual(resize_inst.channel_ids.tolist(), [0, 1, 2, 3]) + self.assertEqual(resize_inst.slice_ids, self.slice_idx) + self.assertListEqual(resize_inst.pos_ids.tolist(), [7, 8]) + resize_dir = resize_inst.get_resize_dir() + self.assertEqual(os.path.join(self.output_dir, "resized_images"), resize_dir) + # Resize + resize_inst.resize_frames() + # Validate + new_shape = tuple([int(scale_factor * x) for x in self.im.shape]) + for i, row in self.frames_meta.iterrows(): + file_name = os.path.join(resize_dir, row["file_name"]) + im = cv2.imread(file_name, cv2.IMREAD_ANYDEPTH) + self.assertTupleEqual(new_shape, im.shape) + self.assertEqual(im.dtype, self.im.dtype) + im_expected = self.im + row["channel_idx"] * 100 + im_expected = cv2.resize(im_expected, (new_shape[1], new_shape[0])) + np.testing.assert_array_equal(im, im_expected) + + def test_upsample(self): + # Half the image size + scale_factor = 2.0 + resize_inst = resize_images.ImageResizer( + input_dir=self.temp_path, + output_dir=self.output_dir, + scale_factor=scale_factor, + ) + self.assertEqual(resize_inst.time_ids, self.time_idx) + self.assertListEqual(resize_inst.channel_ids.tolist(), [0, 1, 2, 3]) + self.assertEqual(resize_inst.slice_ids, self.slice_idx) + self.assertListEqual(resize_inst.pos_ids.tolist(), [7, 8]) + resize_dir = resize_inst.get_resize_dir() + self.assertEqual(os.path.join(self.output_dir, "resized_images"), resize_dir) + # Resize + resize_inst.resize_frames() + # Validate + new_shape = tuple([int(scale_factor * x) for x in self.im.shape]) + for i, row in self.frames_meta.iterrows(): + file_name = os.path.join(resize_dir, row["file_name"]) + im = cv2.imread(file_name, cv2.IMREAD_ANYDEPTH) + self.assertTupleEqual(new_shape, im.shape) + self.assertEqual(im.dtype, self.im.dtype) + im_expected = self.im + row["channel_idx"] * 100 + im_expected = cv2.resize(im_expected, (new_shape[1], new_shape[0])) + np.testing.assert_array_equal(im, im_expected) + + def test_resize_volumes(self): + """Test resizing volumes""" + + # set up a volume with 5 slices, 2 channels + slice_ids = [0, 1, 2, 3, 4] + channel_ids = [2, 3] + resize_dir = os.path.join(self.output_dir, "resized_images") + frames_meta = aux_utils.make_dataframe() + exp_meta_dict = [] + for c in channel_ids: + for s in slice_ids: + im_name = aux_utils.get_im_name( + channel_idx=c, + slice_idx=s, + time_idx=self.time_idx, + pos_idx=self.pos_idx, + ) + cv2.imwrite(os.path.join(self.temp_path, im_name), self.im + c * 100) + frames_meta = frames_meta.append( + aux_utils.parse_idx_from_name( + im_name=im_name, dir_name=self.temp_path + ), + ignore_index=True, + ) + op_fname = "im_c00{}_z000_t005_p007_3.3-0.8-1.0.npy".format(c) + exp_meta_dict.append( + { + "time_idx": self.time_idx, + "pos_idx": self.pos_idx, + "channel_idx": c, + "slice_idx": 0, + "file_name": op_fname, + "mean": np.mean(self.im) + c * 100, + "std": float(0), + "dir_name": resize_dir, + } + ) + exp_meta_df = pd.DataFrame.from_dict(exp_meta_dict) + # Write metadata + frames_meta.to_csv( + os.path.join(self.temp_path, self.meta_name), + sep=",", + ) + + scale_factor = [3.3, 0.8, 1.0] + resize_inst = resize_images.ImageResizer( + input_dir=self.temp_path, + output_dir=self.output_dir, + scale_factor=scale_factor, + ) + + # save all slices in one volume + resize_inst.resize_volumes() + saved_meta = aux_utils.read_meta(resize_dir) + pd.testing.assert_frame_equal(saved_meta, exp_meta_df) + + # num_slices_subvolume = 3, save vol chunks + exp_meta_dict = [] + for c in channel_ids: + for s in [0, 2]: + op_fname = "im_c00{}_z00{}_t005_p007_3.3-0.8-1.0.npy".format(c, s) + exp_meta_dict.append( + { + "time_idx": self.time_idx, + "pos_idx": self.pos_idx, + "channel_idx": c, + "slice_idx": s, + "file_name": op_fname, + "mean": np.mean(self.im) + c * 100, + "std": float(0), + "dir_name": resize_dir, + } + ) + exp_meta_df = pd.DataFrame.from_dict(exp_meta_dict) + resize_inst.resize_volumes(num_slices_subvolume=3) + saved_meta = aux_utils.read_meta(resize_dir) + pd.testing.assert_frame_equal(saved_meta, exp_meta_df) diff --git a/tests/torch_unet/networks/Unet25D_tests.py b/tests/torch_unet/networks/Unet25D_tests.py new file mode 100644 index 00000000..d190474f --- /dev/null +++ b/tests/torch_unet/networks/Unet25D_tests.py @@ -0,0 +1,270 @@ +import collections +import itertools +import unittest + +import numpy as np +import torch + +import viscy.utils.cli_utils as io_utils +from viscy.unet.networks.Unet25D import Unet25d + + +class TestUnet25d(unittest.TestCase): + """ + Testing class for all configurations of the 2.5D Unet architecture + Functionality of core PyTorch and nummpy operations assumed to be + complete and sound. + """ + + def SetUp(self): + """ + Set up inputs and block configurations + """ + # possible inputs and output shapes + self.pass_inputs = { + "standard": [torch.ones((1, 1, 5, 256, 256)), (1, 1, 1, 256, 256)], + "batch": [torch.ones((3, 1, 5, 256, 256)), (3, 1, 1, 256, 256)], + "multichannel": [torch.ones((1, 1, 5, 256, 256)), (1, 2, 5, 256, 256)], + "multichannel_flat": [torch.ones((1, 2, 5, 256, 256)), (1, 2, 5, 256, 256)], + } + self.fail_inputs = { + "nonsquare": [torch.ones((1, 1, 5, 128, 256)), (1, 1, 1, 128, 256)], + "nonsquare_arbitrary": [ + torch.ones((1, 1, 5, 128, 316)), + (1, 1, 1, 128, 316), + ], + "wrong_dims": [torch.ones((1, 1, 1, 1)), (1, 1, 1, 1)], + } + # possible configurations + self.configs = { + "xy_kernel_size": ((1, 1), (3, 5), (3, 3)), + "residual": (True, False), + "dropout": (False, 0.25), + "num_blocks": (1, 2, 4), + "num_block_layers": (1, 3), # True yields padding error in pytorch 1.10 + "num_filters": ([],), + "task": ("reg", "seg"), + } + + def _get_outputs(self, kwargs): + """ + Template testing class + + :param list kwargs: list of arguments for 25D Unet object + + :return numpy.ndarray inputs: inputs to Unet + :return numpy.ndarray outputs: outputs from Unet, respective + :return tuple exp_out: expected output + """ + input_, exp_out_shape = ( + self.pass_inputs["standard"][0], + self.pass_inputs["standard"][1], + ) + + in_channels = input_.shape[1] + out_channels = exp_out_shape[1] + + in_depth = input_.shape[2] + out_depth = exp_out_shape[2] + + network = Unet25d(in_channels, out_channels, in_depth, out_depth, *kwargs) + + try: + output = network(input_) + input_, output = input_.detach().numpy(), output.detach().numpy() + exp_out = output + return input_, output, exp_out + except Exception as e: + self.excep = e + input_.detach().numpy() + return input_, np.ones((1, 1)), np.zeros((1, 1)) + + def _get_output_shapes(self, kwargs, pass_): + """ + Gets outputs for all inputs of type 'pass_' + + If inputs expected to fail, exp_out_shape will be False + + :param list kwargs: list of arguments for Unet25d object + :param boolean pass_: whether inputs are expected to pass tests + + :return list inputs: list of inputs to Unet + :return list outputs: list of outputs from Unet, respective + :return list exp_out_shapes: list of expected output shapes from + Unet, respective + """ + inputs, outputs, exp_out_shapes = [], [], [] + test_inputs = self.pass_inputs if pass_ else self.fail_inputs + for test in test_inputs: + input_, exp_out_shape = test_inputs[test][0], test_inputs[test][1] + + in_channels = input_.shape[1] + out_channels = exp_out_shape[1] + + in_depth = input_.shape[2] + out_depth = exp_out_shape[2] + + network = Unet25d(in_channels, out_channels, in_depth, out_depth, *kwargs) + + try: + output = network(input_) + inputs.append(input_) + outputs.append(output) + exp_out_shapes.append(exp_out_shape) + except Exception as e: + self.excep = e + inputs.append(input_) + outputs.append(False) + exp_out_shapes.append(exp_out_shape if pass_ else False) + + return inputs, outputs, exp_out_shapes + + def _get_residual_params(self, kwargs, resid_index): + """ + Gets parameters of residual and nonresidual networks + + :param list kwargs: list of arguments for Unet25d object + :param int resid_index: index of residual parameter in kwargs + + :return nn.module.parameter params: trainable parameters of non-residual block + :return nn.module.parameter resid_params: trainable parameters of residual block + """ + input_, exp_out_shape = ( + self.pass_inputs["standard"][0], + self.pass_inputs["standard"][1], + ) + + in_channels = input_.shape[1] + out_channels = exp_out_shape[1] + + in_depth = input_.shape[2] + out_depth = exp_out_shape[2] + + resid_kwargs, kwargs = list(kwargs), list(kwargs) + kwargs[resid_index] = False + resid_kwargs[resid_index] = True + + try: + network = Unet25d(in_channels, out_channels, in_depth, out_depth, *kwargs) + resid_network = Unet25d( + in_channels, out_channels, in_depth, out_depth, *resid_kwargs + ) + + return network.parameters(), resid_network.parameters() + except Exception as e: + self.excep = e + return None, None + + def _all_test_configurations(self, test, verbose=True): + """ + Run specified test on all possible 25D Unet input configurations. + Send failure information to stdout. + + Current tests: + - Initialization and input->output for cartesian product of parameters + - shape matching (single-channel, multi-channel) + - residual (same number of trainable params) + - kernel shapes (nonsquare doesnt break functionality) + + :param str test: which test to run. Must be within {'passing', 'failing', 'residual'} + :param bool verbose: Verbosity of str output + """ + self.SetUp() + + configs_list = [self.configs[key] for key in self.configs] + configs_list = list(itertools.product(*configs_list)) + failed_tests = collections.defaultdict(lambda: []) + + print("Testing", len(configs_list), "configurations:") if verbose else None + + for i, args in enumerate(configs_list): + if test == "passing": + # test passing shapes + _, outputs, exp_out_shapes = self._get_output_shapes(args, True) + out_shapes = [ + ar.detach().numpy().shape if isinstance(ar, torch.Tensor) else ar + for ar in outputs + ] + try: + out_shapes = np.array(out_shapes, dtype=object) + exp_out_shapes = np.array(exp_out_shapes, dtype=object) + fail_message = ( + f"'Passing' input tests failed on config {i+1} \n args: {args}" + ) + np.testing.assert_array_equal( + out_shapes, exp_out_shapes, fail_message + ) + except: + failed_tests[i].append(args) + failed_tests[i].append(self.excep) + elif test == "failing": + # test failing shapes + _, outputs, exp_out_shapes = self._get_output_shapes(args, False) + out_shapes = [ + ar.detach().numpy().shape if isinstance(ar, torch.Tensor) else ar + for ar in outputs + ] + try: + out_shapes = np.array(out_shapes, dtype=object) + exp_out_shapes = np.array(exp_out_shapes, dtype=object) + fail_message = ( + f"\t'Failing' tests failed on config {i+1} \n args: {args}" + ) + np.testing.assert_array_equal( + out_shapes, exp_out_shapes, fail_message + ) + except: + failed_tests[i].append(args) + failed_tests[i].append(self.excep) + elif test == "residual": + # test residual + resid_index = 2 + if args[resid_index] == False: + params, resid_params = self._get_residual_params(args, resid_index) + try: + fail_message = f"\t Residual params tests failed on config {i+1} \n args: {args}" + np.testing.assert_equal( + len(list(params)), len(list(resid_params)), fail_message + ) + except: + failed_tests[i].append(args) + failed_tests[i].append(self.excep) + + io_utils.show_progress_bar(configs_list, i, process="testing", interval=10) + if verbose: + print( + f"Testing complete! {len(configs_list)-len(failed_tests)}/{len(configs_list)} passed." + ) + if len(failed_tests) > 0: + print(f"Failed messages:") + for key in failed_tests: + print(f"Config {key}: {failed_tests[key]}") + + # -------------- Tests -----------------# + + def test_residual(self): + """ + Test residual functionality 25D Unet + + Test that residual blocks do not contain additional parameters + Runs test with every possible block configuration. + """ + self._all_test_configurations(test="residual") + + def test_passing(self): + """ + Test passing input functionality 25D Unet + + Test input-output functionality and expected output shape of all passing input shapes. + Runs test with every possible block configuration. + """ + self._all_test_configurations(test="passing") + + def test_failing(self): + """ + Test failing input handling 25D Unet + + Checks to see if all failing input types are caught by conv block. + Runs test with every possible block configuration. + """ + self._all_test_configurations(test="failing") diff --git a/tests/torch_unet/networks/Unet2D_tests.py b/tests/torch_unet/networks/Unet2D_tests.py new file mode 100644 index 00000000..70c6cc3d --- /dev/null +++ b/tests/torch_unet/networks/Unet2D_tests.py @@ -0,0 +1,256 @@ +import collections +import itertools +import unittest + +import numpy as np +import torch + +import viscy.utils.cli_utils as io_utils +from viscy.unet.networks.Unet2D import Unet2d + + +class TestUnet2d(unittest.TestCase): + """ + Testing class for all configurations of the 2D Unet implementaion + Functionality of core PyTorch and nummpy operations assumed to be + complete and sound. + """ + + def SetUp(self): + """ + Set up inputs and block configurations + """ + # possible inputs and output shapes + self.pass_inputs = { + "standard": [torch.ones((1, 1, 256, 256)), (1, 1, 256, 256)], + "batch": [torch.ones((3, 1, 256, 256)), (3, 1, 256, 256)], + "multichannel": [torch.ones((1, 1, 256, 256)), (1, 2, 256, 256)], + "multichannel_flat": [torch.ones((1, 2, 256, 256)), (1, 2, 256, 256)], + } + self.fail_inputs = { + "nonsquare": [torch.ones((1, 1, 128, 256)), (1, 1, 128, 256)], + "nonsquare_arbitrary": [torch.ones((1, 1, 128, 316)), (1, 1, 128, 316)], + "wrong_dims": [torch.ones((1, 1, 1)), (1, 1, 1)], + } + # possible configurations + self.configs = { + "xy_kernel_size": ((1, 1), (3, 5), (3, 3)), + "residual": (True, False), + "dropout": (False, 0.25), + "num_blocks": (1, 2, 4), + "num_block_layers": (1, 3), # True yields padding error in pytorch 1.10 + "num_filters": ([],), + "task": ("reg", "seg"), + } + + def _get_outputs(self, kwargs): + """ + Template testing class + + :param list kwargs: list of arguments for Unet object + + :return numpy.ndarray inputs: inputs to Unet + :return numpy.ndarray outputs: outputs from Unet, respective + :return tuple exp_out: expected output + """ + input_, exp_out_shape = ( + self.pass_inputs["standard"][0], + self.pass_inputs["standard"][1], + ) + + in_channels = input_.shape[1] + out_channels = exp_out_shape[1] + + network = Unet2d(in_channels, out_channels, *kwargs) + + try: + output = network(input_) + input_, output = input_.detach().numpy(), output.detach().numpy() + exp_out = output + return input_, output, exp_out + except Exception as e: + self.excep = e + input_.detach().numpy() + return input_, np.ones((1, 1)), np.zeros((1, 1)) + + def _get_output_shapes(self, kwargs, pass_): + """ + Gets outputs for all inputs of type 'pass_' + + If inputs expected to fail, exp_out_shape will be False + + :param list kwargs: list of arguments for Unet2d object + :param boolean pass_: whether inputs are expected to pass tests + + :return list inputs: list of inputs to Unet + :return list outputs: list of outputs from Unet, respective + :return list exp_out_shapes: list of expected output shapes from + Unet, respective + """ + inputs, outputs, exp_out_shapes = [], [], [] + test_inputs = self.pass_inputs if pass_ else self.fail_inputs + for test in test_inputs: + input_, exp_out_shape = test_inputs[test][0], test_inputs[test][1] + + in_channels = input_.shape[1] + out_channels = exp_out_shape[1] + + network = Unet2d(in_channels, out_channels, *kwargs) + + try: + output = network(input_) + inputs.append(input_) + outputs.append(output) + exp_out_shapes.append(exp_out_shape) + except Exception as e: + self.excep = e + inputs.append(input_) + outputs.append(False) + exp_out_shapes.append(exp_out_shape if pass_ else False) + + return inputs, outputs, exp_out_shapes + + def _get_residual_params(self, kwargs, resid_index): + """ + Gets parameters of residual and nonresidual networks + + :param list kwargs: list of arguments for Unet2d object + :param int resid_index: index of residual parameter in kwargs + + :return nn.module.parameter params: trainable parameters of non-residual block + :return nn.module.parameter resid_params: trainable parameters of residual block + """ + input_, exp_out_shape = ( + self.pass_inputs["standard"][0], + self.pass_inputs["standard"][1], + ) + + in_channels = input_.shape[1] + out_channels = exp_out_shape[1] + + resid_kwargs, kwargs = list(kwargs), list(kwargs) + kwargs[resid_index] = False + resid_kwargs[resid_index] = True + + try: + network = Unet2d(in_channels, out_channels, *kwargs) + resid_network = Unet2d(in_channels, out_channels, *resid_kwargs) + + return network.parameters(), resid_network.parameters() + except Exception as e: + self.excep = e + return None, None + + def _all_test_configurations(self, test, verbose=True): + """ + Run specified test on all possible 2d Unet input configurations. + Send failure information to stdout. + + Current tests: + - Initialization and input->output for cartesian product of parameters + - shape matching (single-channel, multi-channel) + - residual (same number of trainable params) + - kernel shapes (nonsquare doesnt break functionality) + + :param str test: which test to run. Must be within {'passing', 'failing', 'residual'} + :param bool verbose: Verbosity of str output + """ + self.SetUp() + + configs_list = [self.configs[key] for key in self.configs] + configs_list = list(itertools.product(*configs_list)) + failed_tests = collections.defaultdict(lambda: []) + + print("Testing", len(configs_list), "configurations:") if verbose else None + + for i, args in enumerate(configs_list): + if test == "passing": + # test passing shapes + _, outputs, exp_out_shapes = self._get_output_shapes(args, True) + out_shapes = [ + ar.detach().numpy().shape if isinstance(ar, torch.Tensor) else ar + for ar in outputs + ] + try: + out_shapes = np.array(out_shapes, dtype=object) + exp_out_shapes = np.array(exp_out_shapes, dtype=object) + fail_message = ( + f"'Passing' input tests failed on config {i+1} \n args: {args}" + ) + np.testing.assert_array_equal( + out_shapes, exp_out_shapes, fail_message + ) + except: + failed_tests[i].append(args) + failed_tests[i].append(self.excep) + elif test == "failing": + # test failing shapes + _, outputs, exp_out_shapes = self._get_output_shapes(args, False) + out_shapes = [ + ar.detach().numpy().shape if isinstance(ar, torch.Tensor) else ar + for ar in outputs + ] + try: + out_shapes = np.array(out_shapes, dtype=object) + exp_out_shapes = np.array(exp_out_shapes, dtype=object) + fail_message = ( + f"\t'Failing' tests failed on config {i+1} \n args: {args}" + ) + np.testing.assert_array_equal( + out_shapes, exp_out_shapes, fail_message + ) + except: + failed_tests[i].append(args) + failed_tests[i].append(self.excep) + elif test == "residual": + # test residual + resid_index = 2 + if args[resid_index] == False: + params, resid_params = self._get_residual_params(args, resid_index) + try: + fail_message = f"\t Residual params tests failed on config {i+1} \n args: {args}" + np.testing.assert_equal( + len(list(params)), len(list(resid_params)), fail_message + ) + except: + failed_tests[i].append(args) + failed_tests[i].append(self.excep) + + io_utils.show_progress_bar(configs_list, i, process="testing", interval=10) + if verbose: + print( + f"Testing complete! {len(configs_list)-len(failed_tests)}/{len(configs_list)} passed." + ) + if len(failed_tests) > 0: + print(f"Failed messages:") + for key in failed_tests: + print(f"Config {key}: {failed_tests[key]}") + + # -------------- Tests -----------------# + + def test_residual(self): + """ + Test residual functionality 2D Unet + + Test that residual blocks do not contain additional parameters + Runs test with every possible block configuration. + """ + self._all_test_configurations(test="residual") + + def test_passing(self): + """ + Test passing input functionality 2D Unet + + Test input-output functionality and expected output shape of all passing input shapes. + Runs test with every possible block configuration. + """ + self._all_test_configurations(test="passing") + + def test_failing(self): + """ + Test failing input handling 2D Unet + + Checks to see if all failing input types are caught by conv block. + Runs test with every possible block configuration. + """ + self._all_test_configurations(test="failing") diff --git a/tests/torch_unet/networks/layers/ConvBlock2D_tests.py b/tests/torch_unet/networks/layers/ConvBlock2D_tests.py new file mode 100644 index 00000000..59aabd8b --- /dev/null +++ b/tests/torch_unet/networks/layers/ConvBlock2D_tests.py @@ -0,0 +1,255 @@ +import collections +import itertools +import unittest + +import numpy as np +import torch + +import viscy.utils.cli_utils as io_utils +from viscy.unet.networks.layers.ConvBlock2D import ConvBlock2D + + +class TestConvBlock2D(unittest.TestCase): + """ + Testing class for all configurations of the 2d conv block + Functionality of core PyTorch and nummpy operations assumed to be + complete and sound. + """ + + def SetUp(self): + """ + Set up inputs and block configurations + """ + # possible inputs and output shapes + self.pass_inputs = { + "standard": [torch.ones((1, 1, 256, 256)), (1, 4, 256, 256)], + "down": [torch.ones((1, 8, 16, 16)), (1, 4, 16, 16)], + "batch": [torch.ones((8, 1, 16, 16)), (8, 4, 16, 16)], + "small": [torch.ones((1, 1, 8, 8)), (1, 4, 8, 8)], + } + self.fail_inputs = { + "nonsquare": [torch.ones((1, 1, 16, 8)), (1, 1, 16, 8)], + "wrong_dims": [torch.ones((1, 1, 1)), (1, 1, 1)], + } + # possible configurations + self.configs = { + "dropout": (False, 0.25), + "norm": ("batch", "instance"), + "residual": (False, True), + "activation": ("relu", "leakyrelu", "selu"), + "transpose": [False], # True yields padding error in pytorch 1.10 + "kernel_size": (1, (3, 3), (5, 3)), + "num_layers": (1, 5), + "filter_steps": ("linear", "first", "last"), + } + + def _get_outputs(self, kwargs): + """ + Template testing class + + :param list kwargs: list of arguments for ConvBlock2D object + + :return numpy.ndarray inputs: inputs to convblock + :return numpy.ndarray outputs: outputs from convblock, respective + :return tuple exp_out: expected output + """ + input_, exp_out_shape = ( + self.pass_inputs["standard"][0], + self.pass_inputs["standard"][1], + ) + + in_filters = input_.shape[1] + out_filters = exp_out_shape[1] + + block = ConvBlock2D(in_filters, out_filters, *kwargs) + + try: + output = block(input_) + input_, output = input_.detach().numpy(), output.detach().numpy() + exp_out = output + return input_, output, exp_out + except Exception as e: + self.excep = e + input_.detach().numpy() + return input_, np.asarray([0]), np.asarray([1]) + + def _get_input_shapes(self, kwargs, pass_): + """ + Gets outputs for all inputs of type 'pass_' + + If inputs expected to fail, exp_out_shape will be False + + :param list kwargs: list of arguments for ConvBlock2D object + :param boolean pass_: whether inputs are expected to pass tests + + :return list inputs: list of inputs to convblock + :return list outputs: list of outputs from convblock, respective + :return list exp_out_shapes: list of expected output shapes from + convblock, respective + """ + inputs, outputs, exp_out_shapes = [], [], [] + test_inputs = self.pass_inputs if pass_ else self.fail_inputs + for test in test_inputs: + input_, exp_out_shape = test_inputs[test][0], test_inputs[test][1] + + in_filters = input_.shape[1] + out_filters = exp_out_shape[1] + + block = ConvBlock2D(in_filters, out_filters, *kwargs) + + try: + output = block(input_) + inputs.append(input_) + outputs.append(output) + exp_out_shapes.append(exp_out_shape) + except Exception as e: + self.excep = e + inputs.append(input_) + outputs.append(False) + exp_out_shapes.append(exp_out_shape if pass_ else False) + + return inputs, outputs, exp_out_shapes + + def _get_residual_params(self, kwargs, resid_index): + """ + Gets parameters of residual and nonresidual blocks + + :param list kwargs: list of arguments for ConvBlock2D object + :param int resid_index: index of residual parameter in kwargs + + :return nn.module.parameter params: trainable parameters of non-residual block + :return nn.module.parameter resid_params: trainable parameters of residual block + """ + input_, exp_out_shape = ( + self.pass_inputs["standard"][0], + self.pass_inputs["standard"][1], + ) + + in_filters = input_.shape[1] + out_filters = exp_out_shape[1] + + resid_kwargs, kwargs = list(kwargs), list(kwargs) + kwargs[resid_index] = False + resid_kwargs[resid_index] = True + + try: + block = ConvBlock2D(in_filters, out_filters, *kwargs) + resid_block = ConvBlock2D(in_filters, out_filters, *resid_kwargs) + + return block.parameters(), resid_block.parameters() + except Exception as e: + self.excep = e + return None, None + + def _all_test_configurations(self, test, verbose=True): + """ + Run specified test on all possible ConvBlock2D input configurations. + + Current tests: + - input->output for cartesian product of parameters + - shape matching (upsampling, downsampling) + - residual (same number of trainable params) + - kernel shapes (nonsquare doesnt break functionality) + + :param str test: which test to run. Must be within {'passing', 'failing', 'residual'} + :param bool verbose: Verbosity of str output + """ + self.SetUp() + + configs_list = [self.configs[key] for key in self.configs] + configs_list = list(itertools.product(*configs_list)) + failed_tests = collections.defaultdict(lambda: []) + + print("\n Testing", len(configs_list), "configurations:") if verbose else None + for i, args in enumerate(configs_list): + if test == "pasing": + # test passing shapes + _, outputs, exp_out_shapes = self._get_input_shapes(args, True) + out_shapes = [ + ar.detach().numpy().shape if isinstance(ar, torch.Tensor) else ar + for ar in outputs + ] + try: + out_shapes = np.array(out_shapes, dtype=object) + exp_out_shapes = np.array(exp_out_shapes, dtype=object) + fail_message = ( + f"'Passing' input tests failed on config {i+1} \n args: {args}" + ) + np.testing.assert_array_equal( + out_shapes, exp_out_shapes, fail_message + ) + except: + failed_tests[i].append(args) + failed_tests[i].append(self.excep) + if test == "failing": + # test failing shapes + _, outputs, exp_out_shapes = self._get_input_shapes(args, False) + out_shapes = [ + ar.detach().numpy().shape if isinstance(ar, torch.Tensor) else ar + for ar in outputs + ] + try: + out_shapes = np.array(out_shapes, dtype=object) + exp_out_shapes = np.array(exp_out_shapes, dtype=object) + fail_message = ( + f"\t'Failing' tests failed on config {i+1} \n args: {args}" + ) + np.testing.assert_array_equal( + out_shapes, exp_out_shapes, fail_message + ) + except: + failed_tests[i].append(args) + failed_tests[i].append(self.excep) + if test == "residual": + # test residual + resid_index = 2 + if args[resid_index] == False: + params, resid_params = self._get_residual_params(args, resid_index) + try: + fail_message = f"\t Residual params tests failed on config {i+1} \n args: {args}" + np.testing.assert_equal( + len(list(params)), len(list(resid_params)), fail_message + ) + except: + failed_tests[i].append(args) + failed_tests[i].append(self.excep) + + io_utils.show_progress_bar(configs_list, i, process="testing", interval=10) + + if verbose: + print( + f"Testing complete! {len(configs_list)-len(failed_tests)}/{len(configs_list)} passed." + ) + if len(failed_tests) > 0: + print(f"Failed messages:") + for key in failed_tests: + print(f"Config {key}: {failed_tests[key]}") + + # -------------- Tests -----------------# + + def test_residual(self): + """ + Test residual functionality 2D ConvBlock + + Test that residual blocks do not contain additional parameters + Runs test with every possible block configuration. + """ + self._all_test_configurations(test="residual") + + def test_passing(self): + """ + Test passing input functionality 2D ConvBlock + + Test input-output functionality and expected output shape of all passing input shapes. + Runs test with every possible block configuration. + """ + self._all_test_configurations(test="passing") + + def test_failing(self): + """ + Test failing input handling 2D ConvBlock + + Checks to see if all failing input types are caught by conv block. + Runs test with every possible block configuration. + """ + self._all_test_configurations(test="failing") diff --git a/tests/torch_unet/networks/layers/ConvBlock3D_tests.py b/tests/torch_unet/networks/layers/ConvBlock3D_tests.py new file mode 100644 index 00000000..c4832d9a --- /dev/null +++ b/tests/torch_unet/networks/layers/ConvBlock3D_tests.py @@ -0,0 +1,251 @@ +import collections +import itertools +import unittest + +import numpy as np +import torch + +import viscy.utils.cli_utils as io_utils +from viscy.unet.networks.layers.ConvBlock3D import ConvBlock3D + + +class TestConvBlock3D(unittest.TestCase): + """ + Testing class for all configurations of the 3d conv block + Functionality of core PyTorch and nummpy operations assumed to be + complete and sound. + """ + + def SetUp(self): + """ + Set up inputs and block configurations + """ + # possible inputs and output shapes + self.pass_inputs = { + "standard": [torch.ones((1, 1, 5, 256, 256)), (1, 4, 5, 256, 256)], + "down": [torch.ones((1, 8, 3, 16, 16)), (1, 4, 3, 16, 16)], + "batch": [torch.ones((8, 1, 3, 16, 16)), (8, 4, 3, 16, 16)], + "deep": [torch.ones((1, 1, 30, 16, 16)), (1, 4, 30, 16, 16)], + "small": [torch.ones((1, 1, 4, 4, 4)), (1, 4, 4, 4, 4)], + } + self.fail_inputs = { + "nonsquare": [torch.ones((1, 1, 4, 16, 8)), (1, 1, 4, 16, 8)], + "wrong_dims": [torch.ones((1, 1, 1)), (1, 1, 1)], + } + # possible configurations + self.configs = { + "dropout": (False, 0.25), + "norm": ("batch", "instance"), + "residual": (True, False), + "activation": ("relu", "leakyrelu", "selu"), + "transpose": [False], # True yields padding error in pytorch 1.10 + "kernel_size": (1, (3, 3, 3), (3, 3, 5)), + "num_layers": (1, 5), + "filter_steps": ("linear", "first", "last"), + } + + def _get_outputs(self, kwargs): + """ + Template testing class + + :param list kwargs: list of arguments for ConvBlock3D object + + :return numpy.ndarray inputs: inputs to convblock + :return numpy.ndarray outputs: outputs from convblock, respective + :return tuple exp_out: expected output + """ + input_, exp_out_shape = ( + self.pass_inputs["standard"][0], + self.pass_inputs["standard"][1], + ) + + in_filters = input_.shape[1] + out_filters = exp_out_shape[1] + + block = ConvBlock3D(in_filters, out_filters, *kwargs) + + try: + output = block(input_) + input_, output = input_.detach().numpy(), output.detach().numpy() + exp_out = output + return input_, output, exp_out + except: + input_.detach().numpy() + return input_, np.ones((1, 1)), np.zeros((1, 1)) + + def _get_output_shapes(self, kwargs, pass_): + """ + Gets outputs for all inputs of type 'pass_' + + If inputs expected to fail, exp_out_shape will be False + + :param list kwargs: list of arguments for ConvBlock3D object + :param boolean pass_: whether inputs are expected to pass tests + + :return list inputs: list of inputs to convblock + :return list outputs: list of outputs from convblock, respective + :return list exp_out_shapes: list of expected output shapes from + convblock, respective + """ + inputs, outputs, exp_out_shapes = [], [], [] + test_inputs = self.pass_inputs if pass_ else self.fail_inputs + for test in test_inputs: + input_, exp_out_shape = test_inputs[test][0], test_inputs[test][1] + + in_filters = input_.shape[1] + out_filters = exp_out_shape[1] + + block = ConvBlock3D(in_filters, out_filters, *kwargs) + + try: + output = block(input_) + inputs.append(input_) + outputs.append(output) + exp_out_shapes.append(exp_out_shape) + except: + inputs.append(input_) + outputs.append(False) + exp_out_shapes.append(exp_out_shape if pass_ else False) + + return inputs, outputs, exp_out_shapes + + def _get_residual_params(self, kwargs, resid_index): + """ + Gets parameters of residual and nonresidual blocks + + :param list kwargs: list of arguments for ConvBlock3D object + :param int resid_index: index of residual parameter in kwargs + + :return nn.module.parameter params: trainable parameters of non-residual block + :return nn.module.parameter resid_params: trainable parameters of residual block + """ + input_, exp_out_shape = ( + self.pass_inputs["standard"][0], + self.pass_inputs["standard"][1], + ) + + in_filters = input_.shape[1] + out_filters = exp_out_shape[1] + + resid_kwargs, kwargs = list(kwargs), list(kwargs) + kwargs[resid_index] = False + resid_kwargs[resid_index] = True + + try: + block = ConvBlock3D(in_filters, out_filters, *kwargs) + resid_block = ConvBlock3D(in_filters, out_filters, *resid_kwargs) + + return block.parameters(), resid_block.parameters() + except: + return None, None + + def _all_test_configurations(self, test, verbose=True): + """ + Run specified test on all possible ConvBlock3D input configurations. + Send failure information to stdout. + + Current tests: + - input->output for cartesian product of parameters + - shape matching (upsampling, downsampling) + - residual (same number of trainable params) + - kernel shapes (nonsquare doesnt break functionality) + + :param str test: which test to run. Must be within {'passing', 'failing', 'residual'} + :param bool verbose: Verbosity of str output + """ + self.SetUp() + + configs_list = [self.configs[key] for key in self.configs] + configs_list = list(itertools.product(*configs_list)) + failed_tests = collections.defaultdict(lambda: []) + + print("Testing", len(configs_list), "configurations:") if verbose else None + + for i, args in enumerate(configs_list): + if test == "passing": + # test passing shapes + _, outputs, exp_out_shapes = self._get_output_shapes(args, True) + out_shapes = [ + ar.detach().numpy().shape if isinstance(ar, torch.Tensor) else ar + for ar in outputs + ] + try: + out_shapes = np.array(out_shapes, dtype=object) + exp_out_shapes = np.array(exp_out_shapes, dtype=object) + fail_message = ( + f"'Passing' input tests failed on config {i+1} \n args: {args}" + ) + np.testing.assert_array_equal( + out_shapes, exp_out_shapes, fail_message + ) + except: + failed_tests[i].append(args) + elif test == "failing": + # test failing shapes + _, outputs, exp_out_shapes = self._get_output_shapes(args, False) + out_shapes = [ + ar.detach().numpy().shape if isinstance(ar, torch.Tensor) else ar + for ar in outputs + ] + try: + out_shapes = np.array(out_shapes, dtype=object) + exp_out_shapes = np.array(exp_out_shapes, dtype=object) + fail_message = ( + f"\t'Failing' tests failed on config {i+1} \n args: {args}" + ) + np.testing.assert_array_equal( + out_shapes, exp_out_shapes, fail_message + ) + except: + failed_tests[i].append(args) + elif test == "residual": + # test residual + resid_index = 2 + if args[resid_index] == False: + params, resid_params = self._get_residual_params(args, resid_index) + try: + fail_message = f"\t Residual params tests failed on config {i+1} \n args: {args}" + np.testing.assert_equal( + len(list(params)), len(list(resid_params)), fail_message + ) + except: + failed_tests[i].append(args) + + io_utils.show_progress_bar(configs_list, i, process="testing", interval=10) + if verbose: + print( + f"Testing complete! {len(configs_list)-len(failed_tests)}/{len(configs_list)} passed." + ) + if len(failed_tests) > 0: + print(f"Failed messages:") + for key in failed_tests: + print(f"Config {key}: {failed_tests[key]}") + + # -------------- Tests -----------------# + + def test_residual(self): + """ + Test residual functionality 3D ConvBlock + + Test that residual blocks do not contain additional parameters + Runs test with every possible block configuration. + """ + self._all_test_configurations(test="residual") + + def test_passing(self): + """ + Test passing input functionality 3D ConvBlock + + Test input-output functionality and expected output shape of all passing input shapes. + Runs test with every possible block configuration. + """ + self._all_test_configurations(test="passing") + + def test_failing(self): + """ + Test failing input handling 3D ConvBlock + + Checks to see if all failing input types are caught by conv block. + Runs test with every possible block configuration. + """ + self._all_test_configurations(test="failing") diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/utils/image_utils_tests.py b/tests/utils/image_utils_tests.py new file mode 100644 index 00000000..f90ab0e3 --- /dev/null +++ b/tests/utils/image_utils_tests.py @@ -0,0 +1,48 @@ +import numpy as np + +from viscy.utils.image_utils import grid_sample_pixel_values, preprocess_image + + +def test_grid_sample_pixel_values(): + im = np.zeros((15, 20)) + row_ids, col_ids, sample_values = grid_sample_pixel_values( + im, + grid_spacing=5, + ) + assert row_ids.tolist() == [5, 5, 5, 10, 10, 10] + assert col_ids.tolist() == [5, 10, 15, 5, 10, 15] + assert sample_values.tolist() == [0, 0, 0, 0, 0, 0] + + +def test_preprocess_image(self): + im = np.zeros((5, 10, 15, 1)) + im[:, :5, :, :] = 10 + im_proc = preprocess_image( + im, + hist_clip_limits=(0, 100), + ) + self.assertEqual(np.mean(im), np.mean(im_proc)) + self.assertTupleEqual(im_proc.shape, (5, 10, 15)) + + +def test_preprocess_image_norm(self): + im = np.zeros((5, 10, 15)) + im[:, :5, :] = 10 + im_proc = preprocess_image( + im, + normalize_im="dataset", + ) + self.assertEqual(np.mean(im_proc), 0.0) + self.assertTupleEqual(im.shape, im_proc.shape) + + +def test_preprocess_image_mask(self): + im = np.zeros((5, 10, 15)) + im[:, :5, :] = 10 + im_proc = preprocess_image( + im, + is_mask=True, + ) + self.assertEqual(np.mean(im_proc), 0.5) + self.assertTupleEqual(im.shape, im_proc.shape) + self.assertTrue(im_proc.dtype == bool) diff --git a/tests/utils/masks_utils_tests.py b/tests/utils/masks_utils_tests.py new file mode 100644 index 00000000..71d437db --- /dev/null +++ b/tests/utils/masks_utils_tests.py @@ -0,0 +1,56 @@ +import nose.tools +import numpy as np +from skimage import draw +from skimage.filters import gaussian + +from viscy.utils.masks import ( + create_unimodal_mask, + get_unet_border_weight_map, + get_unimodal_threshold, +) + +uni_thr_tst_image = np.zeros((31, 31)) +uni_thr_tst_image[5:10, 8:16] = 127 +uni_thr_tst_image[11:21, 2:12] = 97 +uni_thr_tst_image[8:12, 3:7] = 31 +uni_thr_tst_image[17:29, 17:29] = 61 +uni_thr_tst_image[3:14, 17:29] = 47 + + +def test_get_unimodal_threshold(): + input_image = gaussian(uni_thr_tst_image, 1) + best_thr = get_unimodal_threshold(input_image) + nose.tools.assert_equal(np.floor(best_thr), 3.0) + + +def test_unimodal_thresholding(): + input_image = gaussian(uni_thr_tst_image, 1) + mask = create_unimodal_mask(input_image, str_elem_size=0) + nose.tools.assert_equal(input_image.shape, mask.shape) + nose.tools.assert_true(mask.dtype, bool) + # Check that mask is somewhat close to simple thresholding + thresh_im = input_image > 3.04 + nose.tools.assert_true( + np.abs(np.mean(mask) - np.mean(thresh_im)) < 0.1, + ) + + +def test_get_unet_border_weight_map(): + # Creating a test image with 3 circles + # 2 close to each other and one far away + radius = 10 + params = [(20, 16, radius), (44, 16, radius), (47, 47, radius)] + mask = np.zeros((64, 64), dtype=np.uint8) + for i, (cx, cy, radius) in enumerate(params): + rr, cc = draw.disk((cx, cy), radius) + mask[rr, cc] = i + 1 + + weight_map = get_unet_border_weight_map(mask) + + max_weight_map = np.max(weight_map) + # weight map between 20, 16 and 44, 16 should be maximum + # as there is more weight when two objects boundaries overlap + y_coord = params[0][1] + for x_coord in range(params[0][0] + radius, params[1][0] - radius): + distance_near_intersection = weight_map[x_coord, y_coord] + nose.tools.assert_equal(max_weight_map, distance_near_intersection) diff --git a/tests/utils/mp_utils_tests.py b/tests/utils/mp_utils_tests.py new file mode 100644 index 00000000..89b45255 --- /dev/null +++ b/tests/utils/mp_utils_tests.py @@ -0,0 +1,152 @@ +import os +import unittest +import warnings + +import numpy as np +import numpy.testing +import skimage.io as sk_im_io +from testfixtures import TempDirectory + +import viscy.utils.aux_utils as aux_utils +import viscy.utils.image_utils as image_utils +import viscy.utils.mp_utils as mp_utils +from viscy.utils.masks import create_otsu_mask + + +class TestMpUtilsBaseClass(unittest.TestCase): + def get_sphere(self, shape=(32, 32, 8)): + # create an image with bimodal hist + x = np.linspace(-4, 4, shape[0]) + y = x.copy() + z = np.linspace(-3, 3, shape[2]) + xx, yy, zz = np.meshgrid(x, y, z) + sph = xx**2 + yy**2 + zz**2 + fg = (sph <= shape[2]) * (shape[2] - sph) + fg[fg > 1e-8] = (fg[fg > 1e-8] / np.max(fg)) * 127 + 128 + fg = np.around(fg).astype("uint8") + bg = np.around((sph > shape[2]) * sph).astype("uint8") + sph = fg + bg + return sph + + def get_rect(self, shape=(32, 32, 8)): + rect = np.zeros(shape) + rect[3:30, 14:18, 3:6] = 120 + rect[14:18, 3:30, 3:6] = 120 + return rect + + def get_name(self, ch_idx, sl_idx, time_idx, pos_idx): + im_name = ( + "im_c" + + str(ch_idx).zfill(self.int2str_len) + + "_z" + + str(sl_idx).zfill(self.int2str_len) + + "_t" + + str(time_idx).zfill(self.int2str_len) + + "_p" + + str(pos_idx).zfill(self.int2str_len) + + ".png" + ) + return im_name + + def setUp(self): + """Set up a directory for mask generation""" + + self.tempdir = TempDirectory() + self.temp_path = self.tempdir.path + self.meta_fname = "frames_meta.csv" + self.frames_meta = aux_utils.make_dataframe() + self.channel_ids = [1, 2] + self.time_ids = 0 + self.pos_ids = 1 + self.int2str_len = 3 + + def write_data_in_meta_csv(self, array, frames_meta, ch_idx): + for z in range(array.shape[2]): + im_name = self.get_name(ch_idx, z, self.time_ids, self.pos_ids) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + sk_im_io.imsave( + os.path.join(self.temp_path, im_name), + array[:, :, z].astype("uint8"), + ) + frames_meta = frames_meta.append( + aux_utils.parse_idx_from_name(im_name=im_name, dir_name=self.temp_path), + ignore_index=True, + ) + return frames_meta + + def tearDown(self): + """Tear down temporary folder and file structure""" + TempDirectory.cleanup_all() + self.assertFalse(os.path.isdir(self.temp_path)) + + +class TestMpUtilsOtsu(TestMpUtilsBaseClass): + def setUp(self): + super().setUp() + + def write_mask_data(self): + self.sph_object = self.get_sphere() + self.rect_object = self.get_rect() + + frames_meta = self.write_data_in_meta_csv(self.sph_object, self.frames_meta, 1) + frames_meta = self.write_data_in_meta_csv(self.rect_object, frames_meta, 2) + self.frames_meta = frames_meta + # Write metadata + self.frames_meta.to_csv(os.path.join(self.temp_path, self.meta_fname), sep=",") + self.output_dir = os.path.join(self.temp_path, "mask_dir") + os.makedirs(self.output_dir, exist_ok=True) + + def test_create_save_mask_otsu(self): + """test create_save_mask otsu""" + self.write_mask_data() + for sl_idx in range(8): + channels_meta_sub = aux_utils.get_sub_meta( + frames_metadata=self.frames_meta, + time_ids=self.time_ids, + channel_ids=self.channel_ids, + slice_ids=sl_idx, + pos_ids=self.pos_ids, + ) + cur_meta = mp_utils.create_save_mask( + channels_meta_sub=channels_meta_sub, + str_elem_radius=1, + mask_dir=self.output_dir, + mask_channel_idx=3, + int2str_len=3, + mask_type="otsu", + mask_ext=".png", + ) + fname = aux_utils.get_im_name( + time_idx=self.time_ids, + channel_idx=3, + slice_idx=sl_idx, + pos_idx=self.pos_ids, + ) + self.assertEqual(cur_meta["channel_idx"], 3) + self.assertEqual(cur_meta["slice_idx"], sl_idx) + self.assertEqual(cur_meta["time_idx"], self.time_ids) + self.assertEqual(cur_meta["pos_idx"], self.pos_ids) + self.assertEqual(cur_meta["file_name"], fname) + # Check that mask file has been written + op_fname = os.path.join(self.output_dir, fname) + self.assertTrue(os.path.exists(op_fname)) + # Read mask iamge + mask_image = image_utils.read_image(op_fname) + if mask_image.dtype != bool: + mask_image = mask_image > 0 + input_image = ( + self.sph_object[:, :, sl_idx], + self.rect_object[:, :, sl_idx], + ) + mask_stack = np.stack( + [ + create_otsu_mask(input_image[0], str_elem_size=1), + create_otsu_mask(input_image[1], str_elem_size=1), + ] + ) + mask_exp = np.any(mask_stack, axis=0) + numpy.testing.assert_array_equal( + mask_image, + mask_exp, + ) diff --git a/tests/utils/test_aux_utils.py b/tests/utils/test_aux_utils.py new file mode 100644 index 00000000..09cfea5e --- /dev/null +++ b/tests/utils/test_aux_utils.py @@ -0,0 +1,13 @@ +import json +from pathlib import Path + +from viscy.utils.aux_utils import read_config + + +def test_read_config(tmp_path: Path): + config = tmp_path / "config.yml" + # The function doesn't care about file format, names just have to start with im_ + test_config = {"param": 10} + config.write_text(json.dumps(test_config)) + config = read_config(str(config)) + assert config == test_config diff --git a/viscy/__init__.py b/viscy/__init__.py new file mode 100644 index 00000000..31573ed3 --- /dev/null +++ b/viscy/__init__.py @@ -0,0 +1 @@ +"""Learning vision for cells""" diff --git a/viscy/cli/__init__.py b/viscy/cli/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/viscy/cli/cli.py b/viscy/cli/cli.py new file mode 100644 index 00000000..acb72774 --- /dev/null +++ b/viscy/cli/cli.py @@ -0,0 +1,47 @@ +from datetime import datetime + +import torch +from jsonargparse import lazy_instance +from lightning.pytorch.cli import LightningCLI +from lightning.pytorch.loggers import TensorBoardLogger + +from viscy.light.data import HCSDataModule +from viscy.light.engine import VSTrainer, VSUNet + + +class VSLightningCLI(LightningCLI): + """Extending lightning CLI arguments and defualts.""" + + @staticmethod + def subcommands() -> dict[str, set[str]]: + subcommands = LightningCLI.subcommands() + subcommands["export"] = {"model", "dataloaders", "datamodule"} + return subcommands + + def add_arguments_to_parser(self, parser): + parser.link_arguments("data.batch_size", "model.batch_size") + parser.link_arguments("data.yx_patch_size", "model.example_input_yx_shape") + parser.link_arguments("model.model_config.architecture", "data.architecture") + parser.set_defaults( + { + "trainer.logger": lazy_instance( + TensorBoardLogger, + save_dir="", + version=datetime.now().strftime(r"%Y%m%d-%H%M%S"), + log_graph=True, + ) + } + ) + + +def main(): + torch.set_float32_matmul_precision("high") + _ = VSLightningCLI( + model_class=VSUNet, + datamodule_class=HCSDataModule, + trainer_class=VSTrainer, + ) + + +if __name__ == "__main__": + main() diff --git a/viscy/cli/curator_script.py b/viscy/cli/curator_script.py new file mode 100644 index 00000000..1c35da2d --- /dev/null +++ b/viscy/cli/curator_script.py @@ -0,0 +1,187 @@ +# %% script to generate your ground truth directory for viscy prediction evaluation +# After inference, the predictions generated are stored as zarr store. +# Evaluation metrics can be computed by comparison of prediction +# to human proof read ground truth. + +import argparse +import os + +import imageio as iio +import iohub.ngff as ngff +import numpy as np +from PIL import Image + +import viscy.evaluation.evaluation_metrics as metrics +import viscy.utils.aux_utils as aux_utils + +# from waveorder.focus import focus_from_transverse_band + +# %% read the below details from the config file + + +def parse_args(): + """ + Parse command line arguments + In python namespaces are implemented as dictionaries + + :return: namespace containing the arguments passed. + """ + parser = argparse.ArgumentParser() + + parser.add_argument( + "--config", + type=str, + help="path to yaml configuration file", + ) + args = parser.parse_args() + return args + + +def main(config): + """ + pick the focus slice from n_pos number of positions, cellpose segment, + and save as TIFFs. + also save segmentation input and label-free image as tifs for ground truth curation + segment fluorescence predictions and store mask as new channel + """ + + torch_config = aux_utils.read_config(config) + + zarr_dir = torch_config["data"]["data_path"] + pred_dir = torch_config["evaluation_metrics"]["pred_dir"] + ground_truth_chans = torch_config["data"]["target_channel"] + labelFree_chan = torch_config["data"]["source_channel"] + PosList = torch_config["evaluation_metrics"]["PosList"] + z_list = torch_config["evaluation_metrics"]["z_list"] + cp_model = torch_config["evaluation_metrics"]["cp_model"] + metric_channel = torch_config["evaluation_metrics"]["metric_channel"] + + # if torch_config["evaluation_metrics"]["NA_det"] is None: + # NA_det = 1.3 + # lambda_illu = 0.4 + # pxl_sz = 0.103 + # else: + # NA_det = torch_config["evaluation_metrics"]["NA_det"] + # lambda_illu = torch_config["evaluation_metrics"]["lambda_illu"] + # pxl_sz = torch_config["evaluation_metrics"]["pxl_sz"] + + ground_truth_subdir = "ground_truth" + path_split_head_tail = os.path.split(pred_dir) + target_zarr_dir, _zarr_name = path_split_head_tail[0] + + if not os.path.exists(os.path.join(target_zarr_dir, ground_truth_subdir)): + os.mkdir( + os.path.join(target_zarr_dir, ground_truth_subdir) + ) # create dir to store single page tifs + plate = ngff.open_ome_zarr(store_path=zarr_dir, mode="r+") + chan_names = plate.channel_names + + for position, pos_data in plate.positions(): + im = pos_data.data + # im = plate.data + out_shape = im.shape + # zarr_pos_len = reader.get_num_positions() + try: + assert len(PosList) > out_shape[0] + except AssertionError: + print( + "number of positions listed in config exceeds " + "the number of positions in the dataset" + ) + pos = int(position.split("/")[-1]) + for gt_chan in ground_truth_chans: + if pos in PosList: + idx = PosList.index(pos) + target_data = im[0, chan_names.index(gt_chan), ...] + Z, Y, X = target_data.shape + focus_idx_target = z_list[idx] + # focus_idx_target = focus_from_transverse_band( + # target_data, NA_det, lambda_illu, pxl_sz + # ) + target_focus_slice = target_data[ + focus_idx_target, :, : + ] # FL focus slice image + + im_target = Image.fromarray( + target_focus_slice + ) # save focus slice as single page tif + save_name = ( + "_p" + str(format(pos, "03d")) + "_z" + str(focus_idx_target) + ) + im_target.save( + os.path.join( + target_zarr_dir, + ground_truth_subdir, + gt_chan + save_name + ".tif", + ) + ) + + source_focus_slice = im[ + 0, chan_names.index(labelFree_chan[0]), focus_idx_target, :, : + ] # lable-free focus slice image + im_source = Image.fromarray( + source_focus_slice + ) # save focus slice as single page tif + im_source.save( + os.path.join( + target_zarr_dir, + ground_truth_subdir, + labelFree_chan[0] + save_name + ".tif", + ) + ) # save for reference + + cp_mask = metrics.cpmask_array( + target_focus_slice, cp_model + ) # cellpose segmnetation for binary mask + iio.imwrite( + os.path.join( + target_zarr_dir, + ground_truth_subdir, + gt_chan + save_name + "_cp_mask.png", + ), + cp_mask, + ) # save binary mask as numpy or png + + # segment prediction and add mask as channel to pred_dir + pred_plate = ngff.open_ome_zarr(store_path=pred_dir, mode="r+") + # im_pred = pred_plate.data + chan_names = pred_plate.channel_names + + predseg_data = ngff.open_ome_zarr( + os.path.join(target_zarr_dir, metric_channel + "_pred.zarr"), + layout="hcs", + mode="w-", + channel_names=chan_names, + ) + for position, pos_data in pred_plate.positions(): + row, col, fov = position.split("/") + new_pos = predseg_data.create_position(row, col, fov) + + if int(fov) in PosList: + idx = PosList.index(int(fov)) + raw_data = pos_data.data + target_data = raw_data[:, :, z_list[idx]] + _, _, Y, X = target_data.shape + new_pos.create_image("0", target_data[np.newaxis, :]) + + chan_no = len(chan_names) + with ngff.open_ome_zarr( + os.path.join(target_zarr_dir, metric_channel + "_pred.zarr"), mode="r+" + ) as dataset: + for _, position in dataset.positions(): + data = position.data + new_channel_array = np.zeros((1, 1, Y, X)) + + cp_mask = metrics.cpmask_array( + data[0, chan_names.index(metric_channel), 0, :, :], cp_model + ) + new_channel_array[0, 0, :, :] = cp_mask + + new_channel_name = metric_channel + "_cp_mask" + position.append_channel(new_channel_name, resize_arrays=True) + position["0"][:, chan_no] = new_channel_array + + +if __name__ == "__main__": + args = parse_args() + main(args.config) diff --git a/viscy/cli/fit_example.yml b/viscy/cli/fit_example.yml new file mode 100644 index 00000000..a8274bcb --- /dev/null +++ b/viscy/cli/fit_example.yml @@ -0,0 +1,62 @@ +# lightning.pytorch==2.0.1 +seed_everything: true +trainer: + accelerator: auto + strategy: auto # ddp_find_unused_parameters_true for more GPUs + devices: auto + num_nodes: 1 + precision: 32-true + callbacks: null + fast_dev_run: false + max_epochs: null + min_epochs: null + max_steps: -1 + min_steps: null + max_time: null + limit_train_batches: null + limit_val_batches: null + limit_test_batches: null + limit_predict_batches: null + overfit_batches: 0.0 + val_check_interval: null + check_val_every_n_epoch: 1 + num_sanity_val_steps: null + log_every_n_steps: null + enable_checkpointing: null + enable_progress_bar: null + enable_model_summary: null + accumulate_grad_batches: 1 + gradient_clip_val: null + gradient_clip_algorithm: null + deterministic: null + benchmark: null + inference_mode: true + use_distributed_sampler: true + profiler: null + detect_anomaly: false + barebones: false + plugins: null + sync_batchnorm: true + reload_dataloaders_every_n_epochs: 0 + default_root_dir: null +model: + model_config: {} + loss_function: null + lr: 0.001 + schedule: Constant + log_num_samples: 8 +data: + data_path: null + source_channel: null + target_channel: null + z_window_size: null + split_ratio: null + batch_size: 16 + num_workers: 8 + yx_patch_size: + - 256 + - 256 + augment: true + caching: false + normalize_source: false +ckpt_path: null diff --git a/viscy/cli/metrics_script.py b/viscy/cli/metrics_script.py new file mode 100644 index 00000000..a9cb1afd --- /dev/null +++ b/viscy/cli/metrics_script.py @@ -0,0 +1,132 @@ +# %% script to generate your ground truth directory for viscy prediction evaluation +# After inference, the predictions generated are stored as zarr store. +# Evaluation metrics can be computed by comparison of prediction to +# human proof read ground truth. + +import argparse +import os + +import imageio as iio +import iohub.ngff as ngff +import pandas as pd + +import viscy.evaluation.evaluation_metrics as metrics +import viscy.utils.aux_utils as aux_utils + +# %% read the below details from the config file + + +def parse_args(): + """ + Parse command line arguments + In python namespaces are implemented as dictionaries + + :return: namespace containing the arguments passed. + """ + parser = argparse.ArgumentParser() + + parser.add_argument( + "--config", + type=str, + help="path to yaml configuration file", + ) + args = parser.parse_args() + return args + + +def main(config): + """ + pick focus slice mask from pred_zarr from slice number stored on png mask name + input pred mask & corrected ground truth mask to metrics computation + store the metrics values as csv file to corresponding positions in list + Info to be stored: + 1. position no, + 2. eval metrics values + """ + + torch_config = aux_utils.read_config(config) + + pred_dir = torch_config["evaluation_metrics"]["pred_dir"] + metric_channel = torch_config["evaluation_metrics"]["metric_channel"] + PosList = torch_config["evaluation_metrics"]["PosList"] + z_list = torch_config["evaluation_metrics"]["z_list"] + metrics_list = torch_config["evaluation_metrics"]["metrics"] + ground_truth_chans = torch_config["data"]["target_channel"] + ground_truth_subdir = "ground_truth" + + d_pod = [ + "OD_true_positives", + "OD_false_positives", + "OD_false_negatives", + "OD_precision", + "OD_recall", + "OD_f1_score", + ] + + metric_map = { + "ssim": metrics.ssim_metric, + "corr": metrics.corr_metric, + "r2": metrics.r2_metric, + "mse": metrics.mse_metric, + "mae": metrics.mae_metric, + "dice": metrics.dice_metric, + "IoU": metrics.IOU_metric, + "VI": metrics.VOI_metric, + "POD": metrics.POD_metric, + } + + path_split_head_tail = os.path.split(pred_dir) + target_zarr_dir = path_split_head_tail[0] + pred_plate = ngff.open_ome_zarr( + store_path=os.path.join(target_zarr_dir, metric_channel + "_pred.zarr"), + mode="r+", + ) + chan_names = pred_plate.channel_names + metric_chan_mask = metric_channel + "_cp_mask" + ground_truth_dir = os.path.join(target_zarr_dir, ground_truth_subdir) + + col_val = metrics_list[:] + if "POD" in col_val: + col_val.remove("POD") + for i in range(len(d_pod)): + col_val.insert(i + metrics_list.index("POD"), d_pod[i]) + df_metrics = pd.DataFrame(columns=col_val, index=PosList) + + for position, pos_data in pred_plate.positions(): + pos = int(position.split("/")[-1]) + + if pos in PosList: + idx = PosList.index(pos) + raw_data = pos_data.data + pred_mask = raw_data[0, chan_names.index(metric_chan_mask)] + + z_slice_no = z_list[idx] + gt_mask_save_name = ( + ground_truth_chans[0] + + "_p" + + str(format(pos, "03d")) + + "_z" + + str(z_slice_no) + + "_cp_mask.png" + ) + + gt_mask = iio.imread(os.path.join(ground_truth_dir, gt_mask_save_name)) + + pos_metric_list = [] + for metric_name in metrics_list: + metric_fn = metric_map[metric_name] + cur_metric_list = metric_fn( + gt_mask, + pred_mask[0], + ) + pos_metric_list = pos_metric_list + cur_metric_list + + df_metrics.loc[pos] = pos_metric_list + + csv_filename = os.path.join(ground_truth_dir, "GT_metrics.csv") + df_metrics.to_csv(csv_filename) + + +if __name__ == "__main__": + args = parse_args() + main(args.config) diff --git a/viscy/cli/predict_example.yml b/viscy/cli/predict_example.yml new file mode 100644 index 00000000..3a74da51 --- /dev/null +++ b/viscy/cli/predict_example.yml @@ -0,0 +1,69 @@ +# lightning.pytorch==2.0.1 +predict: + seed_everything: true + trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: 32-true + callbacks: + - class_path: viscy.light.predict_writer.HCSPredictionWriter + init_args: + output_store: null + write_input: false + write_interval: batch + fast_dev_run: false + max_epochs: null + min_epochs: null + max_steps: -1 + min_steps: null + max_time: null + limit_train_batches: null + limit_val_batches: null + limit_test_batches: null + limit_predict_batches: null + overfit_batches: 0.0 + val_check_interval: null + check_val_every_n_epoch: 1 + num_sanity_val_steps: null + log_every_n_steps: null + enable_checkpointing: null + enable_progress_bar: null + enable_model_summary: null + accumulate_grad_batches: 1 + gradient_clip_val: null + gradient_clip_algorithm: null + deterministic: null + benchmark: null + inference_mode: true + use_distributed_sampler: true + profiler: null + detect_anomaly: false + barebones: false + plugins: null + sync_batchnorm: false + reload_dataloaders_every_n_epochs: 0 + default_root_dir: null + model: + model_config: {} + loss_function: null + lr: 0.001 + schedule: Constant + log_num_samples: 8 + data: + data_path: null + source_channel: null + target_channel: null + z_window_size: null + split_ratio: null + batch_size: 16 + num_workers: 8 + yx_patch_size: + - 256 + - 256 + augment: true + caching: false + normalize_source: false + return_predictions: null + ckpt_path: null diff --git a/viscy/cli/preprocess_script.py b/viscy/cli/preprocess_script.py new file mode 100644 index 00000000..82c69ea7 --- /dev/null +++ b/viscy/cli/preprocess_script.py @@ -0,0 +1,143 @@ +"""Script for preprocessing stack""" +import argparse +import time + +import iohub.ngff as ngff + +import viscy.utils.aux_utils as aux_utils +import viscy.utils.meta_utils as meta_utils +from viscy.preprocessing.generate_masks import MaskProcessor + + +def parse_args(): + """Parse command line arguments + + In python namespaces are implemented as dictionaries + :return: namespace containing the arguments passed. + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + type=str, + help="path to yaml configuration file", + ) + args = parser.parse_args() + return args + + +def pre_process(torch_config): + """ + Preprocess data. Possible options are: + + normalize: Calculate values for on-the-fly normalization on a FOV & + dataset level + create_masks: Generate binary masks from given input channels + + This script will preprocess your dataset, save auxilary data and + associated metadata for on-the-fly processing during training. Masks + will be saved both as an additional channel and as an array tracked in + custom metadata. + + :param dict torch_config: 'master' torch config with subfields for all steps + of data analysis + :raises AssertionError: If 'masks' in preprocess_config contains both channels + and mask_dir (the former is for generating masks from a channel) + """ + time_start = time.time() + plate = ngff.open_ome_zarr(torch_config["zarr_dir"], layout="hcs", mode="r") + preprocess_config = torch_config["preprocessing"] + + # ----------------- Generate normalization values ----------------- + if "normalize" in preprocess_config: + print("Computing Normalization Values: ------------- \n") + # collect params + normalize_config = preprocess_config["normalize"] + + norm_num_workers = 4 + if "num_workers" in normalize_config: + norm_num_workers = normalize_config["num_workers"] + + norm_channel_ids = -1 + if "channel_ids" in normalize_config: + norm_channel_ids = normalize_config["channel_ids"] + + norm_block_size = 32 + if "block_size" in normalize_config: + norm_block_size = normalize_config["block_size"] + + meta_utils.generate_normalization_metadata( + zarr_dir=torch_config["zarr_dir"], + num_workers=norm_num_workers, + channel_ids=norm_channel_ids, + grid_spacing=norm_block_size, + ) + + # ------------------------Generate masks------------------------- + if "masks" in preprocess_config: + print("Generating Masks: ------------- \n") + # collect params + mask_config = preprocess_config["masks"] + + mask_channel_ids = -1 + if "channel_ids" in mask_config: + mask_channel_ids = mask_config["channel_ids"] + + mask_time_ids = -1 + if "time_ids" in mask_config: + mask_time_ids = mask_config["time_ids"] + + mask_pos_ids = -1 + + mask_num_workers = 4 + if "num_workers" in mask_config: + mask_num_workers = mask_config["num_workers"] + + mask_type = "unimodal" + if "thresholding_type" in mask_config: + mask_type = mask_config["thresholding_type"] + + overwrite_ok = True + if "allow_overwrite_old_mask" in mask_config: + overwrite_ok = mask_config["allow_overwrite_old_mask"] + + structuring_radius = 5 + if "structure_element_radius" in mask_config: + structuring_radius = mask_config["structure_element_radius"] + + # validate + if mask_type not in { + "unimodal", + "otsu", + "mem_detection", + "borders_weight_loss_map", + }: + raise ValueError( + f"Thresholding type {mask_type} must be one of: " + f"{['unimodal', 'otsu', 'mem_detection', 'borders_weight_loss_map']}" + ) + + # generate masks + mask_generator = MaskProcessor( + zarr_dir=torch_config["zarr_dir"], + channel_ids=mask_channel_ids, + time_ids=mask_time_ids, + pos_ids=mask_pos_ids, + num_workers=mask_num_workers, + mask_type=mask_type, + overwrite_ok=overwrite_ok, + ) + mask_generator.generate_masks(structure_elem_radius=structuring_radius) + + # ----------------------Generate weight map----------------------- + # TODO: determine if weight map generation should be offered in simultaneity + # with binary mask generation + + plate.close() + return time.time() - time_start + + +if __name__ == "__main__": + args = parse_args() + torch_config = aux_utils.read_config(args.config) + runtime = pre_process(torch_config) + print(f"Preprocessing complete. Runtime: {runtime:.2f} seconds") diff --git a/viscy/cli/readme.md b/viscy/cli/readme.md new file mode 100644 index 00000000..06395af2 --- /dev/null +++ b/viscy/cli/readme.md @@ -0,0 +1,28 @@ +# Command-line interface + +Access CLI help message by: + +```sh +viscy --help +``` + +## Exporting models to ONNX + +Current implementation will export a checkpoint to ONNX IR version 9 +and OP set version 18 with: + +```sh +viscy export -c config.yaml +``` + +Use argument `export_path` to configure where the output is stored. + +### Notes + +* For CPU sharing reasons, running an ONNX model +requires an exclusive node on HPC OR a non-distributed system (e.g. a PC). + +* Models must be located in a lighting training logs directory +with a valid `config.yaml` in order to be initialized. +This can be "hacked" by locating the config in a directory +called `checkpoints` beneath a valid config's directory. diff --git a/viscy/data_organization.md b/viscy/data_organization.md new file mode 100644 index 00000000..f69611d8 --- /dev/null +++ b/viscy/data_organization.md @@ -0,0 +1,173 @@ +# Data Organization for Virtual Staining + +Here we document our conventions for storing data, metadata, configs, and models. + +## Data flow in the pipeline + +TODO: Following diagram captures the planned flow of data and metadata through the new version of the pipeline. + +### Preprocessing + +Parameters are provided via the CLI, and stored in the attributes of +the OME-Zarr datasets using [iohub](https://github.com/czbiohub/iohub). +The metadata is in json format that is practical to edit by hand. + +### Training + +We use the PyTorch Lightning framework for training, +which provides [good defaults](https://pytorch-lightning.readthedocs.io/en/1.6.5/common/lightning_cli.html) +for CLI and training configs, +and organized TensorBoard [logs](https://lightning.ai/docs/pytorch/stable/extensions/logging.html). + +### Inference + +The inference module does not depend on lightning, but just on PyTorch. +Parameters are provided with CLI and stored with the OME-Zarr datasets, +similar to preprocessing. + +> This can be further incorporated into the lightning pipeline + +### Evaluation + +Evaluating the models requires human proof-reading of ground truth. +Currently computing the evaluation metrics does not depend on PyTorch. + +> This can be further incorporated into the lightning pipeline + +### Deployment + +For run-time deployment, we export the model to the ONNX format. + +## Data hierarchy (Lightning Framework) + +Data generated by the pipeline is stored on the file system following this schema: + +```yaml +# project root directory +virtual_staining: + # registerd, deconvolved, preprocessed OME-Zarr stores + datasets: + train: + yyyymmdd_dataname0.zarr + yyyymmdd_dataname1.zarr + ... + test: + yyyymmdd_dataname0.zarr + yyyymmdd_dataname1.zarr + ... + # FOVs with human-proof-read ground truth segmentation masks + ground_truth: + yyyymmdd_dataname0_gt: + fovs.zarr + .tif + _mask.png + .tif + _mask.png + ... + yyyymmdd_dataname1_gt: + ... + ... + # computational experiments + models: + experiment0: + # working training configs (may not specify default fields) + training_config0.yaml + training_config1.yaml + ... + lightning_logs: + # output of one run + yyyymmdd-hhmmss: + # model weights + checkpoints: + epoch0.ckpt + epoch1.ckpt + ... + # autosaved full config + config.yaml + yyyymmdd-hhmmss: + ... + # Inference and/or Evaluation of selected models. + test: + # config for prediction with test dataset. + test_.yml # config used for inference, optionally copies ground truth and input for evaluation. This config will follow the lightning CLI/config format. + + # inference output on test dataset, may include copies of input and ground truth to facilitate visualization of model performance. + test_.zarr # Not all test datasets need to have human curated ground truth. + ... + + # config for evaluation: checkpoint path, test data path that have ground turth included, and choice of metrics. + evaluation_.yaml + ... + + # evaluation metrics + evaluation_metrics_.csv + ... + # (optional) tensorboard logs generated to visualize distribution of metrics or specific samples of input, prediction, ground truth. + evaluation_logs: + # (optional) exported models for deployment + deployment: + _.onnx + README + experiment1: + ... + ... +``` + +## Data hierarchy (Gunpowder Framework) + +This data hierarchy is deprecated, and is documented for archiving purposes. +The hierarchy organizes subdirectories of config files, models/training logs and data first according to the data related to the computational experiment, then by the specific experiment. Each set of config files should have a corresponding sibling-level training log, and parent-level dataset in their respective directories + +```yaml +# project root directory +torch_microDL: + + #training and test data + data: + _: # data-level sibling to dataset dir in config files + # Due to evolving data format, no single standard for each dataset's format + # Generally these directories are populated by one of the following: + # tile directories (from old preprocessing) + # single page tiff directories (from old raw data) + # zarr stores (from new dataloading) + : + : + ... + ... + + # configuration files (preprocessing, training, inference, etc) + config_files: + _: # data_name is often an abbreviated tag to the microscopy experiment sourcing this data + _: # config files often stored under additional subdirectories. No standard format + config0_.yml + config1_.yml + ... + ... + ... + + # training logs and saved models + models: + _: # data-level sibling to dataset dir in config files + _: + model_0: + # Sometimes there is an additional subdirectory here deliniating different runs. + # There are often *many* training models. These should be cleaned or sorted by size. + training_model_: + + data_splits.yml + saved_model_ep__testloss_.pt + saved_model_ep__testloss_.pt + ... + prediction_ep_.png + prediction_ep_.png + ... + inference_results_: + + .tiff + .tiff + ... + ... + ... + ... + +``` diff --git a/viscy/evaluation/__init__.py b/viscy/evaluation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/viscy/evaluation/evaluation.py b/viscy/evaluation/evaluation.py new file mode 100644 index 00000000..becc7cee --- /dev/null +++ b/viscy/evaluation/evaluation.py @@ -0,0 +1,204 @@ +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +import viscy.evaluation.evaluation_metrics as inference_metrics + + +class TorchEvaluator(object): + """ + Handles all procedures involved with model evaluation. + + Params: + :param dict torch_config: master config file + """ + + def __init__(self, torch_config, device=None) -> None: + self.torch_config = torch_config + + self.zarr_dir = self.torch_config["zarr_dir"] + self.network_config = self.torch_config["model"] + self.training_config = self.torch_config["training"] + self.dataset_config = self.torch_config["dataset"] + self.inference_config = self.torch_config["inference"] + self.preprocessing_config = self.torch_config["preprocessing"] + + self.inference_metrics = {} + self.log_writer = SummaryWriter(log_dir=self.save_folder) + + def get_save_location(self): + """ + Sets save location as specified in config files. + """ + # TODO implement + return + # TODO Change the functionality of saving to put inference in the actual + # train directory the model comes from. Not a big fan + + # model_dir = os.path.dirname(self.inference_config["model_dir"]) + # save_to_train_save_dir = self.inference_config["save_preds_to_model_dir"] + + # if save_to_train_save_dir: + # save_dir = model_dir + # elif "custom_save_preds_dir" in self.inference_config: + # custom_save_dir = self.inference_config["custom_save_preds_dir"] + # save_dir = custom_save_dir + # else: + # raise ValueError( + # "Must provide custom_save_preds_dir if save_preds_to" + # "_model_dir is False." + # ) + + # now = aux_utils.get_timestamp() + # self.save_folder = os.path.join(save_dir, f"inference_results_{now}") + # if not os.path.exists(self.save_folder): + # os.makedirs(self.save_folder) + + def _collapse_metrics_dict(self, metrics_dict): + """ + Collapses metrics dict in the form of + {metric_name: {index: metric,...}} + to the form + {metric_name: np.ndarray[metric1, metrics2,...]} + + :param dict metrics_dict: dict of metrics in the first format + + :return dict collapsed_metrics_dict: dict of metrics in the second format + """ + collapsed_metrics_dict = {} + for metric_name in metrics_dict: + val_dict = metrics_dict[metric_name] + values = [val_dict[index] for index in val_dict] + collapsed_metrics_dict[metric_name] = np.array(values) + + return collapsed_metrics_dict + + def _get_metrics( + self, + target, + prediction, + metrics_list, + metrics_orientations, + path="unspecified", + window=None, + ): + """ + Gets metrics for this target_/prediction pair in all the specified orientations + for all the specified metrics. + + :param np.ndarray target: 5d target array (on cpu) + :param np.ndarray prediction: 5d prediction array (on cpu) + :param list metrics_list: list of strings + indicating the name of a desired metric, + for options see inference.evaluation_metrics. MetricsEstimator docstring + :param list metrics_orientations: list of strings + indicating the orientation to compute, + for options see inference.evaluation_metrics. MetricsEstimator docstring + :param tuple window: spatial window of this target/prediction pair + in the larger arrays they come from. + + :return dict prediction_metrics: dict mapping orientation -> pd.dataframe + of metrics for that orientation + """ + metrics_estimator = inference_metrics.MetricsEstimator(metrics_list) + prediction_metrics = {} + + # transpose target and prediction to be in xyz format + # NOTE: This expects target and pred to be in the format bczyx! + target = np.transpose(target, (0, 1, -2, -1, -3)) + prediction = np.transpose(prediction, (0, 1, -2, -1, -3)) + + zstart, zend = window[0][0], window[0][0] + window[1][0] # end = start + length + pred_name = f"slice_{zstart}-{zend}" + + if "xy" in metrics_orientations: + metrics_estimator.estimate_xy_metrics( + target=target, + prediction=prediction, + pred_name=pred_name, + ) + metrics_xy = self._collapse_metrics_dict( + metrics_estimator.get_metrics_xy().to_dict() + ) + prediction_metrics["xy"] = metrics_xy + + if "xyz" in metrics_orientations: + metrics_estimator.estimate_xyz_metrics( + target=target, + prediction=prediction, + pred_name=pred_name, + ) + metrics_xyz = self._collapse_metrics_dict( + metrics_estimator.get_metrics_xyz().to_dict() + ) + prediction_metrics["xyz"] = metrics_xyz + + if "xz" in metrics_orientations: + metrics_estimator.estimate_xz_metrics( + target=target, + prediction=prediction, + pred_name=pred_name, + ) + metrics_xz = self._collapse_metrics_dict( + metrics_estimator.get_metrics_xz().to_dict() + ) + prediction_metrics["xz"] = metrics_xz + + if "yz" in metrics_orientations: + metrics_estimator.estimate_yz_metrics( + target=target, + prediction=prediction, + pred_name=pred_name, + ) + metrics_yz = self._collapse_metrics_dict( + metrics_estimator.get_metrics_yz().to_dict() + ) + prediction_metrics["yz"] = metrics_yz + + # format metrics + tag = path + f"_{window}" + self.inference_metrics[tag] = prediction_metrics + + return prediction_metrics + + def record_metrics(self, sample_information): + """ + Handles metric recording in tensorboard. + + Metrics are saved position by position. + If multiple scalar metric values are stored for a + particular metric in a particular position, + they are plotted along the axis they are calculated on. + + :param list sample_information: list of tuples containing information about + each sample in the form + (position_group, position_path, normalization_meta, window) + """ + for info_tuple in sample_information: + _, position_path, normalization_meta, window = info_tuple + position = position_path.split("/")[-1] + sample_metrics = self.inference_metrics[position_path + f"_{window}"] + + for orientation in sample_metrics: + scalar_dict = sample_metrics[orientation] + pred_name = scalar_dict.pop("pred_name")[0] + + # generate a unique plot & tag for each orientation + main_tag = f"{position}/{orientation}_{pred_name}" + + # Need to plot a line if metrics calculated along an axis + if scalar_dict[list(scalar_dict.keys())[0]].shape[0] == 1: + self.writer.add_scalars( + main_tag=main_tag, + tag_scalar_dict=scalar_dict, + ) + else: + axis_length = scalar_dict[list(scalar_dict.keys())[0]].shape[0] + for i in range(axis_length): + scalar_dict_i = {} + for key in scalar_dict.keys(): + scalar_dict_i[key] = scalar_dict[key][i] + self.writer.add_scalars( + main_tag=main_tag, + tag_scalar_dict=scalar_dict_i, + global_step=i, + ) diff --git a/viscy/evaluation/evaluation_metrics.py b/viscy/evaluation/evaluation_metrics.py new file mode 100644 index 00000000..deed3394 --- /dev/null +++ b/viscy/evaluation/evaluation_metrics.py @@ -0,0 +1,171 @@ +"""Metrics for model evaluation""" +import numpy as np +import torch +from lapsolver import solve_dense +from skimage.measure import label, regionprops +from torchmetrics.detection import MeanAveragePrecision +from torchvision.ops import masks_to_boxes + + +def VOI_metric(target, prediction): + """variation of information metric + Reports overlap between predicted and ground truth mask + : param np.array target: ground truth mask + : param np.array prediction: model infered FL image cellpose mask + : return float VI: VI for image masks + """ + # cellpose segmentation of predicted image: outputs labl mask + pred_bin = prediction > 0 + target_bin = target > 0 + + # convert to binary mask + im_targ_mask = target_bin > 0 + im_pred_mask = pred_bin > 0 + + # compute entropy from pred_mask + marg_pred = np.histogramdd(np.ravel(im_pred_mask), bins=256)[0] / im_pred_mask.size + marg_pred = list(filter(lambda p: p > 0, np.ravel(marg_pred))) + entropy_pred = -np.sum(np.multiply(marg_pred, np.log2(marg_pred))) + + # compute entropy from target_mask + marg_targ = np.histogramdd(np.ravel(im_targ_mask), bins=256)[0] / im_targ_mask.size + marg_targ = list(filter(lambda p: p > 0, np.ravel(marg_targ))) + entropy_targ = -np.sum(np.multiply(marg_targ, np.log2(marg_targ))) + + # intersection entropy + im_intersection = np.logical_and(im_pred_mask, im_targ_mask) + im_inters_informed = im_intersection * im_targ_mask * im_pred_mask + + marg_intr = ( + np.histogramdd(np.ravel(im_inters_informed), bins=256)[0] + / im_inters_informed.size + ) + marg_intr = list(filter(lambda p: p > 0, np.ravel(marg_intr))) + entropy_intr = -np.sum(np.multiply(marg_intr, np.log2(marg_intr))) + + # variation of entropy/information + VI = entropy_pred + entropy_targ - (2 * entropy_intr) + + return [VI] + + +def POD_metric(target_bin, pred_bin): + # pred_bin = cpmask_array(prediction) + + # relabel mask for ordered labelling across images for efficient LAP mapping + props_pred = regionprops(label(pred_bin)) + props_targ = regionprops(label(target_bin)) + + # construct empty cost matrix based on the number of objects being mapped + n_predObj = len(props_pred) + n_targObj = len(props_targ) + dim_cost = max(n_predObj, n_targObj) + + # calculate cost based on proximity of centroid b/w objects + cost_matrix = np.zeros((dim_cost, dim_cost)) + a = 0 + b = 0 + lab_targ = [] # enumerate the labels from labelled ground truth mask + lab_pred = [] # enumerate the labels from labelled predicted image mask + lab_targ_major_axis = [] # store the major axis of target masks + for props_t in props_targ: + y_t, x_t = props_t.centroid + lab_targ.append(props_t.label) + lab_targ_major_axis.append(props_t.axis_major_length) + for props_p in props_pred: + y_p, x_p = props_p.centroid + lab_pred.append(props_p.label) + # using centroid distance as measure for mapping + cost_matrix[a, b] = np.sqrt(((y_t - y_p) ** 2) + ((x_t - x_p) ** 2)) + b = b + 1 + a = a + 1 + b = 0 + + distance_threshold = np.mean(lab_targ_major_axis) / 2 + + # LAPsolver for minimizing cost matrix of objects + rids, cids = solve_dense(cost_matrix) + + # filter out rid and cid pairs that exceed distance threshold + matching_targ = [] + matching_pred = [] + for rid, cid in zip(rids, cids): + if cost_matrix[rid, cid] <= distance_threshold: + matching_targ.append(rid) + matching_pred.append(cid) + + true_positives = len(matching_pred) + false_positives = n_predObj - len(matching_pred) + false_negatives = n_targObj - len(matching_targ) + precision = true_positives / (true_positives + false_positives) + recall = true_positives / (true_positives + false_negatives) + f1_score = 2 * (precision * recall / (precision + recall)) + + return [ + true_positives, + false_positives, + false_negatives, + precision, + recall, + f1_score, + ] + + +def labels_to_masks(labels: torch.ShortTensor) -> torch.BoolTensor: + """Convert integer labels to a stack of boolean masks. + + :param torch.ShortTensor labels: 2D labels where each value is an object + (0 is background) + :return torch.BoolTensor: Boolean masks of shape (objects, H, W) + """ + if labels.ndim != 2: + raise ValueError(f"Labels must be 2D, got shape {labels.shape}.") + masks = torch.zeros( + (labels.max(), *labels.shape), dtype=torch.bool, device=labels.device + ) + # TODO: optimize this? + for segment in range(labels.max()): + # start from label value 1, i.e. skip background label + masks[segment] = labels == (segment + 1) + return masks + + +def labels_to_detection(labels: torch.ShortTensor) -> dict[str, torch.Tensor]: + """Convert integer labels to a torchvision/torchmetrics detection dictionary. + + :param torch.ShortTensor labels: 2D labels where each value is an object + (0 is background) + :return dict[str, torch.Tensor]: detection boxes, scores, labels, and masks + """ + masks = labels_to_masks(labels) + boxes = masks_to_boxes(masks) + return { + "boxes": boxes, + # dummy confidence scores + "scores": torch.ones( + (boxes.shape[0],), dtype=torch.float32, device=boxes.device + ), + # dummy class labels + "labels": torch.zeros( + (boxes.shape[0],), dtype=torch.uint8, device=boxes.device + ), + "masks": masks, + } + + +def mean_average_precision( + pred_labels: torch.ShortTensor, target_labels: torch.ShortTensor, **kwargs +) -> dict[str, torch.Tensor]: + """Compute the mAP metric for instance segmentation. + + :param torch.ShortTensor pred_labels: 2D integer prediction labels + :param torch.ShortTensor target_labels: 2D integer prediction labels + :param dict **kwargs: keyword arguments passed to + :py:class:`torchmetrics.detection.MeanAveragePrecision` + :return dict[str, torch.Tensor]: COCO-style metrics + """ + map_metric = MeanAveragePrecision(box_format="xyxy", iou_type="segm", **kwargs) + map_metric.update( + [labels_to_detection(pred_labels)], [labels_to_detection(target_labels)] + ) + return map_metric.compute() diff --git a/viscy/light/data.py b/viscy/light/data.py new file mode 100644 index 00000000..b95d41c1 --- /dev/null +++ b/viscy/light/data.py @@ -0,0 +1,522 @@ +import logging +import os +import re +import tempfile +from glob import glob +from typing import Callable, Iterable, Literal, Sequence, TypedDict, Union + +import numpy as np +import torch +import zarr +from imageio import imread +from iohub.ngff import ImageArray, Plate, Position, open_ome_zarr +from lightning.pytorch import LightningDataModule +from monai.data import set_track_meta +from monai.transforms import ( + CenterSpatialCropd, + Compose, + InvertibleTransform, + MapTransform, + RandAdjustContrastd, + RandAffined, + RandGaussianSmoothd, + RandWeightedCropd, +) +from torch.utils.data import DataLoader, Dataset + + +def _ensure_channel_list(str_or_seq: Union[str, Sequence[str]]): + if isinstance(str_or_seq, str): + return [str_or_seq] + try: + return list(str_or_seq) + except TypeError: + raise TypeError( + "Channel argument must be a string or sequence of strings. " + f"Got {str_or_seq}." + ) + + +def _search_int_in_str(pattern: str, file_name: str) -> str: + """Search image indices in a file name with regex patterns and strip leading zeros. + E.g. ``'001'`` -> ``1``""" + match = re.search(pattern, file_name) + if match: + return match.group() + else: + raise ValueError(f"Cannot find pattern {pattern} in {file_name}.") + + +class ChannelMap(TypedDict, total=False): + source: Union[str, Sequence[str]] + # optional + target: Union[str, Sequence[str]] + + +class Sample(TypedDict, total=False): + index: tuple[str, int, int] + # optional + source: torch.Tensor + target: torch.Tensor + labels: torch.Tensor + + +class NormalizeSampled(MapTransform, InvertibleTransform): + """Dictionary transform to only normalize target (fluorescence) channel. + + :param Union[str, Iterable[str]] keys: keys to normalize + :param dict[str, dict] norm_meta: Plate normalization metadata + written in preprocessing + """ + + def __init__( + self, keys: Union[str, Iterable[str]], norm_meta: dict[str, dict] + ) -> None: + if set(keys) > set(norm_meta.keys()): + raise KeyError(f"{keys} is not a subset of {norm_meta.keys()}") + super().__init__(keys, allow_missing_keys=False) + self.norm_meta = norm_meta + + def _stat(self, key: str) -> dict: + return self.norm_meta[key]["dataset_statistics"] + + def __call__(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + d = dict(data) + for key in self.keys: + d[key] = (d[key] - self._stat(key)["median"]) / self._stat(key)["iqr"] + return d + + def inverse(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + d = dict(data) + for key in self.keys: + d[key] = (d[key] * self._stat(key)["iqr"]) + self._stat(key)["median"] + + +class SlidingWindowDataset(Dataset): + """Torch dataset where each element is a window of + (C, Z, Y, X) where C=2 (source and target) and Z is ``z_window_size``. + + :param list[Position] positions: FOVs to include in dataset + :param ChannelMap channels: source and target channel names, + e.g. ``{'source': 'Phase', 'target': ['Nuclei', 'Membrane']}`` + :param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D + :param Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] transform: + a callable that transforms data, defaults to None + """ + + def __init__( + self, + positions: list[Position], + channels: ChannelMap, + z_window_size: int, + transform: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] = None, + ) -> None: + super().__init__() + self.positions = positions + self.channels = {k: _ensure_channel_list(v) for k, v in channels.items()} + self.source_ch_idx = [ + positions[0].get_channel_index(c) for c in channels["source"] + ] + self.target_ch_idx = ( + [positions[0].get_channel_index(c) for c in channels["target"]] + if "target" in channels + else None + ) + self.z_window_size = z_window_size + self.transform = transform + self._get_windows() + + def _get_windows(self) -> None: + """Count the sliding windows along T and Z, + and build an index-to-window LUT.""" + w = 0 + self.window_keys = [] + self.window_arrays = [] + for fov in self.positions: + img_arr = fov["0"] + ts = img_arr.frames + zs = img_arr.slices - self.z_window_size + 1 + w += ts * zs + self.window_keys.append(w) + self.window_arrays.append(img_arr) + self._max_window = w + + def _find_window(self, index: int) -> tuple[int, int]: + """Look up window given index.""" + window_idx = sorted(self.window_keys + [index + 1]).index(index + 1) + w = self.window_keys[window_idx] + tz = index - self.window_keys[window_idx - 1] if window_idx > 0 else index + return self.window_arrays[self.window_keys.index(w)], tz + + def _read_img_window( + self, img: ImageArray, ch_idx: list[str], tz: int + ) -> tuple[tuple[torch.Tensor], tuple[str, int, int]]: + """Read image window as tensor. + + :param ImageArray img: NGFF image array + :param list[int] channels: list of channel indices to read, + output channel ordering will reflect the sequence + :param int tz: window index within the FOV, counted Z-first + :return tuple[torch.Tensor], tuple[str, int, int]: + tuple of (C=1, Z, Y, X) image tensors, + tuple of image name, time index, and Z index + """ + zs = img.shape[-3] - self.z_window_size + 1 + t = (tz + zs) // zs - 1 + z = tz - t * zs + data = img.oindex[ + slice(t, t + 1), + [int(i) for i in ch_idx], + slice(z, z + self.z_window_size), + ] + return torch.from_numpy(data).unbind(dim=1), (img.name, t, z) + + def __len__(self) -> int: + return self._max_window + + def _stack_channels( + self, sample_images: dict[str, torch.Tensor], key: str + ) -> torch.Tensor: + return torch.stack([sample_images[ch][0] for ch in self.channels[key]]) + + def __getitem__(self, index: int) -> Sample: + img, tz = self._find_window(index) + ch_names = self.channels["source"].copy() + ch_idx = self.source_ch_idx.copy() + if self.target_ch_idx is not None: + ch_names.extend(self.channels["target"]) + ch_idx.extend(self.target_ch_idx) + images, sample_index = self._read_img_window(img, ch_idx, tz) + sample_images = {k: v for k, v in zip(ch_names, images)} + if self.target_ch_idx is not None: + # FIXME: this uses the first target channel as weight for performance + # since adding a reference to a tensor does not copy + # maybe write a weight map in preprocessing to use more information? + sample_images["weight"] = sample_images[self.channels["target"][0]] + if self.transform: + sample_images = self.transform(sample_images) + if isinstance(sample_images, list): + sample_images = sample_images[0] + if "weight" in sample_images: + del sample_images["weight"] + sample = { + "index": sample_index, + "source": self._stack_channels(sample_images, "source"), + } + if self.target_ch_idx is not None: + sample["target"] = self._stack_channels(sample_images, "target") + return sample + + def __del__(self): + """Close the Zarr store when the dataset instance gets GC'ed.""" + self.positions[0].zgroup.store.close() + + +class MaskTestDataset(SlidingWindowDataset): + """Torch dataset where each element is a window of + (C, Z, Y, X) where C=2 (source and target) and Z is ``z_window_size``. + This a testing stage version of :py:class:`viscy.light.data.SlidingWindowDataset`, + and can only be used with batch size 1 for efficiency (no padding for collation), + since the mask is not available for each stack. + + :param list[Position] positions: FOVs to include in dataset + :param ChannelMap channels: source and target channel names, + e.g. ``{'source': 'Phase', 'target': ['Nuclei', 'Membrane']}`` + :param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D + :param Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] transform: + a callable that transforms data, defaults to None + """ + + def __init__( + self, + positions: list[Position], + channels: ChannelMap, + z_window_size: int, + transform: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] = None, + ground_truth_masks: str = None, + ) -> None: + super().__init__(positions, channels, z_window_size, transform) + self.masks = {} + for img_path in glob(os.path.join(ground_truth_masks, "*cp_masks.png")): + img_name = os.path.basename(img_path) + position_name = _search_int_in_str(r"(?<=_p)\d{3}", img_name) + # TODO: specify time index in the file name + t_idx = 0 + # TODO: record channel name + # channel_name = re.search(r"^.+(?=_p\d{3})", img_name).group() + z_idx = _search_int_in_str(r"(?<=_z)\d+", img_name) + self.masks[(int(position_name), int(t_idx), int(z_idx))] = img_path + logging.info(str(self.masks)) + + def __getitem__(self, index: int) -> Sample: + sample = super().__getitem__(index) + img_name, t_idx, z_idx = sample["index"] + position_name = int(img_name.split("/")[-2]) + key = (position_name, int(t_idx), int(z_idx) + self.z_window_size // 2) + if img_path := self.masks.get(key): + sample["labels"] = torch.from_numpy(imread(img_path).astype(np.int16)) + return sample + + +class HCSDataModule(LightningDataModule): + """Lightning data module for a preprocessed HCS NGFF Store. + + :param str data_path: path to the data store + :param Union[str, Sequence[str]] source_channel: name(s) of the source channel, + e.g. ``'Phase'`` + :param Union[str, Sequence[str]] target_channel: name(s) of the target channel, + e.g. ``['Nuclei', 'Membrane']`` + :param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D + :param float split_ratio: split ratio of the training subset in the fit stage, + e.g. 0.8 means a 80/20 split between training/validation + :param int batch_size: batch size, defaults to 16 + :param int num_workers: number of data-loading workers, defaults to 8 + :param Literal["2.5D", "2D", "3D"] architecture: U-Net architecture, + defaults to "2.5D" + :param tuple[int, int] yx_patch_size: patch size in (Y, X), + defaults to (256, 256) + :param bool augment: whether to apply augmentation in training, + defaults to True + :param bool caching: whether to decompress all the images and cache the result, + defaults to False + :param str ground_truth_masks: path to the ground truth segmentation masks, + defaults to None + """ + + def __init__( + self, + data_path: str, + source_channel: Union[str, Sequence[str]], + target_channel: Union[str, Sequence[str]], + z_window_size: int, + split_ratio: float, + batch_size: int = 16, + num_workers: int = 8, + architecture: Literal["2.5D", "2D", "3D"] = "2.5D", + yx_patch_size: tuple[int, int] = (256, 256), + augment: bool = True, + caching: bool = False, + normalize_source: bool = False, + ground_truth_masks: str = None, + ): + super().__init__() + self.data_path = data_path + self.source_channel = _ensure_channel_list(source_channel) + self.target_channel = _ensure_channel_list(target_channel) + self.batch_size = batch_size + self.num_workers = num_workers + self.target_2d = True if architecture == "2.5D" else False + self.z_window_size = z_window_size + self.split_ratio = split_ratio + self.yx_patch_size = yx_patch_size + self.augment = augment + self.caching = caching + self.normalize_source = normalize_source + self.ground_truth_masks = ground_truth_masks + self.tmp_zarr = None + + def prepare_data(self): + if not self.caching: + return + # setup logger + logger = logging.getLogger("viscy_data") + logger.propagate = False + logger.setLevel(logging.DEBUG) + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + logger.addHandler(console_handler) + os.mkdir(self.trainer.logger.log_dir) + file_handler = logging.FileHandler( + os.path.join(self.trainer.logger.log_dir, "data.log") + ) + file_handler.setLevel(logging.DEBUG) + logger.addHandler(file_handler) + # cache in temporary directory + self.tmp_zarr = os.path.join( + tempfile.gettempdir(), os.path.basename(self.data_path) + ) + logger.info(f"Caching dataset at {self.tmp_zarr}.") + tmp_store = zarr.NestedDirectoryStore(self.tmp_zarr) + with open_ome_zarr(self.data_path, mode="r") as lazy_plate: + _, skipped, _ = zarr.copy( + lazy_plate.zgroup, + zarr.open(tmp_store, mode="a"), + name="/", + log=logger.debug, + if_exists="skip_initialized", + compressor=None, + ) + if skipped > 0: + logger.warning( + f"Skipped {skipped} items when caching. Check debug log for details." + ) + + def setup(self, stage: Literal["fit", "validate", "test", "predict"]): + channels = {"source": self.source_channel} + dataset_settings = dict(channels=channels, z_window_size=self.z_window_size) + if stage in ("fit", "validate"): + self._setup_fit(dataset_settings) + elif stage == "test": + self._setup_test(dataset_settings) + elif stage == "predict": + self._setup_predict(dataset_settings) + else: + raise NotImplementedError(f"{stage} stage") + + def _setup_eval(self, dataset_settings: dict) -> tuple[Plate, MapTransform]: + dataset_settings["channels"]["target"] = self.target_channel + data_path = self.tmp_zarr if self.tmp_zarr else self.data_path + plate = open_ome_zarr(data_path, mode="r") + # disable metadata tracking in MONAI for performance + set_track_meta(False) + # define training stage transforms + norm_keys = self.target_channel + if self.normalize_source: + norm_keys += self.source_channel + normalize_transform = NormalizeSampled( + norm_keys, + plate.zattrs["normalization"], + ) + return plate, normalize_transform + + def _setup_fit(self, dataset_settings: dict): + plate, normalize_transform = self._setup_eval(dataset_settings) + fit_transform = self._fit_transform() + train_transform = Compose( + [normalize_transform] + self._train_transform() + fit_transform + ) + val_transform = Compose([normalize_transform] + fit_transform) + # shuffle positions, randomness is handled globally + positions = [pos for _, pos in plate.positions()] + shuffled_indices = torch.randperm(len(positions)) + positions = list(positions[i] for i in shuffled_indices) + num_train_fovs = int(len(positions) * self.split_ratio) + # train/val split + self.train_dataset = SlidingWindowDataset( + positions[:num_train_fovs], + transform=train_transform, + **dataset_settings, + ) + self.val_dataset = SlidingWindowDataset( + positions[num_train_fovs:], transform=val_transform, **dataset_settings + ) + + def _setup_test(self, dataset_settings): + if self.batch_size != 1: + logging.warning(f"Ignoring batch size {self.batch_size} in test stage.") + plate, normalize_transform = self._setup_eval(dataset_settings) + self.test_dataset = MaskTestDataset( + [p for _, p in plate.positions()], + transform=normalize_transform, + ground_truth_masks=self.ground_truth_masks, + **dataset_settings, + ) + + def _setup_predict(self, dataset_settings: dict): + # track metadata for inverting transform + set_track_meta(True) + if self.caching: + logging.warning("Ignoring caching config in 'predict' stage.") + plate = open_ome_zarr(self.data_path, mode="r") + predict_transform = ( + NormalizeSampled( + self.source_channel, + plate.zattrs["normalization"], + ) + if self.normalize_source + else None + ) + self.predict_dataset = SlidingWindowDataset( + [p for _, p in plate.positions()], + transform=predict_transform, + **dataset_settings, + ) + + def on_before_batch_transfer(self, batch: Sample, dataloader_idx: int) -> Sample: + if self.trainer.predicting or isinstance(batch, torch.Tensor): + # skipping example input array + return batch + if self.target_2d: + # slice the center during training or testing + z_index = self.z_window_size // 2 + batch["target"] = batch["target"][:, :, slice(z_index, z_index + 1)] + return batch + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + persistent_workers=True, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + persistent_workers=True, + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=1, + num_workers=self.num_workers, + shuffle=False, + ) + + def predict_dataloader(self): + return DataLoader( + self.predict_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def _fit_transform(self): + return [ + CenterSpatialCropd( + keys=self.source_channel + self.target_channel, + roi_size=( + -1, + self.yx_patch_size[0], + self.yx_patch_size[1], + ), + ) + ] + + def _train_transform(self) -> list[Callable]: + transforms = [ + RandWeightedCropd( + keys=self.source_channel + self.target_channel, + w_key="weight", + spatial_size=(-1, self.yx_patch_size[0] * 2, self.yx_patch_size[1] * 2), + num_samples=1, + ) + ] + if self.augment: + transforms.extend( + [ + RandAffined( + keys=self.source_channel + self.target_channel, + prob=0.5, + rotate_range=(np.pi, 0, 0), + shear_range=(0, (0.05), (0.05)), + scale_range=(0, 0.3, 0.3), + ), + RandAdjustContrastd( + keys=self.source_channel, prob=0.3, gamma=(0.75, 1.5) + ), + RandGaussianSmoothd( + keys=self.source_channel, + prob=0.3, + sigma_x=(0.05, 0.25), + sigma_y=(0.05, 0.25), + sigma_z=((0.05, 0.25)), + ), + ] + ) + return transforms diff --git a/viscy/light/engine.py b/viscy/light/engine.py new file mode 100644 index 00000000..658d1831 --- /dev/null +++ b/viscy/light/engine.py @@ -0,0 +1,338 @@ +import logging +import os +from typing import Callable, Literal, Sequence + +import numpy as np +import torch +import torch.nn.functional as F +from cellpose.models import CellposeModel +from imageio import imwrite +from lightning.pytorch import LightningDataModule, LightningModule, Trainer +from lightning.pytorch.utilities.compile import _maybe_unwrap_optimized +from matplotlib.cm import get_cmap +from monai.optimizers import WarmupCosineSchedule +from monai.transforms import DivisiblePad +from skimage.exposure import rescale_intensity +from torch.onnx import OperatorExportTypes +from torch.optim.lr_scheduler import ConstantLR +from torchmetrics.functional import ( + accuracy, + cosine_similarity, + dice, + jaccard_index, + mean_absolute_error, + mean_squared_error, + pearson_corrcoef, + r2_score, + structural_similarity_index_measure, +) + +from viscy.evaluation.evaluation_metrics import mean_average_precision +from viscy.light.data import Sample +from viscy.unet.networks.Unet25D import Unet25d +from viscy.unet.utils.model import ModelDefaults25D, define_model + + +class VSTrainer(Trainer): + def export( + self, + model: LightningModule, + export_path: str, + ckpt_path: str, + format="onnx", + datamodule: LightningDataModule = None, + dataloaders: Sequence = None, + ): + """Export the model for deployment (currently only ONNX is supported). + + :param LightningModule model: module to export + :param str export_path: output file name + :param str ckpt_path: model checkpoint + :param str format: format (currently only ONNX is supported), defaults to "onnx" + :param LightningDataModule datamodule: placeholder for datamodule, + defaults to None + :param Sequence dataloaders: placeholder for dataloaders, defaults to None + """ + if dataloaders or datamodule: + logging.debug("Ignoring datamodule and dataloaders during export.") + if not format.lower() == "onnx": + raise NotImplementedError(f"Export format '{format}'") + model = _maybe_unwrap_optimized(model) + self.strategy._lightning_module = model + model.load_state_dict(torch.load(ckpt_path)["state_dict"]) + model.eval() + model.to_onnx( + export_path, + input_sample=model.example_input_array, + export_params=True, + opset_version=18, + operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK, + input_names=["input"], + output_names=["output"], + dynamic_axes={ + "input": { + 0: "batch_size", + 1: "channels", + 3: "num_rows", + 4: "num_cols", + }, + "output": { + 0: "batch_size", + 1: "channels", + 3: "num_rows", + 4: "num_cols", + }, + }, + ) + logging.info(f"ONNX exported at {export_path}") + + +class VSUNet(LightningModule): + """Regression U-Net module for virtual staining. + + :param dict model_config: model config, + defaults to :py:class:`viscy.unet.utils.model.ModelDefaults25D` + :param int batch_size: batch size, defaults to 16 + :param Callable[[torch.Tensor, torch.Tensor], torch.Tensor] loss_function: + loss function in training/validation, defaults to L2 (mean squared error) + :param float lr: learning rate in training, defaults to 1e-3 + :param Literal['WarmupCosine', 'Constant'] schedule: + learning rate scheduler, defaults to "Constant" + :param int log_num_samples: + number of image samples to log each training/validation epoch, defaults to 8 + :param Sequence[int] example_input_yx_shape: + XY shape of the example input for network graph tracing, defaults to (256, 256) + :param str test_cellpose_model_path: + path to the CellPose model for testing segmentation, defaults to None + :param float test_cellpose_diameter: + diameter parameter of the CellPose model for testing segmentation, + defaults to None + :param bool test_evaluate_cellpose: + evaluate the performance of the CellPose model instead of the trained model + in test stage, defaults to False + """ + + def __init__( + self, + model_config: dict = {}, + batch_size: int = 16, + loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None, + lr: float = 1e-3, + schedule: Literal["WarmupCosine", "Constant"] = "Constant", + log_num_samples: int = 8, + example_input_yx_shape: Sequence[int] = (256, 256), + test_cellpose_model_path: str = None, + test_cellpose_diameter: float = None, + test_evaluate_cellpose: bool = False, + ) -> None: + super().__init__() + self.model = define_model(Unet25d, ModelDefaults25D(), model_config) + # TODO: handle num_outputs in metrics + # self.out_channels = self.model.terminal_block.out_filters + self.batch_size = batch_size + self.loss_function = loss_function if loss_function else F.mse_loss + self.lr = lr + self.schedule = schedule + self.log_num_samples = log_num_samples + self.training_step_outputs = [] + self.validation_step_outputs = [] + # required to log the graph + self.example_input_array = torch.rand( + 1, + 1, + (model_config.get("in_stack_depth") or 5), + *example_input_yx_shape, + ) + self.test_cellpose_model_path = test_cellpose_model_path + self.test_cellpose_diameter = test_cellpose_diameter + self.test_evaluate_cellpose = test_evaluate_cellpose + + def forward(self, x) -> torch.Tensor: + return self.model(x) + + def training_step(self, batch: Sample, batch_idx: int): + source = batch["source"] + target = batch["target"] + pred = self.forward(source) + loss = self.loss_function(pred, target) + self.log( + "loss/train", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + batch_size=self.batch_size, + sync_dist=True, + ) + if batch_idx < self.log_num_samples: + self.training_step_outputs.append( + self._detach_sample((source, target, pred)) + ) + return loss + + def validation_step(self, batch: Sample, batch_idx: int): + source = batch["source"] + target = batch["target"] + pred = self.forward(source) + loss = self.loss_function(pred, target) + self.log("loss/validate", loss, batch_size=self.batch_size, sync_dist=True) + if batch_idx < self.log_num_samples: + self.validation_step_outputs.append( + self._detach_sample((source, target, pred)) + ) + + def test_step(self, batch: Sample, batch_idx: int): + source = batch["source"] + target = batch["target"][:, 0] + if self.test_evaluate_cellpose: + pred = target + else: + pred = self.forward(source)[:, 0] + # FIXME: Only works for batch size 1 and the first channel + self._log_regression_metrics(pred, target) + img_names, ts, zs = batch["index"] + position = float(img_names[0].split("/")[-2]) + self.log_dict( + { + "position": position, + "time": float(ts[0]), + "slice": float(zs[0]), + }, + on_step=True, + on_epoch=False, + ) + if "labels" in batch: + pred_labels = self._cellpose_predict( + pred, f"p{int(position)}_t{ts[0]}_z{zs[0]}" + ) + self._log_segmentation_metrics(pred_labels, batch["labels"][0]) + else: + self._log_segmentation_metrics(None, None) + + def _log_regression_metrics(self, pred: torch.Tensor, target: torch.Tensor): + # paired image translation metrics + self.log_dict( + { + # regression + "test_metrics/MAE": mean_absolute_error(pred, target), + "test_metrics/MSE": mean_squared_error(pred, target), + "test_metrics/cosine": cosine_similarity( + pred, target, reduction="mean" + ), + "test_metrics/pearson": pearson_corrcoef( + pred.flatten() * 1e4, target.flatten() * 1e4 + ), + "test_metrics/r2": r2_score(pred.flatten(), target.flatten()), + # image perception + "test_metrics/SSIM": structural_similarity_index_measure( + pred, target, gaussian_kernel=False, kernel_size=21 + ), + }, + on_step=True, + on_epoch=True, + ) + + def _cellpose_predict(self, pred: torch.Tensor, name: str) -> torch.ShortTensor: + pred_labels_np = self.cellpose_model.eval( + pred.cpu().numpy(), channels=[0, 0], diameter=self.test_cellpose_diameter + )[0].astype(np.int16) + imwrite(os.path.join(self.logger.log_dir, f"{name}.png"), pred_labels_np) + return torch.from_numpy(pred_labels_np).to(self.device) + + def _log_segmentation_metrics( + self, pred_labels: torch.ShortTensor, target_labels: torch.ShortTensor + ): + compute = pred_labels is not None + if compute: + pred_binary = pred_labels > 0 + target_binary = target_labels > 0 + coco_metrics = mean_average_precision(pred_labels, target_labels) + logging.debug(coco_metrics) + self.log_dict( + { + # semantic segmentation + "test_metrics/accuracy": accuracy( + pred_binary, target_binary, task="binary" + ) + if compute + else -1, + "test_metrics/dice": dice(pred_binary, target_binary) + if compute + else -1, + "test_metrics/jaccard": jaccard_index( + pred_binary, target_binary, task="binary" + ) + if compute + else -1, + "test_metrics/mAP": coco_metrics["map"] if compute else -1, + "test_metrics/mAP_50": coco_metrics["map_50"] if compute else -1, + "test_metrics/mAP_75": coco_metrics["map_75"] if compute else -1, + "test_metrics/mAR_100": coco_metrics["mar_100"] if compute else -1, + }, + on_step=True, + on_epoch=False, + ) + + def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): + source = self._predict_pad(batch["source"]) + return self._predict_pad.inverse(self.forward(source)) + + def on_train_epoch_end(self): + self._log_samples("train_samples", self.training_step_outputs) + self.training_step_outputs = [] + + def on_validation_epoch_end(self): + self._log_samples("val_samples", self.validation_step_outputs) + self.validation_step_outputs = [] + + def on_test_start(self): + """Load CellPose model for segmentation.""" + if self.test_cellpose_model_path is not None: + self.cellpose_model = CellposeModel( + model_type=self.test_cellpose_model_path, device=self.device + ) + + def on_predict_start(self): + """Pad the input shape to be divisible by the downsampling factor. + The inverse of this transform crops the prediction to original shape. + """ + down_factor = 2**self.model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) + if self.schedule == "WarmupCosine": + scheduler = WarmupCosineSchedule( + optimizer, + warmup_steps=3, + t_total=self.trainer.max_epochs, + warmup_multiplier=1e-3, + ) + elif self.schedule == "Constant": + scheduler = ConstantLR( + optimizer, factor=1, total_iters=self.trainer.max_epochs + ) + return [optimizer], [scheduler] + + @staticmethod + def _detach_sample(imgs: Sequence[torch.Tensor]): + return [np.squeeze(img[0].detach().cpu().numpy().max(axis=(1))) for img in imgs] + + def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): + images_grid = [] + for sample_images in imgs: + images_row = [] + for i, image in enumerate(sample_images): + cm_name = "gray" if i == 0 else "inferno" + if image.ndim == 2: + image = image[np.newaxis] + for channel in image: + channel = rescale_intensity(channel, out_range=(0, 1)) + render = get_cmap(cm_name)(channel, bytes=True)[..., :3] + images_row.append(render) + images_grid.append(np.concatenate(images_row, axis=1)) + grid = np.concatenate(images_grid, axis=0) + self.logger.experiment.add_image( + key, grid, self.current_epoch, dataformats="HWC" + ) diff --git a/viscy/light/predict_writer.py b/viscy/light/predict_writer.py new file mode 100644 index 00000000..5e830aa0 --- /dev/null +++ b/viscy/light/predict_writer.py @@ -0,0 +1,132 @@ +import logging +import os +from typing import Literal, Sequence + +import torch +from iohub.ngff import ImageArray, _pad_shape, open_ome_zarr +from lightning.pytorch import LightningModule, Trainer +from lightning.pytorch.callbacks import BasePredictionWriter +from numpy.typing import DTypeLike + +from viscy.light.data import Sample + + +def _resize_image(image: ImageArray, t_index: int, z_index: int): + """Resize image array if incoming T and Z index is not within bounds.""" + if image.shape[0] <= t_index or image.shape[2] <= z_index: + logging.debug( + f"Resizing image '{image.name}' {image.shape} for T={t_index}, Z={z_index}." + ) + image.resize( + max(t_index + 1, image.shape[0]), + image.channels, + max(z_index + 1, image.shape[1]), + *image.shape[-2:], + ) + + +class HCSPredictionWriter(BasePredictionWriter): + """Callback to store virtual staining predictions as HCS OME-Zarr. + + :param str output_store: Path to the zarr store to store output + :param bool write_input: Write the source and target channels too + (must be writing to a new store), + defaults to False + :param Literal['batch', 'epoch', 'batch_and_epoch'] write_interval: + When to write, defaults to "batch" + """ + + def __init__( + self, + output_store: str, + write_input: bool = False, + write_interval: Literal["batch", "epoch", "batch_and_epoch"] = "batch", + ) -> None: + super().__init__(write_interval) + self.output_store = output_store + self.write_input = write_input + + def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + source_channel: list[str] = trainer.datamodule.source_channel + target_channel: list[str] = trainer.datamodule.target_channel + prediction_channel = [ch + "_prediction" for ch in target_channel] + if os.path.exists(self.output_store): + if self.write_input: + raise FileExistsError( + "Cannot write input to an existing store. Aborting." + ) + else: + self.plate = open_ome_zarr(self.output_store, mode="r+") + for _, pos in self.plate.positions(): + for ch in prediction_channel: + pos.append_channel(ch, resize_arrays=True) + else: + channel_names = prediction_channel + if self.write_input: + channel_names = source_channel + target_channel + channel_names + self.plate = open_ome_zarr( + self.output_store, layout="hcs", mode="a", channel_names=channel_names + ) + logging.info(f"Writing prediction to: '{self.plate.zgroup.store.path}'.") + if self.write_input: + self.source_index = self._get_channel_indices(source_channel) + self.target_index = self._get_channel_indices(target_channel) + self.prediction_index = self._get_channel_indices(prediction_channel) + + def _get_channel_indices(self, channel_names: list[str]) -> list[int]: + return [self.plate.get_channel_index(ch) for ch in channel_names] + + def write_on_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + prediction: torch.Tensor, + batch_indices: Sequence[int] | None, + batch: Sample, + batch_idx: int, + dataloader_idx: int, + ) -> None: + logging.debug(f"Writing batch {batch_idx}.") + for sample_index, _ in enumerate(batch["index"][0]): + self.write_sample(batch, prediction[sample_index], sample_index) + + def on_predict_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + self.plate.close() + + def write_sample( + self, batch: Sample, sample_prediction: torch.Tensor, sample_index: int + ) -> None: + logging.debug(f"Writing sample {sample_index}.") + sample_prediction = sample_prediction.cpu().numpy() + img_name, t_index, z_index = [batch["index"][i][sample_index] for i in range(3)] + t_index = int(t_index) + z_index = int(z_index) + image = self._create_image( + img_name, sample_prediction.shape, sample_prediction.dtype + ) + _resize_image(image, t_index, z_index) + if self.write_input: + # FIXME: should write center sclice of source + image[t_index, self.source_index, z_index] = batch["source"][ + sample_index + ].cpu()[:, 0] + image[t_index, self.target_index, z_index] = batch["target"][ + sample_index + ].cpu()[:, 0] + # write C1YX + image.oindex[t_index, self.prediction_index, z_index] = sample_prediction[:, 0] + + def _create_image(self, img_name: str, shape: tuple[int], dtype: DTypeLike): + if img_name in self.plate.zgroup: + return self.plate[img_name] + logging.debug(f"Creating image '{img_name}'") + _, row_name, col_name, pos_name, arr_name = img_name.split("/") + position = self.plate.create_position(row_name, col_name, pos_name) + shape = [1] + list(shape) + shape[1] = len(position.channel_names) + return position.create_zeros( + arr_name, + shape=shape, + dtype=dtype, + chunks=_pad_shape(tuple(shape[-2:]), 5), + ) diff --git a/viscy/preprocessing/generate_masks.py b/viscy/preprocessing/generate_masks.py new file mode 100644 index 00000000..f88f8fbe --- /dev/null +++ b/viscy/preprocessing/generate_masks.py @@ -0,0 +1,115 @@ +"""Generate masks from sum of flurophore channels""" +import iohub.ngff as ngff + +import viscy.utils.aux_utils as aux_utils +from viscy.utils.mp_utils import mp_create_and_write_mask + + +class MaskProcessor: + """ + Appends Masks to zarr directories + """ + + def __init__( + self, + zarr_dir, + channel_ids, + time_ids=-1, + pos_ids=-1, + num_workers=4, + mask_type="otsu", + overwrite_ok=False, + ): + """ + :param str zarr_dir: directory of HCS zarr store to pull data from. + Note: data in store is assumed to be stored in + (time, channel, z, y, x) format. + :param list[int] channel_ids: Channel indices to be masked (typically + just one) + :param int/list channel_ids: generate mask from the sum of these + (flurophore) channel indices + :param list/int time_ids: timepoints to consider + :param int pos_ids: Position (FOV) indices to use + :param int num_workers: number of workers for multiprocessing + :param str mask_type: method to use for generating mask. Needed for + mapping to the masking function. One of: + {'otsu', 'unimodal', 'borders_weight_loss_map'} + """ + self.zarr_dir = zarr_dir + self.num_workers = num_workers + + # Validate that given indices are available. + metadata_ids = aux_utils.validate_metadata_indices( + zarr_dir=zarr_dir, + time_ids=time_ids, + channel_ids=channel_ids, + pos_ids=pos_ids, + ) + self.time_ids = metadata_ids["time_ids"] + self.channel_ids = metadata_ids["channel_ids"] + self.position_ids = metadata_ids["pos_ids"] + + assert mask_type in [ + "otsu", + "unimodal", + "mem_detection", + "borders_weight_loss_map", + ], ( + "Masking method invalid, 'otsu', 'unimodal', 'mem_detection', " + "'borders_weight_loss_map' are supported" + ) + self.mask_type = mask_type + self.ints_metadata = None + self.channel_thr_df = None + + plate = ngff.open_ome_zarr(store_path=zarr_dir, mode="r") + + # deal with output channel selection/overwriting messages + if overwrite_ok: + mask_name = "_".join(["mask", self.mask_type]) + if mask_name in plate.channel_names: + print(f"Mask found in channel {mask_name}. Overwriting with this mask.") + plate.close() + + def generate_masks(self, structure_elem_radius=5): + """ + The sum of flurophore channels is thresholded to generate a foreground + mask. + + Masks are saved as an additional channel in each data array for each + specified position. If certain channels are not specified, gaps are + filled with arrays of zeros. + + Masks are also saved as an additional untracked array named "mask" and + tracked in the "mask" metadata field. + + :param int structure_elem_radius: Radius of structuring element for + morphological operations + """ + + # Gather function arguments for each index pair at each position + plate = ngff.open_ome_zarr(store_path=self.zarr_dir, mode="r+") + + mp_mask_creator_args = [] + + for i, (_, position) in enumerate(plate.positions()): + # TODO: make a better progress bar for mask generation + verbose = i % 4 == 0 + mp_mask_creator_args.append( + tuple( + [ + position, + self.time_ids, + self.channel_ids, + structure_elem_radius, + self.mask_type, + "_".join(["mask", self.mask_type]), + verbose, + ] + ) + ) + + # create and write masks and metadata using multiprocessing + mp_create_and_write_mask(mp_mask_creator_args, workers=self.num_workers) + + plate.close() diff --git a/viscy/preprocessing/preprocessing.md b/viscy/preprocessing/preprocessing.md new file mode 100644 index 00000000..76d508c5 --- /dev/null +++ b/viscy/preprocessing/preprocessing.md @@ -0,0 +1,109 @@ +## Preprocessing output format + +The preprocessing step performs the following steps. + +* Segments the target images using selected segmentation algorithm in the configuration file(otsu, mem_detection, unimodal). + +* Stores the mask output as an extra channel in the zarr store with the name of the segmented channel with added subscript '_mask'. For instance, if the user segments the channel named 'Deconvolved-Nuc', the mask channel added will be called 'Deconvolved-Nuc_mask'. The datatype is the same as the channel which is segmented (most possibly float32). + +* Stores the information related to normalization of all input channels mentioned in the configuration. The dataset statistics are stored in the plate level .zattrs file, while the information specific to the position is added in the .zattrs file at each position. The details are explained below. + +Here is the structure of a 0.4 NGFF version HCS format zarr store wriiten using [iohub](https://github.com/czbiohub/iohub) for a dataset with a single condition and multiple imaging FOVs. + +```text +. # Root folder +│ +└── my_zarr_name.zarr # Zarr folder name + ├── .zgroup + ├── .zattrs # Implements "plate" specification + ├── FOVs # Named 'FOVs' to indicate different FOVs inside + │ ├── .zgroup + │ │ + │ ├── 000 # First FOV + │ │ ├── .zgroup + │ │ ├── .zattrs # Implements "well" specification + │ │ │ + │ │ ├── 0 + │ │ │ │ + │ │ │ ├── .zgroup + │ │ │ ├── .zattrs # Implements "multiscales", "omero" + │ │ │ ├── 0 # (T, C, Z, Y, X) float32 + │ │ │ │ ... # Resolution levels + │ │ │ ├── n + │ │ │ └── labels # Labels (optional) + | | + | ├── 001 # Second FOV + + ``` + +Here the dataset statistics is stored inside the 'plate' folder and the position statistics is stored in '.zattrs' inside plate/A/1/0 folder. + +If the dataset contains multiple conditions from different wells the structure can be as follows. + +```text +. # Root folder +│ +└── my_zarr_name.zarr # Zarr folder level + ├── .zgroup + ├── .zattrs # Implements "plate" specification + ├── A # First row of the plate + │ ├── .zgroup + │ │ + │ ├── 1 # First column (well A1 in plate) + │ │ ├── .zgroup + │ │ ├── .zattrs # Implements "well" specification + │ │ │ + │ │ ├── 0 # First field of view of well A1 + │ │ │ │ + │ │ │ ├── .zgroup + │ │ │ ├── .zattrs # Implements "multiscales", "omero" + │ │ │ ├── 0 # (T, C, Z, Y, X) float32 + │ │ │ │ ... # Resolution levels + │ │ │ ├── n + │ │ │ └── labels # Labels (optional) + + ``` + +The statistics are added as dictionaries into the .zattrs file. An example of plate level metadata is here: + +```json + "normalization": { + "Deconvolved-Nuc": { + "dataset_statistics": { + "iqr": 149.7620086669922, + "mean": 262.2070617675781, + "median": 65.5246353149414, + "std": 890.0471801757812 + } + }, + "Phase3D": { + "dataset_statistics": { + "iqr": 0.0011349652777425945, + "mean": -1.9603044165705796e-06, + "median": 3.388232289580628e-05, + "std": 0.005480962339788675 + } + } + } +``` + +FOV level statistics added to every position: + +```json + "normalization": { + "Deconvolved-Nuc": { + "fov_statistics": { + "iqr": 450.4745788574219, + "mean": 486.3854064941406, + "median": 83.43557739257812, + "std": 976.02392578125 + } + }, + "Phase3D": { + "fov_statistics": { + "iqr": 0.006403466919437051, + "mean": 0.0010083537781611085, + "median": 0.00022060875198803842, + "std": 0.007864165119826794 + } + } diff --git a/viscy/preprocessing/readme.md b/viscy/preprocessing/readme.md new file mode 100644 index 00000000..4422c16e --- /dev/null +++ b/viscy/preprocessing/readme.md @@ -0,0 +1,60 @@ +## Preprocessing + +The main command for preprocessing is: + +```buildoutcfg +python viscy/cli/preprocess_script.py --config +``` + +The following settings can be adjusted in preprocessing using a config file (see example in preprocess_config.yml): + +* input_dir: (str) Directory where data to be preprocessed is located +* output_dir: (str) folder name where all processed data will be written +* channel_ids: (list of ints) specify channel numbers (default is -1 for all indices) +* num_workers: (int) Number of workers for multiprocessing +* slice_ids: (int/list) Value(s) of z-indices to be processed (default is -1 for all indices) +* time_ids: (int/list) Value(s) of timepoints to be processed (default is -1 for all indices) +* pos_ids: (int/list) Value(s) of FOVs/positions to be processed (default is -1 for all indices) +* verbose: (int) Logging verbosity levels: NOTSET:0, DEBUG:10, INFO:20, WARNING:30, ERROR:40, CRITICAL:50 +* resize: + * scale_factor(float/list): Scale factor for resizing 2D frames, e.g. to match resolution in z or resizing volumes + * num_slices_subvolume (int): number of slices to be included in each subvolume, default=-1, includes all slices in slice_ids + +* masks: + * channels: (list of ints) which channels should be used to generate masks from + * str_elem_radius: (int) structuring element radius for morphological operations on masks + * normalize_im (bool): Whether to normalize image before generating masks + * mask_dir (str): As an alternative to channels/str_element_radius, you can specify a directory + containing already generated masks (e.g. manual annotations). Masks must match input images in + terms of shape and indices. + * csv_name (str): If specifying mask_dir, the directory must contain a csv file matching mask names + with image names. If left out, the script will look for first a frames_meta.csv, + second one csv file containing mask names in one column and matched image names in a + second column. +* do_tiling: (bool) do tiling (recommended) +* tile: + * tile_size: (list of ints) tile size in pixels for each dimension + * step_size: (list of ints) step size in pixels for each dimension + * depths: (list of ints) tile z depth for all the channels specified + * mask_depth: (int) z depth of mask + * image_format (str): 'zyx' (default) or 'xyz'. Order of tile dimensions + * train_fraction (float): If specified in range (0, 1), will randomly select that fraction + of training data in each epoch. It will update steps_per_epoch in fit_generator accordingly. + * min_fraction: (float) minimum fraction of image occupied by foreground in masks + * hist_clip_limits: (list of ints) lower and upper intensity percentiles for histogram clipping + +The tiling class will take the 2D image files, assemble them to stacks in case 3D tiles are required, +and store them as tiles based on input tile size, step size, and depth. + +All data will be stored in the specified output dir, where a 'preprocessing_info.json' file + +During preprocessing, a csv file named frames_csv.csv will be generated, which +will be used for further processing. The csv contains the following fields for each image tile: + +* 'time_idx': the timepoint it came from +* 'channel_idx': its channel +* 'slice_idx': the z index in case of 3D data +* 'pos_idx': the field of view index +* 'file_name': file name +* 'row_start': starting row for tile (add tile_size for endpoint) +* 'col_start': start column (add tile_size for endpoint) diff --git a/viscy/scripts/README.md b/viscy/scripts/README.md new file mode 100644 index 00000000..2a6e432c --- /dev/null +++ b/viscy/scripts/README.md @@ -0,0 +1,4 @@ +# Utility scripts + +- [Monitor GPU usage](gpu_usage_monitor.py) +- [Log feature map](log_feature_map.py) \ No newline at end of file diff --git a/viscy/scripts/export_from_tensorboard.py b/viscy/scripts/export_from_tensorboard.py new file mode 100644 index 00000000..428fc275 --- /dev/null +++ b/viscy/scripts/export_from_tensorboard.py @@ -0,0 +1,57 @@ +# %% This script allows us to export visualizations from tensorboard logs +# written by lightning training CLI. + +import matplotlib.pyplot as plt +import torch + +# Get the path to the tensorboard event file +event_file_path = ( + "/hpc/projects/CompMicro/projects/virtual_staining/models/" + "phase2nuclei018/lightning_logs/20230514-003340/" + "events.out.tfevents.1684049623.gpu-c-1.2407720.0" +) + +# Read the event file +with open(event_file_path, "rb") as f: + event_data = torch.utils.tensorboard.load_event_file(f) + +# Seeing this error: +# AttributeError: module 'torch.utils' has no attribute 'tensorboard'. +# tensorboard 2.12.0 is installed in the environment. Not sure why this error. + +# Get the scalars +loss_train = [] +loss_val = [] +loss_train_step = [] +for event in event_data: + for value in event.summary.value: + if value.tag == "loss/train_epoch": + loss_train.append(value.simple_value) + elif value.tag == "loss/val": + loss_val.append(value.simple_value) + elif value.tag == "loss/train_step": + loss_train_step.append(value.simple_value) + +# %% Plot the scalars +plt.plot(loss_train, label="train") +plt.plot(loss_val, label="val") +plt.plot(loss_train_step, label="train_step") +plt.legend() +plt.show() + +# Export the images with tag 'val_samples' as mp4 +images = [] +for event in event_data: + for image in event.summary.value: + if image.tag == "val_samples": + images.append(image.image.encoded_image_string) + +# Create a video writer +writer = torch.utils.tensorboard.writer.FFmpegWriter("./val_samples.mp4") + +# Write the images to the video writer +for image in images: + writer.write(image) + +# Close the video writer +writer.close() diff --git a/viscy/scripts/log_feature_map.py b/viscy/scripts/log_feature_map.py new file mode 100644 index 00000000..007e52f9 --- /dev/null +++ b/viscy/scripts/log_feature_map.py @@ -0,0 +1,31 @@ +# %% +import time + +import numpy as np +import torch + +from viscy.utils import cli_utils + + +def main(): + # create test data + data = np.random.random((5, 5, 5, 128, 128)) + data_tensor = torch.Tensor(data) + print(data_tensor.shape) + + # run feature logger on test data + cli_utils.log_feature_map( + data_tensor, + "/home/christian.foley/virtual_staining/example_log/", + dim_names=["batch", "channels"], + spatial_dims=3, + vmax=None, + ) + + +if __name__ == "main": + start = time.time() + try: + main() + except Exception as e: + print(f"Error after {time.time() - start} seconds:\n{e}") diff --git a/viscy/scripts/profiling.py b/viscy/scripts/profiling.py new file mode 100644 index 00000000..1353fa65 --- /dev/null +++ b/viscy/scripts/profiling.py @@ -0,0 +1,34 @@ +# script to profile dataloading + +from profilehooks import profile + +from viscy.light.data import HCSDataModule + +dataset = "/path/to/dataset.zarr" + + +dm = HCSDataModule( + dataset, + "Phase3D", + "Deconvolved-Nuc", + 5, + 0.8, + batch_size=32, + num_workers=32, + augment=True, + caching=False, +) + +dm.setup("fit") + + +@profile(immediate=True, sort="time", dirs=True) +def load_batch(n=1): + for i, batch in enumerate(dm.train_dataloader()): + print(batch["source"].shape) + print(dm.on_before_batch_transfer(batch, 0)["target"].shape) + if i == n - 1: + break + + +load_batch(3) diff --git a/viscy/unet/__init__.py b/viscy/unet/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/viscy/unet/networks/Unet25D.py b/viscy/unet/networks/Unet25D.py new file mode 100644 index 00000000..72418538 --- /dev/null +++ b/viscy/unet/networks/Unet25D.py @@ -0,0 +1,255 @@ +import torch +import torch.nn as nn + +from viscy.unet.networks.layers.ConvBlock3D import ConvBlock3D + + +class Unet25d(nn.Module): + def __name__(self): + return "Unet25d" + + def __init__( + self, + in_channels=1, + out_channels=1, + in_stack_depth=5, + out_stack_depth=1, + xy_kernel_size=(3, 3), + residual=False, + dropout=0.2, + num_blocks=4, + num_block_layers=2, + num_filters=[], + task="seg", + ): + """ + Instance of 2.5D Unet. + 1.) https://elifesciences.org/articles/55502 + + Architecture takes in stack of 2d inputs given as a 3d tensor + and returns a 2d interpretation. + Learns 3d information based upon input stack, + but speeds up training by compressing 3d information before the decoding path. + Uses interruption conv layers in the Unet skip paths to + compress information with z-channel convolution. + + :param int in_channels: number of feature channels in (1 or more) + :param int out_channels: number of feature channels out (1 or more) + :param int input_stack_depth: depth of input stack in z + :param int output_stack_depth: depth of output stack + :param int/tuple(int, int) xy_kernel_size: size of x and y dimensions + of conv kernels in blocks + :param bool residual: see name + :param float dropout: probability of dropout, between 0 and 0.5 + :param int num_blocks: number of convolutional blocks + on encoder and decoder paths + :param int num_block_layers: number of layer sequences repeated per block + :param list[int] num_filters: list of filters/feature levels + at each conv block depth + :param str task: network task (for virtual staining this is regression), + one of 'seg','reg' + :param str debug_mode: if true logs features at each step of architecture, + must be manually set + """ + super(Unet25d, self).__init__() + self.in_channels = in_channels + self.num_blocks = num_blocks + self.kernel_size = xy_kernel_size + self.residual = residual + assert ( + dropout >= 0 and dropout <= 0.5 + ), f"Dropout {dropout} not in allowed range: [0, 0.5]" + self.dropout = dropout + self.task = task + self.debug_mode = False + + # ----- set static parameters ----- # + self.block_padding = "same" + down_mode = "avgpool" # TODO set static avgpool + up_mode = "trilinear" # TODO set static trilinear + activation = "relu" # TODO set static relu + self.bottom_block_spatial = False # TODO set static + # TODO set conv_block layer order variable + + # ----- Standardize Filter Sequence ----- # + if len(num_filters) != 0: + assert len(num_filters) == num_blocks + 1, ( + "Length of num_filters must be equal to num_" + "blocks + 1 (number of convolutional blocks per path)." + ) + self.num_filters = num_filters + else: + self.num_filters = [pow(2, i) * 16 for i in range(num_blocks + 1)] + downsampling_filters = [in_channels] + self.num_filters + upsampling_filters = [ + self.num_filters[-(i + 1)] + self.num_filters[-(i + 2)] + for i in range(len(self.num_filters)) + if i < len(self.num_filters) - 1 + ] + [out_channels] + + # ----- Downsampling steps -----# + self.down_list = [] + if down_mode == "maxpool": + for i in range(num_blocks): + self.down_list.append( + nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)) + ) + elif down_mode == "avgpool": + for i in range(num_blocks): + self.down_list.append( + nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)) + ) + elif down_mode == "conv": + raise NotImplementedError("Not yet implemented!") + # TODO: implement. + self.register_modules(self.down_list, "down_samp") + + # ----- Upsampling steps ----- # + self.up_list = [] + for i in range(num_blocks): + self.up_list.append( + nn.Upsample(scale_factor=(1, 2, 2), mode=up_mode, align_corners=False) + ) + + # ----- Convolutional blocks ----- # + self.down_conv_blocks = [] + for i in range(num_blocks): + self.down_conv_blocks.append( + ConvBlock3D( + downsampling_filters[i], + downsampling_filters[i + 1], + dropout=self.dropout, + residual=self.residual, + activation=activation, + kernel_size=(3, self.kernel_size[0], self.kernel_size[1]), + num_repeats=num_block_layers, + ) + ) + self.register_modules(self.down_conv_blocks, "down_conv_block") + + if self.bottom_block_spatial: + # TODO: residual must be false or dimensionality breaks. Fix later + self.bottom_transition_block = ConvBlock3D( + self.num_filters[-2], + self.num_filters[-1], + num_repeats=1, + residual=False, + kernel_size=( + 1 + in_stack_depth - out_stack_depth, + self.kernel_size[0], + self.kernel_size[1], + ), + padding=(0, 1, 1), + ) + else: + self.bottom_transition_block = nn.Conv3d( + self.num_filters[-2], + self.num_filters[-1], + kernel_size=(1 + in_stack_depth - out_stack_depth, 1, 1), + padding=0, + ) + + self.up_conv_blocks = [] + for i in range(num_blocks): + self.up_conv_blocks.append( + ConvBlock3D( + upsampling_filters[i], + downsampling_filters[-(i + 2)], + dropout=self.dropout, + residual=self.residual, + activation=activation, + kernel_size=(1, self.kernel_size[0], self.kernel_size[1]), + num_repeats=num_block_layers, + ) + ) + self.register_modules(self.up_conv_blocks, "up_conv_block") + + # ----- Skip Interruption Conv Blocks ----- # + self.skip_conv_layers = [] + for i in range(num_blocks): + self.skip_conv_layers.append( + nn.Conv3d( + downsampling_filters[i + 1], + downsampling_filters[i + 1], + kernel_size=(1 + in_stack_depth - out_stack_depth, 1, 1), + ) + ) + self.register_modules(self.skip_conv_layers, "skip_conv_layer") + + # ----- Terminal Block and Activation Layer ----- # + if self.task == "reg": + self.terminal_block = ConvBlock3D( + downsampling_filters[1], + out_channels, + dropout=False, + residual=False, + activation="linear", + kernel_size=(1, 3, 3), + norm="none", + num_repeats=1, + ) + else: + self.terminal_block = ConvBlock3D( + downsampling_filters[1], + out_channels, + dropout=self.dropout, + residual=False, + activation=activation, + kernel_size=(1, 3, 3), + num_repeats=1, + ) + + # ----- Feature Logging ----- # + self.log_save_folder = None + + def forward(self, x): + """ + Forward call of network. + + Call order: + => num_block 3D convolutional blocks, with downsampling in between (encoder) + => skip connections between corresponding blocks in encoder and decoder + => num_block 2D (3d with 1 z-channel) convolutional blocks, with upsampling + between them (decoder) + => terminal block collapses to output dimensions + + :param torch.tensor x: input image + """ + + # encoder + skip_tensors = [] + for i in range(self.num_blocks): + x = self.down_conv_blocks[i](x) + skip_tensors.append(x) + x = self.down_list[i](x) + + # transition block + x = self.bottom_transition_block(x) + + # skip interruptions + for i in range(self.num_blocks): + skip_tensors[i] = self.skip_conv_layers[i](skip_tensors[i]) + + # decoder + for i in range(self.num_blocks): + x = self.up_list[i](x) + x = torch.cat([x, skip_tensors[-1 * (i + 1)]], 1) + x = self.up_conv_blocks[i](x) + + # output channel collapsing layer + x = self.terminal_block(x) + return x + + def register_modules(self, module_list, name): + """ + Helper function that registers modules stored in a list to the model object + so that the can be seen by PyTorch optimizer. + + Used to enable model graph creation with + non-sequential model types and dynamic layer numbers + + :param list(torch.nn.module) module_list: list of modules to register + :param str name: name of module type + """ + for i, module in enumerate(module_list): + self.add_module(f"{name}_{str(i)}", module) diff --git a/viscy/unet/networks/Unet2D.py b/viscy/unet/networks/Unet2D.py new file mode 100644 index 00000000..36abb26b --- /dev/null +++ b/viscy/unet/networks/Unet2D.py @@ -0,0 +1,226 @@ +import torch +import torch.nn as nn + +from viscy.unet.networks.layers import ConvBlock2D + + +class Unet2d(nn.Module): + def __name__(self): + return "Unet2d" + + def __init__( + self, + in_channels=1, + out_channels=1, + kernel_size=(3, 3), + residual=False, + dropout=0.2, + num_blocks=4, + num_block_layers=2, + num_filters=[], + task="seg", + ): + """ + 2D Unet with variable input/output channels and depth (block numbers). + Follows 2D UNet Architecture: + 1) Unet: https://arxiv.org/pdf/1505.04597.pdf + 2) residual Unet: https://arxiv.org/pdf/1711.10684.pdf + + :param int in_channels: number of feature channels in + :param int out_channels: number of feature channels out + :param int/tuple(int,int) kernel_size: size of x and y dimensions + of conv kernels in blocks + :param bool residual: see name + :param float dropout: probability of dropout, between 0 and 0.5 + :param int num_blocks: number of convolutional blocks on encoder and decoder + :param int num_block_layers: number of layers per block + :param list[int] num_filters: list of filters/feature levels + at each conv block depth + :param str task: network task (for virtual staining this is regression), + one of 'seg','reg' + """ + + super(Unet2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.residual = residual + self.dropout = dropout + self.num_blocks = num_blocks + self.num_block_layers = num_block_layers + self.task = task + + # ----- set static parameters -----# + self.block_padding = "same" + down_mode = "avgpool" # TODO set static avgpool + up_mode = "bilinear" # TODO set static bilinear + activation = "relu" # TODO set static relu + self.bottom_block_spatial = False # TODO set static + + # ----- Standardize Filter Sequence -----# + if len(num_filters) != 0: + assert len(num_filters) == num_blocks, ( + "Length of num_filters must be equal to num_blo" + "cks + 1 (number of convolutional blocks per path)." + ) + self.num_filters = num_filters + else: + self.num_filters = [pow(2, i) * 16 for i in range(num_blocks + 1)] + downsampling_filters = [in_channels] + self.num_filters + upsampling_filters = [ + self.num_filters[-(i + 1)] + self.num_filters[-(i + 2)] + for i in range(len(self.num_filters)) + if i < len(self.num_filters) - 1 + ] + [out_channels] + + # ----- Downsampling steps -----# + self.down_list = [] + if down_mode == "maxpool": + for i in range(num_blocks): + self.down_list.append(nn.MaxPool2d(kernel_size=2)) + elif down_mode == "avgpool": + for i in range(num_blocks): + self.down_list.append(nn.AvgPool2d(kernel_size=2)) + elif down_mode == "conv": + raise NotImplementedError("Not yet implemented!") + # TODO: implement. + self.register_modules(self.down_list, "down_samp") + + # ----- Upsampling steps -----# + self.up_list = [] + if up_mode == "bilinear": + for i in range(num_blocks): + self.up_list.append( + nn.Upsample(mode=up_mode, scale_factor=2, align_corners=False) + ) + elif up_mode == "conv": + raise NotImplementedError("Not yet implemented!") + # TODO: implement + elif up_mode == "tconv": + raise NotImplementedError("Not yet implemented!") + # TODO: implement + else: + raise NotImplementedError(f"Upsampling mode '{up_mode}' not supported.") + + # ----- Convolutional blocks -----# + self.down_conv_blocks = [] + for i in range(num_blocks): + self.down_conv_blocks.append( + ConvBlock2D( + downsampling_filters[i], + downsampling_filters[i + 1], + dropout=self.dropout, + residual=self.residual, + activation=activation, + kernel_size=self.kernel_size, + num_repeats=self.num_block_layers, + ) + ) + self.register_modules(self.down_conv_blocks, "down_conv_block") + + self.bottom_transition_block = ConvBlock2D( + self.num_filters[-2], + self.num_filters[-1], + dropout=self.dropout, + residual=self.residual, + activation=activation, + kernel_size=self.kernel_size, + num_repeats=self.num_block_layers, + ) + + self.up_conv_blocks = [] + for i in range(num_blocks): + self.up_conv_blocks.append( + ConvBlock2D( + upsampling_filters[i], + downsampling_filters[-(i + 2)], + dropout=self.dropout, + residual=self.residual, + activation=activation, + kernel_size=self.kernel_size, + num_repeats=self.num_block_layers, + ) + ) + self.register_modules(self.up_conv_blocks, "up_conv_block") + + # ----- Terminal Block and Activation Layer -----# + if self.task == "reg": + self.terminal_block = ConvBlock2D( + downsampling_filters[1], + out_channels, + dropout=self.dropout, + residual=False, + activation="linear", + num_repeats=1, + norm="none", + kernel_size=self.kernel_size, + ) + else: + self.terminal_block = ConvBlock2D( + downsampling_filters[1], + out_channels, + dropout=self.dropout, + residual=False, + activation=activation, + num_repeats=1, + norm="none", + kernel_size=self.kernel_size, + ) + + def forward(self, x, validate_input=False): + """ + Forward call of network + - x -> Torch.tensor: input image stack + + Call order: + => num_block 2D convolutional blocks, with downsampling in between (encoder) + => num_block 2D convolutional blocks, with upsampling between them (decoder) + => skip connections between corresponding blocks on encoder and decoder + => terminal block collapses to output dimensions + + :param torch.tensor x: input image + :param bool validate_input: Deactivates assertions which are redundant + if forward pass is being traced by tensorboard writer. + """ + # handle input exceptions + if validate_input: + assert x.shape[-1] == x.shape[-2], "Input must be square in xy" + assert x.shape[-3] == self.in_channels, ( + f"Input channels must equal network" + f" input channels: {self.in_channels}" + ) + + # encoder + skip_tensors = [] + for i in range(self.num_blocks): + x = self.down_conv_blocks[i](x, validate_input=validate_input) + skip_tensors.append(x) + x = self.down_list[i](x) + + # transition block + x = self.bottom_transition_block(x) + + # decoder + for i in range(self.num_blocks): + x = self.up_list[i](x) + x = torch.cat([x, skip_tensors[-1 * (i + 1)]], 1) + x = self.up_conv_blocks[i](x, validate_input=validate_input) + + # output channel collapsing layer + x = self.terminal_block(x) + + return x + + def register_modules(self, module_list, name): + """ + Helper function that registers modules stored in a list to the model object + so that they can be seen by PyTorch optimizer. + + Used to enable model graph creation with + non-sequential model types and dynamic layer numbers + + :param list(torch.nn.module) module_list: list of modules to register + :param str name: name of module type + """ + for i, module in enumerate(module_list): + self.add_module(f"{name}_{str(i)}", module) diff --git a/viscy/unet/networks/__init__.py b/viscy/unet/networks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/viscy/unet/networks/layers/ConvBlock2D.py b/viscy/unet/networks/layers/ConvBlock2D.py new file mode 100644 index 00000000..114777a7 --- /dev/null +++ b/viscy/unet/networks/layers/ConvBlock2D.py @@ -0,0 +1,377 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ConvBlock2D(nn.Module): + def __init__( + self, + in_filters, + out_filters, + dropout=False, + norm="batch", + residual=True, + activation="relu", + transpose=False, + kernel_size=3, + num_repeats=3, + filter_steps="first", + layer_order="can", + ): + """ + Convolutional block for lateral layers in Unet + + Format for layer initialization is as follows: + if layer type specified + => for number of layers + => add layer to list of that layer type + => register elements of list + This is done to allow for dynamic layer number specification in the conv blocks, + which allows us to change the parameter numbers of the network. + + :param int in_filters: number of images in in stack + :param int out_filters: number of images in out stack + :param float dropout: dropout probability (False => 0) + :param str norm: normalization type: 'batch', 'instance' + :param bool residual: as name + :param str activation: activation function: 'relu', 'leakyrelu', 'elu', 'selu' + :param bool transpose: as name + :param int/tuple kernel_size: convolutional kernel size + :param int num_repeats: number of times the layer_order layer sequence + is repeated in the block + :param str filter_steps: determines where in the block + the filters inflate channels (learn abstraction information): + 'linear','first','last' + :param str layer_order: order of conv, norm, and act layers in block: + 'can', 'cna', 'nca', etc + """ + + super(ConvBlock2D, self).__init__() + self.in_filters = in_filters + self.out_filters = out_filters + self.dropout = dropout + self.norm = norm + self.residual = residual + self.activation = activation + self.transpose = transpose + self.num_repeats = num_repeats + self.filter_steps = filter_steps + self.layer_order = layer_order + + # ---- Handle Kernel ----# + ks = kernel_size + if isinstance(ks, int): + assert ks % 2 == 1, "Kernel dims must be odd" + elif isinstance(ks, tuple): + for i in range(len(ks)): + assert ks[i] % 2 == 1, "Kernel dims must be odd" + assert i == 1, "kernel_size length must be 2" + else: + raise AttributeError("'kernel_size' must be either int or tuple") + self.kernel_size = kernel_size + + # ----- Init Dropout -----# + if self.dropout: + self.drop_list = [] + for i in range(self.num_repeats): + self.drop_list.append(nn.Dropout2d(int(self.dropout))) + + # ---- Init linear filter steps ----# + steps = np.linspace(in_filters, out_filters, num_repeats + 1).astype(int) + + # ----- Init Normalization Layers -----# + # The parameters governing the initiation logic flow are: + # self.norm + # self.num_repeats + # self.filter_steps + self.norm_list = [None for i in range(num_repeats)] + if self.norm == "batch": + for i in range(self.num_repeats): + if self.filter_steps == "linear": + self.norm_list[i] = nn.BatchNorm2d(steps[i + 1]) + elif self.filter_steps == "first": + self.norm_list[i] = nn.BatchNorm2d(steps[-1]) + elif self.filter_steps == "last": + if i < self.num_repeats - 1: + self.norm_list[i] = nn.BatchNorm2d(steps[0]) + else: + self.norm_list[i] = nn.BatchNorm2d(steps[-1]) + elif self.norm == "instance": + for i in range(self.num_repeats): + if self.filter_steps == "linear": + self.norm_list[i] = nn.InstanceNorm2d(steps[i + 1]) + elif self.filter_steps == "first": + self.norm_list[i] = nn.InstanceNorm2d(steps[-1]) + elif self.filter_steps == "last": + if i < self.num_repeats - 1: + self.norm_list[i] = nn.InstanceNorm2d(steps[0]) + else: + self.norm_list[i] = nn.InstanceNorm2d(steps[-1]) + self.register_modules(self.norm_list, f"{norm}_norm") + + # ----- Init Conv Layers -----# + # init conv layers and determine transposition during convolution + # The parameters governing the initiation logic flow are: + # self.transpose + # self.num_repeats + # self.filter_steps + # See above for definitions. + # -------# + + self.conv_list = [] + if self.filter_steps == "linear": # learn progressively over steps + for i in range(self.num_repeats): + depth_pair = ( + (steps[i], steps[i + 1]) + if i + 1 < num_repeats + else (steps[i], steps[-1]) + ) + if self.transpose: + self.conv_list.append( + nn.ConvTranspose2d( + depth_pair[0], + depth_pair[1], + kernel_size=kernel_size, + padding="same", + ) + ) + else: + self.conv_list.append( + nn.Conv2d( + depth_pair[0], + depth_pair[1], + kernel_size=kernel_size, + padding="same", + ) + ) + + elif self.filter_steps == "first": # learn in the first convolution + if self.transpose: + raise NotImplementedError( + "PyTorch-side problem with 'same' padding in ConvTranspose2d." + ) + for i in range(self.num_repeats): + if i == 0: + self.conv_list.append( + nn.ConvTranspose2d( + in_filters, + out_filters, + kernel_size=kernel_size, + padding="same", + ) + ) + else: + self.conv_list.append( + nn.ConvTranspose2d( + out_filters, + out_filters, + kernel_size=kernel_size, + padding="same", + ) + ) + else: + for i in range(self.num_repeats): + if i == 0: + self.conv_list.append( + nn.Conv2d( + in_filters, + out_filters, + kernel_size=kernel_size, + padding="same", + ) + ) + else: + self.conv_list.append( + nn.Conv2d( + out_filters, + out_filters, + kernel_size=kernel_size, + padding="same", + ) + ) + + elif self.filter_steps == "last": # learn in the last convolution + if self.transpose: + raise NotImplementedError( + "Problem with 'same' padding in ConvTranspose2d." + ) + for i in range(self.num_repeats): + if i == self.num_repeats - 1: + self.conv_list.append( + nn.ConvTranspose2d( + in_filters, + out_filters, + kernel_size=kernel_size, + padding="same", + ) + ) + else: + self.conv_list.append( + nn.ConvTranspose2d( + out_filters, + out_filters, + kernel_size=kernel_size, + padding="same", + ) + ) + else: + for i in range(self.num_repeats): + if i == self.num_repeats - 1: + self.conv_list.append( + nn.Conv2d( + in_filters, + out_filters, + kernel_size=kernel_size, + padding="same", + ) + ) + else: + self.conv_list.append( + nn.Conv2d( + in_filters, + in_filters, + kernel_size=kernel_size, + padding="same", + ) + ) + self.register_modules(self.conv_list, "Conv2d") + + # ----- Init Residual Layer -----# + self.resid_conv = nn.Conv2d( + self.in_filters, self.out_filters, kernel_size=1, padding=0 + ) + + # ----- Init Activation Layers -----# + self.act_list = [] + if self.activation == "relu": + for i in range(self.num_repeats): + self.act_list.append(nn.ReLU()) + elif self.activation == "leakyrelu": + for i in range(self.num_repeats): + self.act_list.append(nn.LeakyReLU()) + elif self.activation == "elu": + for i in range(self.num_repeats): + self.act_list.append(nn.ELU()) + elif self.activation == "selu": + for i in range(self.num_repeats): + self.act_list.append(nn.SELU()) + elif self.activation != "linear": + raise NotImplementedError( + f"Activation type {self.activation} not supported." + ) + self.register_modules(self.act_list, f"{self.activation}_act") + + def forward(self, x, validate_input=False): + """ + Forward call of convolutional block + + Order of layers within the block is defined by the 'layer_order' parameter, + which is a string of 'c's, 'a's and 'n's + in reference to convolution, activation, and normalization layers. + This sequence is repeated num_repeats times. + + Recommended layer order: convolution -> activation -> normalization + + Regardless of layer order, + the final layer sequence in the block always ends in activation. + This allows for usage of passthrough layers + or a final output activation function determined separately. + + Residual blocks: + if input channels are greater than output channels, + we use a 1x1 convolution on input to get desired feature channels; + if input channels are less than output channels, + we zero-pad input channels to output channel size. + + :param torch.tensor x: input tensor + :param bool validate_input: Deactivates assertions + which are redundant if forward pass is being traced by tensorboard writer. + """ + if validate_input: + if isinstance(self.kernel_size, int): + assert ( + x.shape[-1] > self.kernel_size and x.shape[-2] > self.kernel_size + ), ( + f"Input size" + f" {x.shape} too small for kernel of size {self.kernel_size}" + ) + elif isinstance(self.kernel_size, tuple): + assert ( + x.shape[-1] > self.kernel_size[-1] + and x.shape[-2] > self.kernel_size[-2] + ), ( + f"Input size" + f" {x.shape} too small for kernel of size {self.kernel_size}" + ) + + x_0 = x + for i in range(self.num_repeats): + order = list(self.layer_order) + while len(order) > 0: + layer = order.pop(0) + if layer == "c": + x = self.conv_list[i](x) + if self.dropout: + x = self.drop_list[i](x) + elif layer == "a": + if i < self.num_repeats - 1 or self.activation != "linear": + x = self.act_list[i](x) + elif layer == "n" and self.norm_list[i]: + x = self.norm_list[i](x) + + # residual summation after final activation/normalization + if self.residual: + if self.in_filters > self.out_filters: + x_0 = self.resid_conv(x_0) + elif self.in_filters < self.out_filters: + x_0 = F.pad( + x_0, + (*[0] * 4, self.out_filters - self.in_filters, *[0] * 3), + mode="constant", + value=0, + ) + x = torch.add(x_0, x) + + return x + + def model(self): + """ + Allows calling of parameters inside ConvBlock object: + 'ConvBlock.model().parameters()'' + + Layer order: convolution -> normalization -> activation + + We can make a list of layer modules and unpack them into nn.Sequential. + Note: this is distinct from the forward call + because we want to use the forward call with addition, + since this is a residual block. + The forward call performs the residial calculation, + and all the parameters can be seen by the optimizer when given this model. + """ + layers = [] + + for i in range(self.num_repeats): + layers.append(self.conv_list[i]) + if self.dropout: + layers.append(self.drop_list[i]) + if self.norm[i]: + layers.append(self.norm_list[i]) + if i < len(self.act_list): + layers.append(self.act_list[i]) + + return nn.Sequential(*layers) + + def register_modules(self, module_list, name): + """ + Helper function that registers modules stored in a list to the model object + so that they can be seen by PyTorch optimizer. + + Used to enable model graph creation + with non-sequential model types and dynamic layer numbers + + :param list(torch.nn.module) module_list: list of modules to register + :param str name: name of module type + """ + for i, module in enumerate(module_list): + self.add_module(f"{name}_{str(i)}", module) diff --git a/viscy/unet/networks/layers/ConvBlock3D.py b/viscy/unet/networks/layers/ConvBlock3D.py new file mode 100644 index 00000000..893c612e --- /dev/null +++ b/viscy/unet/networks/layers/ConvBlock3D.py @@ -0,0 +1,352 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ConvBlock3D(nn.Module): + def __init__( + self, + in_filters, + out_filters, + dropout=False, + norm="batch", + residual=True, + activation="relu", + transpose=False, + kernel_size=(3, 3, 3), + num_repeats=3, + filter_steps="first", + layer_order="can", + padding=None, + ): + """ + Convolutional block for lateral layers in Unet. + This block only accepts tensors of dimensions in + order [...,z,x,y] or [...,z,y,x] + + Format for layer initialization is as follows: + if layer type specified + => for number of layers + => add layer to list of that layer type + This is done to allow for dynamic layer number specification in the conv blocks, + which allows us to change the parameter numbers of the network. + + Only 'same' convolutional padding is recommended, + as the conv blocks are intended for deep Unets. + However padding can be specified as follows: + padding -> token{'same', 'valid', 'valid_stack'} or tuple(int) or int: + -> 'same': pads with same convolution + -> 'valid': pads for valid convolution on all dimensions + -> 'valid_stack': pads for valid convolution on xy dims (-1, -2), + same on z dim (-3). + -> tuple (int): pads above and below corresponding dimensions + -> int: pads above and below all dimensions + + :param int in_filters: number of images in in stack + :param int out_filters: number of images in out stack + :param float dropout: dropout probability (False => 0) + :param str norm: normalization type: 'batch', 'instance' + :param bool residual: as name + :param str activation: activation function: 'relu', 'leakyrelu', 'elu', 'selu' + :param bool transpose: as name + :param int/tuple kernel_size: convolutional kernel size + :param int num_repeats: as name + :param str filter_steps: determines where in the block + the filters inflate channels + (learn abstraction information): 'linear','first','last' + :param str layer_order: order of conv, norm, and act layers in block: + 'can', 'cna', etc. + NOTE: for now conv must always come first as required by norm feature counts + :paramn str/tuple(int)/tuple/None padding: convolutional padding, + see docstring for details + """ + + super(ConvBlock3D, self).__init__() + self.in_filters = in_filters + self.out_filters = out_filters + self.dropout = dropout + self.norm = norm + self.residual = residual + self.activation = activation + self.transpose = transpose + self.num_repeats = num_repeats + self.filter_steps = filter_steps + self.layer_order = layer_order + + # ---- Handle Kernel ----# + ks = kernel_size + if isinstance(ks, int): + assert ks % 2 == 1, "Kernel dims must be odd" + elif isinstance(ks, tuple): + for i in range(len(ks)): + assert ks[i] % 2 == 1, "Kernel dims must be odd" + assert i == 2, "kernel_size length must be 3" + else: + raise AttributeError("'kernel_size' must be either int or tuple") + self.kernel_size = kernel_size + + # ---- Handle Padding ----# + self.pad_type = "same" + self.padding = (ks[2] // 2, ks[1] // 2, ks[0] // 2) + if padding == "valid": + self.padding = (0, 0, 0) + elif self.padding == "valid_stack": # note: deprecated + ks = kernel_size + self.padding = (0, 0, ks[0] // 2) + elif isinstance(padding, tuple): + self.padding = padding + self.padding = tuple(self.padding[i // 2] for i in range(6)) + (0,) * 4 + + # ----- Init Dropout -----# + if self.dropout: + self.drop_list = [] + for i in range(self.num_repeats): + self.drop_list.append(nn.Dropout3d(self.dropout)) + self.register_modules(self.drop_list, "dropout") + + # ---- Init linear filter steps ----# + steps = np.linspace(in_filters, out_filters, num_repeats + 1).astype(int) + + # ----- Init Normalization Layers -----# + self.norm_list = [None for i in range(num_repeats)] + if self.norm == "batch": + for i in range(self.num_repeats): + if self.filter_steps == "linear": + self.norm_list[i] = nn.BatchNorm3d(steps[i + 1]) + elif self.filter_steps == "first": + self.norm_list[i] = nn.BatchNorm3d(steps[-1]) + elif self.filter_steps == "last": + if i < self.num_repeats - 1: + self.norm_list[i] = nn.BatchNorm3d(steps[0]) + else: + self.norm_list[i] = nn.BatchNorm3d(steps[-1]) + elif self.norm == "instance": + for i in range(self.num_repeats): + if self.filter_steps == "linear": + self.norm_list[i] = nn.InstanceNorm3d(steps[i + 1]) + elif self.filter_steps == "first": + self.norm_list[i] = nn.InstanceNorm3d(steps[-1]) + elif self.filter_steps == "last": + if i < self.num_repeats - 1: + self.norm_list[i] = nn.InstanceNorm3d(steps[0]) + else: + self.norm_list[i] = nn.InstanceNorm3d(steps[-1]) + self.register_modules(self.norm_list, f"{norm}_norm") + + # ----- Init Conv Layers -----# + # + # init conv layers and determine transposition during convolution + # The parameters governing the initiation logic flow are: + # self.filter_steps + # self.transpose + # self.num_repeats + # See above for definitions. + # -------# + + self.conv_list = [] + if self.filter_steps == "linear": + for i in range(self.num_repeats): + depth_pair = ( + (steps[i], steps[i + 1]) + if i + 1 < num_repeats + else (steps[i], steps[-1]) + ) + if self.transpose: + self.conv_list.append( + nn.ConvTranspose3d( + depth_pair[0], depth_pair[1], kernel_size=kernel_size + ) + ) + else: + self.conv_list.append( + nn.Conv3d(depth_pair[0], depth_pair[1], kernel_size=kernel_size) + ) + + elif self.filter_steps == "first": + if self.transpose: + for i in range(self.num_repeats): + if i == 0: + self.conv_list.append( + nn.ConvTranspose3d( + in_filters, out_filters, kernel_size=kernel_size + ) + ) + else: + self.conv_list.append( + nn.ConvTranspose3d( + out_filters, out_filters, kernel_size=kernel_size + ) + ) + else: + for i in range(self.num_repeats): + if i == 0: + self.conv_list.append( + nn.Conv3d(in_filters, out_filters, kernel_size=kernel_size) + ) + else: + self.conv_list.append( + nn.Conv3d(out_filters, out_filters, kernel_size=kernel_size) + ) + + elif self.filter_steps == "last": + if self.transpose: + for i in range(self.num_repeats): + if i == self.num_repeats - 1: + self.conv_list.append( + nn.ConvTranspose3d( + in_filters, out_filters, kernel_size=kernel_size + ) + ) + else: + self.conv_list.append( + nn.ConvTranspose3d( + in_filters, in_filters, kernel_size=kernel_size + ) + ) + else: + for i in range(self.num_repeats): + if i == self.num_repeats - 1: + self.conv_list.append( + nn.Conv3d(in_filters, out_filters, kernel_size=kernel_size) + ) + else: + self.conv_list.append( + nn.Conv3d(in_filters, in_filters, kernel_size=kernel_size) + ) + self.register_modules(self.conv_list, "Conv3d") + + # ----- Init Residual Layer -----# + # Note that convolution is only used in residual layer + # when block is shrinking feature space + # Unregistered -- Not a learnable parameter + self.resid_conv = nn.Conv3d( + self.in_filters, self.out_filters, kernel_size=1, padding=0 + ) + + # ----- Init Activation Layers -----# + self.act_list = [] + if self.activation == "relu": + for i in range(self.num_repeats): + self.act_list.append(nn.ReLU()) + elif self.activation == "leakyrelu": + for i in range(self.num_repeats): + self.act_list.append(nn.LeakyReLU()) + elif self.activation == "elu": + for i in range(self.num_repeats): + self.act_list.append(nn.ELU()) + elif self.activation == "selu": + for i in range(self.num_repeats): + self.act_list.append(nn.SELU()) + elif self.activation != "linear": + raise NotImplementedError( + f"Activation type {self.activation} not supported." + ) + self.register_modules(self.act_list, f"{self.activation}_act") + + def forward(self, x): + """ + Forward call of convolutional block + + Order of layers within the block is defined by the 'layer_order' parameter, + which is a string of 'c's, 'a's and 'n's in reference to + convolution, activation, and normalization layers. + This sequence is repeated num_repeats times. + + Recommended layer order: convolution -> activation -> normalization + + Regardless of layer order, + the final layer sequence in the block always ends in activation. + This allows for usage of passthrough layers + or a final output activation function determined separately. + + Residual blocks: + if input channels are greater than output channels, + we use a 1x1 convolution on input to get desired feature channels + if input channels are less than output channels, + we zero-pad input channels to output channel size + + :param torch.tensor x: input tensor + """ + x_0 = x + for i in range(self.num_repeats): + order = list(self.layer_order) + while len(order) > 0: + layer = order.pop(0) + if layer == "c": + x = F.pad(x, self.padding, "constant", 0) + x = self.conv_list[i](x) + if self.dropout: + x = self.drop_list[i](x) + elif layer == "a": + if i < self.num_repeats - 1 or self.activation != "linear": + x = self.act_list[i](x) + elif layer == "n" and self.norm_list[i]: + x = self.norm_list[i](x) + + # residual summation comes after final activation/normalization + if self.residual: + # pad/collapse feature dimension + if self.in_filters > self.out_filters: + x_0 = self.resid_conv(x_0) + elif self.in_filters < self.out_filters: + x_0 = F.pad( + x_0, + (*[0] * 6, self.out_filters - self.in_filters, *[0] * 3), + mode="constant", + value=0, + ) + + # fit xy dimensions + if self.pad_type == "valid_stack": + lost = [dim // 2 * self.num_repeats for dim in self.kernel_size[1:]] + x_0 = x_0[ + ..., + lost[0] : x_0.shape[-2] - lost[0], + lost[1] : x_0.shape[-1] - lost[1], + ] + + x = torch.add(x, x_0) + + return x + + def model(self): + """ + Allows calling of parameters inside ConvBlock object: + 'ConvBlock.model().parameters()'' + + Layer order: convolution -> normalization -> activation + + We can make a list of layer modules and unpack them into nn.Sequential. + Note: this is distinct from the forward call + because we want to use the forward call with addition, + since this is a residual block. + The forward call performs the residual calculation, + and all the parameters can be seen by the optimizer when given this model. + """ + layers = [] + + for i in range(self.num_repeats): + layers.append(self.conv_list[i]) + if self.dropout: + layers.append(self.drop_list[i]) + if self.norm[i]: + layers.append(self.norm_list[i]) + if i < len(self.act_list): + layers.append(self.act_list[i]) + + return nn.Sequential(*layers) + + def register_modules(self, module_list, name): + """ + Helper function that registers modules stored in a list to the model object + so that the can be seen by PyTorch optimizer. + + Used to enable model graph creation + with non-sequential model types and dynamic layer numbers + + :param list(torch.nn.module) module_list: list of modules to register + :param str name: name of module type + """ + for i, module in enumerate(module_list): + self.add_module(f"{name}_{str(i)}", module) diff --git a/viscy/unet/networks/layers/__init__.py b/viscy/unet/networks/layers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/viscy/unet/utils/logging.py b/viscy/unet/utils/logging.py new file mode 100644 index 00000000..33c66f9d --- /dev/null +++ b/viscy/unet/utils/logging.py @@ -0,0 +1,284 @@ +import datetime +import os +import time + +import torch + +from viscy.utils.cli_utils import save_figure +from viscy.utils.normalize import hist_clipping + + +def log_feature(feature_map, name, log_save_folder, debug_mode): + """ + If self.debug_mode, creates a visual of the given feature map, and saves it at + 'log_save_folder' + If no log_save_folder specified, saves relative to working directory with timestamp. + + Currently only saving in working directory is supported. + This is meant to be an analysis tool, + and results should not be saved permanently. + + :param torch.tensor feature_map: feature map to create visualization log of + :param str name: string + :param str log_save_folder + """ + try: + if debug_mode: + now = datetime.datetime.now() + log_save_folder = ( + f"feature_map_{now.year}_{now.month}_" + f"{now.day}_{now.hour}_{now.minute}/" + ) + logger = FeatureLogger( + save_folder=log_save_folder, + spatial_dims=3, + grid_width=8, + ) + logger.log_feature_map( + feature_map, + name, + dim_names=["batch", "channels"], + ) + except Exception: + print( + "Features of one input logged. Results saved at:" + f"\n\t {log_save_folder}. Will not log to avoid overwrite. \n" + ) + + +class FeatureLogger: + def __init__( + self, + save_folder, + spatial_dims=3, + full_batch=False, + save_as_grid=True, + grid_width=0, + normalize_by_grid=False, + ): + """ + Logger object for handling logging feature maps inside network architectures. + + Saves each 2d slice of a feature map in either a single grid per feature map + stack or a directory tree of labeled slices. + + By default saves images into grid. + + :param str save_folder: output directory + :param bool full_batch: if true, log all sample in batch (warning slow!), + defaults to False + :param bool save_as_grid: if true feature maps are to be saved as a grid + containing all channels, else saved individually, + defaults to True + :param int grid_width: desired width of grid if save_as_grid, by default + 1/4 the number of channels, defaults to 0 + :param bool normalize_by_grid: if true, images saved in grid are normalized + to brightest pixel in entire grid, defaults to False + + """ + self.save_folder = save_folder + self.spatial_dims = spatial_dims + self.full_batch = full_batch + self.save_as_grid = save_as_grid + self.grid_width = grid_width + self.normalize_by_grid = normalize_by_grid + + print("--- Initializing Logger ---") + + def log_feature_map( + self, + feature_map, + feature_name, + dim_names=[], + vmax=0, + ): + """ + Creates a log of figures the given feature map tensor at 'save_folder'. + Log is saved as images of feature maps in nested directory tree. + + By default _assumes that batch dimension is the first dimension_, and + only logs the first sample in the batch, for performance reasons. + + Feature map logs cannot overwrite. + + :param torch.Tensor feature_map: feature map to log (typically 5d tensor) + :parapm str feature_name: name of feature (will be used as dir name) + :param list dim_names: names of each dimension, by default just numbers + :param int spatial_dims: number of spatial dims, defaults to 3 + :param float vmax: maximum intensity to normalize figures by, by default + (if given 0) does relative normalization + """ + # take tensor off of gpu and detach gradient + feature_map = feature_map.detach().cpu() + + # handle dim names + num_dims = len(feature_map.shape) + if len(dim_names) == 0: + dim_names = ["dim_" + str(i) for i in range(len(num_dims))] + else: + assert len(dim_names) + self.spatial_dims == num_dims, ( + "dim_names must be " "same length as nonspatial tensor dim length" + ) + self.dim_names = dim_names + + # handle current feature_name + feature_name = " " + feature_name if len(feature_name) > 0 else "" + print(f"Logging{feature_name} feature map...", end="") + self.feature_save_folder = os.path.join(self.save_folder, feature_name) + + start = time.time() + self.map_feature_dims(feature_map, self.save_as_grid, vmax=vmax) + + print(f"done. Took {time.time() - start:.2f} seconds") + + def map_feature_dims( + self, + feature_map, + save_as_grid, + vmax=0, + depth=0, + ): + """ + Recursive directory creation for organizing feature map logs + + If save_as_grid, will compile 'channels' (assumed to be last + non-spatial dimension) into a single large image grid before saving. + + :param numpy.ndarray feature_map: see name + :param str save_dir: see name + :param bool save_as_grid: if true, saves images as channel grid + :param float vmax: maximum intensity to normalize figures by + :param int depth: recursion counter. depth in dimensions + """ + + for i in range(feature_map.shape[0]): + if len(feature_map.shape) == 3: + # individual saving + z_slice = feature_map[i] + save_figure( + z_slice.unsqueeze(0), + self.feature_save_folder, + f"z_slice_{i}", + vmax=vmax, + ) + + elif len(feature_map.shape) == 4 and save_as_grid: + if feature_map.shape[0] == 1: + # if a single channel, can't save as grid + self.map_feature_dims( + feature_map, + save_as_grid=False, + depth=depth, + ) + else: + # grid saving + for z_depth in range(feature_map.shape[1]): + # set grid_width + if self.grid_width == 0: + if feature_map.shape[0] % 4 != 0: + raise AttributeError( + f"number of channels ({feature_map.shape[0]}) " + "must be divisible by 4 if grid_width unspecified" + ) + self.grid_width = feature_map.shape[0] // 4 + else: + if feature_map.shape[0] % self.grid_width != 0: + raise AttributeError( + f"Grid width {self.grid_width} must be a divisor " + f"of the number of channels {feature_map.shape[0]}" + ) + # build grid by rows + # interleaving bars for ease of visualization + feature_map_grid = [] + current_grid_row = [] + + for channel_num in range(feature_map.shape[0]): + # build rows by item in col + col_num = channel_num % self.grid_width + if col_num == 0 and channel_num != 0: + feature_map_grid.append( + torch.cat( + self.interleave_bars(current_grid_row, axis=1), + dim=1, + ) + ) + current_grid_row = [] + + # get 2d slice + map_slice = feature_map[channel_num, z_depth] + + # norm slice to (0,1) unless normalize_by_grid + # which is done later + if not self.normalize_by_grid: + map_slice = torch.tensor( + hist_clipping( + map_slice.numpy(), + min_percentile=0, + max_percentile=100, + ) + ) + map_slice = ( + map_slice - torch.min(map_slice) + ) / torch.max(map_slice) + + current_grid_row.append(map_slice) + feature_map_grid.append( + torch.cat( + self.interleave_bars(current_grid_row, axis=1), + dim=1, + ) + ) + feature_map_grid = torch.cat( + self.interleave_bars(feature_map_grid, axis=0), dim=0 + ) + save_figure( + torch.unsqueeze(feature_map_grid, 0), + self.feature_save_folder, + f"z_slice_{z_depth}_channels_0-{feature_map.shape[0]}", + vmax=vmax, + ) + break + else: + # tree recursion + try: + name = os.path.join( + self.feature_save_folder, self.dim_names[depth] + f"_{i}" + ) + except Exception: + raise AttributeError("error in recursion") + os.makedirs(name, exist_ok=False) + self.map_feature_dims( + feature_map[i], + name, + save_as_grid, + depth=depth + 1, + ) + + if depth == 0 and not self.full_batch: + break + return + + def interleave_bars(self, arrays, axis, pixel_width=3, value=0): + """ + Takes list of 2d torch tensors and interleaves bars to improve + grid visualization quality. + Assumes arrays are all of the same shape. + + :param list grid_arrays: list of tensors to place bars between + :param int axis: axis on which to interleave bars (0 or 1) + :param int pixel_width: width of bar, defaults to 3 + :param int value: value of bar pixels, defaults to 0 + """ + shape_match_axis = abs(axis - 1) + length = arrays[0].shape[shape_match_axis] + + if axis == 0: + bar = torch.ones((pixel_width, length)) * value + elif axis == 1: + bar = torch.ones((length, pixel_width)) * value + else: + raise AttributeError("axis must be 0 or 1") + + for i in range(1, len(arrays) * 2 - 1, 2): + arrays.insert(i, bar) + return arrays diff --git a/viscy/unet/utils/model.py b/viscy/unet/utils/model.py new file mode 100644 index 00000000..cf888ea8 --- /dev/null +++ b/viscy/unet/utils/model.py @@ -0,0 +1,123 @@ +import torch + +import viscy.unet.networks.Unet2D as Unet2D +import viscy.unet.networks.Unet25D as Unet25D + + +def model_init(network_config, device=torch.device("cuda"), debug_mode=False): + """ + Initializes network model from a configuration dictionary. + + :param dict network_config: dict containing the configuration parameters for + the model + :param torch.device device: device to store model parameters on (must be same + as data) + """ + if device == "gpu": + device = "cuda" + + assert ( + "architecture" in network_config + ), "Must specify network architecture: 2D, 2.5D" + + if network_config["architecture"] == "2.5D": + default_model = ModelDefaults25D() + model_class = Unet25D.Unet25d + model = define_model( + model_class, + default_model, + network_config, + ) + elif network_config["architecture"] == "2D": + default_model = ModelDefaults2D() + model_class = Unet2D.Unet2d + model = define_model( + model_class, + default_model, + network_config, + ) + else: + raise NotImplementedError("Only 2.5D and 2D architectures available.") + + model.debug_mode = debug_mode + + model.to(device) + + return model + + +def define_model(model_class, model_defaults, config): + """ + Returns an instance of the model given the parameter config and specified + defaults. The model weights are not on cpu at this point. + + :param nn.Module model_class: actual model class to pass defaults into + :param ModelDefaults model_defaults: default model corresponding to config + :param dict config: _description_ + """ + kwargs = {} + for param_name in vars(model_defaults): + if param_name in config: + kwargs[param_name] = config[param_name] + else: + kwargs[param_name] = model_defaults.get(param_name) + + return model_class(**kwargs) + + +class ModelDefaults: + def __init__(self): + """ + Parent class of the model defaults objects. + """ + + def get(self, varname): + """ + Logic for getting an attribute of the default parameters class + + :param str varname: name of attribute + """ + return getattr(self, varname) + + +class ModelDefaults2D(ModelDefaults): + def __init__(self): + """ + Instance of model defaults class, containing all of the default + hyper-parameters for the 2D unet + + All parameters in this default model CAN be accessed by name through + the model config + """ + super(ModelDefaults, self).__init__() + self.in_channels = 1 + self.out_channels = 1 + self.kernel_size = (3, 3) + self.residual = False + self.dropout = 0.2 + self.num_blocks = 4 + self.num_block_layers = 2 + self.num_filters = [] + self.task = "reg" + + +class ModelDefaults25D(ModelDefaults): + def __init__(self): + """ + Instance of default model class, containing all of the default + hyper-parameters for the 2D unet. + + All parameters in this default model CAN be accessed by name through + the model config + """ + self.in_channels = 1 + self.out_channels = 1 + self.in_stack_depth = 5 + self.out_stack_depth = 1 + self.xy_kernel_size = (3, 3) + self.residual = False + self.dropout = 0.2 + self.num_blocks = 4 + self.num_block_layers = 2 + self.num_filters = [] + self.task = "reg" diff --git a/viscy/utils/__init__.py b/viscy/utils/__init__.py new file mode 100644 index 00000000..5e2c7e1e --- /dev/null +++ b/viscy/utils/__init__.py @@ -0,0 +1 @@ +"""Module for utility functions""" diff --git a/viscy/utils/aux_utils.py b/viscy/utils/aux_utils.py new file mode 100644 index 00000000..f49137be --- /dev/null +++ b/viscy/utils/aux_utils.py @@ -0,0 +1,99 @@ +"""Auxiliary utility functions""" + +import iohub.ngff as ngff +import yaml + + +def _assert_unique_subset(subset, superset, name): + """ + Helper function to allow for clean code: + Throws error if unique elements of subset are not a subset of + unique elements of superset. + + Returns unique elements of subset if given a list. If subset is -1, + returns all unique elements of superset + """ + if subset == -1: + subset = superset + if not (isinstance(subset, list) or isinstance(subset, tuple)): + subset = list(subset) + unique_subset = set(subset) + unique_superset = set(superset) + assert unique_subset.issubset(unique_superset), ( + f"{name} in requested {name}: {unique_subset}" + f" not in available {name}: {unique_superset}" + ) + return unique_subset + + +def validate_metadata_indices( + zarr_dir, + time_ids=[], + channel_ids=[], + slice_ids=[], + pos_ids=[], +): + """ + Check the availability of indices provided timepoints, channels, positions + and slices for all data, and returns only the available of the specified + indices. + + If input ids are None, the indices for that parameter will not be + evaluated. If input ids are -1, all indices for that parameter will + be returned. + + Assumes uniform structure, as such structure is required for HCS compatibility + + :param str zarr_dir: HCS-compatible zarr directory to validate indices against + :param list time_ids: check availability of these timepoints in image + metadata + :param list channel_ids: check availability of these channels in image + metadata + :param list pos_ids: Check availability of positions in zarr_dir + :param list slice_ids: Check availability of z slices in image metadata + + :return dict indices_metadata: All indices found given input + :raise AssertionError: If not all channels, timepoints, positions + or slices are present + """ + plate = ngff.open_ome_zarr(zarr_dir, layout="hcs", mode="r") + position_path, position = next(plate.positions()) + + # read available channel indices from zarr store + available_time_ids = range(position.data.shape[0]) + if isinstance(channel_ids, int): + available_channel_ids = range(len(plate.channel_names)) + elif isinstance(channel_ids[0], int): + available_channel_ids = range(len(plate.channel_names)) + else: + available_channel_ids = len(plate.channel_names) + available_slice_ids = range(position.data.shape[-3]) + available_pos_ids = [x[0] for x in list(plate.positions())] + + # enforce that requested indices are subsets of available indices + time_ids = _assert_unique_subset(time_ids, available_time_ids, "slices") + channel_ids = _assert_unique_subset(channel_ids, available_channel_ids, "channels") + slice_ids = _assert_unique_subset(slice_ids, available_slice_ids, "slices") + pos_ids = _assert_unique_subset(pos_ids, available_pos_ids, "positions") + + indices_metadata = { + "time_ids": list(time_ids), + "channel_ids": list(channel_ids), + "slice_ids": list(slice_ids), + "pos_ids": list(pos_ids), + } + plate.close() + return indices_metadata + + +def read_config(config_fname): + """Read the config file in yml format + + :param str config_fname: fname of config yaml with its full path + :return: dict config: Configuration parameters + """ + + with open(config_fname, "r") as f: + config = yaml.safe_load(f) + + return config diff --git a/viscy/utils/cli_utils.py b/viscy/utils/cli_utils.py new file mode 100644 index 00000000..c3a003e2 --- /dev/null +++ b/viscy/utils/cli_utils.py @@ -0,0 +1,119 @@ +import collections +import os +import re + +import numpy as np +import torch +from PIL import Image + + +def unique_tags(directory): + """ + Returns list of unique nume tags from data directory + + :param str directory: directory containing '.tif' files + TODO: Remove, unused and poorly written + """ + files = [ + f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f)) + ] + + tags = collections.defaultdict(lambda: 0) + for f in files: + f_name, f_type = f.split(".")[0], f.split(".")[1] + if f_type == "tif": + suffixes = re.split("_", f_name) + + unique_tag = suffixes[2] + "_" + suffixes[3] + "_" + suffixes[4] + tags[unique_tag + "." + f_type] += 1 + return tags + + +class MultiProcessProgressBar(object): + """ + Provides the ability to create & update a single progress bar for multi-depth + multi-processed tasks by calling updates on a single object + """ + + def __init__(self, total_updates): + self.dataloader = list(range(total_updates)) + self.current = 0 + + def tick(self, process): + self.current += 1 + show_progress_bar(self.dataloader, self.current, process) + + +def show_progress_bar(dataloader, current, process="training", interval=1): + """ + Utility function to print tensorflow-like progress bar. + + Written instead of using tqdm to allow for custom progress bar readouts. + + :param iterable dataloader: dataloader currently being processed + :param int current: current index in dataloader + :param str proces: current process being performed + :param int interval: interval at which to update progress bar + """ + current += 1 + bar_length = 50 + fraction_computed = current / dataloader.__len__() + + if current % interval != 0 and fraction_computed < 1: + return + + # pointer = ">" if fraction_computed < 1 else "=" + loading_string = ( + "=" * int(bar_length * fraction_computed) + + ">" + + "_" * int(bar_length * (1 - fraction_computed)) + ) + output_string = ( + f"\t {process} {current}/{dataloader.__len__()} " + f"[{loading_string}] ({int(fraction_computed * 100)}%)" + ) + + if fraction_computed <= (dataloader.__len__() - interval) / dataloader.__len__(): + print(" " * (bar_length + len(process) + 5), end="\r") + print(output_string, end="\r") + else: + print(output_string) + + +def save_figure(data, save_folder, name, title=None, vmax=0, ext=".png"): + """ + Saves .png or .jpeg figure of data to folder save_folder under 'name'. + 'data' must be a 3d tensor or numpy array, in channels_first format + + :param numpy.ndarray/torch.tensor data: input image/stack data to save + :param str save_folder: global path to folder where data is saved. + :param str name: name of data, no extension specified + :param str/None title: image title, if none specified, defaults used + :param float vmax: value to normalize figure to, by default uses data max + :param str ext: image save file extension + """ + assert len(data.shape) == 3, f"'{len(data.shape)}d' data must be 3-dimensional" + + if isinstance(data, torch.Tensor): + data = data.detach().cpu().numpy() + elif not isinstance(data, np.ndarray): + raise AttributeError( + f"'data' of type {type(data)} must be torch tensor" " or numpy array." + ) + if vmax == 0: + vmax = np.max(data) + + # normalize and convert to uint8 + data = np.array(((data - np.min(data)) / float(vmax)) * 255, dtype=np.uint8) + + # save + if data.shape[-3] > 1: + data = np.mean(data, 0) + im = Image.fromarray(data).convert("L") + im.info["size"] = data.shape + im.save(os.path.join(save_folder, name + ext)) + else: + data = data[0] + im = Image.fromarray(data).convert("L") + im.info["size"] = data.shape + im.save(os.path.join(save_folder, name + ext)) diff --git a/viscy/utils/image_utils.py b/viscy/utils/image_utils.py new file mode 100644 index 00000000..f9020dc9 --- /dev/null +++ b/viscy/utils/image_utils.py @@ -0,0 +1,106 @@ +"""Utility functions for processing images""" + +import itertools +import sys + +import numpy as np + +import viscy.utils.normalize as normalize + + +def im_bit_convert(im, bit=16, norm=False, limit=[]): + im = im.astype( + np.float32, copy=False + ) # convert to float32 without making a copy to save memory + if norm: + if not limit: + # scale each image individually based on its min and max + limit = [np.nanmin(im[:]), np.nanmax(im[:])] + im = ( + (im - limit[0]) + / (limit[1] - limit[0] + sys.float_info.epsilon) + * (2**bit - 1) + ) + im = np.clip( + im, 0, 2**bit - 1 + ) # clip the values to avoid wrap-around by np.astype + if bit == 8: + im = im.astype(np.uint8, copy=False) # convert to 8 bit + else: + im = im.astype(np.uint16, copy=False) # convert to 16 bit + return im + + +def im_adjust(img, tol=1, bit=8): + """ + Stretches contrast of the image and converts to 'bit'-bit. + Useful for weight-maps in masking + """ + limit = np.percentile(img, [tol, 100 - tol]) + im_adjusted = im_bit_convert(img, bit=bit, norm=True, limit=limit.tolist()) + return im_adjusted + + +def grid_sample_pixel_values(im, grid_spacing): + """Sample pixel values in the input image at the grid. Any incomplete + grids (remainders of modulus operation) will be ignored. + + :param np.array im: 2D image + :param int grid_spacing: spacing of the grid + :return int row_ids: row indices of the grids + :return int col_ids: column indices of the grids + :return np.array sample_values: sampled pixel values + """ + + im_shape = im.shape + assert grid_spacing < im_shape[0], "grid spacing larger than image height" + assert grid_spacing < im_shape[1], "grid spacing larger than image width" + # leave out the grid points on the edges + sample_coords = np.array( + list( + itertools.product( + np.arange(grid_spacing, im_shape[0], grid_spacing), + np.arange(grid_spacing, im_shape[1], grid_spacing), + ) + ) + ) + row_ids = sample_coords[:, 0] + col_ids = sample_coords[:, 1] + sample_values = im[row_ids, col_ids] + return row_ids, col_ids, sample_values + + +def preprocess_image( + im, + hist_clip_limits=None, + is_mask=False, + normalize_im=None, + zscore_mean=None, + zscore_std=None, +): + """ + Do histogram clipping, z score normalization, and potentially binarization. + + :param np.array im: Image (stack) + :param tuple hist_clip_limits: Percentile histogram clipping limits + :param bool is_mask: True if mask + :param str/None normalize_im: Normalization, if any + :param float/None zscore_mean: Data mean + :param float/None zscore_std: Data std + """ + # remove singular dimension for 3D images + if len(im.shape) > 3: + im = np.squeeze(im) + if not is_mask: + if hist_clip_limits is not None: + im = normalize.hist_clipping(im, hist_clip_limits[0], hist_clip_limits[1]) + if normalize_im is not None: + im = normalize.zscore( + im, + im_mean=zscore_mean, + im_std=zscore_std, + ) + else: + if im.dtype != bool: + im = im > 0 + return im diff --git a/viscy/utils/masks.py b/viscy/utils/masks.py new file mode 100644 index 00000000..a0881fa0 --- /dev/null +++ b/viscy/utils/masks.py @@ -0,0 +1,212 @@ +import numpy as np +import scipy.ndimage as ndimage +from scipy.ndimage import binary_fill_holes +from skimage.filters import gaussian, laplace, threshold_otsu +from skimage.morphology import ( + ball, + binary_dilation, + binary_opening, + disk, + remove_small_objects, +) + + +def create_otsu_mask(input_image, sigma=0.6): + """Create a binary mask using morphological operations + :param np.array input_image: generate masks from this 3D image + :param float sigma: Gaussian blur standard deviation, + increase in value increases blur + :return: volume mask of input_image, 3D np.array + """ + + input_sz = input_image.shape + mid_slice_id = input_sz[0] // 2 + + thresh = threshold_otsu(input_image[mid_slice_id, :, :]) + mask = input_image >= thresh + + return mask + + +def create_membrane_mask(input_image, str_elem_size=23, sigma=0.4, k_size=3, msize=120): + """Create a binary mask using Laplacian of Gaussian (LOG) feature detection + + :param np.array input_image: generate masks from this image + :param int str_elem_size: size of the laplacian filter + used for contarst enhancement, odd number. + Increase in value increases sensitivity of contrast enhancement + :param float sigma: Gaussian blur standard deviation + :param int k_size: disk/ball size for mask dilation, + ball for 3D and disk for 2D data + :param int msize: size of small objects removed to clean segmentation + :return: mask of input_image, np.array + """ + + input_image_blur = gaussian(input_image, sigma=sigma) + + input_Lapl = laplace(input_image_blur, ksize=str_elem_size) + + thresh = threshold_otsu(input_Lapl) + mask_bin = input_Lapl >= thresh + + if len(input_image.shape) == 2: + str_elem = disk(k_size) + else: + str_elem = ball(k_size) + + mask_dilated = binary_dilation(mask_bin, str_elem) + + mask = remove_small_objects(mask_dilated, min_size=msize) + + return mask + + +def get_unimodal_threshold(input_image): + """Determines optimal unimodal threshold + + https://users.cs.cf.ac.uk/Paul.Rosin/resources/papers/unimodal2.pdf + https://www.mathworks.com/matlabcentral/fileexchange/45443-rosin-thresholding + + :param np.array input_image: generate mask for this image + :return float best_threshold: optimal lower threshold for the foreground + hist + """ + + hist_counts, bin_edges = np.histogram( + input_image, + bins=256, + range=(input_image.min(), np.percentile(input_image, 99.5)), + ) + bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 + + # assuming that background has the max count + max_idx = np.argmax(hist_counts) + int_with_max_count = bin_centers[max_idx] + p1 = [int_with_max_count, hist_counts[max_idx]] + + # find last non-empty bin + pos_counts_idx = np.where(hist_counts > 0)[0] + last_binedge = pos_counts_idx[-1] + p2 = [bin_centers[last_binedge], hist_counts[last_binedge]] + + best_threshold = -np.inf + max_dist = -np.inf + for idx in range(max_idx, last_binedge, 1): + x0 = bin_centers[idx] + y0 = hist_counts[idx] + a = [p1[0] - p2[0], p1[1] - p2[1]] + b = [x0 - p2[0], y0 - p2[1]] + cross_ab = a[0] * b[1] - b[0] * a[1] + per_dist = np.linalg.norm(cross_ab) / np.linalg.norm(a) + if per_dist > max_dist: + best_threshold = x0 + max_dist = per_dist + assert best_threshold > -np.inf, "Error in unimodal thresholding" + return best_threshold + + +def create_unimodal_mask(input_image, str_elem_size=3, sigma=0.6): + """ + Create a mask with unimodal thresholding and morphological operations. + Unimodal thresholding seems to oversegment, erode it by a fraction + + :param np.array input_image: generate masks from this image + :param int str_elem_size: size of the structuring element. typically 3, 5 + :param float sigma: gaussian blur standard deviation + :return mask of input_image, np.array + """ + + input_image = gaussian(input_image, sigma=sigma) + + if np.min(input_image) == np.max(input_image): + thr = np.unique(input_image) + else: + thr = get_unimodal_threshold(input_image) + if len(input_image.shape) == 2: + str_elem = disk(str_elem_size) + else: + str_elem = ball(str_elem_size) + # remove small objects in mask + mask = input_image >= thr + mask = binary_opening(mask, str_elem) + mask = binary_fill_holes(mask) + return mask + + +def get_unet_border_weight_map(annotation, w0=10, sigma=5): + """ + Return weight map for borders as specified in UNet paper + :param annotation A 2D array of shape (image_height, image_width) + contains annotation with each class labeled as an integer. + :param w0 multiplier to the exponential distance loss + default 10 as mentioned in UNet paper + :param sigma standard deviation in the exponential distance term + e^(-d1 + d2) ** 2 / 2 (sigma ^ 2) + default 5 as mentioned in UNet paper + :return weight mapt for borders as specified in UNet + + TODO: Calculate boundaries directly and calculate distance + from boundary of cells to another + Note: The below method only works for UNet Segmentation only + """ + # if there is only one label, zero return the array as is + if np.sum(annotation) == 0: + return annotation + + # Masks could be saved as .npy bools, if so convert to uint8 and generate + # labels from binary + if annotation.dtype == bool: + annotation = annotation.astype(np.uint8) + assert annotation.dtype in [ + np.uint8, + np.uint16, + ], "Expected data type uint, it is {}".format(annotation.dtype) + + # cells instances for distance computation + # 4 connected i.e default (cross-shaped) + # structuring element to measure connectivy + # If cells are 8 connected/touching they are labeled as one single object + # Loss metric on such borders is not useful + labeled_array, _ = ndimage.measurements.label(annotation) + # class balance weights w_c(x) + unique_values = np.unique(labeled_array).tolist() + weight_map = [0] * len(unique_values) + for index, unique_value in enumerate(unique_values): + mask = np.zeros((annotation.shape[0], annotation.shape[1]), dtype=np.float64) + mask[annotation == unique_value] = 1 + weight_map[index] = 1 / mask.sum() + + # this normalization is important - foreground pixels must have weight 1 + weight_map = [i / max(weight_map) for i in weight_map] + + wc = np.zeros((annotation.shape[0], annotation.shape[1]), dtype=np.float64) + for index, unique_value in enumerate(unique_values): + wc[annotation == unique_value] = weight_map[index] + + # cells distance map + border_loss_map = np.zeros( + (annotation.shape[0], annotation.shape[1]), dtype=np.float64 + ) + distance_maps = np.zeros( + (annotation.shape[0], annotation.shape[1], np.max(labeled_array)), + dtype=np.float64, + ) + + if np.max(labeled_array) >= 2: + for index in range(np.max(labeled_array)): + mask = np.ones_like(labeled_array) + mask[labeled_array == index + 1] = 0 + distance_maps[:, :, index] = ndimage.distance_transform_edt(mask) + distance_maps = np.sort(distance_maps, 2) + d1 = distance_maps[:, :, 0] + d2 = distance_maps[:, :, 1] + border_loss_map = w0 * np.exp((-1 * (d1 + d2) ** 2) / (2 * (sigma**2))) + + zero_label = np.zeros((annotation.shape[0], annotation.shape[1]), dtype=np.float64) + zero_label[labeled_array == 0] = 1 + border_loss_map = np.multiply(border_loss_map, zero_label) + + # unet weight map mask + weight_map_mask = wc + border_loss_map + + return weight_map_mask diff --git a/viscy/utils/meta_utils.py b/viscy/utils/meta_utils.py new file mode 100644 index 00000000..d7b8738b --- /dev/null +++ b/viscy/utils/meta_utils.py @@ -0,0 +1,223 @@ +import os +import sys + +import iohub.ngff as ngff +import numpy as np +import pandas as pd + +import viscy.utils.mp_utils as mp_utils +from viscy.utils.cli_utils import show_progress_bar + + +def write_meta_field(position: ngff.Position, metadata, field_name, subfield_name): + """ + Writes 'metadata' to position's plate-level or FOV level .zattrs metadata by either + creating a new field (field_name) according to 'metadata', or updating the metadata + to an existing field if found, + or concatenating the metadata from different channels. + + Assumes that the zarr store group given follows the OMG-NGFF HCS + format as specified here: + https://ngff.openmicroscopy.org/latest/#hcs-layout + + Warning: Dangerous. Writing metadata fields above the image-level of + an HCS hierarchy can break HCS compatibility + + :param Position zarr_dir: NGFF position node object + :param dict metadata: metadata dictionary to write to JSON .zattrs + :param str subfield_name: name of subfield inside the the main field + (values for different channels) + """ + if field_name in position.zattrs: + if subfield_name in position.zattrs[field_name]: + position.zattrs[field_name][subfield_name].update(metadata) + else: + D1 = position.zattrs[field_name] + field_metadata = { + subfield_name: metadata, + } + # position.zattrs[field_name][subfield_name] = metadata + position.zattrs[field_name] = {**D1, **field_metadata} + else: + field_metadata = { + subfield_name: metadata, + } + position.zattrs[field_name] = field_metadata + + +def generate_normalization_metadata( + zarr_dir, + num_workers=4, + channel_ids=-1, + grid_spacing=32, +): + """ + Generate pixel intensity metadata to be later used in on-the-fly normalization + during training and inference. Sampling is used for efficient estimation of median + and interquartile range for intensity values on both a dataset and field-of-view + level. + + Normalization values are recorded in the image-level metadata in the corresponding + position of each zarr_dir store. Format of metadata is as follows: + { + channel_idx : { + dataset_statistics: dataset level normalization values (positive float), + fov_statistics: field-of-view level normalization values (positive float) + }, + . + . + . + } + + :param str zarr_dir: path to zarr store directory containing dataset. + :param int num_workers: number of cpu workers for multiprocessing, defaults to 4 + :param list/int channel_ids: indices of channels to process in dataset arrays, + by default calculates all + :param int grid_spacing: distance between points in sampling grid + """ + plate = ngff.open_ome_zarr(zarr_dir, mode="r+") + position_map = list(plate.positions()) + + if channel_ids == -1: + channel_ids = range(len(plate.channel_names)) + elif isinstance(channel_ids, int): + channel_ids = [channel_ids] + + # get arguments for multiprocessed grid sampling + mp_grid_sampler_args = [] + for _, position in position_map: + mp_grid_sampler_args.append([position, grid_spacing]) + + # sample values and use them to get normalization statistics + for i, channel in enumerate(channel_ids): + show_progress_bar( + dataloader=channel_ids, + current=i, + process="sampling channel values", + ) + + channel_name = plate.channel_names[channel] + this_channels_args = tuple([args + [channel] for args in mp_grid_sampler_args]) + + # NOTE: Doing sequential mp with pool execution creates synchronization + # points between each step. This could be detrimental to performance + positions, fov_sample_values = mp_utils.mp_sample_im_pixels( + this_channels_args, num_workers + ) + dataset_sample_values = np.stack(fov_sample_values, 0) + + fov_level_statistics = mp_utils.mp_get_val_stats(fov_sample_values, num_workers) + dataset_level_statistics = mp_utils.get_val_stats(dataset_sample_values) + + dataset_statistics = { + "dataset_statistics": dataset_level_statistics, + } + + write_meta_field( + position=plate, + metadata=dataset_statistics, + field_name="normalization", + subfield_name=channel_name, + ) + + for j, pos in enumerate(positions): + show_progress_bar( + dataloader=position_map, + current=j, + process=f"calculating channel statistics {channel}/{list(channel_ids)}", + ) + position_statistics = { + "fov_statistics": fov_level_statistics[j], + } + + write_meta_field( + position=pos, + metadata=position_statistics, + field_name="normalization", + subfield_name=channel_name, + ) + plate.close() + + +def compute_zscore_params( + frames_meta, ints_meta, input_dir, normalize_im, min_fraction=0.99 +): + """ + Get zscore median and interquartile range + + :param pd.DataFrame frames_meta: Dataframe containing all metadata + :param pd.DataFrame ints_meta: Metadata containing intensity statistics + each z-slice and foreground fraction for masks + :param str input_dir: Directory containing images + :param None or str normalize_im: normalization scheme for input images + :param float min_fraction: Minimum foreground fraction (in case of masks) + for computing intensity statistics. + + :return pd.DataFrame frames_meta: Dataframe containing all metadata + :return pd.DataFrame ints_meta: Metadata containing intensity statistics + each z-slice + """ + + assert normalize_im in [ + None, + "slice", + "volume", + "dataset", + ], 'normalize_im must be None or "slice" or "volume" or "dataset"' + + if normalize_im is None: + # No normalization + frames_meta["zscore_median"] = 0 + frames_meta["zscore_iqr"] = 1 + return frames_meta + elif normalize_im == "dataset": + agg_cols = ["time_idx", "channel_idx", "dir_name"] + elif normalize_im == "volume": + agg_cols = ["time_idx", "channel_idx", "dir_name", "pos_idx"] + else: + agg_cols = ["time_idx", "channel_idx", "dir_name", "pos_idx", "slice_idx"] + # median and inter-quartile range are more robust than mean and std + ints_meta_sub = ints_meta[ints_meta["fg_frac"] >= min_fraction] + ints_agg_median = ints_meta_sub[agg_cols + ["intensity"]].groupby(agg_cols).median() + ints_agg_hq = ( + ints_meta_sub[agg_cols + ["intensity"]].groupby(agg_cols).quantile(0.75) + ) + ints_agg_lq = ( + ints_meta_sub[agg_cols + ["intensity"]].groupby(agg_cols).quantile(0.25) + ) + ints_agg = ints_agg_median + ints_agg.columns = ["zscore_median"] + ints_agg["zscore_iqr"] = ints_agg_hq["intensity"] - ints_agg_lq["intensity"] + ints_agg.reset_index(inplace=True) + + cols_to_merge = frames_meta.columns[ + [col not in ["zscore_median", "zscore_iqr"] for col in frames_meta.columns] + ] + frames_meta = pd.merge( + frames_meta[cols_to_merge], + ints_agg, + how="left", + on=agg_cols, + ) + if frames_meta["zscore_median"].isnull().values.any(): + raise ValueError( + "Found NaN in normalization parameters. \ + min_fraction might be too low or images might be corrupted." + ) + frames_meta_filename = os.path.join(input_dir, "frames_meta.csv") + frames_meta.to_csv(frames_meta_filename, sep=",") + + cols_to_merge = ints_meta.columns[ + [col not in ["zscore_median", "zscore_iqr"] for col in ints_meta.columns] + ] + ints_meta = pd.merge( + ints_meta[cols_to_merge], + ints_agg, + how="left", + on=agg_cols, + ) + ints_meta["intensity_norm"] = ( + ints_meta["intensity"] - ints_meta["zscore_median"] + ) / (ints_meta["zscore_iqr"] + sys.float_info.epsilon) + + return frames_meta, ints_meta diff --git a/viscy/utils/mp_utils.py b/viscy/utils/mp_utils.py new file mode 100644 index 00000000..27a0d8ef --- /dev/null +++ b/viscy/utils/mp_utils.py @@ -0,0 +1,322 @@ +from concurrent.futures import ProcessPoolExecutor + +import iohub.ngff as ngff +import numpy as np +import scipy.stats + +import viscy.utils.image_utils as image_utils +import viscy.utils.masks as mask_utils + + +def mp_wrapper(fn, fn_args, workers): + """Create and save masks with multiprocessing + + :param list of tuple fn_args: list with tuples of function arguments + :param int workers: max number of workers + :return: list of returned dicts from create_save_mask + """ + with ProcessPoolExecutor(workers) as ex: + # can't use map directly as it works only with single arg functions + res = ex.map(fn, *zip(*fn_args)) + return list(res) + + +def mp_create_and_write_mask(fn_args, workers): + """Create and save masks with multiprocessing. For argument parameters + see mp_utils.create_and_write_mask. + + :param list of tuple fn_args: list with tuples of function arguments + :param int workers: max number of workers + :return: list of returned dicts from create_save_mask + """ + with ProcessPoolExecutor(workers) as ex: + # can't use map directly as it works only with single arg functions + res = ex.map(create_and_write_mask, *zip(*fn_args)) + return list(res) + + +def add_channel( + position: ngff.Position, + new_channel_array, + new_channel_name, + overwrite_ok=False, +): + """ + Adds a channels to the data array at position "position". Note that there is + only one 'tracked' data array in current HCS spec at each position. Also + updates the 'omero' channel-tracking metadata to track the new channel. + + The 'new_channel_array' must match the dimensions of the current array in + all positions but the channel position (1) and have the same datatype + + Note: to maintain HCS compatibility of the zarr store, all positions (wells) + must maintain arrays with congruent channels. That is, if you add a channel + to one position of an HCS compatible zarr store, an additional channel must + be added to every position in that store to maintain HCS compatibility. + + :param Position zarr_dir: NGFF position node object + :param np.ndarray new_channel_array: array to add as new channel with matching + dimensions (except channel dim) and dtype + :param str new_channel_name: name of new channel + :param bool overwrite_ok: if true, if a channel with the same name as + 'new_channel_name' is found, will overwrite + """ + assert len(new_channel_array.shape) == len(position.data.shape) - 1, ( + "New channel array must match all dimensions of the position array, " + "except in the inferred channel dimension: " + f"array shape: {position.data.shape}, " + "expected channel shape: " + f"{(position.data.shape[0], ) + position.data.shape[2:]}, " + f"received channel shape: {new_channel_array.shape}" + ) + # determine whether to overwrite or append + if new_channel_name in position.channel_names and overwrite_ok: + new_channel_index = list(position.channel_names).index(new_channel_name) + else: + new_channel_index = len(position.channel_names) + position.append_channel(new_channel_name, resize_arrays=True) + + # replace or append channel + position["0"][:, new_channel_index] = new_channel_array + + +def create_and_write_mask( + position: ngff.Position, + time_indices, + channel_indices, + structure_elem_radius, + mask_type, + mask_name, + verbose=False, +): + # TODO: rewrite docstring + """ + Create mask *for all depth slices* at each time and channel index specified + in this position, and save them both as an additional channel in the data array + of the given zarr store and a separate 'untracked' array with specified name. + If output_channel_index is specified as an existing channel index, will overwrite + this channel instead. + + Saves custom metadata related to the mask creation in the well-level + .zattrs in the 'mask' field. + + When >1 channel are used to generate the mask, mask of each channel is + generated then added together. Foreground fraction is calculated on + a timepoint-position basis. That is, it will be recorded as an average + foreground fraction over all slices in any given timepoint. + + + :param str zarr_dir: directory to HCS compatible zarr store for usage + :param str position_path: path within store to position to generate masks for + :param list time_indices: list of time indices for mask generation, + if an index is skipped over, will populate with + zeros + :param list channel_indices: list of channel indices for mask generation, + if more than 1 channel specified, masks from all + channels are aggregated + :param int structure_elem_radius: size of structuring element used for binary + opening. str_elem: disk or ball + :param str mask_type: thresholding type used for masking or str to map to + masking function + :param str mask_name: name under which to save untracked copy of mask in + position + :param bool verbose: whether this process should send updates to stdout + """ + + shape = position.data.shape + position_masks_shape = tuple([shape[0], len(channel_indices), *shape[2:]]) + + # calculate masks over every time index and channel slice + position_masks = np.zeros(position_masks_shape) + position_foreground_fractions = {} + + for time_index in range(shape[0]): + timepoint_foreground_fraction = {} + + for channel_index in channel_indices: + channel_name = position.channel_names[channel_index] + mask_array_chan_idx = channel_indices.index(channel_index) + + if "mask" in channel_name: + print("\n") + if mask_type in channel_name: + print(f"Found existing channel: '{channel_name}'.") + print("You are likely creating duplicates, which is bad practice.") + print(f"Skipping mask channel '{channel_name}' for thresholding") + else: + # print progress update + if verbose: + time_progress = f"time {time_index+1}/{shape[0]}" + channel_progress = f"chan {channel_index}/{channel_indices}" + position_progress = f"pos {position.zgroup.name}" + p = ( + f"Computing masks slice [{position_progress}, {time_progress}," + f" {channel_progress}]\n" + ) + print(p) + + # get mask for image slice or populate with zeros + if time_index in time_indices: + mask = get_mask_slice( + position_zarr=position.data, + time_index=time_index, + channel_index=channel_index, + mask_type=mask_type, + structure_elem_radius=structure_elem_radius, + ) + else: + mask = np.zeros(shape[-2:]) + + position_masks[time_index, mask_array_chan_idx] = mask + + # compute & record channel-wise foreground fractions + frame_foreground_fraction = float( + np.mean(position_masks[time_index, mask_array_chan_idx]).item() + ) + timepoint_foreground_fraction[channel_name] = frame_foreground_fraction + position_foreground_fractions[time_index] = timepoint_foreground_fraction + + # combine masks along channels and compute & record combined foreground fraction + position_masks = np.sum(position_masks, axis=1) + position_masks = np.where(position_masks > 0.5, 1, 0) + for time_index in time_indices: + frame_foreground_fraction = float(np.mean(position_masks[time_index]).item()) + timepoint_foreground_fraction["combined_fraction"] = frame_foreground_fraction + + # save masks as additional channel + position_masks = position_masks.astype(position.data.dtype) + new_channel_name = channel_name + "_mask" + add_channel( + position=position, + new_channel_array=position_masks, + new_channel_name=new_channel_name, + overwrite_ok=True, + ) + + +def get_mask_slice( + position_zarr, + time_index, + channel_index, + mask_type, + structure_elem_radius, +): + """ + Given a set of indices, mask type, and structuring element, + pulls an image slice from the given zarr array, computes the + requested mask and returns. + + :param zarr.Array position_zarr: zarr array of the desired position + :param time_index: see name + :param channel_index: see name + :param mask_type: see name, + options are {otsu, unimodal, mem_detection, borders_weight_loss_map} + :param int structure_elem_radius: creation radius for the structuring + element + :return np.ndarray mask: 2d mask for this slice + """ + # read and correct/preprocess slice + im = position_zarr[time_index, channel_index] + im = image_utils.preprocess_image(im, hist_clip_limits=(1, 99)) + # generate mask for slice + if mask_type == "otsu": + mask = mask_utils.create_otsu_mask(im.astype("float32")) + elif mask_type == "unimodal": + mask = mask_utils.create_unimodal_mask( + im.astype("float32"), structure_elem_radius + ) + elif mask_type == "mem_detection": + mask = mask_utils.create_membrane_mask( + im.astype("float32"), + structure_elem_radius, + ) + elif mask_type == "borders_weight_loss_map": + mask = mask_utils.get_unet_border_weight_map(im) + mask = image_utils.im_adjust(mask).astype(position_zarr.dtype) + + return mask + + +def mp_get_val_stats(fn_args, workers): + """ + Computes statistics of numpy arrays with multiprocessing + + :param list of tuple fn_args: list with tuples of function arguments + :param int workers: max number of workers + :return: list of returned df from get_im_stats + """ + with ProcessPoolExecutor(workers) as ex: + # can't use map directly as it works only with single arg functions + res = ex.map(get_val_stats, fn_args) + return list(res) + + +def get_val_stats(sample_values): + """ + Computes the statistics of a numpy array and returns a dictionary + of metadata corresponding to input sample values. + + :param list(float) sample_values: List of sample values at respective + indices + :return dict meta_row: Dict with intensity data for image + """ + + meta_row = { + "mean": float(np.nanmean(sample_values)), + "std": float(np.nanstd(sample_values)), + "median": float(np.nanmedian(sample_values)), + "iqr": float(scipy.stats.iqr(sample_values)), + } + return meta_row + + +def mp_sample_im_pixels(fn_args, workers): + """Read and computes statistics of images with multiprocessing + + :param list of tuple fn_args: list with tuples of function arguments + :param int workers: max number of workers + :return: list of paths and corresponding returned df from get_im_stats + """ + + with ProcessPoolExecutor(workers) as ex: + # can't use map directly as it works only with single arg functions + res = ex.map(sample_im_pixels, *zip(*fn_args)) + return list(map(list, zip(*list(res)))) + + +def sample_im_pixels( + position: ngff.Position, + grid_spacing, + channel, +): + # TODO move out of mp utils into normalization utils + """ + Read and computes statistics of images for each point in a grid. + Grid spacing determines distance in pixels between grid points + for rows and cols. + By default, samples from every time position and every z-depth, and + assumes that the data in the zarr store is stored in [T,C,Z,Y,X] format, + for time, channel, z, y, x. + + :param Position zarr_dir: NGFF position node object + :param int grid_spacing: spacing of sampling grid in x and y + :param int channel: channel to sample from + + :return list meta_rows: Dicts with intensity data for each grid point + """ + image_zarr = position.data + + all_sample_values = [] + all_time_indices = list(range(image_zarr.shape[0])) + all_z_indices = list(range(image_zarr.shape[2])) + + for time_index in all_time_indices: + for z_index in all_z_indices: + image_slice = image_zarr[time_index, channel, z_index, :, :] + _, _, sample_values = image_utils.grid_sample_pixel_values( + image_slice, grid_spacing + ) + all_sample_values.append(sample_values) + sample_values = np.stack(all_sample_values, 0).flatten() + + return position, sample_values diff --git a/viscy/utils/normalize.py b/viscy/utils/normalize.py new file mode 100644 index 00000000..93c11713 --- /dev/null +++ b/viscy/utils/normalize.py @@ -0,0 +1,87 @@ +"""Image normalization related functions""" +import sys + +import numpy as np +from skimage.exposure import equalize_adapthist + + +def zscore(input_image, im_mean=None, im_std=None): + """ + Performs z-score normalization. Adds epsilon in denominator for robustness + + :param np.array input_image: input image for intensity normalization + :param float/None im_mean: Image mean + :param float/None im_std: Image std + :return np.array norm_img: z score normalized image + """ + if not im_mean: + im_mean = np.nanmean(input_image) + if not im_std: + im_std = np.nanstd(input_image) + norm_img = (input_image - im_mean) / (im_std + sys.float_info.epsilon) + return norm_img + + +def unzscore(im_norm, zscore_median, zscore_iqr): + """ + Revert z-score normalization applied during preprocessing. Necessary + before computing SSIM + + :param im_norm: Normalized image for un-zscore + :param zscore_median: Image median + :param zscore_iqr: Image interquartile range + :return im: image at its original scale + """ + im = im_norm * (zscore_iqr + sys.float_info.epsilon) + zscore_median + return im + + +def hist_clipping(input_image, min_percentile=2, max_percentile=98): + """Clips and rescales histogram from min to max intensity percentiles + + rescale_intensity with input check + + :param np.array input_image: input image for intensity normalization + :param int/float min_percentile: min intensity percentile + :param int/flaot max_percentile: max intensity percentile + :return: np.float, intensity clipped and rescaled image + """ + + assert (min_percentile < max_percentile) and max_percentile <= 100 + pmin, pmax = np.percentile(input_image, (min_percentile, max_percentile)) + hist_clipped_image = np.clip(input_image, pmin, pmax) + return hist_clipped_image + + +def hist_adapteq_2D(input_image, kernel_size=None, clip_limit=None): + """CLAHE on 2D images + + skimage.exposure.equalize_adapthist works only for 2D. Extend to 3D or use + openCV? Not ideal, as it enhances noise in homogeneous areas + + :param np.array input_image: input image for intensity normalization + :param int/list kernel_size: Neighbourhood to be used for histogram + equalization. If none, use default of 1/8th image size. + :param float clip_limit: Clipping limit, normalized between 0 and 1 + (higher values give more contrast, ~ max percent of voxels in any + histogram bin, if > this limit, the voxel intensities are redistributed). + if None, default=0.01 + """ + nrows, ncols = input_image.shape + if kernel_size is not None: + if isinstance(kernel_size, int): + assert kernel_size < min(nrows, ncols) + elif isinstance(kernel_size, (list, tuple)): + assert len(kernel_size) == len(input_image.shape) + else: + raise ValueError("kernel size invalid: not an int / list / tuple") + + if clip_limit is not None: + assert 0 <= clip_limit <= 1, "Clip limit {} is out of range [0, 1]".format( + clip_limit + ) + + adapt_eq_image = equalize_adapthist( + input_image, kernel_size=kernel_size, clip_limit=clip_limit + ) + return adapt_eq_image