Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a mapping function in transforms.io #7769

Merged
merged 55 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
afae503
Fixes #7557
May 14, 2024
542a77d
Fixes #7557
May 14, 2024
e9f7565
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2024
7969d21
Fixes #7557
May 14, 2024
3ce5f30
Fixes #7557
May 14, 2024
0699eeb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2024
274cd04
fix-issue-7557
Jun 1, 2024
d4fb0b7
fix-issue-7557
Jun 1, 2024
bfb6d58
Fixes #7557
Jun 1, 2024
5ab2521
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2024
894854d
Fixes #7557
Jun 1, 2024
c372225
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2024
682379b
Fixes #7557
Jun 3, 2024
56d8df5
Fixes #7557
Jun 3, 2024
8bab11b
Fixes #7557
Jun 3, 2024
ca48fec
fix-issue-7557
Jun 3, 2024
117dd78
fix-issue-7557
Jun 13, 2024
3908cdd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 13, 2024
36e5af0
fix-issue-7557
Jun 13, 2024
1a3da38
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 13, 2024
37d19ed
fix-issue-7557
Jun 13, 2024
cff2926
fix-issue-7557
Jul 15, 2024
33c078b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 15, 2024
40b3e21
fix-issue-7557
Jul 21, 2024
36047a2
fix-issue-7557
Jul 21, 2024
b385288
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 21, 2024
3937448
fix-issue-7557
Jul 21, 2024
dc8645f
Merge branch 'dev' into fix-issue-7557
staydelight Jul 21, 2024
44307fc
fix-issue-7557
Jul 22, 2024
cdf4a1b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 22, 2024
5dd268e
fix-issue-7557
Jul 25, 2024
401557a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 25, 2024
476e15e
Merge branch 'dev' into fix-issue-7557
staydelight Jul 31, 2024
4adb87d
fix-issue-7557
Aug 6, 2024
f073c64
Merge branch 'dev' into fix-issue-7557
staydelight Aug 6, 2024
0c14f4f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2024
8fafc05
fix-issue-7557
Aug 6, 2024
b238987
fix-issue-7557
Aug 6, 2024
5c75990
fix-issue-7557
Aug 7, 2024
8a5e7f1
Merge branch 'Project-MONAI:dev' into fix-issue-7557
staydelight Aug 19, 2024
f7deb86
fix-issue-7557
Aug 19, 2024
3607b18
fix-issue-7557
Aug 19, 2024
130eaa1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 19, 2024
b1475be
fix-issue-7557
Aug 19, 2024
05a00fc
Merge branch 'Project-MONAI:dev' into fix-issue-7557
staydelight Aug 27, 2024
ca15156
fix-issue-7557
Aug 27, 2024
3dc9f49
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2024
3ea0df2
fix-issue-7557
Aug 27, 2024
8ad9808
fix-issue-7557
Aug 27, 2024
802e554
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2024
60f5b79
fix-issue-7557
Aug 27, 2024
b7957b6
fix-issue-7557
Aug 27, 2024
773a218
fix-issue-7557
Aug 27, 2024
b28b184
fix-issue-7557
Aug 28, 2024
f5eab6c
Merge branch 'dev' into fix-issue-7557
KumoLiu Aug 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,12 @@ IO
:members:
:special-members: __call__

`WriteFileMapping`
""""""""""""""""""
.. autoclass:: WriteFileMapping
:members:
:special-members: __call__


NVIDIA Tool Extension (NVTX)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -1642,6 +1648,12 @@ IO (Dict)
:members:
:special-members: __call__

`WriteFileMappingd`
"""""""""""""""""""
.. autoclass:: WriteFileMappingd
:members:
:special-members: __call__

Post-processing (Dict)
^^^^^^^^^^^^^^^^^^^^^^

Expand Down
14 changes: 12 additions & 2 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,18 @@
)
from .inverse import InvertibleTransform, TraceableTransform
from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict
from .io.array import SUPPORTED_READERS, LoadImage, SaveImage
from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict
from .io.array import SUPPORTED_READERS, LoadImage, SaveImage, WriteFileMapping
from .io.dictionary import (
LoadImaged,
LoadImageD,
LoadImageDict,
SaveImaged,
SaveImageD,
SaveImageDict,
WriteFileMappingd,
WriteFileMappingD,
WriteFileMappingDict,
)
from .lazy.array import ApplyPending
from .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict
from .lazy.functional import apply_pending
Expand Down
60 changes: 58 additions & 2 deletions monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import inspect
import json
import logging
import sys
import traceback
Expand Down Expand Up @@ -45,11 +46,19 @@
from monai.transforms.utility.array import EnsureChannelFirst
from monai.utils import GridSamplePadMode
from monai.utils import ImageMetaKey as Key
from monai.utils import OptionalImportError, convert_to_dst_type, ensure_tuple, look_up_option, optional_import
from monai.utils import (
MetaKeys,
OptionalImportError,
convert_to_dst_type,
ensure_tuple,
look_up_option,
optional_import,
)

nib, _ = optional_import("nibabel")
Image, _ = optional_import("PIL.Image")
nrrd, _ = optional_import("nrrd")
FileLock, has_filelock = optional_import("filelock", name="FileLock")

__all__ = ["LoadImage", "SaveImage", "SUPPORTED_READERS"]

Expand Down Expand Up @@ -505,7 +514,7 @@ def __call__(
else:
self._data_index += 1
if self.savepath_in_metadict and meta_data is not None:
meta_data["saved_to"] = filename
meta_data[MetaKeys.SAVED_TO] = filename
return img
msg = "\n".join([f"{e}" for e in err])
raise RuntimeError(
Expand All @@ -514,3 +523,50 @@ def __call__(
" https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n"
f" The current registered writers for {self.output_ext}: {self.writers}.\n{msg}"
)


class WriteFileMapping(Transform):
"""
Writes a JSON file that logs the mapping between input image paths and their corresponding output paths.
This class uses FileLock to ensure safe writing to the JSON file in a multiprocess environment.

Args:
mapping_file_path (Path or str): Path to the JSON file where the mappings will be saved.
"""

def __init__(self, mapping_file_path: Path | str = "mapping.json"):
self.mapping_file_path = Path(mapping_file_path)

def __call__(self, img: NdarrayOrTensor):
"""
Args:
img: The input image with metadata.
"""
if isinstance(img, MetaTensor):
meta_data = img.meta

if MetaKeys.SAVED_TO not in meta_data:
raise KeyError(
"Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True."
)

input_path = meta_data[Key.FILENAME_OR_OBJ]
output_path = meta_data[MetaKeys.SAVED_TO]
log_data = {"input": input_path, "output": output_path}

if has_filelock:
with FileLock(str(self.mapping_file_path) + ".lock"):
self._write_to_file(log_data)
else:
self._write_to_file(log_data)
return img

def _write_to_file(self, log_data):
try:
with self.mapping_file_path.open("r") as f:
existing_log_data = json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
existing_log_data = []
existing_log_data.append(log_data)
with self.mapping_file_path.open("w") as f:
json.dump(existing_log_data, f, indent=4)
31 changes: 29 additions & 2 deletions monai/transforms/io/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@

from __future__ import annotations

from collections.abc import Hashable, Mapping
from pathlib import Path
from typing import Callable

import numpy as np

import monai
from monai.config import DtypeLike, KeysCollection
from monai.config import DtypeLike, KeysCollection, NdarrayOrTensor
from monai.data import image_writer
from monai.data.image_reader import ImageReader
from monai.transforms.io.array import LoadImage, SaveImage
from monai.transforms.io.array import LoadImage, SaveImage, WriteFileMapping
from monai.transforms.transform import MapTransform, Transform
from monai.utils import GridSamplePadMode, ensure_tuple, ensure_tuple_rep
from monai.utils.enums import PostFix
Expand Down Expand Up @@ -320,5 +321,31 @@ def __call__(self, data):
return d


class WriteFileMappingd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.WriteFileMapping`.

Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
mapping_file_path: Path to the JSON file where the mappings will be saved.
Defaults to "mapping.json".
allow_missing_keys: don't raise exception if key is missing.
"""

def __init__(
self, keys: KeysCollection, mapping_file_path: Path | str = "mapping.json", allow_missing_keys: bool = False
) -> None:
super().__init__(keys, allow_missing_keys)
self.mapping = WriteFileMapping(mapping_file_path)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.mapping(d[key])
return d


LoadImageD = LoadImageDict = LoadImaged
SaveImageD = SaveImageDict = SaveImaged
WriteFileMappingD = WriteFileMappingDict = WriteFileMappingd
1 change: 1 addition & 0 deletions monai/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ class MetaKeys(StrEnum):
SPATIAL_SHAPE = "spatial_shape" # optional key for the length in each spatial dimension
SPACE = "space" # possible values of space type are defined in `SpaceKeys`
ORIGINAL_CHANNEL_DIM = "original_channel_dim" # an integer or float("nan")
SAVED_TO = "saved_to"


class ColorOrder(StrEnum):
Expand Down
117 changes: 117 additions & 0 deletions tests/test_mapping_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import json
import os
import shutil
import tempfile
import unittest

import numpy as np
from parameterized import parameterized

from monai.data import DataLoader, Dataset
from monai.transforms import Compose, LoadImage, SaveImage, WriteFileMapping
from monai.utils import optional_import

nib, has_nib = optional_import("nibabel")


def create_input_file(temp_dir, name):
test_image = np.random.rand(128, 128, 128)
output_ext = ".nii.gz"
input_file = os.path.join(temp_dir, name + output_ext)
nib.save(nib.Nifti1Image(test_image, np.eye(4)), input_file)
return input_file


def create_transform(temp_dir, mapping_file_path, savepath_in_metadict=True):
return Compose(
[
LoadImage(image_only=True),
SaveImage(output_dir=temp_dir, output_ext=".nii.gz", savepath_in_metadict=savepath_in_metadict),
WriteFileMapping(mapping_file_path=mapping_file_path),
]
)


@unittest.skipUnless(has_nib, "nibabel required")
class TestWriteFileMapping(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.mkdtemp()

def tearDown(self):
shutil.rmtree(self.temp_dir)

@parameterized.expand([(True,), (False,)])
def test_mapping_file(self, savepath_in_metadict):
mapping_file_path = os.path.join(self.temp_dir, "mapping.json")
name = "test_image"
input_file = create_input_file(self.temp_dir, name)
output_file = os.path.join(self.temp_dir, name, name + "_trans.nii.gz")

transform = create_transform(self.temp_dir, mapping_file_path, savepath_in_metadict)

if savepath_in_metadict:
transform(input_file)
self.assertTrue(os.path.exists(mapping_file_path))
with open(mapping_file_path) as f:
mapping_data = json.load(f)
self.assertEqual(len(mapping_data), 1)
self.assertEqual(mapping_data[0]["input"], input_file)
self.assertEqual(mapping_data[0]["output"], output_file)
else:
with self.assertRaises(RuntimeError) as cm:
transform(input_file)
cause_exception = cm.exception.__cause__
self.assertIsInstance(cause_exception, KeyError)
self.assertIn(
"Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.",
str(cause_exception),
)

def test_multiprocess_mapping_file(self):
num_images = 50

single_mapping_file = os.path.join(self.temp_dir, "single_mapping.json")
multi_mapping_file = os.path.join(self.temp_dir, "multi_mapping.json")

data = [create_input_file(self.temp_dir, f"test_image_{i}") for i in range(num_images)]

# single process
single_transform = create_transform(self.temp_dir, single_mapping_file)
single_dataset = Dataset(data=data, transform=single_transform)
single_loader = DataLoader(single_dataset, batch_size=1, num_workers=0, shuffle=True)
for _ in single_loader:
pass

# multiple processes
multi_transform = create_transform(self.temp_dir, multi_mapping_file)
multi_dataset = Dataset(data=data, transform=multi_transform)
multi_loader = DataLoader(multi_dataset, batch_size=4, num_workers=3, shuffle=True)
for _ in multi_loader:
pass

with open(single_mapping_file) as f:
single_mapping_data = json.load(f)
with open(multi_mapping_file) as f:
multi_mapping_data = json.load(f)

single_set = {(entry["input"], entry["output"]) for entry in single_mapping_data}
multi_set = {(entry["input"], entry["output"]) for entry in multi_mapping_data}

self.assertEqual(single_set, multi_set)


if __name__ == "__main__":
unittest.main()
Loading
Loading