Skip to content

Commit

Permalink
initial commit adding resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
edyoshikun committed Aug 8, 2024
1 parent baa4ee3 commit b07f0ad
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 3 deletions.
3 changes: 2 additions & 1 deletion examples/configs/fit_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ data:
batch_size: 32
num_workers: 16
yx_patch_size: [256, 256]
pyramid_resolution: 0
normalizations:
- class_path: viscy.transforms.NormalizeSampled
init_args:
Expand Down Expand Up @@ -87,4 +88,4 @@ data:
sigma_z: [0.25, 1.5]
sigma_y: [0.25, 1.5]
sigma_x: [0.25, 1.5]
caching: false
caching: false
1 change: 1 addition & 0 deletions examples/configs/predict_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,6 @@ predict:
- 256
caching: false
predict_scale_source: null
pyramid_resolution: 0
return_predictions: false
ckpt_path: null
1 change: 1 addition & 0 deletions examples/configs/test_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,6 @@ data:
- 256
caching: false
ground_truth_masks: null
pyramid_resolution: 0
ckpt_path: null
verbose: true
14 changes: 12 additions & 2 deletions viscy/data/hcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ class SlidingWindowDataset(Dataset):
:param ChannelMap channels: source and target channel names,
e.g. ``{'source': 'Phase', 'target': ['Nuclei', 'Membrane']}``
:param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D
:param int pyramid_resolution: pyramid level.
defaults to 0 (full resolution)
:param DictTransform | None transform:
a callable that transforms data, defaults to None
"""
Expand All @@ -113,16 +115,21 @@ def __init__(
positions: list[Position],
channels: ChannelMap,
z_window_size: int,
pyramid_resolution: int = 0,
transform: DictTransform | None = None,
) -> None:
super().__init__()
self.positions = positions
self.channels = {k: _ensure_channel_list(v) for k, v in channels.items()}
self.source_ch_idx = [
positions[0].get_channel_index(c) for c in channels["source"]
positions[pyramid_resolution].get_channel_index(c)
for c in channels["source"]
]
self.target_ch_idx = (
[positions[0].get_channel_index(c) for c in channels["target"]]
[
positions[pyramid_resolution].get_channel_index(c)
for c in channels["target"]
]
if "target" in channels
else None
)
Expand Down Expand Up @@ -301,6 +308,8 @@ class HCSDataModule(LightningDataModule):
:param Path | None ground_truth_masks: path to the ground truth masks,
used in the test stage to compute segmentation metrics,
defaults to None
:param int pyramid_resolution: pyramid resolution level.
defaults to 0 (full resolution)
"""

def __init__(
Expand All @@ -318,6 +327,7 @@ def __init__(
augmentations: list[MapTransform] = [],
caching: bool = False,
ground_truth_masks: Path | None = None,
pyramid_resolution: int = 0,
):
super().__init__()
self.data_path = Path(data_path)
Expand Down

0 comments on commit b07f0ad

Please sign in to comment.