Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vendor pad shape function #189

Merged
merged 3 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions tests/translation/test_predict_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from viscy.translation.predict_writer import _pad_shape


def test_pad_shape():
assert _pad_shape((2, 3), 3) == (1, 2, 3)
assert _pad_shape((4, 5), 4) == (1, 1, 4, 5)
full_shape = tuple(range(1, 6))
assert _pad_shape(full_shape, 5) == full_shape
11 changes: 10 additions & 1 deletion viscy/translation/predict_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np
import torch
from iohub.ngff import ImageArray, Plate, Position, _pad_shape, open_ome_zarr
from iohub.ngff import ImageArray, Plate, Position, open_ome_zarr
from iohub.ngff_meta import TransformationMeta
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import BasePredictionWriter
Expand All @@ -17,6 +17,15 @@
_logger = logging.getLogger("lightning.pytorch")


def _pad_shape(shape: tuple[int, ...], target: int = 5) -> tuple[int, ...]:
"""
Pad shape tuple to a target length.
Vendored from ``iohub.ngff.nodes._pad_shape()``.
"""
pad = target - len(shape)
return (1,) * pad + shape


def _resize_image(image: ImageArray, t_index: int, z_slice: slice) -> None:
"""Resize image array if incoming (1, C, Z, Y, X) stack is not within bounds."""
if image.shape[0] <= t_index or image.shape[2] < z_slice.stop:
Expand Down
Loading