diff --git a/.deepsource.toml b/.deepsource.toml new file mode 100644 index 0000000000..f090928ffa --- /dev/null +++ b/.deepsource.toml @@ -0,0 +1,27 @@ +version = 1 + +test_patterns = ["tests/**"] + +exclude_patterns = [ + "monai/_version.py", + "versioneer.py" +] + +[[analyzers]] +name = "python" +enabled = true + + [analyzers.meta] + runtime_version = "3.x.x" + +[[analyzers]] +name = "test-coverage" +enabled = true + +[[analyzers]] +name = "docker" +enabled = true + +[[analyzers]] +name = "shell" +enabled = true \ No newline at end of file diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000..549e63bad5 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,12 @@ +# Ignore the following files/folders during docker build + +__pycache__/ +docs/ + +.coverage +.readthedocs.yml +*.md +*.toml + +!README.md + diff --git a/.github/workflows/cleanup.yml b/.github/workflows/cleanup.yml new file mode 100644 index 0000000000..f3d297286e --- /dev/null +++ b/.github/workflows/cleanup.yml @@ -0,0 +1,20 @@ +name: cleanup-workflow + +on: + workflow_run: + workflows: + - "build" + types: ["requested"] + +jobs: + cancel-duplicated-workflow: + name: "Cancel duplicated workflow" + runs-on: ubuntu-latest + steps: + - uses: potiuk/cancel-workflow-runs@953e057dc81d3458935a18d1184c386b0f6b5738 # tested + name: "Cancel duplicate workflows" + with: + cancelMode: allDuplicates + token: ${{ secrets.GITHUB_TOKEN }} + sourceRunId: ${{ github.event.workflow_run.id }} + skipEventTypes: '["schedule"]' diff --git a/.github/workflows/cron.yml b/.github/workflows/cron.yml index fa68be8c68..d9ffdb7f5e 100644 --- a/.github/workflows/cron.yml +++ b/.github/workflows/cron.yml @@ -56,11 +56,12 @@ jobs: cron-docker: if: github.repository == 'Project-MONAI/MONAI' container: - image: docker://projectmonai/monai:latest + image: localhost:5000/local_monai:dockerhub # use currently latest, locally available dockerhub image options: "--gpus all" runs-on: [self-hosted, linux, x64, common] steps: - name: Run tests report coverage + # The docker image process has done the compilation. BUILD_MONAI=1 may not be necessary. run: | cd /opt/monai nvidia-smi diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml index ed5d560861..7656eb4828 100644 --- a/.github/workflows/setupapp.yml +++ b/.github/workflows/setupapp.yml @@ -149,11 +149,17 @@ jobs: ref: master - name: docker_build run: | + # build and run original docker image for local registry docker build -t localhost:5000/local_monai:latest -f Dockerfile . docker push localhost:5000/local_monai:latest + # build once more w/ tag "latest": remove flake package as it is not needed on hub.docker.com sed -i '/flake/d' requirements-dev.txt docker build -t projectmonai/monai:latest -f Dockerfile . - docker login -u projectmonai -p ${{ secrets.DOCKER_PW }} + # also push as tag "dockerhub" to local registry + docker image tag projectmonai/monai:latest localhost:5000/local_monai:dockerhub + docker push localhost:5000/local_monai:dockerhub + # distribute as always w/ tag "latest" to hub.docker.com + echo "${{ secrets.DOCKER_PW }}" | docker login -u projectmonai --password-stdin docker push projectmonai/monai:latest docker logout diff --git a/.github/workflows/weekly-preview.yml b/.github/workflows/weekly-preview.yml index 54e43d6968..bb68a0801d 100644 --- a/.github/workflows/weekly-preview.yml +++ b/.github/workflows/weekly-preview.yml @@ -6,6 +6,7 @@ on: jobs: packaging: + if: github.repository == 'Project-MONAI/MONAI' runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 @@ -24,12 +25,16 @@ jobs: sed -i 's/name\ =\ monai$/name\ =\ monai-weekly/g' setup.cfg echo "__commit_id__ = \"$HEAD_COMMIT_ID\"" >> monai/__init__.py git diff setup.cfg monai/__init__.py - # build tar.gz and wheel git config user.name "CI Builder" - git config user.email "monai.miccai2019@gmail.com" + git config user.email "monai.contact@gmail.com" git add setup.cfg monai/__init__.py git commit -m "Weekly build at $HEAD_COMMIT_ID" - git tag 0.5.dev$(date +'%y%U') + export YEAR_WEEK=$(date +'%y%U') + echo "Year week for tag is ${YEAR_WEEK}" + if ! [[ $YEAR_WEEK =~ ^[0-9]{4}$ ]] ; then echo "Wrong 'year week' format. Should be 4 digits."; exit 1 ; fi + git tag "0.5.dev${YEAR_WEEK}" + git log -1 + git tag --list python setup.py sdist bdist_wheel - name: Publish to PyPI diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000..c6f6fda20a --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,76 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to making participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies both within project spaces and in public spaces +when an individual is representing the project or its community. Examples of +representing a project or community include using an official project e-mail +address, posting via an official social media account, or acting as an appointed +representative at an online or offline event. Representation of a project may be +further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at monai.contact@gmail.com. All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d9b610ee64..01a4773b5a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -32,6 +32,8 @@ _Pull request early_ We encourage you to create pull requests early. It helps us track the contributions under development, whether they are ready to be merged or not. Change your pull request's title to begin with `[WIP]` until it is ready for formal review. +Please note that, as per PyTorch, MONAI uses American English spelling. This means classes and variables should be: normali**z**e, visuali**z**e, colo~~u~~r, etc. + ### Preparing pull requests To ensure the code quality, MONAI relies on several linting tools ([flake8 and its plugins](https://gitlab.com/pycqa/flake8), [black](https://github.com/psf/black), [isort](https://github.com/timothycrosley/isort)), static type analysis tools ([mypy](https://github.com/python/mypy), [pytype](https://github.com/google/pytype)), as well as a set of unit/integration tests. @@ -58,7 +60,7 @@ pip install -U -r requirements-dev.txt # install the latest tools License information: all source code files should start with this paragraph: ``` -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -71,6 +73,13 @@ License information: all source code files should start with this paragraph: ``` +##### Exporting modules + +If you intend for any variables/functions/classes to be available outside of the file with the edited functionality, then: + +- Create or append to the `__all__` variable (in the file in which functionality has been added), and +- Add to the `__init__.py` file. + #### Unit testing MONAI tests are located under `tests/`. diff --git a/Dockerfile b/Dockerfile index 671cfb9377..a600f9de84 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,10 +11,11 @@ ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:20.10-py3 -FROM ${PYTORCH_IMAGE} as base +FROM ${PYTORCH_IMAGE} + +LABEL maintainer="monai.contact@gmail.com" WORKDIR /opt/monai -ENV PATH=/opt/tools:$PATH # install full deps COPY requirements.txt requirements-min.txt requirements-dev.txt /tmp/ @@ -23,13 +24,24 @@ RUN cp /tmp/requirements.txt /tmp/req.bak \ && python -m pip install --no-cache-dir --use-feature=2020-resolver -r /tmp/requirements-dev.txt # compile ext and remove temp files -COPY . . +# TODO: remark for issue [revise the dockerfile #1276](https://github.com/Project-MONAI/MONAI/issues/1276) +# please specify exact files and folders to be copied -- else, basically always, the Docker build process cannot cache +# this or anything below it and always will build from at most here; one file change leads to no caching from here on... + +COPY LICENSE setup.py setup.cfg versioneer.py runtests.sh .gitignore .gitattributes README.md MANIFEST.in ./ +COPY tests ./tests +COPY monai ./monai +COPY .git ./.git RUN BUILD_MONAI=1 FORCE_CUDA=1 python setup.py develop \ && rm -rf build __pycache__ # NGC Client WORKDIR /opt/tools -RUN wget -q https://ngc.nvidia.com/downloads/ngccli_cat_linux.zip && \ +ARG NGC_CLI_URI="https://ngc.nvidia.com/downloads/ngccli_cat_linux.zip" +RUN wget -q ${NGC_CLI_URI} && \ unzip ngccli_cat_linux.zip && chmod u+x ngc && \ + md5sum -c ngc.md5 && \ rm -rf ngccli_cat_linux.zip ngc.md5 +# append /opt/tools to runtime path for NGC CLI to be accessible from all file system locations +ENV PATH=${PATH}:/opt/tools WORKDIR /opt/monai diff --git a/README.md b/README.md index 1741f2c518..f06a2d146f 100644 --- a/README.md +++ b/README.md @@ -29,23 +29,14 @@ Its ambitions are: ## Installation -To install [the current release](https://pypi.org/project/monai/): -```bash -pip install monai -``` -To install from the source code repository: +To install [the current release](https://pypi.org/project/monai/), you can simply run: + ```bash -pip install git+https://github.com/Project-MONAI/MONAI#egg=MONAI +pip install monai ``` -Alternatively, pre-built Docker image is available via [DockerHub](https://hub.docker.com/r/projectmonai/monai): - ```bash - # with docker v19.03+ - docker run --gpus all --rm -ti --ipc=host projectmonai/monai:latest - ``` - -For more details, please refer to [the installation guide](https://docs.monai.io/en/latest/installation.html). +For other installation methods (using the master branch, using Docker, etc.), please refer to [the installation guide](https://docs.monai.io/en/latest/installation.html). ## Getting Started diff --git a/docs/source/conf.py b/docs/source/conf.py index 534193c936..a2f1b3af5c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -22,7 +22,7 @@ # -- Project information ----------------------------------------------------- project = "MONAI" -copyright = "2020 MONAI Consortium" +copyright = "2020 - 2021 MONAI Consortium" author = "MONAI Contributors" # The full version, including alpha/beta/rc tags diff --git a/docs/source/data.rst b/docs/source/data.rst index f6ed71c266..11609964c3 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -63,6 +63,11 @@ Generic Interfaces :members: :special-members: __getitem__ +`ImageDataset` +~~~~~~~~~~~~~~ +.. autoclass:: ImageDataset + :members: + :special-members: __getitem__ Patch-based dataset ------------------- @@ -104,11 +109,6 @@ PILReader Nifti format handling --------------------- -Reading -~~~~~~~ -.. autoclass:: monai.data.NiftiDataset - :members: - Writing Nifti ~~~~~~~~~~~~~ .. autoclass:: monai.data.NiftiSaver diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 2962f725d8..81d28fb4ac 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -16,12 +16,25 @@ Model checkpoint saver .. autoclass:: CheckpointSaver :members: + +Metrics saver +------------- +.. autoclass:: MetricsSaver + :members: + + CSV saver --------- .. autoclass:: ClassificationSaver :members: +Iteration Metric +---------------- +.. autoclass:: IterationMetric + :members: + + Mean Dice metrics handler ------------------------- .. autoclass:: MeanDice diff --git a/docs/source/highlights.md b/docs/source/highlights.md index d8fe5c2ff9..29302bda77 100644 --- a/docs/source/highlights.md +++ b/docs/source/highlights.md @@ -39,7 +39,7 @@ There is a rich set of transforms in six categories: Crop & Pad, Intensity, IO, ### 2. Medical specific transforms MONAI aims at providing a comprehensive medical image specific transformations. These currently include, for example: -- `LoadNifti`: Load Nifti format file from provided path +- `LoadImage`: Load medical specific formats file from provided path - `Spacing`: Resample input image into the specified `pixdim` - `Orientation`: Change the image's orientation into the specified `axcodes` - `RandGaussianNoise`: Perturb image intensities by adding statistical noises diff --git a/docs/source/installation.md b/docs/source/installation.md index e02e38cb8f..cb540b1559 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -1,17 +1,24 @@ # Installation guide +## Table of Contents +1. [From PyPI](#from-pypi) + 1. [Milestone release](#milestone-release) + 2. [Weekly preview release](#weekly-preview-release) +2. [From GitHub](#from-github) + 1. [System-wide](#milestone-release) + 2. [Editable](#weekly-preview-release) +3. [Validating the install](#validating-the-install) +4. [MONAI version string](#monai-version-string) +5. [From DockerHub](#from-dockerhub) +6. [Installing the recommended dependencies](#Installing-the-recommended-dependencies) + +--- + MONAI's core functionality is written in Python 3 (>= 3.6) and only requires [Numpy](https://numpy.org/) and [Pytorch](https://pytorch.org/). The package is currently distributed via Github as the primary source code repository, and the Python package index (PyPI). The pre-built Docker images are made available on DockerHub. -This page provides steps to: -- [Install MONAI from PyPI](#from-pypi) -- [Install MONAI from GitHub](#from-github) -- [Validate the install](#validating-the-install) -- [Understand MONAI version string](#monai-version-string) -- [Run MONAI From DockerHub](#from-dockerhub) - To install optional features such as handling the NIfTI files using [Nibabel](https://nipy.org/nibabel/), or building workflows using [Pytorch Ignite](https://pytorch.org/ignite/), please follow the instructions: diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 3f87f172d5..a6aa4d566d 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -43,6 +43,11 @@ Segmentation Losses .. autoclass:: generalized_wasserstein_dice :members: +`DiceCELoss` +~~~~~~~~~~~~ +.. autoclass:: DiceCELoss + :members: + `FocalLoss` ~~~~~~~~~~~ .. autoclass:: FocalLoss @@ -52,3 +57,21 @@ Segmentation Losses ~~~~~~~~~~~~~ .. autoclass:: TverskyLoss :members: + +Registration Losses +------------------- + +`BendingEnergyLoss` +~~~~~~~~~~~~~~~~~~~ +.. autoclass:: BendingEnergyLoss + :members: + +`LocalNormalizedCrossCorrelationLoss` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: LocalNormalizedCrossCorrelationLoss + :members: + +`GlobalMutualInformationLoss` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: GlobalMutualInformationLoss + :members: \ No newline at end of file diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 0bcfbd4240..32a3faf380 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -37,7 +37,3 @@ Metrics .. autoclass:: SurfaceDistanceMetric :members: - -`Occlusion sensitivity` ------------------------ -.. autofunction:: compute_occlusion_sensitivity \ No newline at end of file diff --git a/docs/source/networks.rst b/docs/source/networks.rst index fc16e8c86e..7c22964835 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -119,7 +119,25 @@ Blocks .. autoclass:: Subpixelupsample .. autoclass:: SubpixelUpSample +`LocalNet DownSample Block` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: LocalNetDownSampleBlock + :members: + +`LocalNet UpSample Block` +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: LocalNetUpSampleBlock + :members: + +`LocalNet Feature Extractor Block` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: LocalNetFeatureExtractorBlock + :members: +`Warp` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: Warp + :members: Layers ------ @@ -183,6 +201,20 @@ Layers ~~~~~~~~~~~~~~~~ .. autoclass:: GaussianFilter :members: + +`BilateralFilter` +~~~~~~~~~~~~~~~~~ +.. autoclass:: BilateralFilter + :members: + +`PHLFilter` +~~~~~~~~~~~~~~~~~ +.. autoclass:: PHLFilter + +`SavitzkyGolayFilter` +~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: SavitzkyGolayFilter + :members: `HilbertTransform` ~~~~~~~~~~~~~~~~~~ @@ -293,6 +325,11 @@ Nets .. autoclass:: VNet :members: +`LocalNet` +~~~~~~~~~~~ +.. autoclass:: LocalNet + :members: + `AutoEncoder` ~~~~~~~~~~~~~ .. autoclass:: AutoEncoder diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index f7e075f376..90d960a6b9 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -186,6 +186,12 @@ Intensity :members: :special-members: __call__ +`SavitzkyGolaySmooth` +""""""""""""""""""""" +.. autoclass:: SavitzkyGolaySmooth + :members: + :special-members: __call__ + `GaussianSmooth` """""""""""""""" .. autoclass:: GaussianSmooth @@ -231,24 +237,6 @@ IO :members: :special-members: __call__ -`LoadNifti` -""""""""""" -.. autoclass:: LoadNifti - :members: - :special-members: __call__ - -`LoadPNG` -""""""""" -.. autoclass:: LoadPNG - :members: - :special-members: __call__ - -`LoadNumpy` -""""""""""" -.. autoclass:: LoadNumpy - :members: - :special-members: __call__ - Post-processing ^^^^^^^^^^^^^^^ @@ -504,6 +492,24 @@ Utility :members: :special-members: __call__ +`ConvertToMultiChannelBasedOnBratsClasses` +"""""""""""""""""""""""""""""""""""""""""" +.. autoclass:: ConvertToMultiChannelBasedOnBratsClasses + :members: + :special-members: __call__ + +`AddExtremePointsChannel` +""""""""""""""""""""""""" +.. autoclass:: AddExtremePointsChannel + :members: + :special-members: __call__ + +`TorchVision` +""""""""""""" +.. autoclass:: TorchVision + :members: + :special-members: __call__ + Dictionary Transforms --------------------- @@ -690,36 +696,12 @@ Instensity (Dict) IO (Dict) ^^^^^^^^^ -`LoadDatad` -""""""""""" -.. autoclass:: LoadDatad - :members: - :special-members: __call__ - `LoadImaged` """""""""""" .. autoclass:: LoadImaged :members: :special-members: __call__ -`LoadNiftid` -"""""""""""" -.. autoclass:: LoadNiftid - :members: - :special-members: __call__ - -`LoadPNGd` -"""""""""" -.. autoclass:: LoadPNGd - :members: - :special-members: __call__ - -`LoadNumpyd` -"""""""""""" -.. autoclass:: LoadNumpyd - :members: - :special-members: __call__ - Post-processing (Dict) ^^^^^^^^^^^^^^^^^^^^^^ @@ -969,6 +951,24 @@ Utility (Dict) :members: :special-members: __call__ +`ConvertToMultiChannelBasedOnBratsClassesd` +""""""""""""""""""""""""""""""""""""""""""" +.. autoclass:: ConvertToMultiChannelBasedOnBratsClassesd + :members: + :special-members: __call__ + +`AddExtremePointsChanneld` +"""""""""""""""""""""""""" +.. autoclass:: AddExtremePointsChanneld + :members: + :special-members: __call__ + +`TorchVisiond` +"""""""""""""" +.. autoclass:: TorchVisiond + :members: + :special-members: __call__ + Transform Adaptors ------------------ .. automodule:: monai.transforms.adaptors diff --git a/docs/source/visualize.rst b/docs/source/visualize.rst index 9668d48114..850fd51770 100644 --- a/docs/source/visualize.rst +++ b/docs/source/visualize.rst @@ -17,4 +17,10 @@ Class activation map -------------------- .. automodule:: monai.visualize.class_activation_maps - :members: \ No newline at end of file + :members: + +Occlusion sensitivity +--------------------- + +.. automodule:: monai.visualize.occlusion_sensitivity + :members: diff --git a/monai/__init__.py b/monai/__init__.py index a6f5c75309..910698ee14 100644 --- a/monai/__init__.py +++ b/monai/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,7 +22,7 @@ __revision_id__ = version_dict.get("full-revisionid", None) del get_versions, version_dict -__copyright__ = "(c) 2020 MONAI Consortium" +__copyright__ = "(c) 2020 - 2021 MONAI Consortium" __basedir__ = os.path.dirname(__file__) diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index 5b142cbb43..59f38cbb6f 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,5 +9,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .datasets import * -from .utils import * +from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset +from .utils import check_hash, download_and_extract, download_url, extractall diff --git a/monai/apps/datasets.py b/monai/apps/datasets.py index 6272b50b4c..d8fd815ce9 100644 --- a/monai/apps/datasets.py +++ b/monai/apps/datasets.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -37,9 +37,7 @@ class MedNISTDataset(Randomizable, CacheDataset): Args: root_dir: target directory to download and load MedNIST dataset. section: expected data section, can be: `training`, `validation` or `test`. - transform: transforms to execute operations on input data. the default transform is `LoadPNGd`, - which can load data into numpy array with [H, W] shape. for further usage, use `AddChanneld` - to convert the shape to [C, H, W, D]. + transform: transforms to execute operations on input data. download: whether to download and extract the MedNIST from resource link, default is False. if expected file already exists, skip downloading even set it to True. user can manually copy `MedNIST.tar.gz` file or `MedNIST` folder to root directory. @@ -85,6 +83,7 @@ def __init__( self.set_random_state(seed=seed) tarfile_name = os.path.join(root_dir, self.compressed_file_name) dataset_dir = os.path.join(root_dir, self.dataset_folder_name) + self.num_class = 0 if download: download_and_extract(self.resource, tarfile_name, root_dir, self.md5) @@ -100,6 +99,10 @@ def __init__( def randomize(self, data: Optional[Any] = None) -> None: self.rann = self.R.random() + def get_num_classes(self) -> int: + """Get number of classes.""" + return self.num_class + def _generate_data_list(self, dataset_dir: str) -> List[Dict]: """ Raises: @@ -107,20 +110,22 @@ def _generate_data_list(self, dataset_dir: str) -> List[Dict]: """ class_names = sorted((x for x in os.listdir(dataset_dir) if os.path.isdir(os.path.join(dataset_dir, x)))) - num_class = len(class_names) + self.num_class = len(class_names) image_files = [ [ os.path.join(dataset_dir, class_names[i], x) for x in os.listdir(os.path.join(dataset_dir, class_names[i])) ] - for i in range(num_class) + for i in range(self.num_class) ] - num_each = [len(image_files[i]) for i in range(num_class)] + num_each = [len(image_files[i]) for i in range(self.num_class)] image_files_list = [] image_class = [] - for i in range(num_class): + class_name = [] + for i in range(self.num_class): image_files_list.extend(image_files[i]) image_class.extend([i] * num_each[i]) + class_name.extend([class_names[i]] * num_each[i]) num_total = len(image_class) data = [] @@ -140,7 +145,7 @@ def _generate_data_list(self, dataset_dir: str) -> List[Dict]: raise ValueError( f'Unsupported section: {self.section}, available options are ["training", "validation", "test"].' ) - data.append({"image": image_files_list[i], "label": image_class[i]}) + data.append({"image": image_files_list[i], "label": image_class[i], "class_name": class_name[i]}) return data @@ -158,8 +163,7 @@ class DecathlonDataset(Randomizable, CacheDataset): "Task03_Liver", "Task04_Hippocampus", "Task05_Prostate", "Task06_Lung", "Task07_Pancreas", "Task08_HepaticVessel", "Task09_Spleen", "Task10_Colon"). section: expected data section, can be: `training`, `validation` or `test`. - transform: transforms to execute operations on input data. the default transform is `LoadNiftid`, - which can load Nifti format data into numpy array with [H, W, D] or [H, W, D, C] shape. + transform: transforms to execute operations on input data. for further usage, use `AddChanneld` or `AsChannelFirstd` to convert the shape to [C, H, W, D]. download: whether to download and extract the Decathlon from resource link, default is False. if expected file already exists, skip downloading even set it to True. @@ -185,7 +189,7 @@ class DecathlonDataset(Randomizable, CacheDataset): transform = Compose( [ - LoadNiftid(keys=["image", "label"]), + LoadImaged(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), ScaleIntensityd(keys="image"), ToTensord(keys=["image", "label"]), @@ -291,10 +295,9 @@ def get_properties(self, keys: Optional[Union[Sequence[str], str]] = None): """ if keys is None: return self._properties - elif self._properties is not None: + if self._properties is not None: return {key: self._properties[key] for key in ensure_tuple(keys)} - else: - return {} + return {} def _generate_data_list(self, dataset_dir: str) -> List[Dict]: section = "training" if self.section in ["training", "validation"] else "test" diff --git a/monai/apps/utils.py b/monai/apps/utils.py index e48dfb63f2..e2970b4a3d 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,15 +10,13 @@ # limitations under the License. import hashlib -import logging import os -import shutil import tarfile import warnings import zipfile from typing import TYPE_CHECKING, Optional from urllib.error import ContentTooShortError, HTTPError, URLError -from urllib.request import Request, urlopen, urlretrieve +from urllib.request import urlretrieve from monai.utils import min_version, optional_import @@ -112,50 +110,10 @@ def download_url(url: str, filepath: str, hash_val: Optional[str] = None, hash_t raise RuntimeError( f"Download of file from {url} to {filepath} failed due to network issue or denied permission." ) - elif url.startswith("https://msd-for-monai.s3-us-west-2.amazonaws.com"): - block_size = 1024 * 1024 - tmp_file_path = filepath + ".part" - first_byte = os.path.getsize(tmp_file_path) if os.path.exists(tmp_file_path) else 0 - file_size = -1 - - try: - file_size = int(urlopen(url).info().get("Content-Length", -1)) - if has_tqdm: - pbar = tqdm( - unit="B", - unit_scale=True, - unit_divisor=1024, - miniters=1, - desc=filepath.split(os.sep)[-1], - total=file_size, - ) - else: - warnings.warn("tqdm is not installed, will not show the downloading progress bar.") - - while first_byte < file_size: - last_byte = first_byte + block_size if first_byte + block_size < file_size else file_size - 1 - - req = Request(url) - req.headers["Range"] = "bytes=%s-%s" % (first_byte, last_byte) - data_chunk = urlopen(req, timeout=10).read() - with open(tmp_file_path, "ab") as f: - f.write(data_chunk) - if has_tqdm: - pbar.update(len(data_chunk)) - first_byte = last_byte + 1 - if has_tqdm: - pbar.close() - except IOError as e: - logging.debug("IO Error - %s" % e) - finally: - if file_size == os.path.getsize(tmp_file_path): - if hash_val and not check_hash(tmp_file_path, hash_val, hash_type): - raise Exception(f"Error validating the file against its {hash_type} hash") - shutil.move(tmp_file_path, filepath) - elif file_size == -1: - raise Exception("Error getting Content-Length from server: %s" % url) else: - os.makedirs(os.path.dirname(filepath), exist_ok=True) + path = os.path.dirname(filepath) + if path: + os.makedirs(path, exist_ok=True) try: if has_tqdm: diff --git a/monai/config/__init__.py b/monai/config/__init__.py index a6e05e7044..251be002f2 100644 --- a/monai/config/__init__.py +++ b/monai/config/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,5 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .deviceconfig import * -from .type_definitions import * +from .deviceconfig import ( + USE_COMPILED, + get_gpu_info, + get_system_info, + print_config, + print_debug_info, + print_gpu_info, + print_system_info, +) +from .type_definitions import IndexSelection, KeysCollection diff --git a/monai/config/deviceconfig.py b/monai/config/deviceconfig.py index c70d495555..9e448a9ac3 100644 --- a/monai/config/deviceconfig.py +++ b/monai/config/deviceconfig.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -38,6 +38,16 @@ psutil, has_psutil = optional_import("psutil") psutil_version = psutil.__version__ if has_psutil else "NOT INSTALLED or UNKNOWN VERSION." +__all__ = [ + "print_config", + "get_system_info", + "print_system_info", + "get_gpu_info", + "print_gpu_info", + "print_debug_info", + "USE_COMPILED", +] + def get_config_values(): """ @@ -113,11 +123,11 @@ def get_system_info() -> OrderedDict: """ output: OrderedDict = OrderedDict() - _dict_append(output, "System", lambda: platform.system()) + _dict_append(output, "System", platform.system) if output["System"] == "Windows": - _dict_append(output, "Win32 version", lambda: platform.win32_ver()) + _dict_append(output, "Win32 version", platform.win32_ver) if hasattr(platform, "win32_edition"): - _dict_append(output, "Win32 edition", lambda: platform.win32_edition()) # type:ignore[attr-defined] + _dict_append(output, "Win32 edition", platform.win32_edition) # type:ignore[attr-defined] elif output["System"] == "Darwin": _dict_append(output, "Mac version", lambda: platform.mac_ver()[0]) else: @@ -125,19 +135,19 @@ def get_system_info() -> OrderedDict: if linux_ver: _dict_append(output, "Linux version", lambda: linux_ver.group(1)) - _dict_append(output, "Platform", lambda: platform.platform()) - _dict_append(output, "Processor", lambda: platform.processor()) - _dict_append(output, "Machine", lambda: platform.machine()) - _dict_append(output, "Python version", lambda: platform.python_version()) + _dict_append(output, "Platform", platform.platform) + _dict_append(output, "Processor", platform.processor) + _dict_append(output, "Machine", platform.machine) + _dict_append(output, "Python version", platform.python_version) if not has_psutil: _dict_append(output, "`psutil` missing", lambda: "run `pip install monai[psutil]`") else: p = psutil.Process() with p.oneshot(): - _dict_append(output, "Process name", lambda: p.name()) - _dict_append(output, "Command", lambda: p.cmdline()) - _dict_append(output, "Open files", lambda: p.open_files()) + _dict_append(output, "Process name", p.name) + _dict_append(output, "Command", p.cmdline) + _dict_append(output, "Open files", p.open_files) _dict_append(output, "Num physical CPUs", lambda: psutil.cpu_count(logical=False)) _dict_append(output, "Num logical CPUs", lambda: psutil.cpu_count(logical=True)) _dict_append(output, "Num usable CPUs", lambda: len(psutil.Process().cpu_affinity())) @@ -186,27 +196,34 @@ def get_gpu_info() -> OrderedDict: _dict_append(output, "Num GPUs", lambda: num_gpus) _dict_append(output, "Has CUDA", lambda: bool(torch.cuda.is_available())) + if output["Has CUDA"]: _dict_append(output, "CUDA version", lambda: torch.version.cuda) cudnn_ver = torch.backends.cudnn.version() _dict_append(output, "cuDNN enabled", lambda: bool(cudnn_ver)) + if cudnn_ver: _dict_append(output, "cuDNN version", lambda: cudnn_ver) if num_gpus > 0: - _dict_append(output, "Current device", lambda: torch.cuda.current_device()) - _dict_append(output, "Library compiled for CUDA architectures", lambda: torch.cuda.get_arch_list()) + _dict_append(output, "Current device", torch.cuda.current_device) + if hasattr(torch.cuda, "get_arch_list"): # get_arch_list is new in torch 1.7.1 + _dict_append(output, "Library compiled for CUDA architectures", torch.cuda.get_arch_list) + for gpu in range(num_gpus): - _dict_append(output, "Info for GPU", gpu) gpu_info = torch.cuda.get_device_properties(gpu) - _dict_append(output, "\tName", lambda: gpu_info.name) - _dict_append(output, "\tIs integrated", lambda: bool(gpu_info.is_integrated)) - _dict_append(output, "\tIs multi GPU board", lambda: bool(gpu_info.is_multi_gpu_board)) - _dict_append(output, "\tMulti processor count", lambda: gpu_info.multi_processor_count) - _dict_append(output, "\tTotal memory (GB)", lambda: round(gpu_info.total_memory / 1024 ** 3, 1)) - _dict_append(output, "\tCached memory (GB)", lambda: round(torch.cuda.memory_reserved(gpu) / 1024 ** 3, 1)) - _dict_append(output, "\tAllocated memory (GB)", lambda: round(torch.cuda.memory_allocated(gpu) / 1024 ** 3, 1)) - _dict_append(output, "\tCUDA capability (maj.min)", lambda: f"{gpu_info.major}.{gpu_info.minor}") + _dict_append(output, f"GPU {gpu} Name", lambda: gpu_info.name) + _dict_append(output, f"GPU {gpu} Is integrated", lambda: bool(gpu_info.is_integrated)) + _dict_append(output, f"GPU {gpu} Is multi GPU board", lambda: bool(gpu_info.is_multi_gpu_board)) + _dict_append(output, f"GPU {gpu} Multi processor count", lambda: gpu_info.multi_processor_count) + _dict_append(output, f"GPU {gpu} Total memory (GB)", lambda: round(gpu_info.total_memory / 1024 ** 3, 1)) + _dict_append( + output, f"GPU {gpu} Cached memory (GB)", lambda: round(torch.cuda.memory_reserved(gpu) / 1024 ** 3, 1) + ) + _dict_append( + output, f"GPU {gpu} Allocated memory (GB)", lambda: round(torch.cuda.memory_allocated(gpu) / 1024 ** 3, 1) + ) + _dict_append(output, f"GPU {gpu} CUDA capability (maj.min)", lambda: f"{gpu_info.major}.{gpu_info.minor}") return output diff --git a/monai/config/type_definitions.py b/monai/config/type_definitions.py index ecf08af107..ea0c72576c 100644 --- a/monai/config/type_definitions.py +++ b/monai/config/type_definitions.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,6 +11,8 @@ from typing import Collection, Hashable, Iterable, Union +__all__ = ["KeysCollection", "IndexSelection"] + """Commonly used concepts This module provides naming and type specifications for commonly used concepts within the MONAI package. The intent is to explicitly identify information diff --git a/monai/csrc/ext.cpp b/monai/csrc/ext.cpp index 5aaa2e70c9..2e0644bc78 100644 --- a/monai/csrc/ext.cpp +++ b/monai/csrc/ext.cpp @@ -1,5 +1,5 @@ /* -Copyright 2020 MONAI Consortium +Copyright 2020 - 2021 MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -12,11 +12,17 @@ limitations under the License. */ #include + +#include "filtering/filtering.h" #include "lltm/lltm.h" #include "resample/pushpull.h" #include "utils/resample_utils.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // filtering + m.def("bilateral_filter", &BilateralFilter, "Bilateral Filter"); + m.def("phl_filter", &PermutohedralFilter, "Permutohedral Filter"); + // lltm m.def("lltm_forward", &lltm_forward, "LLTM forward"); m.def("lltm_backward", &lltm_backward, "LLTM backward"); diff --git a/monai/csrc/filtering/bilateral/bilateral.h b/monai/csrc/filtering/bilateral/bilateral.h new file mode 100644 index 0000000000..1c16373fa9 --- /dev/null +++ b/monai/csrc/filtering/bilateral/bilateral.h @@ -0,0 +1,42 @@ +/* +Copyright 2020 - 2021 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#pragma once + +#include +#include "utils/common_utils.h" + +torch::Tensor BilateralFilterCpu(torch::Tensor input, float spatial_sigma, float color_sigma); +torch::Tensor BilateralFilterPHLCpu(torch::Tensor input, float spatial_sigma, float color_sigma); + +#ifdef WITH_CUDA +torch::Tensor BilateralFilterCuda(torch::Tensor input, float spatial_sigma, float color_sigma); +torch::Tensor BilateralFilterPHLCuda(torch::Tensor input, float spatial_sigma, float color_sigma); +#endif + +torch::Tensor BilateralFilter(torch::Tensor input, float spatial_sigma, float color_sigma, bool usePHL) { + torch::Tensor (*filterFunction)(torch::Tensor, float, float); + +#ifdef WITH_CUDA + if (torch::cuda::is_available() && input.is_cuda()) { + CHECK_CONTIGUOUS_CUDA(input); + filterFunction = usePHL ? &BilateralFilterPHLCuda : &BilateralFilterCuda; + } else { + filterFunction = usePHL ? &BilateralFilterPHLCpu : &BilateralFilterCpu; + } +#else + filterFunction = usePHL ? &BilateralFilterPHLCpu : &BilateralFilterCpu; +#endif + + return filterFunction(input, spatial_sigma, color_sigma); +} diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp b/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp new file mode 100644 index 0000000000..474d24b4fa --- /dev/null +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp @@ -0,0 +1,167 @@ +/* +Copyright 2020 - 2021 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include + +#include "utils/tensor_description.h" + +struct Indexer { + public: + Indexer(int dimensions, int* sizes) { + m_dimensions = dimensions; + m_sizes = sizes; + m_index = new int[dimensions]{0}; + } + + bool operator++(int) { + for (int i = 0; i < m_dimensions; i++) { + m_index[i] += 1; + + if (m_index[i] < m_sizes[i]) { + return true; + } else { + m_index[i] = 0; + } + } + + return false; + } + + int& operator[](int dimensionIndex) { + return m_index[dimensionIndex]; + } + + private: + int m_dimensions; + int* m_sizes; + int* m_index; +}; + +template +void BilateralFilterCpu(torch::Tensor inputTensor, torch::Tensor outputTensor, float spatialSigma, float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(inputTensor); + + // Raw tensor data pointers. + scalar_t* inputTensorData = inputTensor.data_ptr(); + scalar_t* outputTensorData = outputTensor.data_ptr(); + + // Pre-calculate common values + int windowSize = (int)ceil(5.0f * spatialSigma) | 1; // ORing last bit to ensure odd window size + int halfWindowSize = floor(0.5f * windowSize); + scalar_t spatialExpConstant = -1.0f / (2 * spatialSigma * spatialSigma); + scalar_t colorExpConstant = -1.0f / (2 * colorSigma * colorSigma); + + // Kernel sizes. + int* kernelSizes = new int[desc.dimensions]; + + for (int i = 0; i < desc.dimensions; i++) { + kernelSizes[i] = windowSize; + } + + // Pre-calculate gaussian kernel in 1D. + scalar_t* gaussianKernel = new scalar_t[windowSize]; + + for (int i = 0; i < windowSize; i++) { + int distance = i - halfWindowSize; + gaussianKernel[i] = exp(distance * distance * spatialExpConstant); + } + + // Kernel aggregates used to calculate + // the output value. + scalar_t* valueSum = new scalar_t[desc.channelCount]; + scalar_t weightSum = 0; + + // Looping over the batches + for (int b = 0; b < desc.batchCount; b++) { + int batchOffset = b * desc.batchStride; + + // Looping over all dimensions for the home element + Indexer homeIndex = Indexer(desc.dimensions, desc.sizes); + do // while(homeIndex++) + { + // Calculating indexing offset for the home element + int homeOffset = batchOffset; + + for (int i = 0; i < desc.dimensions; i++) { + homeOffset += homeIndex[i] * desc.strides[i]; + } + + // Zero kernel aggregates. + for (int i = 0; i < desc.channelCount; i++) { + valueSum[i] = 0; + } + + weightSum = 0.0f; + + // Looping over all dimensions for the neighbour element + Indexer kernelIndex = Indexer(desc.dimensions, kernelSizes); + do // while(kernelIndex++) + { + // Calculating buffer offset for the neighbour element + // Index is clamped to the border in each dimension. + int neighbourOffset = batchOffset; + + for (int i = 0; i < desc.dimensions; i++) { + int neighbourIndex = homeIndex[i] + kernelIndex[i] - halfWindowSize; + int neighbourIndexClamped = std::min(desc.sizes[i] - 1, std::max(0, neighbourIndex)); + neighbourOffset += neighbourIndexClamped * desc.strides[i]; + } + + // Euclidean color distance. + scalar_t colorDistanceSquared = 0; + + for (int i = 0; i < desc.channelCount; i++) { + scalar_t diff = inputTensorData[homeOffset + i * desc.channelStride] - + inputTensorData[neighbourOffset + i * desc.channelStride]; + colorDistanceSquared += diff * diff; + } + + // Calculating and combining the spatial + // and color weights. + scalar_t spatialWeight = 1; + + for (int i = 0; i < desc.dimensions; i++) { + spatialWeight *= gaussianKernel[kernelIndex[i]]; + } + + scalar_t colorWeight = exp(colorDistanceSquared * colorExpConstant); + scalar_t totalWeight = spatialWeight * colorWeight; + + // Aggregating values. + for (int i = 0; i < desc.channelCount; i++) { + valueSum[i] += inputTensorData[neighbourOffset + i * desc.channelStride] * totalWeight; + } + + weightSum += totalWeight; + } while (kernelIndex++); + + for (int i = 0; i < desc.channelCount; i++) { + outputTensorData[homeOffset + i * desc.channelStride] = valueSum[i] / weightSum; + } + } while (homeIndex++); + } +} + +torch::Tensor BilateralFilterCpu(torch::Tensor inputTensor, float spatialSigma, float colorSigma) { + // Preparing output tensor. + torch::Tensor outputTensor = torch::zeros_like(inputTensor); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputTensor.scalar_type(), "BilateralFilterCpu", ([&] { + BilateralFilterCpu( + inputTensor, outputTensor, spatialSigma, colorSigma); + })); + + return outputTensor; +} \ No newline at end of file diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp b/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp new file mode 100644 index 0000000000..1fb48cb6c9 --- /dev/null +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp @@ -0,0 +1,88 @@ +/* +Copyright 2020 - 2021 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include + +#include "filtering/permutohedral/permutohedral.h" +#include "utils/tensor_description.h" + +template +void BilateralFilterPHLCpu( + torch::Tensor inputTensor, + torch::Tensor outputTensor, + float spatialSigma, + float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(inputTensor); + + int featureChannels = desc.channelCount + desc.dimensions; + + // Preparing memory + scalar_t* inputTensorData = inputTensor.data_ptr(); + scalar_t* outputTensorData = outputTensor.data_ptr(); + scalar_t* data = new scalar_t[desc.channelStride * desc.channelCount]; + scalar_t* features = new scalar_t[desc.channelStride * featureChannels]; + + // Precalculating inverse sigmas + float invSpatialSigma = 1.0f / spatialSigma; + float invColorSigma = 1.0f / colorSigma; + + // Looping over batches + for (int b = 0; b < desc.batchCount; b++) { + int batchOffset = b * desc.batchStride; + + // Creating features (also permuting input data to be channel last. Permutohedral + // implementation should be changed to channel first to avoid this) + for (int i = 0; i < desc.channelStride; i++) { + // Color features (and permutation) + for (int c = 0; c < desc.channelCount; c++) { + features[i * featureChannels + c] = invColorSigma * inputTensorData[batchOffset + i + c * desc.channelStride]; + data[i * desc.channelCount + c] = inputTensorData[batchOffset + i + c * desc.channelStride]; + } + + // Spatial features + int offsetRemanider = i; + + for (int d = 0; d < desc.dimensions; d++) { + int coord = offsetRemanider / desc.strides[d]; + offsetRemanider -= coord * desc.strides[d]; + + features[i * featureChannels + desc.channelCount + d] = invSpatialSigma * coord; + } + } + + // Filtering data with respect to the features. + PermutohedralCPU(data, features, desc.channelCount, featureChannels, desc.channelStride); + + // Writing output tensor. + for (int i = 0; i < desc.channelStride; i++) { + for (int c = 0; c < desc.channelCount; c++) { + outputTensorData[batchOffset + i + c * desc.channelStride] = data[i * desc.channelCount + c]; + } + } + } + + delete[] data; + delete[] features; +} + +// Function to choose template implementation based on dynamic, channels and dimensions +torch::Tensor BilateralFilterPHLCpu(torch::Tensor inputTensor, float spatialSigma, float colorSigma) { + torch::Tensor outputTensor = torch::zeros_like(inputTensor); + + AT_DISPATCH_FLOATING_TYPES(inputTensor.scalar_type(), "BilateralFilterPhlCpu", ([&] { + BilateralFilterPHLCpu(inputTensor, outputTensor, spatialSigma, colorSigma); + })); + + return outputTensor; +} diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu b/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu new file mode 100644 index 0000000000..4477ce5845 --- /dev/null +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu @@ -0,0 +1,259 @@ +/* +Copyright 2020 - 2021 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include + +#include "utils/meta_macros.h" +#include "utils/tensor_description.h" + +__constant__ int cBatchStride; +__constant__ int cColorStride; + +__constant__ int cSizes[3]; +__constant__ int cStrides[3]; + +__constant__ int cKernelSize; +__constant__ float cKernel[256]; + +__constant__ float cColorExponentFactor; + +template +__global__ void BilateralFilterCudaKernel1D(scalar_t* input, scalar_t* output) { + int kernelHalfSize = cKernelSize / 2; + + int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; + int batchOffset = blockIdx.y * cBatchStride; + + if (homeOffset >= cColorStride) + return; + + scalar_t weightSum = 0; + + for (int kernelOffset = 0; kernelOffset < cKernelSize; kernelOffset++) { + int neighbourOffset = max(0, min(homeOffset + (kernelOffset - kernelHalfSize), cSizes[0] - 1)); + scalar_t gaussian = cKernel[kernelOffset]; + + scalar_t distanceSquared = 0; + +#pragma unroll + for (int c = 0; c < C; c++) { + scalar_t a = input[batchOffset + homeOffset + c * cColorStride]; + scalar_t b = input[batchOffset + neighbourOffset + c * cColorStride]; + scalar_t diff = a - b; + distanceSquared += diff * diff; + } + + scalar_t spatialWeight = gaussian; + scalar_t colorWeight = exp(cColorExponentFactor * distanceSquared); + scalar_t totalWeight = spatialWeight * colorWeight; + +#pragma unroll + for (int c = 0; c < C; c++) { + scalar_t a = input[batchOffset + neighbourOffset + c * cColorStride]; + + output[batchOffset + homeOffset + c * cColorStride] += a * totalWeight; + } + + weightSum += totalWeight; + } + +#pragma unroll + for (int c = 0; c < C; c++) { + output[batchOffset + homeOffset + c * cColorStride] /= weightSum; + } +} + +template +__global__ void BilateralFilterCudaKernel2D(scalar_t* input, scalar_t* output) { + int kernelHalfSize = cKernelSize / 2; + + int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; + int batchOffset = blockIdx.y * cBatchStride; + + if (homeOffset >= cColorStride) + return; + + int homeX = homeOffset / cStrides[0]; + int homeY = (homeOffset - homeX * cStrides[0]) / cStrides[1]; + + scalar_t weightSum = 0; + + for (int kernelX = 0; kernelX < cKernelSize; kernelX++) { + int neighbourX = max(0, min(homeX + (kernelX - kernelHalfSize), cSizes[0] - 1)); + scalar_t gaussianX = cKernel[kernelX]; + + for (int kernelY = 0; kernelY < cKernelSize; kernelY++) { + int neighbourY = max(0, min(homeY + (kernelY - kernelHalfSize), cSizes[1] - 1)); + scalar_t gaussianY = cKernel[kernelY]; + + int neighbourOffset = neighbourX * cStrides[0] + neighbourY; + + scalar_t distanceSquared = 0; + +#pragma unroll + for (int c = 0; c < C; c++) { + scalar_t a = input[batchOffset + homeOffset + c * cColorStride]; + scalar_t b = input[batchOffset + neighbourOffset + c * cColorStride]; + scalar_t diff = a - b; + distanceSquared += diff * diff; + } + + scalar_t spatialWeight = gaussianX * gaussianY; + scalar_t colorWeight = exp(cColorExponentFactor * distanceSquared); + scalar_t totalWeight = spatialWeight * colorWeight; + +#pragma unroll + for (int c = 0; c < C; c++) { + scalar_t a = input[batchOffset + neighbourOffset + c * cColorStride]; + + output[batchOffset + homeOffset + c * cColorStride] += a * totalWeight; + } + + weightSum += totalWeight; + } + } + +#pragma unroll + for (int c = 0; c < C; c++) { + output[batchOffset + homeOffset + c * cColorStride] /= weightSum; + } +} + +template +__global__ void BilateralFilterCudaKernel3D(scalar_t* input, scalar_t* output) { + int kernelHalfSize = cKernelSize / 2; + + int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; + int batchOffset = blockIdx.y * cBatchStride; + + if (homeOffset >= cColorStride) + return; + + int homeX = homeOffset / cStrides[0]; + int homeY = (homeOffset - homeX * cStrides[0]) / cStrides[1]; + int homeZ = (homeOffset - homeX * cStrides[0] - homeY * cStrides[1]) / cStrides[2]; + + scalar_t weightSum = 0; + + for (int kernelX = 0; kernelX < cKernelSize; kernelX++) { + int neighbourX = max(0, min(homeX + (kernelX - kernelHalfSize), cSizes[0] - 1)); + scalar_t gaussianX = cKernel[kernelX]; + + for (int kernelY = 0; kernelY < cKernelSize; kernelY++) { + int neighbourY = max(0, min(homeY + (kernelY - kernelHalfSize), cSizes[1] - 1)); + scalar_t gaussianY = cKernel[kernelY]; + + for (int kernelZ = 0; kernelZ < cKernelSize; kernelZ++) { + int neighbourZ = max(0, min(homeZ + (kernelZ - kernelHalfSize), cSizes[2] - 1)); + scalar_t gaussianZ = cKernel[kernelZ]; + + int neighbourOffset = neighbourX * cStrides[0] + neighbourY * cStrides[1] + neighbourZ; + + scalar_t distanceSquared = 0; + +#pragma unroll + for (int c = 0; c < C; c++) { + scalar_t a = input[batchOffset + homeOffset + c * cColorStride]; + scalar_t b = input[batchOffset + neighbourOffset + c * cColorStride]; + scalar_t diff = a - b; + distanceSquared += diff * diff; + } + + scalar_t spatialWeight = gaussianX * gaussianY * gaussianZ; + scalar_t colorWeight = exp(cColorExponentFactor * distanceSquared); + scalar_t totalWeight = spatialWeight * colorWeight; + +#pragma unroll + for (int c = 0; c < C; c++) { + scalar_t a = input[batchOffset + neighbourOffset + c * cColorStride]; + output[batchOffset + homeOffset + c * cColorStride] += a * totalWeight; + } + + weightSum += totalWeight; + } + } + } + +#pragma unroll + for (int c = 0; c < C; c++) { + output[batchOffset + homeOffset + c * cColorStride] /= weightSum; + } +} + +template +void BilateralFilterCuda(torch::Tensor inputTensor, torch::Tensor outputTensor, float spatialSigma, float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(inputTensor); + + // Pre-calculating exponent factors. + float spatialExponentFactor = -1.0f / (2 * spatialSigma * spatialSigma); + float colorExponentFactor = -1.0f / (2 * colorSigma * colorSigma); + + // Pre-calculating gaussian kernel. + int kernelSize = (int)ceil(5.0f * spatialSigma) | 1; // ORing last bit to ensure odd window size + int kernelHalfSize = floor(0.5f * kernelSize); + float* kernel = new float[kernelSize]; + + for (int i = 0; i < kernelSize; i++) { + int distance = i - kernelHalfSize; + kernel[i] = exp(distance * distance * spatialExponentFactor); + } + + // Writing constant memory. + cudaMemcpyToSymbol(cBatchStride, &desc.batchStride, sizeof(int)); + cudaMemcpyToSymbol(cColorStride, &desc.channelStride, sizeof(int)); + cudaMemcpyToSymbol(cSizes, desc.sizes, sizeof(int) * D); + cudaMemcpyToSymbol(cStrides, desc.strides, sizeof(int) * D); + cudaMemcpyToSymbol(cKernelSize, &kernelSize, sizeof(int)); + cudaMemcpyToSymbol(cKernel, kernel, sizeof(float) * kernelSize); + cudaMemcpyToSymbol(cColorExponentFactor, &colorExponentFactor, sizeof(float)); + +#define BLOCK_SIZE 32 + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + inputTensor.scalar_type(), "BilateralFilterCudaKernel", ([&] { + // Dispatch kernel. (Partial template function specialisation not supported at present so using this switch + // instead) + switch (D) { + case (1): + BilateralFilterCudaKernel1D + <<>>( + inputTensor.data_ptr(), outputTensor.data_ptr()); + break; + case (2): + BilateralFilterCudaKernel2D + <<>>( + inputTensor.data_ptr(), outputTensor.data_ptr()); + break; + case (3): + BilateralFilterCudaKernel3D + <<>>( + inputTensor.data_ptr(), outputTensor.data_ptr()); + break; + } + })); + + delete[] kernel; +} + +// Function to choose template implementation based on dynamic, channels and dimensions +torch::Tensor BilateralFilterCuda(torch::Tensor inputTensor, float spatialSigma, float colorSigma) { + torch::Tensor outputTensor = torch::zeros_like(inputTensor); + +#define CASE(c, d) BilateralFilterCuda(inputTensor, outputTensor, spatialSigma, colorSigma); + SWITCH_AB(CASE, 16, 3, inputTensor.size(1), inputTensor.dim() - 2); + + return outputTensor; +} diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu b/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu new file mode 100644 index 0000000000..603ab689cf --- /dev/null +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu @@ -0,0 +1,141 @@ +/* +Copyright 2020 - 2021 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include + +#include "filtering/permutohedral/permutohedral.h" +#include "utils/meta_macros.h" +#include "utils/tensor_description.h" + +__constant__ int cBatchStride; +__constant__ int cChannelStride; +__constant__ int cSpatialStrides[3]; +__constant__ float cInvSpatialSigma; +__constant__ float cInvColorSigma; + +template +__global__ void FeatureCreation(const scalar_t* inputTensor, scalar_t* outputData, scalar_t* outputFeatures) { + int elementIndex = blockIdx.x * blockDim.x + threadIdx.x; + int batchIndex = blockIdx.y; + + if (elementIndex >= cChannelStride) + return; + + int dataBatchOffset = batchIndex * cBatchStride; + int featureBatchOffset = batchIndex * (D + C) * cChannelStride; + +#pragma unroll + for (int i = 0; i < C; i++) { + outputData[dataBatchOffset + elementIndex * C + i] = + inputTensor[dataBatchOffset + elementIndex + i * cChannelStride]; + outputFeatures[featureBatchOffset + elementIndex * (C + D) + i] = + inputTensor[dataBatchOffset + elementIndex + i * cChannelStride] * cInvColorSigma; + } + + int remainder = elementIndex; + +#pragma unroll + for (int i = 0; i < D; i++) { + int coord = remainder / cSpatialStrides[i]; + remainder -= coord * cSpatialStrides[i]; + + outputFeatures[featureBatchOffset + elementIndex * (C + D) + C + i] = coord * cInvSpatialSigma; + } +} + +template +__global__ void WriteOutput(const scalar_t* data, scalar_t* outputTensor) { + int elementIndex = blockIdx.x * blockDim.x + threadIdx.x; + int batchIndex = blockIdx.y; + + if (elementIndex >= cChannelStride) + return; + + int batchOffset = batchIndex * cBatchStride; + +#pragma unroll + for (int i = 0; i < C; i++) { + outputTensor[batchOffset + elementIndex + i * cChannelStride] = data[batchOffset + elementIndex * C + i]; + } +} + +template +void BilateralFilterPHLCuda( + torch::Tensor inputTensor, + torch::Tensor outputTensor, + float spatialSigma, + float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(inputTensor); + + int featureChannelCount = desc.channelCount + desc.dimensions; + + // Pre calculating inverse sigmas. + float invSpatialSigma = 1.0f / spatialSigma; + float invColorSigma = 1.0f / colorSigma; + + // Preparing global memory + scalar_t* inputTensorData = inputTensor.data_ptr(); + scalar_t* outputTensorData = outputTensor.data_ptr(); + + scalar_t* data; + scalar_t* features; + cudaMalloc(&data, desc.batchCount * desc.channelStride * desc.channelCount * sizeof(scalar_t)); + cudaMalloc(&features, desc.batchCount * desc.channelStride * featureChannelCount * sizeof(scalar_t)); + + // Prparing constant memory + cudaMemcpyToSymbol(cBatchStride, &desc.batchStride, sizeof(int)); + cudaMemcpyToSymbol(cChannelStride, &desc.channelStride, sizeof(int)); + cudaMemcpyToSymbol(cSpatialStrides, desc.strides, sizeof(int) * desc.dimensions); + cudaMemcpyToSymbol(cInvSpatialSigma, &invSpatialSigma, sizeof(float)); + cudaMemcpyToSymbol(cInvColorSigma, &invColorSigma, sizeof(float)); + +#define BLOCK_SIZE 32 + + // Creating features + FeatureCreation + <<>>( + inputTensorData, data, features); + + // Filtering data with respect to the features for each sample in batch + for (int batchIndex = 0; batchIndex < desc.batchCount; batchIndex++) { + scalar_t* offsetData = data + batchIndex * desc.batchStride; + scalar_t* offsetFeatures = features + batchIndex * featureChannelCount * desc.channelStride; + + PermutohedralCuda(offsetData, offsetFeatures, desc.channelStride, true); + } + + // Writing output + WriteOutput<<>>( + data, outputTensorData); + + cudaFree(data); + cudaFree(features); +} + +// Function to choose template implementation based on dynamic, channels and dimensions +torch::Tensor BilateralFilterPHLCuda(torch::Tensor inputTensor, float spatialSigma, float colorSigma) { + torch::Tensor outputTensor = torch::zeros_like(inputTensor); + +#define CASE(c, d) \ + AT_DISPATCH_FLOATING_TYPES(inputTensor.scalar_type(), "BilateralFilterCudaPHL", ([&] { \ + BilateralFilterPHLCuda( \ + inputTensor, outputTensor, spatialSigma, colorSigma); \ + })); + + SWITCH_AB(CASE, 16, 3, inputTensor.size(1), inputTensor.dim() - 2); + + return outputTensor; +} diff --git a/monai/csrc/filtering/filtering.h b/monai/csrc/filtering/filtering.h new file mode 100644 index 0000000000..25186b182a --- /dev/null +++ b/monai/csrc/filtering/filtering.h @@ -0,0 +1,17 @@ +/* +Copyright 2020 - 2021 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#pragma once + +#include "bilateral/bilateral.h" +#include "permutohedral/permutohedral.h" \ No newline at end of file diff --git a/monai/csrc/filtering/permutohedral/hash_table.cuh b/monai/csrc/filtering/permutohedral/hash_table.cuh new file mode 100644 index 0000000000..7d9d7eb163 --- /dev/null +++ b/monai/csrc/filtering/permutohedral/hash_table.cuh @@ -0,0 +1,260 @@ +/* +Copyright 2020 - 2021 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include + +//#define USE_ADDITIVE_HASH + +// turn this on if you want to get slighly less memory consumption and slightly longer run times. +//#define LINEAR_D_MEMORY + +#define USE_CUSTOM_MODULO + +__device__ __constant__ signed short* table_keys; +__device__ __constant__ int* table_entries; +__device__ __constant__ unsigned int table_capacity; +__device__ __constant__ signed short* table_zeros; +__device__ __constant__ char* table_rank; + +/*************************************************************/ +/* Fast computation of modulo operator with constant divisor */ +/*************************************************************/ +__device__ __constant__ unsigned int __div_m; +__device__ __constant__ unsigned int __div_l; +__device__ __constant__ unsigned int __div_c; + +#ifdef USE_CUSTOM_MODULO +__device__ inline unsigned int modHash(unsigned int n) { + unsigned int t1 = __umulhi(__div_m, n); + return n - ((t1 + ((n - t1) >> 1)) >> (__div_l - 1)) * __div_c; +} + +#else +#define modHash(n) ((n) % (2 * table_capacity)); +#endif + +/*************************************************************/ +/* End modulo */ +/*************************************************************/ + +__device__ __constant__ static unsigned int hOffset[64]; + +template +static scalar_t* createHashTable(int capacity) { + scalar_t* values; + cudaMalloc(&values, capacity * vd * sizeof(scalar_t)); + cudaMemset(values, 0, capacity * vd * sizeof(scalar_t)); + + int* entries; + cudaMalloc(&entries, capacity * 2 * sizeof(int)); + cudaMemset(entries, -1, capacity * 2 * sizeof(int)); + + cudaMemcpyToSymbol(table_capacity, &capacity, sizeof(int)); + + cudaMemcpyToSymbol(table_entries, &entries, sizeof(int*)); + +#ifdef LINEAR_D_MEMORY + + char* ranks; + cudaMalloc(&ranks, capacity * sizeof(char)); + + signed short* zeros; + cudaMalloc(&zeros, capacity * sizeof(signed short)); + + cudaMemcpyToSymbol(table_rank, &ranks, sizeof(char*)); + cudaMemcpyToSymbol(table_zeros, &zeros, sizeof(char*)); + +#else + + signed short* keys; + cudaMalloc(&keys, capacity * kd * sizeof(signed short)); + cudaMemset(keys, 0, capacity * kd * sizeof(signed short)); + + cudaMemcpyToSymbol(table_keys, &keys, sizeof(unsigned int*)); + +#endif + + return values; +} + +template +static void destroyHashTable() { +#ifndef LINEAR_D_MEMORY + signed short* keys; + cudaMemcpyFromSymbol(&keys, table_keys, sizeof(unsigned int*)); + cudaFree(keys); +#endif + + int* entries; + cudaMemcpyFromSymbol(&entries, table_entries, sizeof(int*)); + cudaFree(entries); +} + +template +__device__ __host__ static unsigned int hash(signed short* key) { + unsigned int k = 0; + for (int i = 0; i < kd; i++) { + k += key[i]; + k = k * 2531011; + } + return k; +} + +template +__device__ __host__ static unsigned int hash(int* key) { + unsigned int k = 0; + for (int i = 0; i < kd; i++) { + k += key[i]; + k = k * 2531011; + } + return k; +} + +template +__device__ static bool matchKey(int idx, signed short* key) { + bool match = true; + int slot = idx / (d + 1), color = idx - slot * (d + 1); + char* rank = table_rank + slot * (d + 1); + signed short* zero = table_zeros + slot * (d + 1); + + for (int i = 0; i < d && match; i++) { + match = (key[i] == zero[i] + color - (rank[i] > d - color ? (d + 1) : 0)); + } + + return match; +} + +template +__device__ static void generateKey(int idx, signed short* key) { + int slot = idx / (d + 1), color = idx - slot * (d + 1); + char* rank = table_rank + slot * (d + 1); + signed short* zero = table_zeros + slot * (d + 1); + + for (int i = 0; i < d; i++) { + key[i] = zero[i] + color - (rank[i] > d - color ? (d + 1) : 0); + } +} + +template +__device__ static int hashTableInsert(unsigned int fh, signed short* key, unsigned int slot) { + int h = modHash(fh); + while (1) { + int* e = &table_entries[h]; + + // If the cell is empty (-1), lock it (-2) + int contents = atomicCAS(e, -1, -2); + + if (contents == -2) { + // If it was locked already, move on to the next cell + } else if (contents == -1) { + // If it was empty, we successfully locked it. Write our key. + +#ifndef LINEAR_D_MEMORY + for (int i = 0; i < kd; i++) { + table_keys[slot * kd + i] = key[i]; + } +#endif + + // Unlock + atomicExch(e, slot); + + return h; + } else { +// The cell is unlocked and has a key in it, check if it matches +#ifdef LINEAR_D_MEMORY + if (matchKey(contents, key)) + return h; +#else + bool match = true; + + for (int i = 0; i < kd && match; i++) { + match = (table_keys[contents * kd + i] == key[i]); + } + + if (match) + return h; +#endif + } + // increment the bucket with wraparound + h++; + + if (h == table_capacity * 2) + h = 0; + } +} + +template +__device__ static int hashTableInsert(signed short* key, unsigned int slot) { + unsigned int myHash = hash(key); + return hashTableInsert(myHash, key, slot); +} + +template +__device__ static int hashTableRetrieveWithHash(unsigned int fh, signed short* key) { + int h = modHash(fh); + while (1) { + int* e = table_entries + h; + + if (*e == -1) + return -1; + +#ifdef LINEAR_D_MEMORY + if (matchKey((*e), key)) + return *e; +#else + bool match = true; + + for (int i = 0; i < kd && match; i++) { + match = (table_keys[(*e) * kd + i] == key[i]); + } + + if (match) + return *e; +#endif + + h++; + + if (h == table_capacity * 2) + h = 0; + } +} + +template +__device__ static int hashTableRetrieve(signed short* key) { + int h = modHash(hash(key)); + while (1) { + int* e = table_entries + h; + + if (*e == -1) + return -1; + +#ifdef LINEAR_D_MEMORY + if (matchKey((*e), key)) + return *e; +#else + bool match = true; + + for (int i = 0; i < kd && match; i++) { + match = (table_keys[(*e) * kd + i] == key[i]); + } + + if (match) + return *e; +#endif + + h++; + + if (h == table_capacity * 2) + h = 0; + } +} \ No newline at end of file diff --git a/monai/csrc/filtering/permutohedral/permutohedral.cpp b/monai/csrc/filtering/permutohedral/permutohedral.cpp new file mode 100644 index 0000000000..5d6916b8f4 --- /dev/null +++ b/monai/csrc/filtering/permutohedral/permutohedral.cpp @@ -0,0 +1,71 @@ +#include "utils/common_utils.h" +#include "utils/meta_macros.h" + +#include "permutohedral.h" + +torch::Tensor PermutohedralFilter(torch::Tensor input, torch::Tensor features) { + input = input.contiguous(); + + int batchCount = input.size(0); + int batchStride = input.stride(0); + int elementCount = input.stride(1); + int channelCount = input.size(1); + int featureCount = features.size(1); + +// movedim not support in torch < 1.7.1 +#if MONAI_TORCH_VERSION >= 10701 + torch::Tensor data = input.clone().movedim(1, -1).contiguous(); + features = features.movedim(1, -1).contiguous(); +#else + torch::Tensor data = input.clone(); + features = features; + + for (int i = 1; i < input.dim() - 1; i++) { + data = data.transpose(i, i + 1); + features = features.transpose(i, i + 1); + } + + data = data.contiguous(); + features = features.contiguous(); +#endif + +#ifdef WITH_CUDA + if (torch::cuda::is_available() && data.is_cuda()) { + CHECK_CONTIGUOUS_CUDA(data); + +#define CASE(dc, fc) \ + AT_DISPATCH_FLOATING_TYPES(data.scalar_type(), "PermutohedralCuda", ([&] { \ + for (int batchIndex = 0; batchIndex < batchCount; batchIndex++) { \ + scalar_t* offsetData = data.data_ptr() + batchIndex * batchStride; \ + scalar_t* offsetFeatures = \ + features.data_ptr() + batchIndex * fc * elementCount; \ + PermutohedralCuda(offsetData, offsetFeatures, elementCount, true); \ + } \ + })); + SWITCH_AB(CASE, 16, 19, channelCount, featureCount); + + } else { +#endif + AT_DISPATCH_FLOATING_TYPES( + data.scalar_type(), "PermutohedralCPU", ([&] { + for (int batchIndex = 0; batchIndex < batchCount; batchIndex++) { + scalar_t* offsetData = data.data_ptr() + batchIndex * batchStride; + scalar_t* offsetFeatures = features.data_ptr() + batchIndex * featureCount * elementCount; + PermutohedralCPU(offsetData, offsetFeatures, channelCount, featureCount, elementCount); + } + })); +#ifdef WITH_CUDA + } +#endif + +// movedim not support in torch < 1.7.1 +#if MONAI_TORCH_VERSION >= 10701 + data = data.movedim(-1, 1); +#else + for (int i = input.dim() - 1; i > 1; i--) { + data = data.transpose(i - 1, i); + } +#endif + + return data; +} diff --git a/monai/csrc/filtering/permutohedral/permutohedral.h b/monai/csrc/filtering/permutohedral/permutohedral.h new file mode 100644 index 0000000000..27b0ff4859 --- /dev/null +++ b/monai/csrc/filtering/permutohedral/permutohedral.h @@ -0,0 +1,24 @@ +/* +Copyright 2020 - 2021 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include + +#pragma once +template +void PermutohedralCPU(scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount); +#ifdef WITH_CUDA +template +void PermutohedralCuda(scalar_t* data, scalar_t* features, int elementCount, bool accurate); +#endif + +torch::Tensor PermutohedralFilter(torch::Tensor input, torch::Tensor features); \ No newline at end of file diff --git a/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp b/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp new file mode 100644 index 0000000000..0876997448 --- /dev/null +++ b/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp @@ -0,0 +1,502 @@ +/* +Copyright 2020 - 2021 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +/* +Adapted from https://github.com/abadams/permutohedral +which has the following license... + +MIT License + +Copyright (c) 2020 Andrew Adams + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +#include +#include + +#include + +using namespace std; + +/***************************************************************/ +/* Hash table implementation for permutohedral lattice + * + * The lattice points are stored sparsely using a hash table. + * The key for each point is its spatial location in the (d+1)- + * dimensional space. + */ +/***************************************************************/ +template +class HashTablePermutohedral { + public: + /* Constructor + * kd_: the dimensionality of the position vectors on the hyperplane. + * vd_: the dimensionality of the value vectors + */ + HashTablePermutohedral(int kd_, int vd_) : kd(kd_), vd(vd_) { + capacity = 1 << 15; + filled = 0; + entries = new Entry[capacity]; + keys = new short[kd * capacity / 2]; + values = new scalar_t[vd * capacity / 2]; + memset(values, 0, sizeof(scalar_t) * vd * capacity / 2); + } + + // Returns the number of vectors stored. + int size() { + return filled; + } + + // Returns a pointer to the keys array. + short* getKeys() { + return keys; + } + + // Returns a pointer to the values array. + scalar_t* getValues() { + return values; + } + + /* Returns the index into the hash table for a given key. + * key: a pointer to the position vector. + * h: hash of the position vector. + * create: a flag specifying whether an entry should be created, + * should an entry with the given key not found. + */ + int lookupOffset(short* key, size_t h, bool create = true) { + // Double hash table size if necessary + if (filled >= (capacity / 2) - 1) { + grow(); + } + + // Find the entry with the given key + while (1) { + Entry e = entries[h]; + // check if the cell is empty + if (e.keyIdx == -1) { + if (!create) + return -1; // Return not found. + // need to create an entry. Store the given key. + for (int i = 0; i < kd; i++) + keys[filled * kd + i] = key[i]; + e.keyIdx = filled * kd; + e.valueIdx = filled * vd; + entries[h] = e; + filled++; + return e.valueIdx; + } + + // check if the cell has a matching key + bool match = true; + for (int i = 0; i < kd && match; i++) + match = keys[e.keyIdx + i] == key[i]; + if (match) + return e.valueIdx; + + // increment the bucket with wraparound + h++; + if (h == capacity) + h = 0; + } + } + + /* Looks up the value vector associated with a given key vector. + * k : pointer to the key vector to be looked up. + * create : true if a non-existing key should be created. + */ + scalar_t* lookup(short* k, bool create = true) { + size_t h = hash(k) % capacity; + int offset = lookupOffset(k, h, create); + if (offset < 0) + return NULL; + else + return values + offset; + }; + + /* Hash function used in this implementation. A simple base conversion. */ + size_t hash(const short* key) { + size_t k = 0; + for (int i = 0; i < kd; i++) { + k += key[i]; + k *= 2531011; + } + return k; + } + + private: + /* Grows the size of the hash table */ + void grow() { + size_t oldCapacity = capacity; + capacity *= 2; + + // Migrate the value vectors. + scalar_t* newValues = new scalar_t[vd * capacity / 2]; + memset(newValues, 0, sizeof(scalar_t) * vd * capacity / 2); + memcpy(newValues, values, sizeof(scalar_t) * vd * filled); + delete[] values; + values = newValues; + + // Migrate the key vectors. + short* newKeys = new short[kd * capacity / 2]; + memcpy(newKeys, keys, sizeof(short) * kd * filled); + delete[] keys; + keys = newKeys; + + Entry* newEntries = new Entry[capacity]; + + // Migrate the table of indices. + for (size_t i = 0; i < oldCapacity; i++) { + if (entries[i].keyIdx == -1) + continue; + size_t h = hash(keys + entries[i].keyIdx) % capacity; + while (newEntries[h].keyIdx != -1) { + h++; + if (h == capacity) + h = 0; + } + newEntries[h] = entries[i]; + } + delete[] entries; + entries = newEntries; + } + + // Private struct for the hash table entries. + struct Entry { + Entry() : keyIdx(-1), valueIdx(-1) {} + int keyIdx; + int valueIdx; + }; + + short* keys; + scalar_t* values; + Entry* entries; + size_t capacity, filled; + int kd, vd; +}; + +/***************************************************************/ +/* The algorithm class that performs the filter + * + * PermutohedralLattice::filter(...) does all the work. + * + */ +/***************************************************************/ +template +class PermutohedralLattice { + public: + /* Filters given image against a reference image. + * im : image to be bilateral-filtered. + * ref : reference image whose edges are to be respected. + */ + static void filter(scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount) { + // Create lattice + PermutohedralLattice lattice(featureChannels, dataChannels + 1, elementCount); + + // Splat into the lattice + scalar_t* col = new scalar_t[dataChannels + 1]; + col[dataChannels] = 1; // homogeneous coordinate + + for (int i = 0, e = 0; e < elementCount; e++) { + for (int c = 0; c < dataChannels; c++, i++) { + col[c] = data[i]; + } + + scalar_t* featureVec = features + e * featureChannels; + lattice.splat(featureVec, col); + } + + // Blur the lattice + lattice.blur(); + + // Slice from the lattice + lattice.beginSlice(); + + for (int i = 0, e = 0; e < elementCount; e++) { + lattice.slice(col); + + scalar_t scale = 1.0f / col[dataChannels]; + for (int c = 0; c < dataChannels; c++, i++) { + data[i] = col[c] * scale; + } + } + } + + /* Constructor + * d_ : dimensionality of key vectors + * vd_ : dimensionality of value vectors + * nData_ : number of points in the input + */ + PermutohedralLattice(int d_, int vd_, int nData_) : d(d_), vd(vd_), nData(nData_), hashTable(d_, vd_) { + // Allocate storage for various arrays + elevated = new scalar_t[d + 1]; + scaleFactor = new scalar_t[d]; + + greedy = new short[d + 1]; + rank = new char[d + 1]; + barycentric = new scalar_t[d + 2]; + replay = new ReplayEntry[nData * (d + 1)]; + nReplay = 0; + canonical = new short[(d + 1) * (d + 1)]; + key = new short[d + 1]; + + // compute the coordinates of the canonical simplex, in which + // the difference between a contained point and the zero + // remainder vertex is always in ascending order. (See pg.4 of paper.) + for (int i = 0; i <= d; i++) { + for (int j = 0; j <= d - i; j++) + canonical[i * (d + 1) + j] = i; + for (int j = d - i + 1; j <= d; j++) + canonical[i * (d + 1) + j] = i - (d + 1); + } + + // Compute parts of the rotation matrix E. (See pg.4-5 of paper.) + for (int i = 0; i < d; i++) { + // the diagonal entries for normalization + scaleFactor[i] = 1.0f / (sqrtf((scalar_t)(i + 1) * (i + 2))); + + /* We presume that the user would like to do a Gaussian blur of standard deviation + * 1 in each dimension (or a total variance of d, summed over dimensions.) + * Because the total variance of the blur performed by this algorithm is not d, + * we must scale the space to offset this. + * + * The total variance of the algorithm is (See pg.6 and 10 of paper): + * [variance of splatting] + [variance of blurring] + [variance of splatting] + * = d(d+1)(d+1)/12 + d(d+1)(d+1)/2 + d(d+1)(d+1)/12 + * = 2d(d+1)(d+1)/3. + * + * So we need to scale the space by (d+1)sqrt(2/3). + */ + scaleFactor[i] *= (d + 1) * sqrtf(2.0 / 3); + } + } + + /* Performs splatting with given position and value vectors */ + void splat(scalar_t* position, scalar_t* value) { + // first rotate position into the (d+1)-dimensional hyperplane + elevated[d] = -d * position[d - 1] * scaleFactor[d - 1]; + for (int i = d - 1; i > 0; i--) + elevated[i] = + (elevated[i + 1] - i * position[i - 1] * scaleFactor[i - 1] + (i + 2) * position[i] * scaleFactor[i]); + elevated[0] = elevated[1] + 2 * position[0] * scaleFactor[0]; + + // prepare to find the closest lattice points + scalar_t scale = 1.0f / (d + 1); + char* myrank = rank; + short* mygreedy = greedy; + + // greedily search for the closest zero-colored lattice point + int sum = 0; + for (int i = 0; i <= d; i++) { + scalar_t v = elevated[i] * scale; + scalar_t up = ceilf(v) * (d + 1); + scalar_t down = floorf(v) * (d + 1); + + if (up - elevated[i] < elevated[i] - down) + mygreedy[i] = (short)up; + else + mygreedy[i] = (short)down; + + sum += mygreedy[i]; + } + sum /= d + 1; + + // rank differential to find the permutation between this simplex and the canonical one. + // (See pg. 3-4 in paper.) + memset(myrank, 0, sizeof(char) * (d + 1)); + for (int i = 0; i < d; i++) + for (int j = i + 1; j <= d; j++) + if (elevated[i] - mygreedy[i] < elevated[j] - mygreedy[j]) + myrank[i]++; + else + myrank[j]++; + + if (sum > 0) { + // sum too large - the point is off the hyperplane. + // need to bring down the ones with the smallest differential + for (int i = 0; i <= d; i++) { + if (myrank[i] >= d + 1 - sum) { + mygreedy[i] -= d + 1; + myrank[i] += sum - (d + 1); + } else + myrank[i] += sum; + } + } else if (sum < 0) { + // sum too small - the point is off the hyperplane + // need to bring up the ones with largest differential + for (int i = 0; i <= d; i++) { + if (myrank[i] < -sum) { + mygreedy[i] += d + 1; + myrank[i] += (d + 1) + sum; + } else + myrank[i] += sum; + } + } + + // Compute barycentric coordinates (See pg.10 of paper.) + memset(barycentric, 0, sizeof(scalar_t) * (d + 2)); + for (int i = 0; i <= d; i++) { + barycentric[d - myrank[i]] += (elevated[i] - mygreedy[i]) * scale; + barycentric[d + 1 - myrank[i]] -= (elevated[i] - mygreedy[i]) * scale; + } + barycentric[0] += 1.0f + barycentric[d + 1]; + + // Splat the value into each vertex of the simplex, with barycentric weights. + for (int remainder = 0; remainder <= d; remainder++) { + // Compute the location of the lattice point explicitly (all but the last coordinate - it's redundant because they + // sum to zero) + for (int i = 0; i < d; i++) + key[i] = mygreedy[i] + canonical[remainder * (d + 1) + myrank[i]]; + + // Retrieve pointer to the value at this vertex. + scalar_t* val = hashTable.lookup(key, true); + + // Accumulate values with barycentric weight. + for (int i = 0; i < vd; i++) + val[i] += barycentric[remainder] * value[i]; + + // Record this interaction to use later when slicing + replay[nReplay].offset = val - hashTable.getValues(); + replay[nReplay].weight = barycentric[remainder]; + nReplay++; + } + } + + // Prepare for slicing + void beginSlice() { + nReplay = 0; + } + + /* Performs slicing out of position vectors. Note that the barycentric weights and the simplex + * containing each position vector were calculated and stored in the splatting step. + * We may reuse this to accelerate the algorithm. (See pg. 6 in paper.) + */ + void slice(scalar_t* col) { + scalar_t* base = hashTable.getValues(); + for (int j = 0; j < vd; j++) + col[j] = 0; + for (int i = 0; i <= d; i++) { + ReplayEntry r = replay[nReplay++]; + for (int j = 0; j < vd; j++) { + col[j] += r.weight * base[r.offset + j]; + } + } + } + + /* Performs a Gaussian blur along each projected axis in the hyperplane. */ + void blur() { + // Prepare arrays + short* neighbor1 = new short[d + 1]; + short* neighbor2 = new short[d + 1]; + scalar_t* newValue = new scalar_t[vd * hashTable.size()]; + scalar_t* oldValue = hashTable.getValues(); + scalar_t* hashTableBase = oldValue; + + scalar_t* zero = new scalar_t[vd]; + for (int k = 0; k < vd; k++) + zero[k] = 0; + + // For each of d+1 axes, + for (int j = 0; j <= d; j++) { + // For each vertex in the lattice, + for (int i = 0; i < hashTable.size(); i++) { // blur point i in dimension j + short* key = hashTable.getKeys() + i * (d); // keys to current vertex + for (int k = 0; k < d; k++) { + neighbor1[k] = key[k] + 1; + neighbor2[k] = key[k] - 1; + } + neighbor1[j] = key[j] - d; + neighbor2[j] = key[j] + d; // keys to the neighbors along the given axis. + + scalar_t* oldVal = oldValue + i * vd; + scalar_t* newVal = newValue + i * vd; + + scalar_t *vm1, *vp1; + + vm1 = hashTable.lookup(neighbor1, false); // look up first neighbor + if (vm1) + vm1 = vm1 - hashTableBase + oldValue; + else + vm1 = zero; + + vp1 = hashTable.lookup(neighbor2, false); // look up second neighbor + if (vp1) + vp1 = vp1 - hashTableBase + oldValue; + else + vp1 = zero; + + // Mix values of the three vertices + for (int k = 0; k < vd; k++) + newVal[k] = (0.25f * vm1[k] + 0.5f * oldVal[k] + 0.25f * vp1[k]); + } + scalar_t* tmp = newValue; + newValue = oldValue; + oldValue = tmp; + // the freshest data is now in oldValue, and newValue is ready to be written over + } + + // depending where we ended up, we may have to copy data + if (oldValue != hashTableBase) { + memcpy(hashTableBase, oldValue, hashTable.size() * vd * sizeof(scalar_t)); + delete oldValue; + } else { + delete newValue; + } + + delete zero; + delete neighbor1; + delete neighbor2; + } + + private: + int d, vd, nData; + scalar_t *elevated, *scaleFactor, *barycentric; + short* canonical; + short* key; + + // slicing is done by replaying splatting (ie storing the sparse matrix) + struct ReplayEntry { + int offset; + scalar_t weight; + } * replay; + int nReplay, nReplaySub; + + public: + char* rank; + short* greedy; + HashTablePermutohedral hashTable; +}; + +template +void PermutohedralCPU(scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount) { + PermutohedralLattice::filter(data, features, dataChannels, featureChannels, elementCount); +} + +template void PermutohedralCPU(float* data, float* features, int dataChannels, int featureChannels, int elementCount); +template void PermutohedralCPU(double* data, double* features, int dataChannels, int featureChannels, int elementCount); \ No newline at end of file diff --git a/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu b/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu new file mode 100644 index 0000000000..b87a88a84f --- /dev/null +++ b/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu @@ -0,0 +1,537 @@ +/* +Copyright 2020 - 2021 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +/* +Adapted from https://github.com/abadams/permutohedral +which has the following license... + +MIT License + +Copyright (c) 2020 Andrew Adams + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +#define BLOCK_SIZE 64 + +#include +#include +#include +#include +#include + +#include "hash_table.cuh" +#include "utils/meta_macros.h" + +template +struct MatrixEntry { + int index; + scalar_t weight; +}; + +template +__global__ static void createMatrix( + const int elementCount, + const scalar_t* positions, + const scalar_t* values, + const scalar_t* scaleFactor, + MatrixEntry* matrix) { + const int threadId = threadIdx.x; + const int idx = threadIdx.x + blockIdx.x * BLOCK_SIZE; + const bool outOfBounds = idx >= elementCount; + + scalar_t myElevated[pd + 1]; + const scalar_t* myPosition = positions + idx * pd; + + int myGreedy[pd + 1]; + int myRank[pd + 1]; + + scalar_t myBarycentric[pd + 2]; + __shared__ short keys[pd * BLOCK_SIZE]; + short* myKey = keys + threadId * pd; + + if (!outOfBounds) { + myElevated[pd] = -pd * myPosition[pd - 1] * scaleFactor[pd - 1]; + + for (int i = pd - 1; i > 0; i--) { + myElevated[i] = + myElevated[i + 1] - i * (myPosition[i - 1]) * scaleFactor[i - 1] + (i + 2) * myPosition[i] * scaleFactor[i]; + } + + myElevated[0] = myElevated[1] + 2 * myPosition[0] * scaleFactor[0]; + + // find the closest zero-colored lattice point + + // greedily search for the closest zero-colored lattice point + signed short sum = 0; + + for (int i = 0; i <= pd; i++) { + scalar_t v = myElevated[i] * (1.0f / (pd + 1)); + scalar_t up = ceilf(v) * (pd + 1); + scalar_t down = floorf(v) * (pd + 1); + + myGreedy[i] = (signed short)(up - myElevated[i] < myElevated[i] - down ? up : down); + sum += myGreedy[i]; + } + + sum /= pd + 1; + + // sort differential to find the permutation between this simplex and the canonical one + for (int i = 0; i <= pd; i++) { + myRank[i] = 0; + + for (int j = 0; j <= pd; j++) { + scalar_t iDiff = myElevated[i] - myGreedy[i]; + scalar_t jDiff = myElevated[j] - myGreedy[j]; + + if (iDiff < jDiff || (iDiff == jDiff && i > j)) { + myRank[i]++; + } + } + } + + if (sum > 0) // sum too large, need to bring down the ones with the smallest differential + { + for (int i = 0; i <= pd; i++) { + if (myRank[i] >= pd + 1 - sum) { + myGreedy[i] -= (pd + 1); + myRank[i] += sum - (pd + 1); + } else { + myRank[i] += sum; + } + } + } else if (sum < 0) // sum too small, need to bring up the ones with largest differential + { + for (int i = 0; i <= pd; i++) { + if (myRank[i] < -sum) { + myGreedy[i] += (pd + 1); + myRank[i] += sum + (pd + 1); + } else { + myRank[i] += sum; + } + } + } + +#ifdef LINEAR_D_MEMORY + for (int i = 0; i <= pd; i++) { + table_zeros[idx * (pd + 1) + i] = myGreedy[i]; + table_rank[idx * (pd + 1) + i] = myRank[i]; + } +#endif + + // turn delta into barycentric coords + for (int i = 0; i <= pd + 1; i++) { + myBarycentric[i] = 0; + } + + for (int i = 0; i <= pd; i++) { + scalar_t delta = (myElevated[i] - myGreedy[i]) * (1.0f / (pd + 1)); + myBarycentric[pd - myRank[i]] += delta; + myBarycentric[pd + 1 - myRank[i]] -= delta; + } + + myBarycentric[0] += 1.0f + myBarycentric[pd + 1]; + } + +#ifdef USE_ADDITIVE_HASH + unsigned int cumulative_hash = hash(myGreedy); +#endif + + for (int color = 0; color <= pd; color++) { + // Compute the location of the lattice point explicitly (all but + // the last coordinate - it's redundant because they sum to zero) + if (!outOfBounds) { + for (int i = 0; i < pd; i++) { + myKey[i] = myGreedy[i] + color; + + if (myRank[i] > pd - color) { + myKey[i] -= (pd + 1); + } + } + } + +#ifdef USE_ADDITIVE_HASH + for (int i = 0; i < pd; i++) { + if (myRank[i] == pd - color) { + cumulative_hash += hOffset[i]; + } + } +#endif + + if (!outOfBounds) { + MatrixEntry r; + +#ifdef USE_ADDITIVE_HASH + r.index = hashTableInsert(cumulative_hash, myKey, idx * (pd + 1) + color); +#else + r.index = hashTableInsert(myKey, idx * (pd + 1) + color); +#endif + + r.weight = myBarycentric[color]; + matrix[idx * (pd + 1) + color] = r; + } + } +} + +template +__global__ static void cleanHashTable(const int elementCount, MatrixEntry* matrix) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + + if (idx >= elementCount) + return; + + // find my hash table entry + int* e = table_entries + idx; + + // Check if I created my own key in the previous phase + if (*e >= 0) { + // Rehash my key and reset the pointer in order to merge with + // any other pixel that created a different entry under the + // same key. If the computation was serial this would never + // happen, but sometimes race conditions can make the same key + // be inserted twice. hashTableRetrieve always returns the + // earlier, so it's no problem as long as we rehash now. + +#ifdef LINEAR_D_MEMORY + // Get my key + short myKey[kd]; + generateKey(*e, myKey); + *e = hashTableRetrieve(myKey); +#else + *e = hashTableRetrieve(table_keys + *e * kd); +#endif + } +} + +template +__global__ static void splat( + const int elementCount, + scalar_t* values, + MatrixEntry* matrix, + scalar_t* table_values) { + const int color = threadIdx.y; + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + + const bool outOfBounds = idx >= elementCount; + + if (outOfBounds) { + return; + } + + scalar_t* myValue = values + idx * vd; + + MatrixEntry r = matrix[idx * (pd + 1) + color]; + + matrix[idx * (pd + 1) + color].index = r.index = table_entries[r.index]; + scalar_t* val = table_values + r.index * (vd + 1); + + for (int j = 0; j < vd; j++) { + gpuAtomicAdd(val + j, myValue[j] * r.weight); + } + + gpuAtomicAdd(val + vd, r.weight); +} + +// splat splits by color, so extend the y coordinate to our blocks to represent that +// dim3 oldblocks((w-1)/8+1, (h-1)/8+1, 1); +// dim3 oldblockSize(8, 8, 1); +// oldblocks.y *= pd+1; +// splatCache<<>>(w, h, values, matrix); + +// int blockCount = (elementCount + 1) / BLOCK_SIZE + 1; +// int blockSize = BLOCK_SIZE; + +// splatCache<<>>(elementCount, values, matrix); + +template +__global__ static void splatCache( + const int elementCount, + scalar_t* values, + MatrixEntry* matrix, + scalar_t* table_values) { + // const int x = threadIdx.x + blockIdx.x * blockDim.x; + // const int y = threadIdx.y + (blockIdx.y/(pd+1)) * blockDim.y; + + // const int threadId = threadIdx.y*blockDim.x + threadIdx.x; + // const int color = blockIdx.y % (pd+1); + // const int idx = y*w + x; + + const int threadId = threadIdx.x; + const int color = threadIdx.y; + const int idx = threadIdx.x + blockIdx.x * BLOCK_SIZE; + + const bool outOfBounds = idx >= elementCount; + + __shared__ int sharedOffsets[BLOCK_SIZE]; + __shared__ scalar_t sharedValues[BLOCK_SIZE * (vd + 1)]; + + int myOffset = -1; + scalar_t* myValue = sharedValues + threadId * (vd + 1); + + if (!outOfBounds) { + scalar_t* value = values + idx * vd; + + MatrixEntry r = matrix[idx * (pd + 1) + color]; + + // convert the matrix entry from a pointer into the entries array to a pointer into the keys/values array + matrix[idx * (pd + 1) + color].index = r.index = table_entries[r.index]; + // record the offset into the keys/values array in shared space + myOffset = sharedOffsets[threadId] = r.index * (vd + 1); + + for (int j = 0; j < vd; j++) { + myValue[j] = value[j] * r.weight; + } + myValue[vd] = r.weight; + + } else { + sharedOffsets[threadId] = -1; + } + + __syncthreads(); + + // am I the first thread in this block to care about this key? + + if (outOfBounds) + return; + + for (int i = 0; i < BLOCK_SIZE; i++) { + if (i < threadId) { + if (myOffset == sharedOffsets[i]) { + // somebody else with higher priority cares about this key + return; + } + } else if (i > threadId) { + if (myOffset == sharedOffsets[i]) { + // someone else with lower priority cares about this key, accumulate it into mine + for (int j = 0; j <= vd; j++) { + sharedValues[threadId * (vd + 1) + j] += sharedValues[i * (vd + 1) + j]; + } + } + } + } + + // only the threads with something to write to main memory are still going + scalar_t* val = table_values + myOffset; + for (int j = 0; j <= vd; j++) { + gpuAtomicAdd(val + j, myValue[j]); + } +} + +template +__global__ static void blur( + int n, + scalar_t* newValues, + MatrixEntry* matrix, + int color, + scalar_t* table_values) { + const int idx = (blockIdx.y * gridDim.x + blockIdx.x) * blockDim.x * blockDim.y + threadIdx.x; + + if (idx >= n) + return; + + // Check if I'm valid + if (matrix[idx].index != idx) + return; + + // find my key and the keys of my neighbours + short myKey[pd + 1]; + short np[pd + 1]; + short nm[pd + 1]; + +#ifdef LINEAR_D_MEMORY + generateKey(idx, myKey); + for (int i = 0; i < pd; i++) { + np[i] = myKey[i] + 1; + nm[i] = myKey[i] - 1; + } +#else + for (int i = 0; i < pd; i++) { + myKey[i] = table_keys[idx * pd + i]; + np[i] = myKey[i] + 1; + nm[i] = myKey[i] - 1; + } +#endif + + np[color] -= pd + 1; + nm[color] += pd + 1; + +#ifdef USE_ADDITIVE_HASH + unsigned int hCurrent = hash(myKey); + int offNp = hashTableRetrieveWithHash(hCurrent + hOffset[color], np); + int offNm = hashTableRetrieveWithHash(hCurrent - hOffset[color], nm); +#else + int offNp = hashTableRetrieve(np); + int offNm = hashTableRetrieve(nm); +#endif + + scalar_t* valMe = table_values + (vd + 1) * idx; + scalar_t* valNp = table_values + (vd + 1) * offNp; + scalar_t* valNm = table_values + (vd + 1) * offNm; + scalar_t* valOut = newValues + (vd + 1) * idx; + + if (offNp >= 0 && offNm >= 0) { + for (int i = 0; i <= vd; i++) { + valOut[i] = (valNp[i] + (valMe[i] * 2) + valNm[i]) / 4; + } + } else if (offNp >= 0) { + for (int i = 0; i <= vd; i++) { + valOut[i] = (valNp[i] + (valMe[i] * 2)) / 4; + } + } else if (offNm >= 0) { + for (int i = 0; i <= vd; i++) { + valOut[i] = (valNm[i] + (valMe[i] * 2)) / 4; + } + } else { + for (int i = 0; i <= vd; i++) { + valOut[i] = valMe[i] * 2; + } + } +} + +template +__global__ static void slice( + const int elementCount, + scalar_t* values, + MatrixEntry* matrix, + scalar_t* table_values) { + const int threadId = threadIdx.x; + const int idx = threadIdx.x + blockIdx.x * BLOCK_SIZE; + const bool outOfBounds = idx >= elementCount; + + if (outOfBounds) + return; + + __shared__ scalar_t localValue[BLOCK_SIZE * vd]; + + scalar_t* myValue = localValue + threadId * vd; + scalar_t myWeight = 0; + + for (int i = 0; i < vd; i++) { + myValue[i] = 0; + } + + for (int i = 0; i <= pd; i++) { + MatrixEntry r = matrix[idx * (pd + 1) + i]; + scalar_t* val = table_values + r.index * (vd + 1); + + for (int j = 0; j < vd; j++) { + myValue[j] += r.weight * val[j]; + } + + myWeight += r.weight * val[vd]; + } + + myWeight = 1.0f / myWeight; + + for (int j = 0; j < vd; j++) { + values[idx * vd + j] = myValue[j] * myWeight; + } +} + +template +void PermutohedralCuda(scalar_t* values, scalar_t* positions, int elementCount, bool accurate) { + scalar_t blurVariance = accurate ? 0.5 : 0; + + scalar_t* scaleFactor; + cudaMalloc(&scaleFactor, pd * sizeof(scalar_t)); + + scalar_t scaleFactorHost[pd]; + for (int i = 0; i < pd; i++) { + scaleFactorHost[i] = (pd + 1) * sqrtf((1.0 / 6 + blurVariance) / ((i + 1) * (i + 2))); + } + + cudaMemcpy(scaleFactor, scaleFactorHost, pd * sizeof(scalar_t), cudaMemcpyHostToDevice); + + MatrixEntry* matrix; + cudaMalloc(&matrix, elementCount * (pd + 1) * sizeof(MatrixEntry)); + + scalar_t* table_values = createHashTable(elementCount * (pd + 1)); + + // Populate constant memory for hash helpers + unsigned long long int __host_two32 = ((unsigned long long int)1) << 32; + unsigned int __host_div_c = 2 * (elementCount * (pd + 1)); + unsigned int __host_div_l = ceilf(logf((float)__host_div_c) / logf(2.0f)); + unsigned int __host_div_m = (__host_two32 << __host_div_l) / __host_div_c - __host_two32 + 1; + cudaMemcpyToSymbol(__div_c, &__host_div_c, sizeof(unsigned int)); + cudaMemcpyToSymbol(__div_l, &__host_div_l, sizeof(unsigned int)); + cudaMemcpyToSymbol(__div_m, &__host_div_m, sizeof(unsigned int)); + + // Populate constant memory with hash of offset vectors + unsigned int hOffset_host[pd + 1]; + signed short offset[pd + 1]; + for (int i = 0; i < pd; offset[i] = 1, i++) + ; + for (int i = 0; i <= pd; i++) { + offset[i] -= pd + 1; + hOffset_host[i] = hash(offset); + offset[i] += pd + 1; + } + cudaMemcpyToSymbol(hOffset, &hOffset_host, sizeof(unsigned int) * (pd + 1)); + + int blockCount = (elementCount + 1) / BLOCK_SIZE + 1; + int blockSize = BLOCK_SIZE; + + createMatrix<<>>(elementCount, positions, values, scaleFactor, matrix); + + // fix duplicate hash table entries + int tableSize = elementCount * 2 * (pd + 1); + int cleanBlockSize = 32; + int cleanBlocks = (tableSize - 1) / cleanBlockSize + 1; + + cleanHashTable<<>>(tableSize, matrix); + + splat<<>>(elementCount, values, matrix, table_values); + + if (accurate) { + scalar_t* newValues; + cudaMalloc(&newValues, elementCount * (pd + 1) * (vd + 1) * sizeof(scalar_t)); + cudaMemset(newValues, 0, elementCount * (pd + 1) * (vd + 1) * sizeof(scalar_t)); + + for (int color = 0; color <= pd; color++) { + blur + <<>>(elementCount * (pd + 1), newValues, matrix, color, table_values); + + scalar_t* swap = newValues; + newValues = table_values; + table_values = swap; + } + + cudaFree(newValues); + } + + slice<<>>(elementCount, values, matrix, table_values); + + destroyHashTable(); + cudaFree(table_values); +} + +#define DECLARATION(dc, fc) \ + template void PermutohedralCuda(float* values, float* positions, int elementCount, bool accurate); \ + template void PermutohedralCuda(double* values, double* positions, int elementCount, bool accurate); +DO_FOR_AB(DECLARATION, 16, 19) diff --git a/monai/csrc/lltm/lltm.h b/monai/csrc/lltm/lltm.h index 29ccf2de77..33e17416f8 100644 --- a/monai/csrc/lltm/lltm.h +++ b/monai/csrc/lltm/lltm.h @@ -1,5 +1,5 @@ /* -Copyright 2020 MONAI Consortium +Copyright 2020 - 2021 MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/lltm/lltm_cpu.cpp b/monai/csrc/lltm/lltm_cpu.cpp index 4eb2ed6aae..295c592d00 100644 --- a/monai/csrc/lltm/lltm_cpu.cpp +++ b/monai/csrc/lltm/lltm_cpu.cpp @@ -1,5 +1,5 @@ /* -Copyright 2020 MONAI Consortium +Copyright 2020 - 2021 MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/lltm/lltm_cuda.cu b/monai/csrc/lltm/lltm_cuda.cu index 2667167bdb..4633348477 100644 --- a/monai/csrc/lltm/lltm_cuda.cu +++ b/monai/csrc/lltm/lltm_cuda.cu @@ -1,5 +1,5 @@ /* -Copyright 2020 MONAI Consortium +Copyright 2020 - 2021 MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/resample/bounds_common.h b/monai/csrc/resample/bounds_common.h index d6e2a43089..4997c7d968 100644 --- a/monai/csrc/resample/bounds_common.h +++ b/monai/csrc/resample/bounds_common.h @@ -1,5 +1,5 @@ /* -Copyright 2020 MONAI Consortium +Copyright 2020 - 2021 MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/resample/interpolation_common.h b/monai/csrc/resample/interpolation_common.h index 86ee00ef18..35899298bf 100644 --- a/monai/csrc/resample/interpolation_common.h +++ b/monai/csrc/resample/interpolation_common.h @@ -1,5 +1,5 @@ /* -Copyright 2020 MONAI Consortium +Copyright 2020 - 2021 MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/resample/pushpull.h b/monai/csrc/resample/pushpull.h index 33af30742c..45fd5ce564 100644 --- a/monai/csrc/resample/pushpull.h +++ b/monai/csrc/resample/pushpull.h @@ -1,5 +1,5 @@ /* -Copyright 2020 MONAI Consortium +Copyright 2020 - 2021 MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/resample/pushpull_cpu.cpp b/monai/csrc/resample/pushpull_cpu.cpp index ff440acb7a..40743a6cf1 100644 --- a/monai/csrc/resample/pushpull_cpu.cpp +++ b/monai/csrc/resample/pushpull_cpu.cpp @@ -1,5 +1,5 @@ /* -Copyright 2020 MONAI Consortium +Copyright 2020 - 2021 MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -97,19 +97,25 @@ MONAI_NAMESPACE_DEVICE { // cpu bool do_sgrad) : dim(dim), bound0(bound.size() > 0 ? bound[0] : BoundType::Replicate), - bound1(bound.size() > 1 ? bound[1] : bound.size() > 0 ? bound[0] : BoundType::Replicate), + bound1( + bound.size() > 1 ? bound[1] + : bound.size() > 0 ? bound[0] + : BoundType::Replicate), bound2( - bound.size() > 2 ? bound[2] - : bound.size() > 1 ? bound[1] : bound.size() > 0 ? bound[0] : BoundType::Replicate), + bound.size() > 2 ? bound[2] + : bound.size() > 1 ? bound[1] + : bound.size() > 0 ? bound[0] + : BoundType::Replicate), interpolation0(interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear), interpolation1( - interpolation.size() > 1 ? interpolation[1] - : interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear), + interpolation.size() > 1 ? interpolation[1] + : interpolation.size() > 0 ? interpolation[0] + : InterpolationType::Linear), interpolation2( - interpolation.size() > 2 - ? interpolation[2] + interpolation.size() > 2 ? interpolation[2] : interpolation.size() > 1 ? interpolation[1] - : interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear), + : interpolation.size() > 0 ? interpolation[0] + : InterpolationType::Linear), extrapolate(extrapolate), do_pull(do_pull), do_push(do_push), @@ -136,13 +142,14 @@ MONAI_NAMESPACE_DEVICE { // cpu bound2(bound), interpolation0(interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear), interpolation1( - interpolation.size() > 1 ? interpolation[1] - : interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear), + interpolation.size() > 1 ? interpolation[1] + : interpolation.size() > 0 ? interpolation[0] + : InterpolationType::Linear), interpolation2( - interpolation.size() > 2 - ? interpolation[2] + interpolation.size() > 2 ? interpolation[2] : interpolation.size() > 1 ? interpolation[1] - : interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear), + : interpolation.size() > 0 ? interpolation[0] + : InterpolationType::Linear), extrapolate(extrapolate), do_pull(do_pull), do_push(do_push), @@ -165,10 +172,15 @@ MONAI_NAMESPACE_DEVICE { // cpu bool do_sgrad) : dim(dim), bound0(bound.size() > 0 ? bound[0] : BoundType::Replicate), - bound1(bound.size() > 1 ? bound[1] : bound.size() > 0 ? bound[0] : BoundType::Replicate), + bound1( + bound.size() > 1 ? bound[1] + : bound.size() > 0 ? bound[0] + : BoundType::Replicate), bound2( - bound.size() > 2 ? bound[2] - : bound.size() > 1 ? bound[1] : bound.size() > 0 ? bound[0] : BoundType::Replicate), + bound.size() > 2 ? bound[2] + : bound.size() > 1 ? bound[1] + : bound.size() > 0 ? bound[0] + : BoundType::Replicate), interpolation0(interpolation), interpolation1(interpolation), interpolation2(interpolation), diff --git a/monai/csrc/resample/pushpull_cuda.cu b/monai/csrc/resample/pushpull_cuda.cu index ffa7cc35d6..ecfeb562ab 100644 --- a/monai/csrc/resample/pushpull_cuda.cu +++ b/monai/csrc/resample/pushpull_cuda.cu @@ -1,5 +1,5 @@ /* -Copyright 2020 MONAI Consortium +Copyright 2020 - 2021 MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -94,19 +94,25 @@ MONAI_NAMESPACE_DEVICE { // cuda bool do_sgrad) : dim(dim), bound0(bound.size() > 0 ? bound[0] : BoundType::Replicate), - bound1(bound.size() > 1 ? bound[1] : bound.size() > 0 ? bound[0] : BoundType::Replicate), + bound1( + bound.size() > 1 ? bound[1] + : bound.size() > 0 ? bound[0] + : BoundType::Replicate), bound2( - bound.size() > 2 ? bound[2] - : bound.size() > 1 ? bound[1] : bound.size() > 0 ? bound[0] : BoundType::Replicate), + bound.size() > 2 ? bound[2] + : bound.size() > 1 ? bound[1] + : bound.size() > 0 ? bound[0] + : BoundType::Replicate), interpolation0(interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear), interpolation1( - interpolation.size() > 1 ? interpolation[1] - : interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear), + interpolation.size() > 1 ? interpolation[1] + : interpolation.size() > 0 ? interpolation[0] + : InterpolationType::Linear), interpolation2( - interpolation.size() > 2 - ? interpolation[2] + interpolation.size() > 2 ? interpolation[2] : interpolation.size() > 1 ? interpolation[1] - : interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear), + : interpolation.size() > 0 ? interpolation[0] + : InterpolationType::Linear), extrapolate(extrapolate), do_pull(do_pull), do_push(do_push), @@ -133,13 +139,14 @@ MONAI_NAMESPACE_DEVICE { // cuda bound2(bound), interpolation0(interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear), interpolation1( - interpolation.size() > 1 ? interpolation[1] - : interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear), + interpolation.size() > 1 ? interpolation[1] + : interpolation.size() > 0 ? interpolation[0] + : InterpolationType::Linear), interpolation2( - interpolation.size() > 2 - ? interpolation[2] + interpolation.size() > 2 ? interpolation[2] : interpolation.size() > 1 ? interpolation[1] - : interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear), + : interpolation.size() > 0 ? interpolation[0] + : InterpolationType::Linear), extrapolate(extrapolate), do_pull(do_pull), do_push(do_push), @@ -162,10 +169,15 @@ MONAI_NAMESPACE_DEVICE { // cuda bool do_sgrad) : dim(dim), bound0(bound.size() > 0 ? bound[0] : BoundType::Replicate), - bound1(bound.size() > 1 ? bound[1] : bound.size() > 0 ? bound[0] : BoundType::Replicate), + bound1( + bound.size() > 1 ? bound[1] + : bound.size() > 0 ? bound[0] + : BoundType::Replicate), bound2( - bound.size() > 2 ? bound[2] - : bound.size() > 1 ? bound[1] : bound.size() > 0 ? bound[0] : BoundType::Replicate), + bound.size() > 2 ? bound[2] + : bound.size() > 1 ? bound[1] + : bound.size() > 0 ? bound[0] + : BoundType::Replicate), interpolation0(interpolation), interpolation1(interpolation), interpolation2(interpolation), diff --git a/monai/csrc/utils/common_utils.h b/monai/csrc/utils/common_utils.h index af160d52a2..882312acb3 100644 --- a/monai/csrc/utils/common_utils.h +++ b/monai/csrc/utils/common_utils.h @@ -1,5 +1,5 @@ /* -Copyright 2020 MONAI Consortium +Copyright 2020 - 2021 MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/utils/meta_macros.h b/monai/csrc/utils/meta_macros.h new file mode 100644 index 0000000000..980b253bbe --- /dev/null +++ b/monai/csrc/utils/meta_macros.h @@ -0,0 +1,131 @@ +/* +Copyright 2020 - 2021 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#pragma once + +// Helper Macros: for internal use (see below) +#define _DO_1(TARGET) TARGET(1) +#define _DO_2(TARGET) TARGET(2) _DO_1(TARGET) +#define _DO_3(TARGET) TARGET(3) _DO_2(TARGET) +#define _DO_4(TARGET) TARGET(4) _DO_3(TARGET) +#define _DO_5(TARGET) TARGET(5) _DO_4(TARGET) +#define _DO_6(TARGET) TARGET(6) _DO_5(TARGET) +#define _DO_7(TARGET) TARGET(7) _DO_6(TARGET) +#define _DO_8(TARGET) TARGET(8) _DO_7(TARGET) +#define _DO_9(TARGET) TARGET(9) _DO_8(TARGET) +#define _DO_10(TARGET) TARGET(10) _DO_9(TARGET) +#define _DO_11(TARGET) TARGET(11) _DO_10(TARGET) +#define _DO_12(TARGET) TARGET(12) _DO_11(TARGET) +#define _DO_13(TARGET) TARGET(13) _DO_12(TARGET) +#define _DO_14(TARGET) TARGET(14) _DO_13(TARGET) +#define _DO_15(TARGET) TARGET(15) _DO_14(TARGET) +#define _DO_16(TARGET) TARGET(16) _DO_15(TARGET) +#define _DO_17(TARGET) TARGET(17) _DO_16(TARGET) +#define _DO_18(TARGET) TARGET(18) _DO_17(TARGET) +#define _DO_19(TARGET) TARGET(19) _DO_18(TARGET) +#define _DO_20(TARGET) TARGET(20) _DO_19(TARGET) +#define _DO_21(TARGET) TARGET(21) _DO_20(TARGET) +#define _DO_22(TARGET) TARGET(22) _DO_21(TARGET) +#define _DO_23(TARGET) TARGET(23) _DO_22(TARGET) +#define _DO_24(TARGET) TARGET(24) _DO_23(TARGET) +#define _DO_25(TARGET) TARGET(25) _DO_24(TARGET) +#define _DO_26(TARGET) TARGET(26) _DO_25(TARGET) +#define _DO_27(TARGET) TARGET(27) _DO_26(TARGET) +#define _DO_28(TARGET) TARGET(28) _DO_27(TARGET) +#define _DO_29(TARGET) TARGET(29) _DO_28(TARGET) +#define _DO_30(TARGET) TARGET(30) _DO_29(TARGET) +#define _DO_31(TARGET) TARGET(31) _DO_30(TARGET) +#define _DO_32(TARGET) TARGET(32) _DO_31(TARGET) + +#define _DO_A_1(TARGET, A) TARGET(A, 1) +#define _DO_A_2(TARGET, A) TARGET(A, 2) _DO_A_1(TARGET, A) +#define _DO_A_3(TARGET, A) TARGET(A, 3) _DO_A_2(TARGET, A) +#define _DO_A_4(TARGET, A) TARGET(A, 4) _DO_A_3(TARGET, A) +#define _DO_A_5(TARGET, A) TARGET(A, 5) _DO_A_4(TARGET, A) +#define _DO_A_6(TARGET, A) TARGET(A, 6) _DO_A_5(TARGET, A) +#define _DO_A_7(TARGET, A) TARGET(A, 7) _DO_A_6(TARGET, A) +#define _DO_A_8(TARGET, A) TARGET(A, 8) _DO_A_7(TARGET, A) +#define _DO_A_9(TARGET, A) TARGET(A, 9) _DO_A_8(TARGET, A) +#define _DO_A_10(TARGET, A) TARGET(A, 10) _DO_A_9(TARGET, A) +#define _DO_A_11(TARGET, A) TARGET(A, 11) _DO_A_10(TARGET, A) +#define _DO_A_12(TARGET, A) TARGET(A, 12) _DO_A_11(TARGET, A) +#define _DO_A_13(TARGET, A) TARGET(A, 13) _DO_A_12(TARGET, A) +#define _DO_A_14(TARGET, A) TARGET(A, 14) _DO_A_13(TARGET, A) +#define _DO_A_15(TARGET, A) TARGET(A, 15) _DO_A_14(TARGET, A) +#define _DO_A_16(TARGET, A) TARGET(A, 16) _DO_A_15(TARGET, A) +#define _DO_A_17(TARGET, A) TARGET(A, 17) _DO_A_16(TARGET, A) +#define _DO_A_18(TARGET, A) TARGET(A, 18) _DO_A_17(TARGET, A) +#define _DO_A_19(TARGET, A) TARGET(A, 19) _DO_A_18(TARGET, A) +#define _DO_A_20(TARGET, A) TARGET(A, 20) _DO_A_19(TARGET, A) +#define _DO_A_21(TARGET, A) TARGET(A, 21) _DO_A_20(TARGET, A) +#define _DO_A_22(TARGET, A) TARGET(A, 22) _DO_A_21(TARGET, A) +#define _DO_A_23(TARGET, A) TARGET(A, 23) _DO_A_22(TARGET, A) +#define _DO_A_24(TARGET, A) TARGET(A, 24) _DO_A_23(TARGET, A) +#define _DO_A_25(TARGET, A) TARGET(A, 25) _DO_A_24(TARGET, A) +#define _DO_A_26(TARGET, A) TARGET(A, 26) _DO_A_25(TARGET, A) +#define _DO_A_27(TARGET, A) TARGET(A, 27) _DO_A_26(TARGET, A) +#define _DO_A_28(TARGET, A) TARGET(A, 28) _DO_A_27(TARGET, A) +#define _DO_A_29(TARGET, A) TARGET(A, 29) _DO_A_28(TARGET, A) +#define _DO_A_30(TARGET, A) TARGET(A, 30) _DO_A_29(TARGET, A) +#define _DO_A_31(TARGET, A) TARGET(A, 31) _DO_A_30(TARGET, A) +#define _DO_A_32(TARGET, A) TARGET(A, 32) _DO_A_31(TARGET, A) + +#define _DO_1_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 1) +#define _DO_2_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 2) _DO_1_B(TARGET, B_RANGE) +#define _DO_3_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 3) _DO_2_B(TARGET, B_RANGE) +#define _DO_4_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 4) _DO_3_B(TARGET, B_RANGE) +#define _DO_5_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 5) _DO_4_B(TARGET, B_RANGE) +#define _DO_6_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 6) _DO_5_B(TARGET, B_RANGE) +#define _DO_7_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 7) _DO_6_B(TARGET, B_RANGE) +#define _DO_8_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 8) _DO_7_B(TARGET, B_RANGE) +#define _DO_9_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 9) _DO_8_B(TARGET, B_RANGE) +#define _DO_10_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 10) _DO_9_B(TARGET, B_RANGE) +#define _DO_11_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 11) _DO_10_B(TARGET, B_RANGE) +#define _DO_12_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 12) _DO_11_B(TARGET, B_RANGE) +#define _DO_13_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 13) _DO_12_B(TARGET, B_RANGE) +#define _DO_14_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 14) _DO_13_B(TARGET, B_RANGE) +#define _DO_15_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 15) _DO_14_B(TARGET, B_RANGE) +#define _DO_16_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 16) _DO_15_B(TARGET, B_RANGE) +#define _DO_17_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 17) _DO_16_B(TARGET, B_RANGE) +#define _DO_18_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 18) _DO_17_B(TARGET, B_RANGE) +#define _DO_19_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 19) _DO_18_B(TARGET, B_RANGE) +#define _DO_20_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 20) _DO_19_B(TARGET, B_RANGE) +#define _DO_21_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 21) _DO_20_B(TARGET, B_RANGE) +#define _DO_22_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 22) _DO_21_B(TARGET, B_RANGE) +#define _DO_23_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 23) _DO_22_B(TARGET, B_RANGE) +#define _DO_24_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 24) _DO_23_B(TARGET, B_RANGE) +#define _DO_25_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 25) _DO_24_B(TARGET, B_RANGE) +#define _DO_26_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 26) _DO_25_B(TARGET, B_RANGE) +#define _DO_27_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 27) _DO_26_B(TARGET, B_RANGE) +#define _DO_28_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 28) _DO_27_B(TARGET, B_RANGE) +#define _DO_29_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 29) _DO_28_B(TARGET, B_RANGE) +#define _DO_30_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 30) _DO_29_B(TARGET, B_RANGE) +#define _DO_31_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 31) _DO_30_B(TARGET, B_RANGE) +#define _DO_32_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 32) _DO_31_B(TARGET, B_RANGE) + +#define _CASE_A(A) \ + case (A): \ + CASE(A) break; +#define _CASE_AB(A, B) \ + case (A * 100 + B): \ + CASE(A, B) break; + +// Preproccessor For Loops +#define DO_FOR_A(TARGET, A_RANGE) _DO_##A_RANGE(TARGET) +#define DO_FOR_AB(TARGET, A_RANGE, B_RANGE) _DO_##A_RANGE##_B(TARGET, B_RANGE) + +// Preproccessor Switch Statement Generators +#define SWITCH_A(CASE, A_RANGE, A) \ + switch (A) { DO_FOR_A(_CASE_A, A_RANGE) } +#define SWITCH_AB(CALL, A_RANGE, B_RANGE, A, B) \ + switch (A * 100 + B) { DO_FOR_AB(_CASE_AB, A_RANGE, B_RANGE) } diff --git a/monai/csrc/utils/resample_utils.h b/monai/csrc/utils/resample_utils.h index 48fe643292..4735d13ca1 100644 --- a/monai/csrc/utils/resample_utils.h +++ b/monai/csrc/utils/resample_utils.h @@ -1,5 +1,5 @@ /* -Copyright 2020 MONAI Consortium +Copyright 2020 - 2021 MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/utils/tensor_description.h b/monai/csrc/utils/tensor_description.h new file mode 100644 index 0000000000..dadd26c5f5 --- /dev/null +++ b/monai/csrc/utils/tensor_description.h @@ -0,0 +1,52 @@ +/* +Copyright 2020 - 2021 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include + +// Struct to easily cache descriptive information about a tensor. +// This is helpful as regular calls to the size and stride member +// functions of tensors appear to cause memory issues. +struct TensorDescription { + public: + TensorDescription(torch::Tensor tensor) { + batchCount = tensor.size(0); + batchStride = tensor.stride(0); + + channelCount = tensor.size(1); + channelStride = tensor.stride(1); + + dimensions = tensor.dim() - 2; + sizes = new int[dimensions]; + strides = new int[dimensions]; + + for (int i = 0; i < dimensions; i++) { + sizes[i] = tensor.size(i + 2); + strides[i] = tensor.stride(i + 2); + } + } + + ~TensorDescription() { + delete[] sizes; + delete[] strides; + } + + int batchCount; + int batchStride; + + int channelCount; + int channelStride; + + int dimensions; + int* sizes; + int* strides; +}; diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 2e73545b9e..e0db1e17ae 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,14 +22,38 @@ ZipDataset, ) from .decathlon_datalist import load_decathlon_datalist, load_decathlon_properties -from .grid_dataset import * -from .image_reader import * +from .grid_dataset import GridPatchDataset, PatchDataset +from .image_dataset import ImageDataset +from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader from .iterable_dataset import IterableDataset -from .nifti_reader import NiftiDataset from .nifti_saver import NiftiSaver from .nifti_writer import write_nifti from .png_saver import PNGSaver from .png_writer import write_png -from .synthetic import * +from .synthetic import create_test_image_2d, create_test_image_3d from .thread_buffer import ThreadBuffer -from .utils import * +from .utils import ( + DistributedSampler, + compute_importance_map, + compute_shape_offset, + correct_nifti_header_if_necessary, + create_file_basename, + dense_patch_slices, + get_random_patch, + get_valid_patch_size, + is_supported_format, + iter_patch, + iter_patch_slices, + json_hashing, + list_data_collate, + partition_dataset, + partition_dataset_classes, + pickle_hashing, + rectify_header_sform_qform, + select_cross_validation_folds, + set_rnd, + sorted_dict, + to_affine_nd, + worker_init_fn, + zoom_affine, +) diff --git a/monai/data/csv_saver.py b/monai/data/csv_saver.py index 7654bfdeb3..5f5e415055 100644 --- a/monai/data/csv_saver.py +++ b/monai/data/csv_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -37,7 +37,8 @@ def __init__(self, output_dir: str = "./", filename: str = "predictions.csv", ov """ self.output_dir = output_dir self._cache_dict: OrderedDict = OrderedDict() - assert isinstance(filename, str) and filename[-4:] == ".csv", "filename must be a string with CSV format." + if not (isinstance(filename, str) and filename[-4:] == ".csv"): + raise AssertionError("filename must be a string with CSV format.") self._filepath = os.path.join(output_dir, filename) self.overwrite = overwrite self._data_index = 0 @@ -76,7 +77,8 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] self._data_index += 1 if torch.is_tensor(data): data = data.detach().cpu().numpy() - assert isinstance(data, np.ndarray) + if not isinstance(data, np.ndarray): + raise AssertionError self._cache_dict[save_key] = data.astype(np.float32) def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: diff --git a/monai/data/dataloader.py b/monai/data/dataloader.py index 262e29223a..65935d36cc 100644 --- a/monai/data/dataloader.py +++ b/monai/data/dataloader.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 892546b2a4..e67c7a2954 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -89,7 +89,7 @@ class PersistentDataset(Dataset): .. code-block:: python - [ LoadNiftid(keys=['image', 'label']), + [ LoadImaged(keys=['image', 'label']), Orientationd(keys=['image', 'label'], axcodes='RAS'), ScaleIntensityRanged(keys=['image'], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True), RandCropByPosNegLabeld(keys=['image', 'label'], label_key='label', spatial_size=(96, 96, 96), @@ -97,7 +97,7 @@ class PersistentDataset(Dataset): ToTensord(keys=['image', 'label'])] Upon first use a filename based dataset will be processed by the transform for the - [LoadNiftid, Orientationd, ScaleIntensityRanged] and the resulting tensor written to + [LoadImaged, Orientationd, ScaleIntensityRanged] and the resulting tensor written to the `cache_dir` before applying the remaining random dependant transforms [RandCropByPosNegLabeld, ToTensord] elements for use in the analysis. @@ -446,7 +446,7 @@ class CacheDataset(Dataset): For example, if the transform is a `Compose` of:: transforms = Compose([ - LoadNiftid(), + LoadImaged(), AddChanneld(), Spacingd(), Orientationd(), @@ -457,7 +457,7 @@ class CacheDataset(Dataset): when `transforms` is used in a multi-epoch training pipeline, before the first training epoch, this dataset will cache the results up to ``ScaleIntensityRanged``, as - all non-random transforms `LoadNiftid`, `AddChanneld`, `Spacingd`, `Orientationd`, `ScaleIntensityRanged` + all non-random transforms `LoadImaged`, `AddChanneld`, `Spacingd`, `Orientationd`, `ScaleIntensityRanged` can be cached. During training, the dataset will load the cached results and run ``RandCropByPosNegLabeld`` and ``ToTensord``, as ``RandCropByPosNegLabeld`` is a randomized transform and the outcome not cached. @@ -498,7 +498,13 @@ def _fill_cache(self) -> List: warnings.warn("tqdm is not installed, will not show the caching progress bar.") with ThreadPool(self.num_workers) as p: if has_tqdm: - return list(tqdm(p.imap(self._load_cache_item, range(self.cache_num)), total=self.cache_num)) + return list( + tqdm( + p.imap(self._load_cache_item, range(self.cache_num)), + total=self.cache_num, + desc="Loading dataset", + ) + ) return list(p.imap(self._load_cache_item, range(self.cache_num))) def _load_cache_item(self, idx: int): @@ -699,8 +705,7 @@ def _try_shutdown(self): self._round = 0 self._replace_done = False return True - else: - return False + return False def shutdown(self): """ @@ -807,7 +812,7 @@ def __getitem__(self, index: int): def to_list(x): return list(x) if isinstance(x, (tuple, list)) else [x] - data = list() + data = [] for dataset in self.data: data.extend(to_list(dataset[index])) if self.transform is not None: @@ -826,7 +831,7 @@ class ArrayDataset(Randomizable, _TorchDataset): img_transform = Compose( [ - LoadNifti(image_only=True), + LoadImage(image_only=True), AddChannel(), RandAdjustContrast() ] @@ -835,7 +840,7 @@ class ArrayDataset(Randomizable, _TorchDataset): If training based on images and the metadata, the array transforms can not be composed because several transforms receives multiple parameters or return multiple values. Then Users need - to define their own callable method to parse metadata from `LoadNifti` or set `affine` matrix + to define their own callable method to parse metadata from `LoadImage` or set `affine` matrix to `Spacing` transform:: class TestCompose(Compose): @@ -846,7 +851,7 @@ def __call__(self, input_): return self.transforms[3](img), metadata img_transform = TestCompose( [ - LoadNifti(image_only=False), + LoadImage(image_only=False), AddChannel(), Spacing(pixdim=(1.5, 1.5, 3.0)), RandAdjustContrast() diff --git a/monai/data/decathlon_datalist.py b/monai/data/decathlon_datalist.py index 96f4fe5324..6167e83e47 100644 --- a/monai/data/decathlon_datalist.py +++ b/monai/data/decathlon_datalist.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -39,13 +39,12 @@ def _compute_path(base_dir, element): """ if isinstance(element, str): return os.path.normpath(os.path.join(base_dir, element)) - elif isinstance(element, list): + if isinstance(element, list): for e in element: if not isinstance(e, str): raise TypeError(f"Every file path in element must be a str but got {type(element).__name__}.") return [os.path.normpath(os.path.join(base_dir, e)) for e in element] - else: - raise TypeError(f"element must be one of (str, list) but is {type(element).__name__}.") + raise TypeError(f"element must be one of (str, list) but is {type(element).__name__}.") def _append_paths(base_dir: str, is_segmentation: bool, items: List[Dict]) -> List[Dict]: @@ -136,7 +135,7 @@ def load_decathlon_properties( with open(data_property_file_path) as json_file: json_data = json.load(json_file) - properties = dict() + properties = {} for key in ensure_tuple(property_keys): if key not in json_data: raise KeyError(f"key {key} is not in the data property file.") diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index 63dcc069a2..f85569d88a 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/data/nifti_reader.py b/monai/data/image_dataset.py similarity index 69% rename from monai/data/nifti_reader.py rename to monai/data/image_dataset.py index 3c37448957..7dd55431af 100644 --- a/monai/data/nifti_reader.py +++ b/monai/data/image_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,19 +9,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Sequence +from typing import Any, Callable, Optional, Sequence, Union import numpy as np from torch.utils.data import Dataset -from monai.transforms import LoadNifti, Randomizable, apply_transform +from monai.data.image_reader import ImageReader +from monai.transforms import LoadImage, Randomizable, apply_transform from monai.utils import MAX_SEED, get_seed -class NiftiDataset(Dataset, Randomizable): +class ImageDataset(Dataset, Randomizable): """ - Loads image/segmentation pairs of Nifti files from the given filename lists. Transformations can be specified + Loads image/segmentation pairs of files from the given filename lists. Transformations can be specified for the image and segmentation arrays separately. + The difference between this dataset and `ArrayDataset` is that this dataset can apply transform chain to images + and segs and return both the images and metadata, and no need to specify transform to load images from files. + """ def __init__( @@ -29,11 +33,13 @@ def __init__( image_files: Sequence[str], seg_files: Optional[Sequence[str]] = None, labels: Optional[Sequence[float]] = None, - as_closest_canonical: bool = False, transform: Optional[Callable] = None, seg_transform: Optional[Callable] = None, image_only: bool = True, dtype: Optional[np.dtype] = np.float32, + reader: Optional[Union[ImageReader, str]] = None, + *args, + **kwargs, ) -> None: """ Initializes the dataset with the image and segmentation filename lists. The transform `transform` is applied @@ -43,14 +49,18 @@ def __init__( image_files: list of image filenames seg_files: if in segmentation task, list of segmentation filenames labels: if in classification task, list of classification labels - as_closest_canonical: if True, load the image as closest to canonical orientation transform: transform to apply to image arrays seg_transform: transform to apply to segmentation arrays - image_only: if True return only the image volume, other return image volume and header dict + image_only: if True return only the image volume, otherwise, return image volume and the metadata dtype: if not None convert the loaded image to this data type + reader: register reader to load image file and meta data, if None, will use the default readers. + If a string of reader name provided, will construct a reader object with the `*args` and `**kwargs` + parameters, supported reader name: "NibabelReader", "PILReader", "ITKReader", "NumpyReader" + args: additional parameters for reader if providing a reader name + kwargs: additional parameters for reader if providing a reader name Raises: - ValueError: When ``seg_files`` length differs from ``image_files``. + ValueError: When ``seg_files`` length differs from ``image_files`` """ @@ -63,13 +73,11 @@ def __init__( self.image_files = image_files self.seg_files = seg_files self.labels = labels - self.as_closest_canonical = as_closest_canonical self.transform = transform self.seg_transform = seg_transform self.image_only = image_only - self.dtype = dtype + self.loader = LoadImage(reader, image_only, dtype, *args, **kwargs) self.set_random_state(seed=get_seed()) - self._seed = 0 # transform synchronization seed def __len__(self) -> int: @@ -81,18 +89,18 @@ def randomize(self, data: Optional[Any] = None) -> None: def __getitem__(self, index: int): self.randomize() meta_data = None - img_loader = LoadNifti( - as_closest_canonical=self.as_closest_canonical, image_only=self.image_only, dtype=self.dtype - ) - if self.image_only: - img = img_loader(self.image_files[index]) - else: - img, meta_data = img_loader(self.image_files[index]) seg = None - if self.seg_files is not None: - seg_loader = LoadNifti(image_only=True) - seg = seg_loader(self.seg_files[index]) label = None + + if self.image_only: + img = self.loader(self.image_files[index]) + if self.seg_files is not None: + seg = self.loader(self.seg_files[index]) + else: + img, meta_data = self.loader(self.image_files[index]) + if self.seg_files is not None: + seg, _ = self.loader(self.seg_files[index]) + if self.labels is not None: label = self.labels[index] diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 32d03115ed..0fd784af05 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -154,7 +154,7 @@ def read(self, data: Union[Sequence[str], str], **kwargs): https://github.com/InsightSoftwareConsortium/ITK/blob/master/Wrapping/Generators/Python/itkExtras.py """ - img_: List[Image] = list() + img_: List[Image] = [] filenames: Sequence[str] = ensure_tuple(data) kwargs_ = self.kwargs.copy() @@ -191,7 +191,7 @@ def get_data(self, img): img: a ITK image object loaded from a image file or a list of ITK image objects. """ - img_array: List[np.ndarray] = list() + img_array: List[np.ndarray] = [] compatible_meta: Dict = {} for i in ensure_tuple(img): @@ -325,7 +325,7 @@ def read(self, data: Union[Sequence[str], str], **kwargs): https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py """ - img_: List[Nifti1Image] = list() + img_: List[Nifti1Image] = [] filenames: Sequence[str] = ensure_tuple(data) kwargs_ = self.kwargs.copy() @@ -348,7 +348,7 @@ def get_data(self, img): img: a Nibabel image object loaded from a image file or a list of Nibabel image objects. """ - img_array: List[np.ndarray] = list() + img_array: List[np.ndarray] = [] compatible_meta: Dict = {} for i in ensure_tuple(img): @@ -456,7 +456,7 @@ def read(self, data: Union[Sequence[str], str], **kwargs): https://numpy.org/doc/stable/reference/generated/numpy.load.html """ - img_: List[Nifti1Image] = list() + img_: List[Nifti1Image] = [] filenames: Sequence[str] = ensure_tuple(data) kwargs_ = self.kwargs.copy() @@ -485,13 +485,13 @@ def get_data(self, img): img: a Numpy array loaded from a file or a list of Numpy arrays. """ - img_array: List[np.ndarray] = list() + img_array: List[np.ndarray] = [] compatible_meta: Dict = {} if isinstance(img, np.ndarray): img = (img,) for i in ensure_tuple(img): - header = dict() + header = {} if isinstance(i, np.ndarray): header["spatial_shape"] = i.shape img_array.append(i) @@ -525,7 +525,7 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: filename: file name or a list of file names to read. if a list of files, verify all the suffixes. """ - suffixes: Sequence[str] = ["png", "jpg", "bmp"] + suffixes: Sequence[str] = ["png", "jpg", "jpeg", "bmp"] return has_pil and is_supported_format(filename, suffixes) def read(self, data: Union[Sequence[str], str, np.ndarray], **kwargs): @@ -540,7 +540,7 @@ def read(self, data: Union[Sequence[str], str, np.ndarray], **kwargs): https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.open """ - img_: List[PILImage.Image] = list() + img_: List[PILImage.Image] = [] filenames: Sequence[str] = ensure_tuple(data) kwargs_ = self.kwargs.copy() @@ -565,7 +565,7 @@ def get_data(self, img): img: a PIL Image object loaded from a file or a list of PIL Image objects. """ - img_array: List[np.ndarray] = list() + img_array: List[np.ndarray] = [] compatible_meta: Dict = {} for i in ensure_tuple(img): diff --git a/monai/data/iterable_dataset.py b/monai/data/iterable_dataset.py index c8ee006b12..7f0a0986dd 100644 --- a/monai/data/iterable_dataset.py +++ b/monai/data/iterable_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/data/nifti_saver.py b/monai/data/nifti_saver.py index 6575d31251..f4781f82fd 100644 --- a/monai/data/nifti_saver.py +++ b/monai/data/nifti_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -37,6 +37,7 @@ def __init__( padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, align_corners: bool = False, dtype: Optional[np.dtype] = np.float64, + output_dtype: Optional[np.dtype] = np.float32, ) -> None: """ Args: @@ -57,6 +58,7 @@ def __init__( dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. + output_dtype: data type for saving data. Defaults to ``np.float32``. """ self.output_dir = output_dir self.output_postfix = output_postfix @@ -66,6 +68,7 @@ def __init__( self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) self.align_corners = align_corners self.dtype = dtype + self.output_dtype = output_dtype self._data_index = 0 def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: @@ -118,6 +121,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] padding_mode=self.padding_mode, align_corners=self.align_corners, dtype=self.dtype, + output_dtype=self.output_dtype, ) def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: diff --git a/monai/data/nifti_writer.py b/monai/data/nifti_writer.py index 120bba0e4d..6837ebeb90 100644 --- a/monai/data/nifti_writer.py +++ b/monai/data/nifti_writer.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -89,7 +89,8 @@ def write_nifti( the output data type is always ``np.float32``. output_dtype: data type for saving data. Defaults to ``np.float32``. """ - assert isinstance(data, np.ndarray), "input data must be numpy array." + if not isinstance(data, np.ndarray): + raise AssertionError("input data must be numpy array.") dtype = dtype or data.dtype sr = min(data.ndim, 3) if affine is None: diff --git a/monai/data/png_saver.py b/monai/data/png_saver.py index dea7387bd4..450e327d6b 100644 --- a/monai/data/png_saver.py +++ b/monai/data/png_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/data/png_writer.py b/monai/data/png_writer.py index d35a530a86..d7baa6ea79 100644 --- a/monai/data/png_writer.py +++ b/monai/data/png_writer.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -47,7 +47,8 @@ def write_png( ValueError: When ``scale`` is not one of [255, 65535]. """ - assert isinstance(data, np.ndarray), "input data must be numpy array." + if not isinstance(data, np.ndarray): + raise AssertionError("input data must be numpy array.") if len(data.shape) == 3 and data.shape[2] == 1: # PIL Image can't save image with 1 channel data = data.squeeze(2) if output_spatial_shape is not None: diff --git a/monai/data/synthetic.py b/monai/data/synthetic.py index cdbb660566..90cbe13c2d 100644 --- a/monai/data/synthetic.py +++ b/monai/data/synthetic.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -68,7 +68,8 @@ def create_test_image_2d( noisyimage = rescale_array(np.maximum(image, norm)) if channel_dim is not None: - assert isinstance(channel_dim, int) and channel_dim in (-1, 0, 2), "invalid channel dim." + if not (isinstance(channel_dim, int) and channel_dim in (-1, 0, 2)): + raise AssertionError("invalid channel dim.") if channel_dim == 0: noisyimage = noisyimage[None] labels = labels[None] @@ -131,7 +132,8 @@ def create_test_image_3d( noisyimage = rescale_array(np.maximum(image, norm)) if channel_dim is not None: - assert isinstance(channel_dim, int) and channel_dim in (-1, 0, 3), "invalid channel dim." + if not (isinstance(channel_dim, int) and channel_dim in (-1, 0, 3)): + raise AssertionError("invalid channel dim.") noisyimage, labels = ( (noisyimage[None], labels[None]) if channel_dim == 0 else (noisyimage[..., None], labels[..., None]) ) diff --git a/monai/data/thread_buffer.py b/monai/data/thread_buffer.py index 9832a7c164..252fdd6a21 100644 --- a/monai/data/thread_buffer.py +++ b/monai/data/thread_buffer.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/data/utils.py b/monai/data/utils.py index c5fcbf3c86..ca8f3b1017 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -39,6 +39,32 @@ nib, _ = optional_import("nibabel") +__all__ = [ + "get_random_patch", + "iter_patch_slices", + "dense_patch_slices", + "iter_patch", + "get_valid_patch_size", + "list_data_collate", + "worker_init_fn", + "set_rnd", + "correct_nifti_header_if_necessary", + "rectify_header_sform_qform", + "zoom_affine", + "compute_shape_offset", + "to_affine_nd", + "create_file_basename", + "compute_importance_map", + "is_supported_format", + "partition_dataset", + "partition_dataset_classes", + "select_cross_validation_folds", + "DistributedSampler", + "json_hashing", + "pickle_hashing", + "sorted_dict", +] + def get_random_patch( dims: Sequence[int], patch_size: Sequence[int], rand_state: Optional[np.random.RandomState] = None @@ -686,7 +712,7 @@ def partition_dataset_classes( for i, c in enumerate(classes): class_indices[c].append(i) - class_partition_indices: List[Sequence] = list() + class_partition_indices: List[Sequence] = [] for _, per_class_indices in sorted(class_indices.items()): per_class_partition_indices = partition_dataset( data=per_class_indices, diff --git a/monai/engines/__init__.py b/monai/engines/__init__.py index 13835f915b..8256680735 100644 --- a/monai/engines/__init__.py +++ b/monai/engines/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .evaluator import * -from .multi_gpu_supervised_trainer import * -from .trainer import * +from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator +from .multi_gpu_supervised_trainer import create_multigpu_supervised_evaluator, create_multigpu_supervised_trainer +from .trainer import GanTrainer, SupervisedTrainer, Trainer +from .utils import CommonKeys, GanKeys, IterationEvents, default_make_latent, default_prepare_batch, get_devices_spec diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 306be5f2db..0b7167fb3a 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,9 +15,10 @@ from torch.utils.data import DataLoader from monai.engines.utils import CommonKeys as Keys -from monai.engines.utils import default_prepare_batch +from monai.engines.utils import IterationEvents, default_prepare_batch from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer +from monai.networks.utils import eval_mode from monai.transforms import Transform from monai.utils import ensure_tuple, exact_version, optional_import @@ -163,6 +164,10 @@ def __init__( self.network = network self.inferer = SimpleInferer() if inferer is None else inferer + def _register_additional_events(self): + super()._register_additional_events() + self.register_events(*IterationEvents) + def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. @@ -184,21 +189,23 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) if len(batch) == 2: inputs, targets = batch - args: Tuple = tuple() - kwargs: Dict = dict() + args: Tuple = () + kwargs: Dict = {} else: inputs, targets, args, kwargs = batch + # put iteration outputs into engine.state + engine.state.output = output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # execute forward computation - self.network.eval() - with torch.no_grad(): + with eval_mode(self.network): if self.amp: with torch.cuda.amp.autocast(): - predictions = self.inferer(inputs, self.network, *args, **kwargs) + output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) else: - predictions = self.inferer(inputs, self.network, *args, **kwargs) + output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) + engine.fire_event(IterationEvents.FORWARD_COMPLETED) - return {Keys.IMAGE: inputs, Keys.LABEL: targets, Keys.PRED: predictions} + return output class EnsembleEvaluator(Evaluator): @@ -266,6 +273,10 @@ def __init__( self.pred_keys = ensure_tuple(pred_keys) self.inferer = SimpleInferer() if inferer is None else inferer + def _register_additional_events(self): + super()._register_additional_events() + self.register_events(*IterationEvents) + def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. @@ -290,20 +301,20 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) if len(batch) == 2: inputs, targets = batch - args: Tuple = tuple() - kwargs: Dict = dict() + args: Tuple = () + kwargs: Dict = {} else: inputs, targets, args, kwargs = batch - # execute forward computation - predictions = {Keys.IMAGE: inputs, Keys.LABEL: targets} + # put iteration outputs into engine.state + engine.state.output = output = {Keys.IMAGE: inputs, Keys.LABEL: targets} for idx, network in enumerate(self.networks): - network.eval() - with torch.no_grad(): + with eval_mode(network): if self.amp: with torch.cuda.amp.autocast(): - predictions.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}) + output.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}) else: - predictions.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}) + output.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}) + engine.fire_event(IterationEvents.FORWARD_COMPLETED) - return predictions + return output diff --git a/monai/engines/multi_gpu_supervised_trainer.py b/monai/engines/multi_gpu_supervised_trainer.py index 33268308e5..d12e012a56 100644 --- a/monai/engines/multi_gpu_supervised_trainer.py +++ b/monai/engines/multi_gpu_supervised_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 64b38e2646..efb2ab12fa 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,7 +16,7 @@ from torch.utils.data import DataLoader from monai.engines.utils import CommonKeys as Keys -from monai.engines.utils import GanKeys, default_make_latent, default_prepare_batch +from monai.engines.utils import GanKeys, IterationEvents, default_make_latent, default_prepare_batch from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.transforms import Transform @@ -121,6 +121,10 @@ def __init__( self.loss_function = loss_function self.inferer = SimpleInferer() if inferer is None else inferer + def _register_additional_events(self): + super()._register_additional_events() + self.register_events(*IterationEvents) + def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine. @@ -143,27 +147,36 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) if len(batch) == 2: inputs, targets = batch - args: Tuple = tuple() - kwargs: Dict = dict() + args: Tuple = () + kwargs: Dict = {} else: inputs, targets, args, kwargs = batch + # put iteration outputs into engine.state + engine.state.output = output = {Keys.IMAGE: inputs, Keys.LABEL: targets} + + def _compute_pred_loss(): + output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) + engine.fire_event(IterationEvents.FORWARD_COMPLETED) + output[Keys.LOSS] = self.loss_function(output[Keys.PRED], targets).mean() + engine.fire_event(IterationEvents.LOSS_COMPLETED) self.network.train() self.optimizer.zero_grad() if self.amp and self.scaler is not None: with torch.cuda.amp.autocast(): - predictions = self.inferer(inputs, self.network, *args, **kwargs) - loss = self.loss_function(predictions, targets).mean() - self.scaler.scale(loss).backward() + _compute_pred_loss() + self.scaler.scale(output[Keys.LOSS]).backward() + engine.fire_event(IterationEvents.BACKWARD_COMPLETED) self.scaler.step(self.optimizer) self.scaler.update() else: - predictions = self.inferer(inputs, self.network, *args, **kwargs) - loss = self.loss_function(predictions, targets).mean() - loss.backward() + _compute_pred_loss() + output[Keys.LOSS].backward() + engine.fire_event(IterationEvents.BACKWARD_COMPLETED) self.optimizer.step() + engine.fire_event(IterationEvents.OPTIMIZER_COMPLETED) - return {Keys.IMAGE: inputs, Keys.LABEL: targets, Keys.PRED: predictions, Keys.LOSS: loss.item()} + return output class GanTrainer(Trainer): @@ -282,7 +295,7 @@ def _iteration( if batchdata is None: raise ValueError("must provide batch data for current iteration.") - d_input = self.prepare_batch(batchdata, engine.state.device) + d_input = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) batch_size = self.data_loader.batch_size g_input = self.g_prepare_batch(batch_size, self.latent_shape, engine.state.device, engine.non_blocking) g_output = self.g_inferer(g_input, self.g_network) diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 74d08ce41f..f603338097 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,10 +9,43 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union import torch +from monai.utils import exact_version, optional_import + +if TYPE_CHECKING: + from ignite.engine import EventEnum +else: + EventEnum, _ = optional_import("ignite.engine", "0.4.2", exact_version, "EventEnum") + +__all__ = [ + "IterationEvents", + "CommonKeys", + "GanKeys", + "get_devices_spec", + "default_prepare_batch", + "default_make_latent", +] + + +class IterationEvents(EventEnum): + """ + Addtional Events engine can register and trigger in the iteration process. + Refer to the example in ignite: https://github.com/pytorch/ignite/blob/master/ignite/engine/events.py#L146 + These Events can be triggered during training iteration: + `FORWARD_COMPLETED` is the Event when `network(image, label)` completed. + `LOSS_COMPLETED` is the Event when `loss(pred, label)` completed. + `BACKWARD_COMPLETED` is the Event when `loss.backward()` completed. + + """ + + FORWARD_COMPLETED = "forward_completed" + LOSS_COMPLETED = "loss_completed" + BACKWARD_COMPLETED = "backward_completed" + OPTIMIZER_COMPLETED = "optimizer_completed" + class CommonKeys: """ @@ -34,6 +67,7 @@ class CommonKeys: class GanKeys: """ A set of common keys for generative adversarial networks. + """ REALS = "reals" @@ -87,16 +121,16 @@ def default_prepare_batch( image, label(optional). """ - assert isinstance(batchdata, dict), "default prepare_batch expects dictionary input data." + if not isinstance(batchdata, dict): + raise AssertionError("default prepare_batch expects dictionary input data.") if CommonKeys.LABEL in batchdata: return ( batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking), batchdata[CommonKeys.LABEL].to(device=device, non_blocking=non_blocking), ) - elif GanKeys.REALS in batchdata: + if GanKeys.REALS in batchdata: return batchdata[GanKeys.REALS].to(device=device, non_blocking=non_blocking) - else: - return batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking), None + return batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking), None def default_make_latent( diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 076ed2289b..d6415c1966 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -110,6 +110,7 @@ def set_sampler_epoch(engine: Engine): output=None, batch=None, metrics={}, + metric_details={}, dataloader=None, device=device, key_metric_name=None, # we can set many metrics, only use key_metric to compare and save the best model @@ -119,43 +120,68 @@ def set_sampler_epoch(engine: Engine): self.data_loader = data_loader self.non_blocking = non_blocking self.prepare_batch = prepare_batch + self.amp = amp + self._register_additional_events() if post_transform is not None: + self._register_post_transforms(post_transform) + if key_metric is not None: + self._register_metrics(key_metric, additional_metrics) + if handlers is not None: + self._register_handlers(handlers) - @self.on(Events.ITERATION_COMPLETED) - def run_post_transform(engine: Engine) -> None: - assert post_transform is not None - engine.state.output = apply_transform(post_transform, engine.state.output) + def _register_additional_events(self): + """ + Register more ignite Events to the engine. - if key_metric is not None: + """ + pass - if not isinstance(key_metric, dict): - raise TypeError(f"key_metric must be None or a dict but is {type(key_metric).__name__}.") - self.state.key_metric_name = list(key_metric.keys())[0] - metrics = key_metric - if additional_metrics is not None and len(additional_metrics) > 0: - if not isinstance(additional_metrics, dict): - raise TypeError( - f"additional_metrics must be None or a dict but is {type(additional_metrics).__name__}." - ) - metrics.update(additional_metrics) - for name, metric in metrics.items(): - metric.attach(self, name) - - @self.on(Events.EPOCH_COMPLETED) - def _compare_metrics(engine: Engine) -> None: - if engine.state.key_metric_name is not None: - current_val_metric = engine.state.metrics[engine.state.key_metric_name] - if current_val_metric > engine.state.best_metric: - self.logger.info(f"Got new best metric of {engine.state.key_metric_name}: {current_val_metric}") - engine.state.best_metric = current_val_metric - engine.state.best_metric_epoch = engine.state.epoch + def _register_post_transforms(self, posttrans): + """ + Register the post transforms to the engine, will execute them as a chain when iteration completed. - if handlers is not None: - handlers_ = ensure_tuple(handlers) - for handler in handlers_: - handler.attach(self) - self.amp = amp + """ + + @self.on(Events.ITERATION_COMPLETED) + def run_post_transform(engine: Engine) -> None: + if posttrans is None: + raise AssertionError + engine.state.output = apply_transform(posttrans, engine.state.output) + + def _register_metrics(self, k_metric, add_metrics): + """ + Register the key metric and additional metrics to the engine, supports ignite Metrics. + + """ + if not isinstance(k_metric, dict): + raise TypeError(f"key_metric must be None or a dict but is {type(k_metric).__name__}.") + self.state.key_metric_name = list(k_metric.keys())[0] + metrics = k_metric + if add_metrics is not None and len(add_metrics) > 0: + if not isinstance(add_metrics, dict): + raise TypeError(f"additional metrics must be None or a dict but is {type(add_metrics).__name__}.") + metrics.update(add_metrics) + for name, metric in metrics.items(): + metric.attach(self, name) + + @self.on(Events.EPOCH_COMPLETED) + def _compare_metrics(engine: Engine) -> None: + if engine.state.key_metric_name is not None: + current_val_metric = engine.state.metrics[engine.state.key_metric_name] + if current_val_metric > engine.state.best_metric: + self.logger.info(f"Got new best metric of {engine.state.key_metric_name}: {current_val_metric}") + engine.state.best_metric = current_val_metric + engine.state.best_metric_epoch = engine.state.epoch + + def _register_handlers(self, handlers): + """ + Register the handlers to the engine, supports ignite Handlers with `attach` API. + + """ + handlers_ = ensure_tuple(handlers) + for handler in handlers_: + handler.attach(self) def run(self) -> None: """ diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 37715cad52..6b190518fb 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,14 +14,16 @@ from .classification_saver import ClassificationSaver from .confusion_matrix import ConfusionMatrix from .hausdorff_distance import HausdorffDistance +from .iteration_metric import IterationMetric from .lr_schedule_handler import LrScheduleHandler from .mean_dice import MeanDice from .metric_logger import MetricLogger +from .metrics_saver import MetricsSaver from .roc_auc import ROCAUC from .segmentation_saver import SegmentationSaver from .smartcache_handler import SmartCacheHandler from .stats_handler import StatsHandler from .surface_distance import SurfaceDistance from .tensorboard_handlers import TensorBoardImageHandler, TensorBoardStatsHandler -from .utils import * +from .utils import evenly_divisible_all_gather, stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports from .validation_handler import ValidationHandler diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index 5e8fe741be..648cc8360a 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -54,9 +54,11 @@ def __init__( name: Optional[str] = None, map_location: Optional[Dict] = None, ) -> None: - assert load_path is not None, "must provide clear path to load checkpoint." + if load_path is None: + raise AssertionError("must provide clear path to load checkpoint.") self.load_path = load_path - assert load_dict is not None and len(load_dict) > 0, "must provide target objects to load." + if not (load_dict is not None and len(load_dict) > 0): + raise AssertionError("must provide target objects to load.") self.logger = logging.getLogger(name) self.load_dict = load_dict self._name = name diff --git a/monai/handlers/checkpoint_saver.py b/monai/handlers/checkpoint_saver.py index 0cc05b2dc4..8052e21cb6 100644 --- a/monai/handlers/checkpoint_saver.py +++ b/monai/handlers/checkpoint_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -88,9 +88,11 @@ def __init__( save_interval: int = 0, n_saved: Optional[int] = None, ) -> None: - assert save_dir is not None, "must provide directory to save the checkpoints." + if save_dir is None: + raise AssertionError("must provide directory to save the checkpoints.") self.save_dir = save_dir - assert save_dict is not None and len(save_dict) > 0, "must provide source objects to save." + if not (save_dict is not None and len(save_dict) > 0): + raise AssertionError("must provide source objects to save.") self.save_dict = save_dict self.logger = logging.getLogger(name) self.epoch_level = epoch_level @@ -202,12 +204,15 @@ def completed(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - assert callable(self._final_checkpoint), "Error: _final_checkpoint function not specified." + if not callable(self._final_checkpoint): + raise AssertionError("Error: _final_checkpoint function not specified.") # delete previous saved final checkpoint if existing self._delete_previous_final_ckpt() self._final_checkpoint(engine) - assert self.logger is not None - assert hasattr(self.logger, "info"), "Error, provided logger has not info attribute." + if self.logger is None: + raise AssertionError + if not hasattr(self.logger, "info"): + raise AssertionError("Error, provided logger has not info attribute.") self.logger.info(f"Train completed, saved final checkpoint: {self._final_checkpoint.last_checkpoint}") def exception_raised(self, engine: Engine, e: Exception) -> None: @@ -219,12 +224,15 @@ def exception_raised(self, engine: Engine, e: Exception) -> None: engine: Ignite Engine, it can be a trainer, validator or evaluator. e: the exception caught in Ignite during engine.run(). """ - assert callable(self._final_checkpoint), "Error: _final_checkpoint function not specified." + if not callable(self._final_checkpoint): + raise AssertionError("Error: _final_checkpoint function not specified.") # delete previous saved final checkpoint if existing self._delete_previous_final_ckpt() self._final_checkpoint(engine) - assert self.logger is not None - assert hasattr(self.logger, "info"), "Error, provided logger has not info attribute." + if self.logger is None: + raise AssertionError + if not hasattr(self.logger, "info"): + raise AssertionError("Error, provided logger has not info attribute.") self.logger.info(f"Exception_raised, saved exception checkpoint: {self._final_checkpoint.last_checkpoint}") raise e @@ -234,7 +242,8 @@ def metrics_completed(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - assert callable(self._key_metric_checkpoint), "Error: _key_metric_checkpoint function not specified." + if not callable(self._key_metric_checkpoint): + raise AssertionError("Error: _key_metric_checkpoint function not specified.") self._key_metric_checkpoint(engine) def interval_completed(self, engine: Engine) -> None: @@ -244,10 +253,13 @@ def interval_completed(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - assert callable(self._interval_checkpoint), "Error: _interval_checkpoint function not specified." + if not callable(self._interval_checkpoint): + raise AssertionError("Error: _interval_checkpoint function not specified.") self._interval_checkpoint(engine) - assert self.logger is not None - assert hasattr(self.logger, "info"), "Error, provided logger has not info attribute." + if self.logger is None: + raise AssertionError + if not hasattr(self.logger, "info"): + raise AssertionError("Error, provided logger has not info attribute.") if self.epoch_level: self.logger.info(f"Saved checkpoint at epoch: {engine.state.epoch}") else: diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index e446773144..6753cafcb0 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/handlers/confusion_matrix.py b/monai/handlers/confusion_matrix.py index 7bb68a25fd..1741aa305a 100644 --- a/monai/handlers/confusion_matrix.py +++ b/monai/handlers/confusion_matrix.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,21 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Sequence +from typing import Any, Callable, Optional import torch +from monai.handlers.iteration_metric import IterationMetric from monai.metrics import ConfusionMatrixMetric, compute_confusion_matrix_metric from monai.metrics.utils import MetricReduction, do_metric_reduction -from monai.utils import exact_version, optional_import -NotComputableError, _ = optional_import("ignite.exceptions", "0.4.2", exact_version, "NotComputableError") -Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") -reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "reinit__is_reduced") -sync_all_reduce, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "sync_all_reduce") - -class ConfusionMatrix(Metric): # type: ignore[valid-type, misc] # due to optional_import +class ConfusionMatrix(IterationMetric): """ Compute confusion matrix related metrics from full size Tensor and collects average over batch, class-channels, iterations. """ @@ -32,9 +27,9 @@ def __init__( self, include_background: bool = True, metric_name: str = "hit_rate", - compute_sample: bool = False, output_transform: Callable = lambda x: x, device: Optional[torch.device] = None, + save_details: bool = True, ) -> None: """ @@ -48,80 +43,28 @@ def __init__( ``"informedness"``, ``"markedness"``] Some of the metrics have multiple aliases (as shown in the wikipedia page aforementioned), and you can also input those names instead. - compute_sample: if ``True``, each sample's metric will be computed first. - If ``False``, the confusion matrix for all samples will be accumulated first. Defaults to ``False``. output_transform: transform the ignite.engine.state.output into [y_pred, y] pair. device: device specification in case of distributed computation usage. + save_details: whether to save metric computation details per image, for example: TP/TN/FP/FN of every image. + default to True, will save to `engine.state.metric_details` dict with the metric name as key. See also: :py:meth:`monai.metrics.confusion_matrix` """ - super().__init__(output_transform, device=device) - self.confusion_matrix = ConfusionMatrixMetric( + metric_fn = ConfusionMatrixMetric( include_background=include_background, metric_name=metric_name, - compute_sample=compute_sample, - reduction=MetricReduction.MEAN, + compute_sample=False, + reduction=MetricReduction.NONE, ) - self._sum = 0.0 - self._num_examples = 0 - self.compute_sample = compute_sample self.metric_name = metric_name - self._total_tp = 0.0 - self._total_fp = 0.0 - self._total_tn = 0.0 - self._total_fn = 0.0 - - @reinit__is_reduced - def reset(self) -> None: - self._sum = 0.0 - self._num_examples = 0 - self._total_tp = 0.0 - self._total_fp = 0.0 - self._total_tn = 0.0 - self._total_fn = 0.0 - - @reinit__is_reduced - def update(self, output: Sequence[torch.Tensor]) -> None: - """ - Args: - output: sequence with contents [y_pred, y]. - - Raises: - ValueError: When ``output`` length is not 2. This metric can only support y_pred and y. - - """ - if len(output) != 2: - raise ValueError(f"output must have length 2, got {len(output)}.") - y_pred, y = output - if self.compute_sample is True: - score, not_nans = self.confusion_matrix(y_pred, y) - not_nans = int(not_nans.item()) - - # add all items in current batch - self._sum += score.item() * not_nans - self._num_examples += not_nans - else: - confusion_matrix = self.confusion_matrix(y_pred, y) - confusion_matrix, _ = do_metric_reduction(confusion_matrix, MetricReduction.SUM) - self._total_tp += confusion_matrix[0].item() - self._total_fp += confusion_matrix[1].item() - self._total_tn += confusion_matrix[2].item() - self._total_fn += confusion_matrix[3].item() - - @sync_all_reduce("_sum", "_num_examples", "_total_tp", "_total_fp", "_total_tn", "_total_fn") - def compute(self): - """ - Raises: - NotComputableError: When ``compute`` is called before an ``update`` occurs. + super().__init__( + metric_fn=metric_fn, + output_transform=output_transform, + device=device, + save_details=save_details, + ) - """ - if self.compute_sample is True: - if self._num_examples == 0: - raise NotComputableError( - "ConfusionMatrix metric must have at least one example before it can be computed." - ) - return self._sum / self._num_examples - else: - confusion_matrix = torch.tensor([self._total_tp, self._total_fp, self._total_tn, self._total_fn]) - return compute_confusion_matrix_metric(self.metric_name, confusion_matrix) + def _reduce(self, scores) -> Any: + confusion_matrix, _ = do_metric_reduction(scores, MetricReduction.MEAN) + return compute_confusion_matrix_metric(self.metric_name, confusion_matrix) diff --git a/monai/handlers/hausdorff_distance.py b/monai/handlers/hausdorff_distance.py index 56b8b341ff..7ac52d642a 100644 --- a/monai/handlers/hausdorff_distance.py +++ b/monai/handlers/hausdorff_distance.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,20 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Sequence +from typing import Callable, Optional import torch +from monai.handlers.iteration_metric import IterationMetric from monai.metrics import HausdorffDistanceMetric -from monai.utils import MetricReduction, exact_version, optional_import +from monai.utils import MetricReduction -NotComputableError, _ = optional_import("ignite.exceptions", "0.4.2", exact_version, "NotComputableError") -Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") -reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "reinit__is_reduced") -sync_all_reduce, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "sync_all_reduce") - -class HausdorffDistance(Metric): # type: ignore[valid-type, misc] # due to optional_import +class HausdorffDistance(IterationMetric): """ Computes Hausdorff distance from full size Tensor and collects average over batch, class-channels, iterations. """ @@ -35,6 +31,7 @@ def __init__( directed: bool = False, output_transform: Callable = lambda x: x, device: Optional[torch.device] = None, + save_details: bool = True, ) -> None: """ @@ -49,51 +46,21 @@ def __init__( directed: whether to calculate directed Hausdorff distance. Defaults to ``False``. output_transform: transform the ignite.engine.state.output into [y_pred, y] pair. device: device specification in case of distributed computation usage. + save_details: whether to save metric computation details per image, for example: hausdorff distance + of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. """ super().__init__(output_transform, device=device) - self.hd = HausdorffDistanceMetric( + metric_fn = HausdorffDistanceMetric( include_background=include_background, distance_metric=distance_metric, percentile=percentile, directed=directed, - reduction=MetricReduction.MEAN, + reduction=MetricReduction.NONE, + ) + super().__init__( + metric_fn=metric_fn, + output_transform=output_transform, + device=device, + save_details=save_details, ) - self._sum = 0.0 - self._num_examples = 0 - - @reinit__is_reduced - def reset(self) -> None: - self._sum = 0.0 - self._num_examples = 0 - - @reinit__is_reduced - def update(self, output: Sequence[torch.Tensor]) -> None: - """ - Args: - output: sequence with contents [y_pred, y]. - - Raises: - ValueError: When ``output`` length is not 2. The metric can only support y_pred and y. - - """ - if len(output) != 2: - raise ValueError(f"output must have length 2, got {len(output)}.") - y_pred, y = output - score, not_nans = self.hd(y_pred, y) - not_nans = int(not_nans.item()) - - # add all items in current batch - self._sum += score.item() * not_nans - self._num_examples += not_nans - - @sync_all_reduce("_sum", "_num_examples") - def compute(self) -> float: - """ - Raises: - NotComputableError: When ``compute`` is called before an ``update`` occurs. - - """ - if self._num_examples == 0: - raise NotComputableError("HausdorffDistance must have at least one example before it can be computed.") - return self._sum / self._num_examples diff --git a/monai/handlers/iteration_metric.py b/monai/handlers/iteration_metric.py new file mode 100644 index 0000000000..bfc7252b2f --- /dev/null +++ b/monai/handlers/iteration_metric.py @@ -0,0 +1,131 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence + +import torch + +from monai.handlers.utils import evenly_divisible_all_gather +from monai.metrics import do_metric_reduction +from monai.utils import MetricReduction, exact_version, optional_import + +idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed") +Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") +reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "reinit__is_reduced") +if TYPE_CHECKING: + from ignite.engine import Engine +else: + Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") + + +class IterationMetric(Metric): # type: ignore[valid-type, misc] # due to optional_import + """ + Class for metrics that should be computed on every iteration and compute final results when epoch completed. + Similar to the `EpochMetric` in ignite: + https://github.com/pytorch/ignite/blob/v0.4.2/ignite/metrics/epoch_metric.py#L13. + + Args: + metric_fn: callable function or class to compute raw metric results after every iteration. + expect to return a Tensor with shape (batch, channel, ...) or tuple (Tensor, not_nans). + output_transform: transform the ignite.engine.state.output into [y_pred, y] pair. + device: device specification in case of distributed computation usage. + save_details: whether to save metric computation details per image, for example: mean_dice of every image. + default to True, will save to `engine.state.metric_details` dict with the metric name as key. + + """ + + def __init__( + self, + metric_fn: Callable, + output_transform: Callable = lambda x: x, + device: Optional[torch.device] = None, + save_details: bool = True, + ) -> None: + self._is_reduced: bool = False + self.metric_fn = metric_fn + self.save_details = save_details + self._scores: List = [] + self._engine: Optional[Engine] = None + self._name: Optional[str] = None + super().__init__(output_transform, device=device) + + @reinit__is_reduced + def reset(self) -> None: + self._scores = [] + + @reinit__is_reduced + def update(self, output: Sequence[torch.Tensor]) -> None: + """ + Args: + output: sequence with contents [y_pred, y]. + + Raises: + ValueError: When ``output`` length is not 2. metric_fn can only support y_pred and y. + + """ + if len(output) != 2: + raise ValueError(f"output must have length 2, got {len(output)}.") + y_pred, y = output + score = self.metric_fn(y_pred, y) + if isinstance(score, (tuple, list)): + score = score[0] + self._scores.append(score) + + def compute(self) -> Any: + """ + Raises: + NotComputableError: When ``compute`` is called before an ``update`` occurs. + + """ + _scores = torch.cat(self._scores, dim=0) + + ws = idist.get_world_size() + if ws > 1 and not self._is_reduced: + # all gather across all processes + _scores = evenly_divisible_all_gather(data=_scores) + self._is_reduced = True + + # save score of every image into engine.state for other components + if self.save_details: + if self._engine is None or self._name is None: + raise RuntimeError("plesae call the attach() function to connect expected engine first.") + self._engine.state.metric_details[self._name] = _scores + + result: torch.Tensor = torch.zeros(1) + if idist.get_rank() == 0: + # run compute_fn on zero rank only + result = self._reduce(_scores) + + if ws > 1: + # broadcast result to all processes + result = idist.broadcast(result, src=0) + + return result.item() if torch.is_tensor(result) else result + + def _reduce(self, scores) -> Any: + return do_metric_reduction(scores, MetricReduction.MEAN)[0] + + def attach(self, engine: Engine, name: str) -> None: + """ + Attaches current metric to provided engine. On the end of engine's run, + `engine.state.metrics` dictionary will contain computed metric's value under provided name. + + Args: + engine: the engine to which the metric must be attached. + name: the name of the metric to attach. + + """ + super().attach(engine=engine, name=name) + # FIXME: record engine for communication, ignite will support it in the future version soon + self._engine = engine + self._name = name + if self.save_details and not hasattr(engine.state, "metric_details"): + engine.state.metric_details = {} diff --git a/monai/handlers/lr_schedule_handler.py b/monai/handlers/lr_schedule_handler.py index 9fd2f64885..e5593f07ff 100644 --- a/monai/handlers/lr_schedule_handler.py +++ b/monai/handlers/lr_schedule_handler.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/handlers/mean_dice.py b/monai/handlers/mean_dice.py index 71b7bf5503..7decc3ab9b 100644 --- a/monai/handlers/mean_dice.py +++ b/monai/handlers/mean_dice.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,20 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Sequence +from typing import Callable, Optional import torch +from monai.handlers.iteration_metric import IterationMetric from monai.metrics import DiceMetric -from monai.utils import MetricReduction, exact_version, optional_import +from monai.utils import MetricReduction -NotComputableError, _ = optional_import("ignite.exceptions", "0.4.2", exact_version, "NotComputableError") -Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") -reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "reinit__is_reduced") -sync_all_reduce, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "sync_all_reduce") - -class MeanDice(Metric): # type: ignore[valid-type, misc] # due to optional_import +class MeanDice(IterationMetric): """ Computes Dice score metric from full size Tensor and collects average over batch, class-channels, iterations. """ @@ -32,6 +28,7 @@ def __init__( include_background: bool = True, output_transform: Callable = lambda x: x, device: Optional[torch.device] = None, + save_details: bool = True, ) -> None: """ @@ -40,50 +37,19 @@ def __init__( Defaults to True. output_transform: transform the ignite.engine.state.output into [y_pred, y] pair. device: device specification in case of distributed computation usage. + save_details: whether to save metric computation details per image, for example: mean dice of every image. + default to True, will save to `engine.state.metric_details` dict with the metric name as key. See also: :py:meth:`monai.metrics.meandice.compute_meandice` """ - super().__init__(output_transform, device=device) - self.dice = DiceMetric( + metric_fn = DiceMetric( include_background=include_background, - reduction=MetricReduction.MEAN, + reduction=MetricReduction.NONE, + ) + super().__init__( + metric_fn=metric_fn, + output_transform=output_transform, + device=device, + save_details=save_details, ) - self._sum = 0.0 - self._num_examples = 0 - - @reinit__is_reduced - def reset(self) -> None: - self._sum = 0.0 - self._num_examples = 0 - - @reinit__is_reduced - def update(self, output: Sequence[torch.Tensor]) -> None: - """ - Args: - output: sequence with contents [y_pred, y]. - - Raises: - ValueError: When ``output`` length is not 2. MeanDice metric can only support y_pred and y. - - """ - if len(output) != 2: - raise ValueError(f"output must have length 2, got {len(output)}.") - y_pred, y = output - score, not_nans = self.dice(y_pred, y) - not_nans = int(not_nans.item()) - - # add all items in current batch - self._sum += score.item() * not_nans - self._num_examples += not_nans - - @sync_all_reduce("_sum", "_num_examples") - def compute(self) -> float: - """ - Raises: - NotComputableError: When ``compute`` is called before an ``update`` occurs. - - """ - if self._num_examples == 0: - raise NotComputableError("MeanDice must have at least one example before it can be computed.") - return self._sum / self._num_examples diff --git a/monai/handlers/metric_logger.py b/monai/handlers/metric_logger.py index 3198fcce6a..fdd60da57c 100644 --- a/monai/handlers/metric_logger.py +++ b/monai/handlers/metric_logger.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/handlers/metrics_saver.py b/monai/handlers/metrics_saver.py new file mode 100644 index 0000000000..f9deea35df --- /dev/null +++ b/monai/handlers/metrics_saver.py @@ -0,0 +1,137 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Union + +from monai.handlers.utils import write_metrics_reports +from monai.utils import ensure_tuple, exact_version, optional_import +from monai.utils.module import get_torch_version_tuple + +Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") +idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed") +if TYPE_CHECKING: + from ignite.engine import Engine +else: + Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") + + +class MetricsSaver: + """ + ignite handler to save metrics values and details into expected files. + + Args: + save_dir: directory to save the metrics and metric details. + metrics: expected final metrics to save into files, can be: None, "*" or list of strings. + None - don't save any metrics into files. + "*" - save all the existing metrics in `engine.state.metrics` dict into separate files. + list of strings - specify the expected metrics to save. + default to "*" to save all the metrics into `metrics.csv`. + metric_details: expected metric details to save into files, for example: mean dice + of every channel of every image in the validation dataset. + the data in `engine.state.metric_details` must contain at least 2 dims: (batch, classes, ...), + if not, will unsequeeze to 2 dims. + this arg can be: None, "*" or list of strings. + None - don't save any metrics into files. + "*" - save all the existing metrics in `engine.state.metric_details` dict into separate files. + list of strings - specify the expected metrics to save. + if not None, every metric will save a separate `{metric name}_raw.csv` file. + batch_transform: callable function to extract the meta_dict from input batch data if saving metric details. + used to extract filenames from input dict data. + summary_ops: expected computation operations to generate the summary report. + it can be: None, "*" or list of strings. + None - don't generate summary report for every expected metric_details + "*" - generate summary report for every metric_details with all the supported operations. + list of strings - generate summary report for every metric_details with specified operations, they + should be within this list: [`mean`, `median`, `max`, `min`, `90percent`, `std`]. + default to None. + save_rank: only the handler on specified rank will save to files in multi-gpus validation, default to 0. + delimiter: the delimiter charactor in CSV file, default to "\t". + output_type: expected output file type, supported types: ["csv"], default to "csv". + + """ + + def __init__( + self, + save_dir: str, + metrics: Optional[Union[str, Sequence[str]]] = "*", + metric_details: Optional[Union[str, Sequence[str]]] = None, + batch_transform: Callable = lambda x: x, + summary_ops: Optional[Union[str, Sequence[str]]] = None, + save_rank: int = 0, + delimiter: str = "\t", + output_type: str = "csv", + ) -> None: + self.save_dir = save_dir + self.metrics = ensure_tuple(metrics) if metrics is not None else None + self.metric_details = ensure_tuple(metric_details) if metric_details is not None else None + self.batch_transform = batch_transform + self.summary_ops = ensure_tuple(summary_ops) if summary_ops is not None else None + self.save_rank = save_rank + self.deli = delimiter + self.output_type = output_type + self._filenames: List[str] = [] + + def attach(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + engine.add_event_handler(Events.STARTED, self._started) + engine.add_event_handler(Events.ITERATION_COMPLETED, self._get_filenames) + engine.add_event_handler(Events.EPOCH_COMPLETED, self) + + def _started(self, engine: Engine) -> None: + self._filenames = [] + + def _get_filenames(self, engine: Engine) -> None: + if self.metric_details is not None: + _filenames = list(ensure_tuple(self.batch_transform(engine.state.batch)["filename_or_obj"])) + self._filenames += _filenames + + def __call__(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + ws = idist.get_world_size() + if self.save_rank >= ws: + raise ValueError("target rank is greater than the distributed group size.") + + _images = self._filenames + if ws > 1: + _filenames = self.deli.join(_images) + if get_torch_version_tuple() > (1, 6, 0): + # all gather across all processes + _filenames = self.deli.join(idist.all_gather(_filenames)) + else: + raise RuntimeError("MetricsSaver can not save metric details in distributed mode with PyTorch < 1.7.0.") + _images = _filenames.split(self.deli) + + # only save metrics to file in specified rank + if idist.get_rank() == self.save_rank: + _metrics = {} + if self.metrics is not None and len(engine.state.metrics) > 0: + _metrics = {k: v for k, v in engine.state.metrics.items() if k in self.metrics or "*" in self.metrics} + _metric_details = {} + if self.metric_details is not None and len(engine.state.metric_details) > 0: + for k, v in engine.state.metric_details.items(): + if k in self.metric_details or "*" in self.metric_details: + _metric_details[k] = v + + write_metrics_reports( + save_dir=self.save_dir, + images=_images, + metrics=_metrics, + metric_details=_metric_details, + summary_ops=self.summary_ops, + deli=self.deli, + output_type=self.output_type, + ) diff --git a/monai/handlers/roc_auc.py b/monai/handlers/roc_auc.py index 5e24e645bb..dbca70bf25 100644 --- a/monai/handlers/roc_auc.py +++ b/monai/handlers/roc_auc.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index 444768d555..c712ce9a9e 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -38,7 +38,8 @@ def __init__( mode: Union[GridSampleMode, InterpolateMode, str] = "nearest", padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, scale: Optional[int] = None, - dtype: Optional[np.dtype] = None, + dtype: Optional[np.dtype] = np.float64, + output_dtype: Optional[np.dtype] = np.float32, batch_transform: Callable = lambda x: x, output_transform: Callable = lambda x: x, name: Optional[str] = None, @@ -69,8 +70,10 @@ def __init__( scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. It's used for PNG format only. - dtype: convert the image data to save to this data type. - If None, keep the original type of data. It's used for Nifti format only. + dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. + If None, use the data type of input data. To be compatible with other modules, + the output data type is always ``np.float32``, it's used for Nifti format only. + output_dtype: data type for saving data. Defaults to ``np.float32``, it's used for Nifti format only. batch_transform: a callable that is used to transform the ignite.engine.batch into expected format to extract the meta_data dictionary. output_transform: a callable that is used to transform the @@ -90,6 +93,7 @@ def __init__( mode=GridSampleMode(mode), padding_mode=padding_mode, dtype=dtype, + output_dtype=output_dtype, ) elif output_ext == ".png": self.saver = PNGSaver( diff --git a/monai/handlers/smartcache_handler.py b/monai/handlers/smartcache_handler.py index 2c96f00316..423d87c22a 100644 --- a/monai/handlers/smartcache_handler.py +++ b/monai/handlers/smartcache_handler.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index b38d5ade9e..007fbed413 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -27,7 +27,7 @@ DEFAULT_TAG = "Loss" -class StatsHandler(object): +class StatsHandler: """ StatsHandler defines a set of Ignite Event-handlers for all the log printing logics. It's can be used for any Ignite Engine(trainer, validator and evaluator). diff --git a/monai/handlers/surface_distance.py b/monai/handlers/surface_distance.py index b35089423c..d3fa69bfce 100644 --- a/monai/handlers/surface_distance.py +++ b/monai/handlers/surface_distance.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,20 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Sequence +from typing import Callable, Optional import torch +from monai.handlers.iteration_metric import IterationMetric from monai.metrics import SurfaceDistanceMetric -from monai.utils import MetricReduction, exact_version, optional_import +from monai.utils import MetricReduction -NotComputableError, _ = optional_import("ignite.exceptions", "0.4.2", exact_version, "NotComputableError") -Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") -reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "reinit__is_reduced") -sync_all_reduce, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "sync_all_reduce") - -class SurfaceDistance(Metric): # type: ignore[valid-type, misc] # due to optional_import +class SurfaceDistance(IterationMetric): """ Computes surface distance from full size Tensor and collects average over batch, class-channels, iterations. """ @@ -34,6 +30,7 @@ def __init__( distance_metric: str = "euclidean", output_transform: Callable = lambda x: x, device: Optional[torch.device] = None, + save_details: bool = True, ) -> None: """ @@ -46,50 +43,19 @@ def __init__( the metric used to compute surface distance. Defaults to ``"euclidean"``. output_transform: transform the ignite.engine.state.output into [y_pred, y] pair. device: device specification in case of distributed computation usage. + save_details: whether to save metric computation details per image, for example: surface dice + of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. """ - super().__init__(output_transform, device=device) - self.hd = SurfaceDistanceMetric( + metric_fn = SurfaceDistanceMetric( include_background=include_background, symmetric=symmetric, distance_metric=distance_metric, - reduction=MetricReduction.MEAN, + reduction=MetricReduction.NONE, + ) + super().__init__( + metric_fn=metric_fn, + output_transform=output_transform, + device=device, + save_details=save_details, ) - self._sum = 0.0 - self._num_examples = 0 - - @reinit__is_reduced - def reset(self) -> None: - self._sum = 0.0 - self._num_examples = 0 - - @reinit__is_reduced - def update(self, output: Sequence[torch.Tensor]) -> None: - """ - Args: - output: sequence with contents [y_pred, y]. - - Raises: - ValueError: When ``output`` length is not 2. The metric can only support y_pred and y. - - """ - if len(output) != 2: - raise ValueError(f"output must have length 2, got {len(output)}.") - y_pred, y = output - score, not_nans = self.hd(y_pred, y) - not_nans = int(not_nans.item()) - - # add all items in current batch - self._sum += score.item() * not_nans - self._num_examples += not_nans - - @sync_all_reduce("_sum", "_num_examples") - def compute(self) -> float: - """ - Raises: - NotComputableError: When ``compute`` is called before an ``update`` occurs. - - """ - if self._num_examples == 0: - raise NotComputableError("SurfaceDistance must have at least one example before it can be computed.") - return self._sum / self._num_examples diff --git a/monai/handlers/tensorboard_handlers.py b/monai/handlers/tensorboard_handlers.py index a9d7d661ec..15fa6a5eed 100644 --- a/monai/handlers/tensorboard_handlers.py +++ b/monai/handlers/tensorboard_handlers.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -29,7 +29,7 @@ DEFAULT_TAG = "Loss" -class TensorBoardStatsHandler(object): +class TensorBoardStatsHandler: """ TensorBoardStatsHandler defines a set of Ignite Event-handlers for all the TensorBoard logics. It's can be used for any Ignite Engine(trainer, validator and evaluator). @@ -172,9 +172,9 @@ def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter) -> No writer.flush() -class TensorBoardImageHandler(object): +class TensorBoardImageHandler: """ - TensorBoardImageHandler is an Ignite Event handler that can visualise images, labels and outputs as 2D/3D images. + TensorBoardImageHandler is an Ignite Event handler that can visualize images, labels and outputs as 2D/3D images. 2D output (shape in Batch, channel, H, W) will be shown as simple image using the first element in the batch, for 3D to ND output (shape in Batch, channel, H, W, D) input, each of ``self.max_channels`` number of images' last three dimensions will be shown as animated GIF along the last axis (typically Depth). diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index e96521f47e..ef652efe0a 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,19 +9,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Callable +import os +from collections import OrderedDict +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Union +import numpy as np import torch -import torch.distributed as dist -from monai.utils import exact_version, optional_import +from monai.utils import ensure_tuple, exact_version, optional_import +idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed") if TYPE_CHECKING: from ignite.engine import Engine else: Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") -__all__ = ["stopping_fn_from_metric", "stopping_fn_from_loss", "all_gather"] +__all__ = [ + "stopping_fn_from_metric", + "stopping_fn_from_loss", + "evenly_divisible_all_gather", + "write_metrics_reports", +] def stopping_fn_from_metric(metric_name: str) -> Callable[[Engine], Any]: @@ -46,13 +54,113 @@ def stopping_fn(engine: Engine): return stopping_fn -def all_gather(tensor): +def evenly_divisible_all_gather(data: torch.Tensor) -> torch.Tensor: """ - All gather the data of tensor value in distributed data parallel. + Utility function for distributed data parallel to pad at first dim to make it evenly divisible and all_gather. + + Args: + data: source tensor to pad and execute all_gather in distributed data parallel. + """ - if not dist.is_available() or not dist.is_initialized(): - raise RuntimeError("should not execute all_gather operation before torch.distributed is ready.") - # create placeholder to collect the data from all processes - output = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())] - dist.all_gather(output, tensor) - return torch.cat(output, dim=0) + if not torch.is_tensor(data): + raise ValueError("input data must be PyTorch Tensor.") + + if idist.get_world_size() <= 1: + return data + + # make sure the data is evenly-divisible on multi-GPUs + length = data.shape[0] + all_lens = idist.all_gather(length) + max_len = max(all_lens).item() + if length < max_len: + size = [max_len - length] + list(data.shape[1:]) + data = torch.cat([data, data.new_full(size, 0)], dim=0) + # all gather across all processes + data = idist.all_gather(data) + # delete the padding NaN items + return torch.cat([data[i * max_len : i * max_len + l, ...] for i, l in enumerate(all_lens)], dim=0) + + +def write_metrics_reports( + save_dir: str, + images: Optional[Sequence[str]], + metrics: Optional[Dict[str, Union[torch.Tensor, np.ndarray]]], + metric_details: Optional[Dict[str, Union[torch.Tensor, np.ndarray]]], + summary_ops: Optional[Union[str, Sequence[str]]], + deli: str = "\t", + output_type: str = "csv", +): + """ + Utility function to write the metrics into files, contains 3 parts: + 1. if `metrics` dict is not None, write overall metrics into file, every line is a metric name and value pair. + 2. if `metric_details` dict is not None, write raw metric data of every image into file, every line for 1 image. + 3. if `summary_ops` is not None, compute summary based on operations on `metric_details` and write to file. + + Args: + save_dir: directory to save all the metrics reports. + images: name or path of every input image corresponding to the metric_details data. + if None, will use index number as the filename of every input image. + metrics: a dictionary of (metric name, metric value) pairs. + metric_details: a dictionary of (metric name, metric raw values) pairs, + for example, the raw value can be the mean_dice of every channel of every input image. + summary_ops: expected computation operations to generate the summary report. + it can be: None, "*" or list of strings. + None - don't generate summary report for every expected metric_details + "*" - generate summary report for every metric_details with all the supported operations. + list of strings - generate summary report for every metric_details with specified operations, they + should be within this list: [`mean`, `median`, `max`, `min`, `90percent`, `std`]. + default to None. + deli: the delimiter charactor in the file, default to "\t". + output_type: expected output file type, supported types: ["csv"], default to "csv". + + """ + if output_type.lower() != "csv": + raise ValueError(f"unsupported output type: {output_type}.") + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + if metrics is not None and len(metrics) > 0: + with open(os.path.join(save_dir, "metrics.csv"), "w") as f: + for k, v in metrics.items(): + f.write(f"{k}{deli}{str(v)}\n") + + if metric_details is not None and len(metric_details) > 0: + for k, v in metric_details.items(): + if torch.is_tensor(v): + v = v.cpu().numpy() + if v.ndim == 0: + # reshape to [1, 1] if no batch and class dims + v = v.reshape((1, 1)) + elif v.ndim == 1: + # reshape to [N, 1] if no class dim + v = v.reshape((-1, 1)) + + # add the average value of all classes to v + class_labels = ["class" + str(i) for i in range(v.shape[1])] + ["mean"] + v = np.concatenate([v, np.nanmean(v, axis=1, keepdims=True)], axis=1) + + with open(os.path.join(save_dir, f"{k}_raw.csv"), "w") as f: + f.write(f"filename{deli}{deli.join(class_labels)}\n") + for i, b in enumerate(v): + f.write(f"{images[i] if images is not None else str(i)}{deli}{deli.join([str(c) for c in b])}\n") + + if summary_ops is not None: + supported_ops = OrderedDict( + { + "mean": np.nanmean, + "median": np.nanmedian, + "max": np.nanmax, + "min": np.nanmin, + "90percent": lambda x: np.nanpercentile(x, 10), + "std": np.nanstd, + } + ) + ops = ensure_tuple(summary_ops) + if "*" in ops: + ops = tuple(supported_ops.keys()) + + with open(os.path.join(save_dir, f"{k}_summary.csv"), "w") as f: + f.write(f"class{deli}{deli.join(ops)}\n") + for i, c in enumerate(v.transpose()): + f.write(f"{class_labels[i]}{deli}{deli.join([f'{supported_ops[k](c):.4f}' for k in ops])}\n") diff --git a/monai/handlers/validation_handler.py b/monai/handlers/validation_handler.py index 45261c1548..9cc2e926f4 100644 --- a/monai/handlers/validation_handler.py +++ b/monai/handlers/validation_handler.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/inferers/__init__.py b/monai/inferers/__init__.py index 4d9bca6ce4..1cdea77b0f 100644 --- a/monai/inferers/__init__.py +++ b/monai/inferers/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,5 +9,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .inferer import * +from .inferer import Inferer, SimpleInferer, SlidingWindowInferer from .utils import sliding_window_inference diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 36cc3de478..b17afb4e1d 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -24,6 +24,19 @@ class Inferer(ABC): """ A base class for model inference. Extend this class to support operations during inference, e.g. a sliding window method. + + Example code:: + + device = torch.device("cuda:0") + data = ToTensor()(LoadImage()(filename=img_path)).to(device) + model = UNet(...).to(device) + inferer = SlidingWindowInferer(...) + + model.eval() + with torch.no_grad(): + pred = inferer(inputs=data, network=model) + ... + """ @abstractmethod @@ -53,6 +66,7 @@ def __call__( class SimpleInferer(Inferer): """ SimpleInferer is the normal inference method that run model forward() directly. + Usage example can be found in the :py:class:`monai.inferers.Inferer` base class. """ @@ -83,6 +97,7 @@ class SlidingWindowInferer(Inferer): """ Sliding window method for model inference, with `sw_batch_size` windows for every model.forward(). + Usage example can be found in the :py:class:`monai.inferers.Inferer` base class. Args: roi_size: the window size to execute SlidingWindow evaluation. diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index c7db520cb2..85779fc6d1 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -82,7 +82,8 @@ def sliding_window_inference( """ num_spatial_dims = len(inputs.shape) - 2 - assert 0 <= overlap < 1, "overlap must be >= 0 and < 1." + if overlap < 0 or overlap >= 1: + raise AssertionError("overlap must be >= 0 and < 1.") # determine image spatial size and batch size # Note: all input images must have the same image size and batch size diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 7c3ca0cfe1..591fb08f7b 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .deform import BendingEnergyLoss from .dice import ( Dice, + DiceCELoss, DiceLoss, GeneralizedDiceLoss, GeneralizedWassersteinDiceLoss, @@ -20,4 +22,5 @@ generalized_wasserstein_dice, ) from .focal_loss import FocalLoss +from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss from .tversky import TverskyLoss diff --git a/monai/losses/deform.py b/monai/losses/deform.py new file mode 100644 index 0000000000..acba229121 --- /dev/null +++ b/monai/losses/deform.py @@ -0,0 +1,104 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +import torch +from torch.nn.modules.loss import _Loss + +from monai.utils import LossReduction + + +def spatial_gradient(x: torch.Tensor, dim: int) -> torch.Tensor: + """ + Calculate gradients on single dimension of a tensor using central finite difference. + It moves the tensor along the dimension to calculate the approximate gradient + dx[i] = (x[i+1] - x[i-1]) / 2. + Adapted from: + DeepReg (https://github.com/DeepRegNet/DeepReg) + + Args: + x: the shape should be BCH(WD). + dim: dimension to calculate gradient along. + Returns: + gradient_dx: the shape should be BCH(WD) + """ + slice_1 = slice(1, -1) + slice_2_s = slice(2, None) + slice_2_e = slice(None, -2) + slice_all = slice(None) + slicing_s, slicing_e = [slice_all, slice_all], [slice_all, slice_all] + while len(slicing_s) < x.ndim: + slicing_s = slicing_s + [slice_1] + slicing_e = slicing_e + [slice_1] + slicing_s[dim] = slice_2_s + slicing_e[dim] = slice_2_e + return (x[slicing_s] - x[slicing_e]) / 2.0 + + +class BendingEnergyLoss(_Loss): + """ + Calculate the bending energy based on second-order differentiation of pred using central finite difference. + + Adapted from: + DeepReg (https://github.com/DeepRegNet/DeepReg) + """ + + def __init__( + self, + reduction: Union[LossReduction, str] = LossReduction.MEAN, + ) -> None: + """ + Args: + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + """ + super(BendingEnergyLoss, self).__init__(reduction=LossReduction(reduction).value) + + def forward(self, pred: torch.Tensor) -> torch.Tensor: + """ + Args: + pred: the shape should be BCH(WD) + + Raises: + ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + + """ + if pred.ndim not in [3, 4, 5]: + raise ValueError(f"expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}") + for i in range(pred.ndim - 2): + if pred.shape[-i - 1] <= 4: + raise ValueError("all spatial dimensions must > 4, got pred of shape {pred.shape}") + + # first order gradient + first_order_gradient = [spatial_gradient(pred, dim) for dim in range(2, pred.ndim)] + + energy = torch.tensor(0) + for dim_1, g in enumerate(first_order_gradient): + dim_1 += 2 + energy = spatial_gradient(g, dim_1) ** 2 + energy + for dim_2 in range(dim_1 + 1, pred.ndim): + energy = 2 * spatial_gradient(g, dim_2) ** 2 + energy + + if self.reduction == LossReduction.MEAN.value: + energy = torch.mean(energy) # the batch and channel average + elif self.reduction == LossReduction.SUM.value: + energy = torch.sum(energy) # sum over the batch and channel dims + elif self.reduction == LossReduction.NONE.value: + pass # returns [N, n_classes] losses + else: + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + + return energy diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 998ac38a76..f14aa6955f 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import numpy as np import torch +import torch.nn as nn import torch.nn.functional as F from torch.nn.modules.loss import _Loss @@ -134,9 +135,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: target = target[:, 1:] input = input[:, 1:] - assert ( - target.shape == input.shape - ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})" + if target.shape != input.shape: + raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") # reducing only spatial dimensions (not batch nor channels) reduce_axis = list(range(2, len(input.shape))) @@ -191,16 +191,16 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, mask: Optional[torc """ if mask is not None: # checking if mask is of proper shape - assert input.dim() == mask.dim(), f"dim of input ({input.shape}) is different from mask ({mask.shape})" - assert ( - input.shape[0] == mask.shape[0] or mask.shape[0] == 1 - ), f" batch size of mask ({mask.shape}) must be 1 or equal to input ({input.shape})" + if input.dim() != mask.dim(): + raise AssertionError(f"dim of input ({input.shape}) is different from mask ({mask.shape})") + if not (input.shape[0] == mask.shape[0] or mask.shape[0] == 1): + raise AssertionError(f" batch size of mask ({mask.shape}) must be 1 or equal to input ({input.shape})") if target.dim() > 1: - assert mask.shape[1] == 1, f"mask ({mask.shape}) must have only 1 channel" - assert ( - input.shape[2:] == mask.shape[2:] - ), f"spatial size of input ({input.shape}) is different from mask ({mask.shape})" + if mask.shape[1] != 1: + raise AssertionError(f"mask ({mask.shape}) must have only 1 channel") + if input.shape[2:] != mask.shape[2:]: + raise AssertionError(f"spatial size of input ({input.shape}) is different from mask ({mask.shape})") input = input * mask target = target * mask @@ -321,9 +321,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: target = target[:, 1:] input = input[:, 1:] - assert ( - target.shape == input.shape - ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})" + if target.shape != input.shape: + raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") # reducing only spatial dimensions (not batch nor channels) reduce_axis = list(range(2, len(input.shape))) @@ -594,6 +593,113 @@ def _compute_alpha_generalized_true_positives(self, flat_target: torch.Tensor) - return alpha +class DiceCELoss(_Loss): + """ + Compute both Dice loss and Cross Entropy Loss, and return the sum of these two losses. + Input logits `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]). + Axis N of `input` is expected to have logit predictions for each class rather than being image channels, + while the same axis of `target` can be 1 or N (one-hot format). The `smooth_nr` and `smooth_dr` parameters are + values added for dice loss part to the intersection and union components of the inter-over-union calculation + to smooth results respectively, these values should be small. The `include_background` class attribute can be + set to False for an instance of the loss to exclude the first category (channel index 0) which is by convention + assumed to be background. If the non-background segmentations are small compared to the total image size they can get + overwhelmed by the signal from the background so excluding it in such cases helps convergence. + """ + + def __init__( + self, + include_background: bool = True, + to_onehot_y: bool = False, + sigmoid: bool = False, + softmax: bool = False, + other_act: Optional[Callable] = None, + squared_pred: bool = False, + jaccard: bool = False, + reduction: str = "mean", + smooth_nr: float = 1e-5, + smooth_dr: float = 1e-5, + batch: bool = False, + ce_weight: Optional[torch.Tensor] = None, + ) -> None: + """ + Args: + ``ce_weight`` is only used for cross entropy loss, ``reduction`` is used for both losses and other + parameters are only used for dice loss. + + include_background: if False channel index 0 (background category) is excluded from the calculation. + to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. + sigmoid: if True, apply a sigmoid function to the prediction. + softmax: if True, apply a softmax function to the prediction. + other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute + other activation layers, Defaults to ``None``. for example: + `other_act = torch.tanh`. + squared_pred: use squared versions of targets and predictions in the denominator or not. + jaccard: compute Jaccard Index (soft IoU) instead of dice or not. + reduction: {``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. The dice loss should + as least reduce the spatial dimensions, which is different from cross entropy loss, thus here + the ``none`` option cannot be used. + + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + + smooth_nr: a small constant added to the numerator to avoid zero. + smooth_dr: a small constant added to the denominator to avoid nan. + batch: whether to sum the intersection and union areas over the batch dimension before the dividing. + Defaults to False, a Dice loss value is computed independently from each item in the batch + before any `reduction`. + ce_weight: a rescaling weight given to each class for cross entropy loss. + See ``torch.nn.CrossEntropyLoss()`` for more information. + + """ + super().__init__() + self.dice = DiceLoss( + include_background=include_background, + to_onehot_y=to_onehot_y, + sigmoid=sigmoid, + softmax=softmax, + other_act=other_act, + squared_pred=squared_pred, + jaccard=jaccard, + reduction=reduction, + smooth_nr=smooth_nr, + smooth_dr=smooth_dr, + batch=batch, + ) + self.cross_entropy = nn.CrossEntropyLoss( + weight=ce_weight, + reduction=reduction, + ) + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be BNH[WD]. + target: the shape should be BNH[WD] or B1H[WD]. + + Raises: + ValueError: When number of dimensions for input and target are different. + ValueError: When number of channels for target is nither 1 or the same as input. + + """ + if len(input.shape) != len(target.shape): + raise ValueError("the number of dimensions for input and target should be the same.") + + dice_loss = self.dice(input, target) + + n_pred_ch, n_target_ch = input.shape[1], target.shape[1] + if n_pred_ch == n_target_ch: + # target is in the one-hot format, convert to BH[WD] format to calculate ce loss + target = torch.argmax(target, dim=1) + else: + target = torch.squeeze(target, dim=1) + target = target.long() + ce_loss = self.cross_entropy(input, target) + total_loss: torch.Tensor = dice_loss + ce_loss + return total_loss + + dice = Dice = DiceLoss +dice_ce = DiceCELoss generalized_dice = GeneralizedDiceLoss generalized_wasserstein_dice = GeneralizedWassersteinDiceLoss diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index 96fee17201..da7c63e571 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py new file mode 100644 index 0000000000..b229a0c08f --- /dev/null +++ b/monai/losses/image_dissimilarity.py @@ -0,0 +1,244 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple, Union + +import torch +from torch.nn import functional as F +from torch.nn.modules.loss import _Loss + +from monai.networks.layers import gaussian_1d, separable_filtering +from monai.utils import LossReduction + + +def make_rectangular_kernel(kernel_size: int) -> torch.Tensor: + return torch.ones(kernel_size) + + +def make_triangular_kernel(kernel_size: int) -> torch.Tensor: + fsize = (kernel_size + 1) // 2 + if fsize % 2 == 0: + fsize -= 1 + f = torch.ones((1, 1, fsize), dtype=torch.float).div(fsize) + padding = (kernel_size - fsize) // 2 + fsize // 2 + return F.conv1d(f, f, padding=padding).reshape(-1) + + +def make_gaussian_kernel(kernel_size: int) -> torch.Tensor: + sigma = torch.tensor(kernel_size / 3.0) + kernel = gaussian_1d(sigma=sigma, truncated=kernel_size // 2, approx="sampled", normalize=False) * ( + 2.5066282 * sigma + ) + return kernel[:kernel_size] + + +kernel_dict = { + "rectangular": make_rectangular_kernel, + "triangular": make_triangular_kernel, + "gaussian": make_gaussian_kernel, +} + + +class LocalNormalizedCrossCorrelationLoss(_Loss): + """ + Local squared zero-normalized cross-correlation. + The loss is based on a moving kernel/window over the y_true/y_pred, + within the window the square of zncc is calculated. + The kernel can be a rectangular / triangular / gaussian window. + The final loss is the averaged loss over all windows. + + Adapted from: + https://github.com/voxelmorph/voxelmorph/blob/legacy/src/losses.py + DeepReg (https://github.com/DeepRegNet/DeepReg) + """ + + def __init__( + self, + in_channels: int, + ndim: int = 3, + kernel_size: int = 3, + kernel_type: str = "rectangular", + reduction: Union[LossReduction, str] = LossReduction.MEAN, + smooth_nr: float = 1e-7, + smooth_dr: float = 1e-7, + ) -> None: + """ + Args: + in_channels: number of input channels + ndim: number of spatial ndimensions, {``1``, ``2``, ``3``}. Defaults to 3. + kernel_size: kernel spatial size, must be odd. + kernel_type: {``"rectangular"``, ``"triangular"``, ``"gaussian"``}. Defaults to ``"rectangular"``. + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + smooth_nr: a small constant added to the numerator to avoid nan. + smooth_dr: a small constant added to the denominator to avoid nan. + """ + super(LocalNormalizedCrossCorrelationLoss, self).__init__(reduction=LossReduction(reduction).value) + self.in_channels = in_channels + + self.ndim = ndim + if self.ndim not in [1, 2, 3]: + raise ValueError(f"Unsupported ndim: {self.ndim}-d, only 1-d, 2-d, and 3-d inputs are supported") + + self.kernel_size = kernel_size + if self.kernel_size % 2 == 0: + raise ValueError(f"kernel_size must be odd, got {self.kernel_size}") + + if kernel_type not in kernel_dict.keys(): + raise ValueError( + f'Unsupported kernel_type: {kernel_type}, available options are ["rectangular", "triangular", "gaussian"].' + ) + self.kernel = kernel_dict[kernel_type](self.kernel_size) + self.kernel_vol = self.get_kernel_vol() + + self.smooth_nr = float(smooth_nr) + self.smooth_dr = float(smooth_dr) + + def get_kernel_vol(self): + vol = self.kernel + for _ in range(self.ndim - 1): + vol = torch.matmul(vol.unsqueeze(-1), self.kernel.unsqueeze(0)) + return torch.sum(vol) + + def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + pred: the shape should be BNH[WD]. + target: the shape should be BNH[WD]. + Raises: + ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + """ + if pred.shape[1] != self.in_channels: + raise ValueError(f"expecting pred with {self.in_channels} channels, got pred of shape {pred.shape}") + if pred.ndim - 2 != self.ndim: + raise ValueError(f"expecting pred with {self.ndim} spatial dimensions, got pred of shape {pred.shape}") + if target.shape != pred.shape: + raise ValueError(f"ground truth has differing shape ({target.shape}) from pred ({pred.shape})") + + t2, p2, tp = target ** 2, pred ** 2, target * pred + kernel, kernel_vol = self.kernel.to(pred), self.kernel_vol.to(pred) + # sum over kernel + t_sum = separable_filtering(target, kernels=[kernel] * self.ndim) + p_sum = separable_filtering(pred, kernels=[kernel] * self.ndim) + t2_sum = separable_filtering(t2, kernels=[kernel] * self.ndim) + p2_sum = separable_filtering(p2, kernels=[kernel] * self.ndim) + tp_sum = separable_filtering(tp, kernels=[kernel] * self.ndim) + + # average over kernel + t_avg = t_sum / kernel_vol + p_avg = p_sum / kernel_vol + + # normalized cross correlation between t and p + # sum[(t - mean[t]) * (p - mean[p])] / std[t] / std[p] + # denoted by num / denom + # assume we sum over N values + # num = sum[t * p - mean[t] * p - t * mean[p] + mean[t] * mean[p]] + # = sum[t*p] - sum[t] * sum[p] / N * 2 + sum[t] * sum[p] / N + # = sum[t*p] - sum[t] * sum[p] / N + # = sum[t*p] - sum[t] * mean[p] = cross + # the following is actually squared ncc + cross = tp_sum - p_avg * t_sum + t_var = t2_sum - t_avg * t_sum # std[t] ** 2 + p_var = p2_sum - p_avg * p_sum # std[p] ** 2 + ncc: torch.Tensor = (cross * cross + self.smooth_nr) / (t_var * p_var + self.smooth_dr) + # shape = (batch, 1, D, H, W) + + if self.reduction == LossReduction.SUM.value: + return torch.sum(ncc).neg() # sum over the batch, channel and spatial ndims + if self.reduction == LossReduction.NONE.value: + return ncc.neg() + if self.reduction == LossReduction.MEAN.value: + return torch.mean(ncc).neg() # average over the batch, channel and spatial ndims + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + + +class GlobalMutualInformationLoss(_Loss): + """ + Differentiable global mutual information loss via Parzen windowing method. + + Reference: + https://dspace.mit.edu/handle/1721.1/123142, Section 3.1, equation 3.1-3.5, Algorithm 1 + """ + + def __init__( + self, + num_bins: int = 23, + sigma_ratio: float = 0.5, + reduction: Union[LossReduction, str] = LossReduction.MEAN, + smooth_nr: float = 1e-7, + smooth_dr: float = 1e-7, + ) -> None: + """ + Args: + num_bins: number of bins for intensity + sigma_ratio: a hyper param for gaussian function + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + smooth_nr: a small constant added to the numerator to avoid nan. + smooth_dr: a small constant added to the denominator to avoid nan. + """ + super(GlobalMutualInformationLoss, self).__init__(reduction=LossReduction(reduction).value) + if num_bins <= 0: + raise ValueError("num_bins must > 0, got {num_bins}") + bin_centers = torch.linspace(0.0, 1.0, num_bins) # (num_bins,) + sigma = torch.mean(bin_centers[1:] - bin_centers[:-1]) * sigma_ratio + self.preterm = 1 / (2 * sigma ** 2) + self.bin_centers = bin_centers[None, None, ...] + self.smooth_nr = float(smooth_nr) + self.smooth_dr = float(smooth_dr) + + def parzen_windowing(self, pred: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + pred: the shape should be B[NDHW]. + """ + pred = torch.clamp(pred, 0, 1) + pred = pred.reshape(pred.shape[0], -1, 1) # (batch, num_sample, 1) + weight = torch.exp( + -self.preterm.to(pred) * (pred - self.bin_centers.to(pred)) ** 2 + ) # (batch, num_sample, num_bin) + weight = weight / torch.sum(weight, dim=-1, keepdim=True) # (batch, num_sample, num_bin) + probability = torch.mean(weight, dim=-2, keepdim=True) # (batch, 1, num_bin) + return weight, probability + + def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + pred: the shape should be B[NDHW]. + target: the shape should be same as the pred shape. + Raises: + ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + """ + if target.shape != pred.shape: + raise ValueError(f"ground truth has differing shape ({target.shape}) from pred ({pred.shape})") + wa, pa = self.parzen_windowing(pred) # (batch, num_sample, num_bin), (batch, 1, num_bin) + wb, pb = self.parzen_windowing(target) # (batch, num_sample, num_bin), (batch, 1, num_bin) + pab = torch.bmm(wa.permute(0, 2, 1), wb).div(wa.shape[1]) # (batch, num_bins, num_bins) + + papb = torch.bmm(pa.permute(0, 2, 1), pb) # (batch, num_bins, num_bins) + mi = torch.sum( + pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr), dim=(1, 2) + ) # (batch) + + if self.reduction == LossReduction.SUM.value: + return torch.sum(mi).neg() # sum over the batch and channel ndims + if self.reduction == LossReduction.NONE.value: + return mi.neg() + if self.reduction == LossReduction.MEAN.value: + return torch.mean(mi).neg() # average over the batch and channel ndims + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') diff --git a/monai/losses/tversky.py b/monai/losses/tversky.py index 62b937d680..b1c45a74a2 100644 --- a/monai/losses/tversky.py +++ b/monai/losses/tversky.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -130,9 +130,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: target = target[:, 1:] input = input[:, 1:] - assert ( - target.shape == input.shape - ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})" + if target.shape != input.shape: + raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") p0 = input p1 = 1 - p0 diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index a0d626f45b..818413c30d 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,9 +10,8 @@ # limitations under the License. from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix -from .hausdorff_distance import * +from .hausdorff_distance import HausdorffDistanceMetric, compute_hausdorff_distance, compute_percent_hausdorff_distance from .meandice import DiceMetric, compute_meandice -from .occlusion_sensitivity import compute_occlusion_sensitivity from .rocauc import compute_roc_auc from .surface_distance import SurfaceDistanceMetric, compute_average_surface_distance -from .utils import * +from .utils import do_metric_reduction, get_mask_edges, get_surface_distance, ignore_background diff --git a/monai/metrics/confusion_matrix.py b/monai/metrics/confusion_matrix.py index 916a07439f..a0c840d45a 100644 --- a/monai/metrics/confusion_matrix.py +++ b/monai/metrics/confusion_matrix.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,7 +14,7 @@ import torch -from monai.metrics.utils import * +from monai.metrics.utils import do_metric_reduction, ignore_background from monai.utils import MetricReduction @@ -87,7 +87,7 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): dims = y_pred.ndimension() if dims < 2: raise ValueError("y_pred should have at least two dimensions.") - elif dims == 2 or (dims == 3 and y_pred.shape[-1] == 1): + if dims == 2 or (dims == 3 and y_pred.shape[-1] == 1): if self.compute_sample: warnings.warn("As for classification task, compute_sample should be False.") self.compute_sample = False @@ -103,16 +103,15 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): confusion_matrix = compute_confusion_matrix_metric(self.metric_name, confusion_matrix) f, not_nans = do_metric_reduction(confusion_matrix, self.reduction) return f, not_nans - else: - if len(self.metric_name) < 1: - raise ValueError("the sequence should at least has on metric name.") - results = [] - for metric_name in self.metric_name: - sub_confusion_matrix = compute_confusion_matrix_metric(metric_name, confusion_matrix) - f, not_nans = do_metric_reduction(sub_confusion_matrix, self.reduction) - results.append(f) - results.append(not_nans) - return results + if len(self.metric_name) < 1: + raise ValueError("the sequence should at least has on metric name.") + results = [] + for metric_name in self.metric_name: + sub_confusion_matrix = compute_confusion_matrix_metric(metric_name, confusion_matrix) + f, not_nans = do_metric_reduction(sub_confusion_matrix, self.reduction) + results.append(f) + results.append(not_nans) + return results else: return confusion_matrix @@ -264,8 +263,7 @@ def compute_confusion_matrix_metric(metric_name: str, confusion_matrix: torch.Te if isinstance(denominator, torch.Tensor): return torch.where(denominator != 0, numerator / denominator, nan_tensor) - else: - return numerator / denominator + return numerator / denominator def check_confusion_matrix_metric_name(metric_name: str): @@ -284,37 +282,36 @@ def check_confusion_matrix_metric_name(metric_name: str): metric_name = metric_name.lower() if metric_name in ["sensitivity", "recall", "hit_rate", "true_positive_rate", "tpr"]: return "tpr" - elif metric_name in ["specificity", "selectivity", "true_negative_rate", "tnr"]: + if metric_name in ["specificity", "selectivity", "true_negative_rate", "tnr"]: return "tnr" - elif metric_name in ["precision", "positive_predictive_value", "ppv"]: + if metric_name in ["precision", "positive_predictive_value", "ppv"]: return "ppv" - elif metric_name in ["negative_predictive_value", "npv"]: + if metric_name in ["negative_predictive_value", "npv"]: return "npv" - elif metric_name in ["miss_rate", "false_negative_rate", "fnr"]: + if metric_name in ["miss_rate", "false_negative_rate", "fnr"]: return "fnr" - elif metric_name in ["fall_out", "false_positive_rate", "fpr"]: + if metric_name in ["fall_out", "false_positive_rate", "fpr"]: return "fpr" - elif metric_name in ["false_discovery_rate", "fdr"]: + if metric_name in ["false_discovery_rate", "fdr"]: return "fdr" - elif metric_name in ["false_omission_rate", "for"]: + if metric_name in ["false_omission_rate", "for"]: return "for" - elif metric_name in ["prevalence_threshold", "pt"]: + if metric_name in ["prevalence_threshold", "pt"]: return "pt" - elif metric_name in ["threat_score", "critical_success_index", "ts", "csi"]: + if metric_name in ["threat_score", "critical_success_index", "ts", "csi"]: return "ts" - elif metric_name in ["accuracy", "acc"]: + if metric_name in ["accuracy", "acc"]: return "acc" - elif metric_name in ["balanced_accuracy", "ba"]: + if metric_name in ["balanced_accuracy", "ba"]: return "ba" - elif metric_name in ["f1_score", "f1"]: + if metric_name in ["f1_score", "f1"]: return "f1" - elif metric_name in ["matthews_correlation_coefficient", "mcc"]: + if metric_name in ["matthews_correlation_coefficient", "mcc"]: return "mcc" - elif metric_name in ["fowlkes_mallows_index", "fm"]: + if metric_name in ["fowlkes_mallows_index", "fm"]: return "fm" - elif metric_name in ["informedness", "bookmaker_informedness", "bm"]: + if metric_name in ["informedness", "bookmaker_informedness", "bm"]: return "bm" - elif metric_name in ["markedness", "deltap", "mk"]: + if metric_name in ["markedness", "deltap", "mk"]: return "mk" - else: - raise NotImplementedError("the metric is not implemented.") + raise NotImplementedError("the metric is not implemented.") diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index c649cd3a04..8ecc19ec46 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,7 +15,7 @@ import numpy as np import torch -from monai.metrics.utils import * +from monai.metrics.utils import do_metric_reduction, get_mask_edges, get_surface_distance, ignore_background from monai.utils import MetricReduction __all__ = ["HausdorffDistanceMetric", "compute_hausdorff_distance", "compute_percent_hausdorff_distance"] @@ -166,7 +166,6 @@ def compute_percent_hausdorff_distance( if not percentile: return surface_distance.max() - elif 0 <= percentile <= 100: + if 0 <= percentile <= 100: return np.percentile(surface_distance, percentile) - else: - raise ValueError(f"percentile should be a value between 0 and 100, get {percentile}.") + raise ValueError(f"percentile should be a value between 0 and 100, get {percentile}.") diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index 53716909fe..9d27fff56f 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,7 +14,7 @@ import torch -from monai.metrics.utils import * +from monai.metrics.utils import do_metric_reduction, ignore_background from monai.utils import MetricReduction diff --git a/monai/metrics/occlusion_sensitivity.py b/monai/metrics/occlusion_sensitivity.py deleted file mode 100644 index 9879f472a9..0000000000 --- a/monai/metrics/occlusion_sensitivity.py +++ /dev/null @@ -1,225 +0,0 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections.abc import Sequence -from functools import partial -from typing import Optional, Union - -import numpy as np -import torch -import torch.nn as nn - -try: - from tqdm import trange - - trange = partial(trange, desc="Computing occlusion sensitivity") -except (ImportError, AttributeError): - trange = range - - -def _check_input_image(image): - """Check that the input image is as expected.""" - # Only accept batch size of 1 - if image.shape[0] > 1: - raise RuntimeError("Expected batch size of 1.") - return image - - -def _check_input_label(label, image): - """Check that the input label is as expected.""" - # If necessary turn the label into a 1-element tensor - if isinstance(label, int): - label = torch.tensor([[label]], dtype=torch.int64).to(image.device) - # If the label is a tensor, make sure there's only 1 element - elif label.numel() != image.shape[0]: - raise RuntimeError("Expected as many labels as batches.") - return label - - -def _check_input_bounding_box(b_box, im_shape): - """Check that the bounding box (if supplied) is as expected.""" - # If no bounding box has been supplied, set min and max to None - if b_box is None: - b_box_min = b_box_max = None - - # Bounding box has been supplied - else: - # Should be twice as many elements in `b_box` as `im_shape` - if len(b_box) != 2 * len(im_shape): - raise ValueError("Bounding box should contain upper and lower for all dimensions (except batch number)") - - # If any min's or max's are -ve, set them to 0 and im_shape-1, respectively. - b_box_min = np.array(b_box[::2]) - b_box_max = np.array(b_box[1::2]) - b_box_min[b_box_min < 0] = 0 - b_box_max[b_box_max < 0] = im_shape[b_box_max < 0] - 1 - # Check all max's are < im_shape - if np.any(b_box_max >= im_shape): - raise ValueError("Max bounding box should be < image size for all values") - # Check all min's are <= max's - if np.any(b_box_min > b_box_max): - raise ValueError("Min bounding box should be <= max for all values") - - return b_box_min, b_box_max - - -def _append_to_sensitivity_im(model, batch_images, batch_ids, sensitivity_im): - """For given number of images, get probability of predicting - a given label. Append to previous evaluations.""" - batch_images = torch.cat(batch_images, dim=0) - batch_ids = torch.LongTensor(batch_ids).unsqueeze(1).to(sensitivity_im.device) - scores = model(batch_images).detach().gather(1, batch_ids) - return torch.cat((sensitivity_im, scores)) - - -def compute_occlusion_sensitivity( - model: nn.Module, - image: torch.Tensor, - label: Union[int, torch.Tensor], - pad_val: float = 0.0, - margin: Union[int, Sequence] = 2, - n_batch: int = 128, - b_box: Optional[Sequence] = None, - stride: Union[int, Sequence] = 1, - upsample_mode: str = "nearest", -) -> np.ndarray: - """ - This function computes the occlusion sensitivity for a model's prediction - of a given image. By occlusion sensitivity, we mean how the probability of a given - prediction changes as the occluded section of an image changes. This can - be useful to understand why a network is making certain decisions. - - The result is given as ``baseline`` (the probability of - a certain output) minus the probability of the output with the occluded - area. - - Therefore, higher values in the output image mean there was a - greater the drop in certainty, indicating the occluded region was more - important in the decision process. - - See: R. R. Selvaraju et al. Grad-CAM: Visual Explanations from Deep Networks via - Gradient-based Localization. https://doi.org/10.1109/ICCV.2017.74 - - Args: - model: classification model to use for inference - image: image to test. Should be tensor consisting of 1 batch, can be 2- or 3D. - label: classification label to check for changes (normally the true - label, but doesn't have to be) - pad_val: when occluding part of the image, which values should we put - in the image? - margin: we'll create a cuboid/cube around the voxel to be occluded. if - ``margin==2``, then we'll create a cube that is +/- 2 voxels in - all directions (i.e., a cube of 5 x 5 x 5 voxels). A ``Sequence`` - can be supplied to have a margin of different sizes (i.e., create - a cuboid). - n_batch: number of images in a batch before inference. - b_box: Bounding box on which to perform the analysis. The output image - will also match in size. There should be a minimum and maximum for - all dimensions except batch: ``[min1, max1, min2, max2,...]``. - * By default, the whole image will be used. Decreasing the size will - speed the analysis up, which might be useful for larger images. - * Min and max are inclusive, so [0, 63, ...] will have size (64, ...). - * Use -ve to use 0 for min values and im.shape[x]-1 for xth dimension. - stride: Stride for performing occlusions. Can be single value or sequence - (for varying stride in the different directions). Should be >= 1. - upsample_mode: If stride != 1 is used, we'll upsample such that the size - of the voxels in the output image match the input. Upsampling is done with - ``torch.nn.Upsample``, and mode can be set to: - * ``nearest``, ``linear``, ``bilinear``, ``bicubic`` and ``trilinear`` - * default is ``nearest``. - Returns: - Numpy array. If no bounding box is supplied, this will be the same size - as the input image. If a bounding box is used, the output image will be - cropped to this size. - """ - - # Check input arguments - image = _check_input_image(image) - label = _check_input_label(label, image) - im_shape = np.array(image.shape[1:]) - b_box_min, b_box_max = _check_input_bounding_box(b_box, im_shape) - - # Get baseline probability - baseline = model(image).detach()[0, label].item() - - # Create some lists - batch_images = [] - batch_ids = [] - - sensitivity_im = torch.empty(0, dtype=torch.float32, device=image.device) - - # If no bounding box supplied, output shape is same as input shape. - # If bounding box is present, shape is max - min + 1 - output_im_shape = im_shape if b_box is None else b_box_max - b_box_min + 1 - - # Calculate the downsampled shape - if not isinstance(stride, Sequence): - stride_np = np.full_like(im_shape, stride, dtype=np.int32) - stride_np[0] = 1 # always do stride 1 in channel dimension - else: - # Convert to numpy array and check dimensions match - stride_np = np.array(stride, dtype=np.int32) - if stride_np.size != im_shape.size: - raise ValueError("Sizes of image shape and stride should match.") - - # Obviously if stride = 1, downsampled_im_shape == output_im_shape - downsampled_im_shape = np.floor(output_im_shape / stride_np).astype(np.int32) - downsampled_im_shape[downsampled_im_shape == 0] = 1 # make sure dimension sizes are >= 1 - num_required_predictions = np.prod(downsampled_im_shape) - - # Loop 1D over image - for i in trange(num_required_predictions): - # Get corresponding ND index - idx = np.unravel_index(i, downsampled_im_shape) - # Multiply by stride - idx *= stride_np - # If a bounding box is being used, we need to add on - # the min to shift to start of region of interest - if b_box_min is not None: - idx += b_box_min - - # Get min and max index of box to occlude - min_idx = [max(0, i - margin) for i in idx] - max_idx = [min(j, i + margin) for i, j in zip(idx, im_shape)] - - # Clone and replace target area with `pad_val` - occlu_im = image.clone() - occlu_im[(...,) + tuple(slice(i, j) for i, j in zip(min_idx, max_idx))] = pad_val - - # Add to list - batch_images.append(occlu_im) - batch_ids.append(label) - - # Once the batch is complete (or on last iteration) - if len(batch_images) == n_batch or i == num_required_predictions - 1: - # Do the predictions and append to sensitivity map - sensitivity_im = _append_to_sensitivity_im(model, batch_images, batch_ids, sensitivity_im) - # Clear lists - batch_images = [] - batch_ids = [] - - # Subtract from baseline - sensitivity_im = baseline - sensitivity_im - - # Reshape to match downsampled image - sensitivity_im = sensitivity_im.reshape(tuple(downsampled_im_shape)) - - # If necessary, upsample - if np.any(stride_np != 1): - output_im_shape = tuple(output_im_shape[1:]) # needs to be given as 3D tuple - upsampler = nn.Upsample(size=output_im_shape, mode=upsample_mode) - sensitivity_im = upsampler(sensitivity_im.unsqueeze(0)) - - # Convert tensor to numpy - sensitivity_im = sensitivity_im.cpu().numpy() - - # Squeeze and return - return np.squeeze(sensitivity_im) diff --git a/monai/metrics/rocauc.py b/monai/metrics/rocauc.py index 7b26560d57..9f081d1698 100644 --- a/monai/metrics/rocauc.py +++ b/monai/metrics/rocauc.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,12 +20,10 @@ def _calculate(y: torch.Tensor, y_pred: torch.Tensor) -> float: - assert y.ndimension() == y_pred.ndimension() == 1 and len(y) == len( - y_pred - ), "y and y_pred must be 1 dimension data with same length." - assert y.unique().equal( - torch.tensor([0, 1], dtype=y.dtype, device=y.device) - ), "y values must be 0 or 1, can not be all 0 or all 1." + if not (y.ndimension() == y_pred.ndimension() == 1 and len(y) == len(y_pred)): + raise AssertionError("y and y_pred must be 1 dimension data with same length.") + if not y.unique().equal(torch.tensor([0, 1], dtype=y.dtype, device=y.device)): + raise AssertionError("y values must be 0 or 1, can not be all 0 or all 1.") n = len(y) indices = y_pred.argsort() y = y[indices].cpu().numpy() @@ -114,33 +112,31 @@ def compute_roc_auc( if softmax: warnings.warn("y_pred has only one channel, softmax=True ignored.") return _calculate(y, y_pred) - else: - n_classes = y_pred.shape[1] - if to_onehot_y: - y = one_hot(y, n_classes) - if softmax and other_act is not None: - raise ValueError("Incompatible values: softmax=True and other_act is not None.") - if softmax: - y_pred = y_pred.float().softmax(dim=1) - if other_act is not None: - if not callable(other_act): - raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.") - y_pred = other_act(y_pred) - - assert y.shape == y_pred.shape, "data shapes of y_pred and y do not match." - - average = Average(average) - if average == Average.MICRO: - return _calculate(y.flatten(), y_pred.flatten()) - y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1) - auc_values = [_calculate(y_, y_pred_) for y_, y_pred_ in zip(y, y_pred)] - if average == Average.NONE: - return auc_values - if average == Average.MACRO: - return np.mean(auc_values) - if average == Average.WEIGHTED: - weights = [sum(y_) for y_ in y] - return np.average(auc_values, weights=weights) - raise ValueError( - f'Unsupported average: {average}, available options are ["macro", "weighted", "micro", "none"].' - ) + n_classes = y_pred.shape[1] + if to_onehot_y: + y = one_hot(y, n_classes) + if softmax and other_act is not None: + raise ValueError("Incompatible values: softmax=True and other_act is not None.") + if softmax: + y_pred = y_pred.float().softmax(dim=1) + if other_act is not None: + if not callable(other_act): + raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.") + y_pred = other_act(y_pred) + + if y.shape != y_pred.shape: + raise AssertionError("data shapes of y_pred and y do not match.") + + average = Average(average) + if average == Average.MICRO: + return _calculate(y.flatten(), y_pred.flatten()) + y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1) + auc_values = [_calculate(y_, y_pred_) for y_, y_pred_ in zip(y, y_pred)] + if average == Average.NONE: + return auc_values + if average == Average.MACRO: + return np.mean(auc_values) + if average == Average.WEIGHTED: + weights = [sum(y_) for y_ in y] + return np.average(auc_values, weights=weights) + raise ValueError(f'Unsupported average: {average}, available options are ["macro", "weighted", "micro", "none"].') diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index 8dcbe4d9f6..9e2f130bd2 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,7 +15,7 @@ import numpy as np import torch -from monai.metrics.utils import * +from monai.metrics.utils import do_metric_reduction, get_mask_edges, get_surface_distance, ignore_background from monai.utils import MetricReduction diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index ffe6093621..cc7049ff81 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -53,7 +53,7 @@ def do_metric_reduction( f: a tensor that contains the calculated metric scores per batch and per class. The first two dims should be batch and class. reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``} + ``"mean_channel"``, ``"sum_channel"``}, if "none", return the input f tensor and not_nans. Define the mode to reduce computation result of 1 batch data. Defaults to ``"mean"``. Raises: @@ -65,11 +65,13 @@ def do_metric_reduction( # we need to account for it nans = torch.isnan(f) not_nans = (~nans).float() - f[nans] = 0 t_zero = torch.zeros(1, device=f.device, dtype=f.dtype) reduction = MetricReduction(reduction) + if reduction == MetricReduction.NONE: + return f, not_nans + f[nans] = 0 if reduction == MetricReduction.MEAN: # 2 steps, first, mean by channel (accounting for nans), then by batch not_nans = not_nans.sum(dim=1) diff --git a/monai/networks/__init__.py b/monai/networks/__init__.py index 2858595b09..3c0a68def2 100644 --- a/monai/networks/__init__.py +++ b/monai/networks/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,4 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .utils import * +from .utils import ( + eval_mode, + icnr_init, + normal_init, + normalize_transform, + one_hot, + pixelshuffle, + predict_segmentation, + slice_channels, + to_norm_affine, + train_mode, +) diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index 0710d40145..8ac06f8776 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,6 +16,7 @@ from .downsample import MaxAvgPool from .dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock, get_output_padding, get_padding from .fcn import FCN, GCN, MCFCN, Refine +from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock from .segresnet_block import ResBlock from .squeeze_and_excitation import ( ChannelSELayer, @@ -25,4 +26,5 @@ SEResNetBottleneck, SEResNeXtBottleneck, ) -from .upsample import * +from .upsample import SubpixelUpsample, Subpixelupsample, SubpixelUpSample, Upsample, UpSample +from .warp import Warp diff --git a/monai/networks/blocks/acti_norm.py b/monai/networks/blocks/acti_norm.py index 585726edf2..53ef212209 100644 --- a/monai/networks/blocks/acti_norm.py +++ b/monai/networks/blocks/acti_norm.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -80,7 +80,7 @@ def __init__( super().__init__() op_dict = {"A": None, "D": None, "N": None} - # define the normalisation type and the arguments to the constructor + # define the normalization type and the arguments to the constructor if norm is not None: if norm_dim is None and dropout_dim is None: raise ValueError("norm_dim or dropout_dim needs to be specified.") diff --git a/monai/networks/blocks/activation.py b/monai/networks/blocks/activation.py index 359105f8b8..ef6c74f282 100644 --- a/monai/networks/blocks/activation.py +++ b/monai/networks/blocks/activation.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/networks/blocks/aspp.py b/monai/networks/blocks/aspp.py index 041ecd94b1..d995d64796 100644 --- a/monai/networks/blocks/aspp.py +++ b/monai/networks/blocks/aspp.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/networks/blocks/convolutions.py b/monai/networks/blocks/convolutions.py index eafe028a06..7bfb3b47e4 100644 --- a/monai/networks/blocks/convolutions.py +++ b/monai/networks/blocks/convolutions.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -94,6 +94,7 @@ def __init__( padding = same_padding(kernel_size, dilation) conv_type = Conv[Conv.CONVTRANS if is_transposed else Conv.CONV, dimensions] + conv: nn.Module if is_transposed: if output_padding is None: output_padding = stride_minus_kernel_padding(1, strides) diff --git a/monai/networks/blocks/downsample.py b/monai/networks/blocks/downsample.py index adcbec2850..975c2e15bb 100644 --- a/monai/networks/blocks/downsample.py +++ b/monai/networks/blocks/downsample.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/networks/blocks/dynunet_block.py b/monai/networks/blocks/dynunet_block.py index ba9d71b610..577fd4d71d 100644 --- a/monai/networks/blocks/dynunet_block.py +++ b/monai/networks/blocks/dynunet_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -228,13 +228,13 @@ def get_acti_layer(act: Union[Tuple[str, Dict], str]): def get_norm_layer(spatial_dims: int, out_channels: int, norm_name: str, num_groups: int = 16): if norm_name not in ["batch", "instance", "group"]: raise ValueError(f"Unsupported normalization mode: {norm_name}") + if norm_name == "group": + if out_channels % num_groups != 0: + raise AssertionError("out_channels should be divisible by num_groups.") + norm = Norm[norm_name](num_groups=num_groups, num_channels=out_channels, affine=True) else: - if norm_name == "group": - assert out_channels % num_groups == 0, "out_channels should be divisible by num_groups." - norm = Norm[norm_name](num_groups=num_groups, num_channels=out_channels, affine=True) - else: - norm = Norm[norm_name, spatial_dims](out_channels, affine=True) - return norm + norm = Norm[norm_name, spatial_dims](out_channels, affine=True) + return norm def get_conv_layer( @@ -277,8 +277,8 @@ def get_padding( kernel_size_np = np.atleast_1d(kernel_size) stride_np = np.atleast_1d(stride) padding_np = (kernel_size_np - stride_np + 1) / 2 - error_msg = "padding value should not be negative, please change the kernel size and/or stride." - assert np.min(padding_np) >= 0, error_msg + if np.min(padding_np) < 0: + raise AssertionError("padding value should not be negative, please change the kernel size and/or stride.") padding = tuple(int(p) for p in padding_np) return padding if len(padding) > 1 else padding[0] @@ -294,8 +294,8 @@ def get_output_padding( padding_np = np.atleast_1d(padding) out_padding_np = 2 * padding_np + stride_np - kernel_size_np - error_msg = "out_padding value should not be negative, please change the kernel size and/or stride." - assert np.min(out_padding_np) >= 0, error_msg + if np.min(out_padding_np) < 0: + raise AssertionError("out_padding value should not be negative, please change the kernel size and/or stride.") out_padding = tuple(int(p) for p in out_padding_np) return out_padding if len(out_padding) > 1 else out_padding[0] diff --git a/monai/networks/blocks/fcn.py b/monai/networks/blocks/fcn.py index 9587519d3d..c7cd7cca30 100644 --- a/monai/networks/blocks/fcn.py +++ b/monai/networks/blocks/fcn.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/networks/blocks/localnet_block.py b/monai/networks/blocks/localnet_block.py new file mode 100644 index 0000000000..4166c08774 --- /dev/null +++ b/monai/networks/blocks/localnet_block.py @@ -0,0 +1,320 @@ +from typing import Optional, Sequence, Tuple, Type, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from monai.networks.blocks import Convolution +from monai.networks.layers import same_padding +from monai.networks.layers.factories import Conv, Norm, Pool + + +def get_conv_block( + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int] = 3, + act: Optional[Union[Tuple, str]] = "RELU", + norm: Optional[Union[Tuple, str]] = "BATCH", +) -> nn.Module: + padding = same_padding(kernel_size) + return Convolution( + spatial_dims, + in_channels, + out_channels, + kernel_size=kernel_size, + act=act, + norm=norm, + bias=False, + conv_only=False, + padding=padding, + ) + + +def get_conv_layer( + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int] = 3, +) -> nn.Module: + padding = same_padding(kernel_size) + return Convolution( + spatial_dims, + in_channels, + out_channels, + kernel_size=kernel_size, + bias=False, + conv_only=True, + padding=padding, + ) + + +def get_deconv_block( + spatial_dims: int, + in_channels: int, + out_channels: int, +) -> nn.Module: + return Convolution( + dimensions=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + strides=2, + act="RELU", + norm="BATCH", + bias=False, + is_transposed=True, + padding=1, + output_padding=1, + ) + + +class ResidualBlock(nn.Module): + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + ) -> None: + super(ResidualBlock, self).__init__() + if in_channels != out_channels: + raise ValueError( + f"expecting in_channels == out_channels, " f"got in_channels={in_channels}, out_channels={out_channels}" + ) + self.conv_block = get_conv_block( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + ) + self.conv = get_conv_layer( + spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size + ) + self.norm = Norm[Norm.BATCH, spatial_dims](out_channels) + self.relu = nn.ReLU() + + def forward(self, x) -> torch.Tensor: + out: torch.Tensor = self.relu(self.norm(self.conv(self.conv_block(x))) + x) + return out + + +class LocalNetResidualBlock(nn.Module): + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + ) -> None: + super(LocalNetResidualBlock, self).__init__() + if in_channels != out_channels: + raise ValueError( + f"expecting in_channels == out_channels, " f"got in_channels={in_channels}, out_channels={out_channels}" + ) + self.conv_layer = get_conv_layer( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + ) + self.norm = Norm[Norm.BATCH, spatial_dims](out_channels) + self.relu = nn.ReLU() + + def forward(self, x, mid) -> torch.Tensor: + out: torch.Tensor = self.relu(self.norm(self.conv_layer(x)) + mid) + return out + + +class LocalNetDownSampleBlock(nn.Module): + """ + A down-sample module that can be used for LocalNet, based on: + `Weakly-supervised convolutional neural networks for multimodal image registration + `_. + `Label-driven weakly-supervised learning for multimodal deformable image registration + `_. + + Adapted from: + DeepReg (https://github.com/DeepRegNet/DeepReg) + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + ) -> None: + """ + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + kernel_size: convolution kernel size. + Raises: + NotImplementedError: when ``kernel_size`` is even + """ + super(LocalNetDownSampleBlock, self).__init__() + self.conv_block = get_conv_block( + spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size + ) + self.residual_block = ResidualBlock( + spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size + ) + self.max_pool = Pool[Pool.MAX, spatial_dims]( + kernel_size=2, + ) + + def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Halves the spatial dimensions. + A tuple of (x, mid) is returned: + + - x is the downsample result, in shape (batch, ``out_channels``, insize_1 / 2, insize_2 / 2, [insize_3 / 2]), + - mid is the mid-level feature, in shape (batch, ``out_channels``, insize_1, insize_2, [insize_3]) + + Args: + x: Tensor in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3]) + + Raises: + ValueError: when input spatial dimensions are not even. + """ + for i in x.shape[2:]: + if i % 2 != 0: + raise ValueError("expecting x spatial dimensions be even, " f"got x of shape {x.shape}") + x = self.conv_block(x) + mid = self.residual_block(x) + x = self.max_pool(mid) + return x, mid + + +class LocalNetUpSampleBlock(nn.Module): + """ + A up-sample module that can be used for LocalNet, based on: + `Weakly-supervised convolutional neural networks for multimodal image registration + `_. + `Label-driven weakly-supervised learning for multimodal deformable image registration + `_. + + Adapted from: + DeepReg (https://github.com/DeepRegNet/DeepReg) + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + ) -> None: + """ + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + Raises: + ValueError: when ``in_channels != 2 * out_channels`` + """ + super(LocalNetUpSampleBlock, self).__init__() + self.deconv_block = get_deconv_block( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + ) + self.conv_block = get_conv_block( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + ) + self.residual_block = LocalNetResidualBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + ) + if in_channels / out_channels != 2: + raise ValueError( + f"expecting in_channels == 2 * out_channels, " + f"got in_channels={in_channels}, out_channels={out_channels}" + ) + self.out_channels = out_channels + + def addictive_upsampling(self, x, mid) -> torch.Tensor: + x = F.interpolate(x, mid.shape[2:]) + # [(batch, out_channels, ...), (batch, out_channels, ...)] + x = x.split(split_size=int(self.out_channels), dim=1) + # (batch, out_channels, ...) + out: torch.Tensor = torch.sum(torch.stack(x, dim=-1), dim=-1) + return out + + def forward(self, x, mid) -> torch.Tensor: + """ + Halves the channel and doubles the spatial dimensions. + + Args: + x: feature to be up-sampled, in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3]) + mid: mid-level feature saved during down-sampling, + in shape (batch, ``out_channels``, midsize_1, midsize_2, [midnsize_3]) + + Raises: + ValueError: when ``midsize != insize * 2`` + """ + for i, j in zip(x.shape[2:], mid.shape[2:]): + if j != 2 * i: + raise ValueError( + "expecting mid spatial dimensions be exactly the double of x spatial dimensions, " + f"got x of shape {x.shape}, mid of shape {mid.shape}" + ) + h0 = self.deconv_block(x) + self.addictive_upsampling(x, mid) + r1 = h0 + mid + r2 = self.conv_block(h0) + out: torch.Tensor = self.residual_block(r2, r1) + return out + + +class LocalNetFeatureExtractorBlock(nn.Module): + """ + A feature-extraction module that can be used for LocalNet, based on: + `Weakly-supervised convolutional neural networks for multimodal image registration + `_. + `Label-driven weakly-supervised learning for multimodal deformable image registration + `_. + + Adapted from: + DeepReg (https://github.com/DeepRegNet/DeepReg) + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + act: Optional[Union[Tuple, str]] = "RELU", + initializer: str = "kaiming_uniform", + ) -> None: + """ + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + act: activation type and arguments. Defaults to ReLU. + kernel_initializer: kernel initializer. Defaults to None. + """ + super(LocalNetFeatureExtractorBlock, self).__init__() + self.conv_block = get_conv_block( + spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, act=act, norm=None + ) + conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims] + for m in self.conv_block.modules(): + if isinstance(m, conv_type): + if initializer == "kaiming_uniform": + nn.init.kaiming_normal_(torch.as_tensor(m.weight)) + elif initializer == "zeros": + nn.init.zeros_(torch.as_tensor(m.weight)) + else: + raise ValueError( + f"initializer {initializer} is not supported, " "currently supporting kaiming_uniform and zeros" + ) + + def forward(self, x) -> torch.Tensor: + """ + Args: + x: Tensor in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3]) + """ + out: torch.Tensor = self.conv_block(x) + return out diff --git a/monai/networks/blocks/segresnet_block.py b/monai/networks/blocks/segresnet_block.py index b7c1e68b75..e95466ca7e 100644 --- a/monai/networks/blocks/segresnet_block.py +++ b/monai/networks/blocks/segresnet_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,16 +22,15 @@ def get_norm_layer(spatial_dims: int, in_channels: int, norm_name: str, num_groups: int = 8): if norm_name not in ["batch", "instance", "group"]: raise ValueError(f"Unsupported normalization mode: {norm_name}") + if norm_name == "group": + norm = Norm[norm_name](num_groups=num_groups, num_channels=in_channels) else: - if norm_name == "group": - norm = Norm[norm_name](num_groups=num_groups, num_channels=in_channels) - else: - norm = Norm[norm_name, spatial_dims](in_channels) - if norm.bias is not None: - nn.init.zeros_(norm.bias) - if norm.weight is not None: - nn.init.ones_(norm.weight) - return norm + norm = Norm[norm_name, spatial_dims](in_channels) + if norm.bias is not None: + nn.init.zeros_(norm.bias) + if norm.weight is not None: + nn.init.ones_(norm.weight) + return norm def get_conv_layer( @@ -91,8 +90,10 @@ def __init__( super().__init__() - assert kernel_size % 2 == 1, "kernel_size should be an odd number." - assert in_channels % num_groups == 0, "in_channels should be divisible by num_groups." + if kernel_size % 2 != 1: + raise AssertionError("kernel_size should be an odd number.") + if in_channels % num_groups != 0: + raise AssertionError("in_channels should be divisible by num_groups.") self.norm1 = get_norm_layer(spatial_dims, in_channels, norm_name, num_groups=num_groups) self.norm2 = get_norm_layer(spatial_dims, in_channels, norm_name, num_groups=num_groups) diff --git a/monai/networks/blocks/squeeze_and_excitation.py b/monai/networks/blocks/squeeze_and_excitation.py index e1533d454d..4db6dc30f7 100644 --- a/monai/networks/blocks/squeeze_and_excitation.py +++ b/monai/networks/blocks/squeeze_and_excitation.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/networks/blocks/upsample.py b/monai/networks/blocks/upsample.py index c2093086fd..db85b8bd27 100644 --- a/monai/networks/blocks/upsample.py +++ b/monai/networks/blocks/upsample.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py new file mode 100644 index 0000000000..60e23f6750 --- /dev/null +++ b/monai/networks/blocks/warp.py @@ -0,0 +1,113 @@ +from typing import List, Optional, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from monai.utils import GridSamplePadMode + + +class Warp(nn.Module): + """ + Warp an image with given DDF. + """ + + def __init__( + self, + spatial_dims: int, + mode: int = 1, + padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.ZEROS, + ): + """ + Args: + spatial_dims: {2, 3}. number of spatial dimensions + mode: interpolation mode to calculate output values, defaults to 1. + Possible values are:: + + - 0 or 'nearest' or InterpolationType.nearest + - 1 or 'linear' or InterpolationType.linear + - 2 or 'quadratic' or InterpolationType.quadratic + - 3 or 'cubic' or InterpolationType.cubic + - 4 or 'fourth' or InterpolationType.fourth + - etc. + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + """ + super(Warp, self).__init__() + if spatial_dims not in [2, 3]: + raise ValueError(f"got unsupported spatial_dims={spatial_dims}, only support 2-d and 3-d input") + self.spatial_dims = spatial_dims + if mode < 0: + raise ValueError(f"do not support negative mode, got mode={mode}") + self.mode = mode + self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) + + @staticmethod + def get_reference_grid(ddf: torch.Tensor) -> torch.Tensor: + mesh_points = [torch.arange(0, dim) for dim in ddf.shape[2:]] + grid = torch.stack(torch.meshgrid(*mesh_points), dim=0) # (spatial_dims, ...) + grid = torch.stack([grid] * ddf.shape[0], dim=0) # (batch, spatial_dims, ...) + grid = grid.to(ddf) + return grid + + @staticmethod + def normalize_grid(grid: torch.Tensor) -> torch.Tensor: + # (batch, ..., self.spatial_dims) + for i, dim in enumerate(grid.shape[1:-1]): + grid[..., i] = grid[..., i] * 2 / (dim - 1) - 1 + return grid + + def forward(self, image: torch.Tensor, ddf: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Tensor in shape (batch, num_channels, H, W[, D]) + ddf: Tensor in the same spatial size as image, in shape (batch, spatial_dims, H, W[, D]) + + Returns: + warped_image in the same shape as image (batch, num_channels, H, W[, D]) + """ + if len(image.shape) != 2 + self.spatial_dims: + raise ValueError(f"expecting {self.spatial_dims + 2}-d input, " f"got input in shape {image.shape}") + if len(ddf.shape) != 2 + self.spatial_dims or ddf.shape[1] != self.spatial_dims: + raise ValueError( + f"expecting {self.spatial_dims + 2}-d ddf with {self.spatial_dims} channels, " + f"got ddf in shape {ddf.shape}" + ) + if image.shape[0] != ddf.shape[0] or image.shape[2:] != ddf.shape[2:]: + raise ValueError( + "expecting image and ddf of same batch size and spatial size, " + f"got image of shape {image.shape}, ddf of shape {ddf.shape}" + ) + + grid = self.get_reference_grid(ddf) + ddf + grid = grid.permute([0] + list(range(2, 2 + self.spatial_dims)) + [1]) # (batch, ..., self.spatial_dims) + + if self.mode > 1: + raise ValueError(f"{self.mode}-order interpolation not yet implemented.") + # if not USE_COMPILED: + # raise ValueError(f"cannot perform {self.mode}-order interpolation without C compile.") + # _padding_mode = self.padding_mode.value + # if _padding_mode == "zeros": + # bound = 7 + # elif _padding_mode == "border": + # bound = 0 + # else: + # bound = 1 + # warped_image: torch.Tensor = grid_pull( + # image, + # grid, + # bound=bound, + # extrapolate=True, + # interpolation=self.mode, + # ) + else: + grid = self.normalize_grid(grid) + index_ordering: List[int] = list(range(self.spatial_dims - 1, -1, -1)) + grid = grid[..., index_ordering] # z, y, x -> x, y, z + _interp_mode = "bilinear" if self.mode == 1 else "nearest" + warped_image = F.grid_sample( + image, grid, mode=_interp_mode, padding_mode=self.padding_mode.value, align_corners=True + ) + + return warped_image diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index 9125dc38cf..ba61774a96 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,7 +9,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .convutils import * -from .factories import * -from .simplelayers import * -from .spatial_transforms import * +from .convutils import calculate_out_shape, gaussian_1d, polyval, same_padding, stride_minus_kernel_padding +from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, split_args +from .filtering import BilateralFilter, PHLFilter +from .simplelayers import ( + LLTM, + ChannelPad, + Flatten, + GaussianFilter, + HilbertTransform, + Reshape, + SavitzkyGolayFilter, + SkipConnection, + separable_filtering, +) +from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push diff --git a/monai/networks/layers/convutils.py b/monai/networks/layers/convutils.py index b73a26bdca..c4f798699c 100644 --- a/monai/networks/layers/convutils.py +++ b/monai/networks/layers/convutils.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 1bb33ed9d7..ec36b2ed95 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,7 +16,7 @@ is typically a type but can be any callable producing a layer object. The factory objects contain functions keyed to names converted to upper case, these names can be referred to as members -of the factory so that they can function as constant identifiers. eg. instance normalisation is named `Norm.INSTANCE`. +of the factory so that they can function as constant identifiers. eg. instance normalization is named `Norm.INSTANCE`. For example, to get a transpose convolution layer the name is needed and then a dimension argument is provided which is passed to the factory function: @@ -178,14 +178,13 @@ def split_args(args): if isinstance(args, str): return args, {} - else: - name_obj, name_args = args + name_obj, name_args = args - if not isinstance(name_obj, (str, Callable)) or not isinstance(name_args, dict): - msg = "Layer specifiers must be single strings or pairs of the form (name/object-types, argument dict)" - raise TypeError(msg) + if not isinstance(name_obj, (str, Callable)) or not isinstance(name_args, dict): + msg = "Layer specifiers must be single strings or pairs of the form (name/object-types, argument dict)" + raise TypeError(msg) - return name_obj, name_args + return name_obj, name_args # Define factories for these layer types diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py new file mode 100644 index 0000000000..83a33bc609 --- /dev/null +++ b/monai/networks/layers/filtering.py @@ -0,0 +1,98 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from monai.utils.module import optional_import + +_C, _ = optional_import("monai._C") + +__all__ = ["BilateralFilter", "PHLFilter"] + + +class BilateralFilter(torch.autograd.Function): + """ + Blurs the input tensor spatially whilst preserving edges. Can run on 1D, 2D, or 3D, + tensors (on top of Batch and Channel dimensions). Two implementations are provided, + an exact solution and a much faster approximation which uses a permutohedral lattice. + + See: + https://en.wikipedia.org/wiki/Bilateral_filter + https://graphics.stanford.edu/papers/permutohedral/ + + Args: + input: input tensor. + + spatial sigma: the standard deviation of the spatial blur. Higher values can + hurt performace when not using the approximate method (see fast approx). + + color sigma: the standard deviation of the color blur. Lower values preserve + edges better whilst higher values tend to a simple gaussian spatial blur. + + fast approx: This flag chooses between two implementations. The approximate method may + produce artifacts in some scenarios whereas the exact solution may be intolerably + slow for high spatial standard deviations. + + Returns: + output (torch.Tensor): output tensor. + """ + + @staticmethod + def forward(ctx, input, spatial_sigma=5, color_sigma=0.5, fast_approx=True): + ctx.save_for_backward(spatial_sigma, color_sigma, fast_approx) + output_data = _C.bilateral_filter(input, spatial_sigma, color_sigma, fast_approx) + return output_data + + @staticmethod + def backward(ctx, grad_output): + spatial_sigma, color_sigma, fast_approx = ctx.saved_variables + grad_input = _C.bilateral_filter(grad_output, spatial_sigma, color_sigma, fast_approx) + return grad_input + + +class PHLFilter(torch.autograd.Function): + """ + Filters input based on arbitrary feature vectors. Uses a permutohedral + lattice data structure to efficiently approximate n-dimensional gaussian + filtering. Complexity is broadly independant of kernel size. Most applicable + to higher filter dimensions and larger kernel sizes. + + See: + https://graphics.stanford.edu/papers/permutohedral/ + + Args: + input: input tensor to be filtered. + + features: feature tensor used to filter the input. + + sigmas: the standard deviations of each feature in the filter. + + Returns: + output (torch.Tensor): output tensor. + """ + + @staticmethod + def forward(ctx, input, features, sigmas=None): + + scaled_features = features + if sigmas is not None: + for i in range(features.size(1)): + scaled_features[:, i, ...] /= sigmas[i] + + ctx.save_for_backward(scaled_features) + output_data = _C.phl_filter(input, scaled_features) + return output_data + + @staticmethod + def backward(ctx, grad_output): + scaled_features = ctx.saved_variables + grad_input = PHLFilter.scale(grad_output, scaled_features) + return grad_input diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index 48012dfb1c..285b0d629f 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -39,6 +39,7 @@ "LLTM", "Reshape", "separable_filtering", + "SavitzkyGolayFilter", "HilbertTransform", "ChannelPad", ] @@ -163,7 +164,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x.reshape(shape) -def separable_filtering(x: torch.Tensor, kernels: Union[Sequence[torch.Tensor], torch.Tensor]) -> torch.Tensor: +def separable_filtering( + x: torch.Tensor, kernels: Union[Sequence[torch.Tensor], torch.Tensor], mode: str = "zeros" +) -> torch.Tensor: """ Apply 1-D convolutions along each spatial dimension of `x`. @@ -171,10 +174,14 @@ def separable_filtering(x: torch.Tensor, kernels: Union[Sequence[torch.Tensor], x: the input image. must have shape (batch, channels, H[, W, ...]). kernels: kernel along each spatial dimension. could be a single kernel (duplicated for all dimension), or `spatial_dims` number of kernels. + mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` + or ``'circular'``. Default: ``'zeros'``. Modes other than ``'zeros'`` require PyTorch version >= 1.5.1. See + torch.nn.Conv1d() for more information. Raises: TypeError: When ``x`` is not a ``torch.Tensor``. """ + if not torch.is_tensor(x): raise TypeError(f"x must be a torch.Tensor but is {type(x).__name__}.") @@ -184,7 +191,7 @@ def separable_filtering(x: torch.Tensor, kernels: Union[Sequence[torch.Tensor], for s in ensure_tuple_rep(kernels, spatial_dims) ] _paddings = [cast(int, (same_padding(k.shape[0]))) for k in _kernels] - n_chns = x.shape[1] + n_chs = x.shape[1] def _conv(input_: torch.Tensor, d: int) -> torch.Tensor: if d < 0: @@ -192,15 +199,95 @@ def _conv(input_: torch.Tensor, d: int) -> torch.Tensor: s = [1] * len(input_.shape) s[d + 2] = -1 _kernel = kernels[d].reshape(s) - _kernel = _kernel.repeat([n_chns, 1] + [1] * spatial_dims) + # if filter kernel is unity, don't convolve + if _kernel.numel() == 1 and _kernel[0] == 1: + return _conv(input_, d - 1) + _kernel = _kernel.repeat([n_chs, 1] + [1] * spatial_dims) _padding = [0] * spatial_dims _padding[d] = _paddings[d] conv_type = [F.conv1d, F.conv2d, F.conv3d][spatial_dims - 1] - return conv_type(input=_conv(input_, d - 1), weight=_kernel, padding=_padding, groups=n_chns) + # translate padding for input to torch.nn.functional.pad + _reversed_padding_repeated_twice = [p for p in reversed(_padding) for _ in range(2)] + pad_mode = "constant" if mode == "zeros" else mode + return conv_type( + input=_conv(F.pad(input_, _reversed_padding_repeated_twice, mode=pad_mode), d - 1), + weight=_kernel, + groups=n_chs, + ) return _conv(x, spatial_dims - 1) +class SavitzkyGolayFilter(nn.Module): + """ + Convolve a Tensor along a particular axis with a Savitzky-Golay kernel. + + Args: + window_length: Length of the filter window, must be a positive odd integer. + order: Order of the polynomial to fit to each window, must be less than ``window_length``. + axis (optional): Axis along which to apply the filter kernel. Default 2 (first spatial dimension). + mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` or + ``'circular'``. Default: ``'zeros'``. See torch.nn.Conv1d() for more information. + """ + + def __init__(self, window_length: int, order: int, axis: int = 2, mode: str = "zeros"): + + super().__init__() + if order >= window_length: + raise ValueError("order must be less than window_length.") + + self.axis = axis + self.mode = mode + self.coeffs = self._make_coeffs(window_length, order) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: Tensor or array-like to filter. Must be real, in shape ``[Batch, chns, spatial1, spatial2, ...]`` and + have a device type of ``'cpu'``. + Returns: + torch.Tensor: ``x`` filtered by Savitzky-Golay kernel with window length ``self.window_length`` using + polynomials of order ``self.order``, along axis specified in ``self.axis``. + """ + + # Make input a real tensor on the CPU + x = torch.as_tensor(x, device=x.device if torch.is_tensor(x) else None) + if torch.is_complex(x): + raise ValueError("x must be real.") + else: + x = x.to(dtype=torch.float) + + if (self.axis < 0) or (self.axis > len(x.shape) - 1): + raise ValueError("Invalid axis for shape of x.") + + # Create list of filter kernels (1 per spatial dimension). The kernel for self.axis will be the savgol coeffs, + # while the other kernels will be set to [1]. + n_spatial_dims = len(x.shape) - 2 + spatial_processing_axis = self.axis - 2 + new_dims_before = spatial_processing_axis + new_dims_after = n_spatial_dims - spatial_processing_axis - 1 + kernel_list = [self.coeffs.to(device=x.device, dtype=x.dtype)] + for _ in range(new_dims_before): + kernel_list.insert(0, torch.ones(1, device=x.device, dtype=x.dtype)) + for _ in range(new_dims_after): + kernel_list.append(torch.ones(1, device=x.device, dtype=x.dtype)) + + return separable_filtering(x, kernel_list, mode=self.mode) + + @staticmethod + def _make_coeffs(window_length, order): + + half_length, rem = divmod(window_length, 2) + if rem == 0: + raise ValueError("window_length must be odd.") + + idx = torch.arange(window_length - half_length - 1, -half_length - 1, -1, dtype=torch.float, device="cpu") + a = idx ** torch.arange(order + 1, dtype=torch.float, device="cpu").reshape(-1, 1) + y = torch.zeros(order + 1, dtype=torch.float, device="cpu") + y[0] = 1.0 + return torch.lstsq(y, a).solution.squeeze() + + class HilbertTransform(nn.Module): """ Determine the analytical signal of a Tensor along a particular axis. @@ -233,8 +320,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = torch.as_tensor(x, device=x.device if torch.is_tensor(x) else None) if torch.is_complex(x): raise ValueError("x must be real.") - else: - x = x.to(dtype=torch.float) + x = x.to(dtype=torch.float) if (self.axis < 0) or (self.axis > len(x.shape) - 1): raise ValueError("Invalid axis for shape of x.") diff --git a/monai/networks/layers/spatial_transforms.py b/monai/networks/layers/spatial_transforms.py index a64b6d2d0a..c0f22502c8 100644 --- a/monai/networks/layers/spatial_transforms.py +++ b/monai/networks/layers/spatial_transforms.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -518,7 +518,7 @@ def forward( if spatial_size is not None: dst_size = src_size[:2] + ensure_tuple(spatial_size) - # reverse and normalise theta if needed + # reverse and normalize theta if needed if not self.normalized: theta = to_norm_affine( affine=theta, src_size=src_size[2:], dst_size=dst_size[2:], align_corners=self.align_corners diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index aa31360c4e..a9308de9d7 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,16 +11,17 @@ from .ahnet import AHNet from .autoencoder import AutoEncoder -from .basic_unet import * -from .classifier import * +from .basic_unet import BasicUNet, BasicUnet, Basicunet +from .classifier import Classifier, Critic, Discriminator from .densenet import DenseNet, densenet121, densenet169, densenet201, densenet264 -from .dynunet import * -from .fullyconnectednet import * +from .dynunet import DynUNet, DynUnet, Dynunet +from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet from .generator import Generator from .highresnet import HighResBlock, HighResNet +from .localnet import LocalNet from .regressor import Regressor from .segresnet import SegResNet, SegResNetVAE from .senet import SENet, se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4d, se_resnext101_32x4d, senet154 -from .unet import * +from .unet import UNet, Unet, unet from .varautoencoder import VarAutoEncoder from .vnet import VNet diff --git a/monai/networks/nets/ahnet.py b/monai/networks/nets/ahnet.py index 0bf98385a0..847993bd44 100644 --- a/monai/networks/nets/ahnet.py +++ b/monai/networks/nets/ahnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -371,9 +371,12 @@ def __init__( self.pool_type = pool_type self.spatial_dims = spatial_dims self.psp_block_num = psp_block_num + self.psp = None - assert spatial_dims in [2, 3], "spatial_dims can only be 2 or 3." - assert psp_block_num in [0, 1, 2, 3, 4], "psp_block_num should be an integer that belongs to [0, 4]." + if spatial_dims not in [2, 3]: + raise AssertionError("spatial_dims can only be 2 or 3.") + if psp_block_num not in [0, 1, 2, 3, 4]: + raise AssertionError("psp_block_num should be an integer that belongs to [0, 4].") self.conv1 = conv_type( in_channels, @@ -508,7 +511,7 @@ def forward(self, x): sum4 = self.up3(d3) + conv_x d4 = self.dense4(sum4) - if self.psp_block_num > 0: + if self.psp_block_num > 0 and self.psp is not None: psp = self.psp(d4) x = torch.cat((psp, d4), dim=1) else: diff --git a/monai/networks/nets/autoencoder.py b/monai/networks/nets/autoencoder.py index 8d0aadafd6..53e96b0841 100644 --- a/monai/networks/nets/autoencoder.py +++ b/monai/networks/nets/autoencoder.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -50,7 +50,7 @@ def __init__( self.norm = norm self.dropout = dropout self.num_inter_units = num_inter_units - self.inter_channels = inter_channels if inter_channels is not None else list() + self.inter_channels = inter_channels if inter_channels is not None else [] self.inter_dilations = list(inter_dilations or [1] * len(self.inter_channels)) # The number of channels and strides should match @@ -148,18 +148,17 @@ def _get_encode_layer(self, in_channels: int, out_channels: int, strides: int, i dropout=self.dropout, last_conv_only=is_last, ) - else: - return Convolution( - dimensions=self.dimensions, - in_channels=in_channels, - out_channels=out_channels, - strides=strides, - kernel_size=self.kernel_size, - act=self.act, - norm=self.norm, - dropout=self.dropout, - conv_only=is_last, - ) + return Convolution( + dimensions=self.dimensions, + in_channels=in_channels, + out_channels=out_channels, + strides=strides, + kernel_size=self.kernel_size, + act=self.act, + norm=self.norm, + dropout=self.dropout, + conv_only=is_last, + ) def _get_decode_layer(self, in_channels: int, out_channels: int, strides: int, is_last: bool) -> nn.Sequential: diff --git a/monai/networks/nets/basic_unet.py b/monai/networks/nets/basic_unet.py index ebac0273a9..7a4b0bb8f1 100644 --- a/monai/networks/nets/basic_unet.py +++ b/monai/networks/nets/basic_unet.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/networks/nets/classifier.py b/monai/networks/nets/classifier.py index 6c42cbe96a..92fee4f566 100644 --- a/monai/networks/nets/classifier.py +++ b/monai/networks/nets/classifier.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,6 +17,8 @@ from monai.networks.layers.factories import Act, Norm, split_args from monai.networks.nets.regressor import Regressor +__all__ = ["Classifier", "Discriminator", "Critic"] + class Classifier(Regressor): """ diff --git a/monai/networks/nets/densenet.py b/monai/networks/nets/densenet.py index 9fe33d7ccd..ad1d1d6e5f 100644 --- a/monai/networks/nets/densenet.py +++ b/monai/networks/nets/densenet.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index 0915785db6..ba88c35f8d 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,9 +12,10 @@ from typing import List, Optional, Sequence, Union +import torch import torch.nn as nn -from monai.networks.blocks.dynunet_block import * +from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock __all__ = ["DynUNet", "DynUnet", "Dynunet"] @@ -79,10 +80,6 @@ class DynUNet(nn.Module): upsample_kernel_size: convolution kernel size for transposed convolution layers. norm_name: [``"batch"``, ``"instance"``, ``"group"``] feature normalization type and arguments. - deep_supervision: whether to add deep supervision head before output. Defaults to ``True``. - If added, in training mode, the network will output not only the last feature maps - (after being converted via output block), but also the previous feature maps that come - from the intermediate up sample layers. deep_supr_num: number of feature maps that will output during deep supervision head. The value should be less than the number of up sample layers. Defaults to 1. res_block: whether to use residual connection based convolution blocks during the network. @@ -98,7 +95,6 @@ def __init__( strides: Sequence[Union[Sequence[int], int]], upsample_kernel_size: Sequence[Union[Sequence[int], int]], norm_name: str = "instance", - deep_supervision: bool = True, deep_supr_num: int = 1, res_block: bool = False, ): @@ -110,7 +106,6 @@ def __init__( self.strides = strides self.upsample_kernel_size = upsample_kernel_size self.norm_name = norm_name - self.deep_supervision = deep_supervision self.conv_block = UnetResBlock if res_block else UnetBasicBlock self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))] self.input_block = self.get_input_block() @@ -136,15 +131,15 @@ def create_skips(index, downsamples, upsamples, superheads, bottleneck): shouldn't be associated with a supervision head. """ - assert len(downsamples) == len(upsamples), f"{len(downsamples)} != {len(upsamples)}" - assert (len(downsamples) - len(superheads)) in (1, 0), f"{len(downsamples)}-(0,1) != {len(superheads)}" + if len(downsamples) != len(upsamples): + raise AssertionError(f"{len(downsamples)} != {len(upsamples)}") + if (len(downsamples) - len(superheads)) not in (1, 0): + raise AssertionError(f"{len(downsamples)}-(0,1) != {len(superheads)}") if len(downsamples) == 0: # bottom of the network, pass the bottleneck block return bottleneck - elif index == 0: # don't associate a supervision head with self.input_block + if index == 0: # don't associate a supervision head with self.input_block current_head, rest_heads = nn.Identity(), superheads - elif not self.deep_supervision: # bypass supervision heads by passing nn.Identity in place of a real one - current_head, rest_heads = nn.Identity(), superheads[1:] else: current_head, rest_heads = superheads[0], superheads[1:] @@ -164,31 +159,36 @@ def create_skips(index, downsamples, upsamples, superheads, bottleneck): def check_kernel_stride(self): kernels, strides = self.kernel_size, self.strides error_msg = "length of kernel_size and strides should be the same, and no less than 3." - assert len(kernels) == len(strides) and len(kernels) >= 3, error_msg + if not (len(kernels) == len(strides) and len(kernels) >= 3): + raise AssertionError(error_msg) for idx in range(len(kernels)): kernel, stride = kernels[idx], strides[idx] if not isinstance(kernel, int): error_msg = "length of kernel_size in block {} should be the same as spatial_dims.".format(idx) - assert len(kernel) == self.spatial_dims, error_msg + if len(kernel) != self.spatial_dims: + raise AssertionError(error_msg) if not isinstance(stride, int): error_msg = "length of stride in block {} should be the same as spatial_dims.".format(idx) - assert len(stride) == self.spatial_dims, error_msg + if len(stride) != self.spatial_dims: + raise AssertionError(error_msg) def check_deep_supr_num(self): deep_supr_num, strides = self.deep_supr_num, self.strides num_up_layers = len(strides) - 1 - error_msg = "deep_supr_num should be less than the number of up sample layers." - assert 1 <= deep_supr_num < num_up_layers, error_msg + if deep_supr_num < 1 or deep_supr_num >= num_up_layers: + raise AssertionError("deep_supr_num should be less than the number of up sample layers.") def forward(self, x): out = self.skip_layers(x) - out = self.output_block(out) + return self.output_block(out) - if self.training and self.deep_supervision: - return [out] + self.heads[1 : self.deep_supr_num + 1] + def get_feature_maps(self): + """ + Return the feature maps. - return [out] + """ + return self.heads[1 : self.deep_supr_num + 1] def get_input_block(self): return self.conv_block( diff --git a/monai/networks/nets/fullyconnectednet.py b/monai/networks/nets/fullyconnectednet.py index c792ec4c22..b906bab015 100644 --- a/monai/networks/nets/fullyconnectednet.py +++ b/monai/networks/nets/fullyconnectednet.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,14 +17,15 @@ from monai.networks.blocks import ADN from monai.networks.layers.factories import Act +__all__ = ["FullyConnectedNet", "VarFullyConnectedNet"] + def _get_adn_layer( act: Optional[Union[Tuple, str]], dropout: Optional[Union[Tuple, str, float]], ordering: Optional[str] ) -> ADN: if ordering: return ADN(act=act, dropout=dropout, dropout_dim=1, ordering=ordering) - else: - return ADN(act=act, dropout=dropout, dropout_dim=1) + return ADN(act=act, dropout=dropout, dropout_dim=1) class FullyConnectedNet(nn.Sequential): diff --git a/monai/networks/nets/generator.py b/monai/networks/nets/generator.py index a8b0a1390d..1f24944a63 100644 --- a/monai/networks/nets/generator.py +++ b/monai/networks/nets/generator.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/networks/nets/highresnet.py b/monai/networks/nets/highresnet.py index 918b5b5349..5d9c3d1df6 100644 --- a/monai/networks/nets/highresnet.py +++ b/monai/networks/nets/highresnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/networks/nets/localnet.py b/monai/networks/nets/localnet.py new file mode 100644 index 0000000000..ea8abca185 --- /dev/null +++ b/monai/networks/nets/localnet.py @@ -0,0 +1,129 @@ +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from monai.networks.blocks.localnet_block import ( + LocalNetDownSampleBlock, + LocalNetFeatureExtractorBlock, + LocalNetUpSampleBlock, + get_conv_block, +) + + +class LocalNet(nn.Module): + """ + Reimplementation of LocalNet, based on: + `Weakly-supervised convolutional neural networks for multimodal image registration + `_. + `Label-driven weakly-supervised learning for multimodal deformable image registration + `_. + + Adapted from: + DeepReg (https://github.com/DeepRegNet/DeepReg) + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_channel_initial: int, + extract_levels: List[int], + out_activation: Optional[Union[Tuple, str]], + out_initializer: str = "kaiming_uniform", + ) -> None: + """ + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + num_channel_initial: number of initial channels. + extract_levels: number of extraction levels. + out_activation: activation to use at end layer. + out_initializer: initializer for extraction layers. + """ + super(LocalNet, self).__init__() + self.extract_levels = extract_levels + self.extract_max_level = max(self.extract_levels) # E + self.extract_min_level = min(self.extract_levels) # D + + num_channels = [ + num_channel_initial * (2 ** level) for level in range(self.extract_max_level + 1) + ] # level 0 to E + + self.downsample_blocks = nn.ModuleList( + [ + LocalNetDownSampleBlock( + spatial_dims=spatial_dims, + in_channels=in_channels if i == 0 else num_channels[i - 1], + out_channels=num_channels[i], + kernel_size=7 if i == 0 else 3, + ) + for i in range(self.extract_max_level) + ] + ) # level 0 to self.extract_max_level - 1 + self.conv3d_block = get_conv_block( + spatial_dims=spatial_dims, in_channels=num_channels[-2], out_channels=num_channels[-1] + ) # self.extract_max_level + + self.upsample_blocks = nn.ModuleList( + [ + LocalNetUpSampleBlock( + spatial_dims=spatial_dims, + in_channels=num_channels[level + 1], + out_channels=num_channels[level], + ) + for level in range(self.extract_max_level - 1, self.extract_min_level - 1, -1) + ] + ) # self.extract_max_level - 1 to self.extract_min_level + + self.extract_layers = nn.ModuleList( + [ + # if kernels are not initialized by zeros, with init NN, extract may be too large + LocalNetFeatureExtractorBlock( + spatial_dims=spatial_dims, + in_channels=num_channels[level], + out_channels=out_channels, + act=out_activation, + initializer=out_initializer, + ) + for level in self.extract_levels + ] + ) + + def forward(self, x) -> torch.Tensor: + image_size = x.shape[2:] + for size in image_size: + if size % (2 ** self.extract_max_level) != 0: + raise ValueError( + f"given extract_max_level {self.extract_max_level}, " + f"all input spatial dimension must be devidable by {2 ** self.extract_max_level}, " + f"got input of size {image_size}" + ) + mid_features = [] # 0 -> self.extract_max_level - 1 + for downsample_block in self.downsample_blocks: + x, mid = downsample_block(x) + mid_features.append(mid) + x = self.conv3d_block(x) # self.extract_max_level + + decoded_features = [x] + for idx, upsample_block in enumerate(self.upsample_blocks): + x = upsample_block(x, mid_features[-idx - 1]) + decoded_features.append(x) # self.extract_max_level -> self.extract_min_level + + output = torch.mean( + torch.stack( + [ + F.interpolate( + extract_layer(decoded_features[self.extract_max_level - self.extract_levels[idx]]), + size=image_size, + ) + for idx, extract_layer in enumerate(self.extract_layers) + ], + dim=-1, + ), + dim=-1, + ) + return output diff --git a/monai/networks/nets/regressor.py b/monai/networks/nets/regressor.py index e049a56923..a1abadb6ba 100644 --- a/monai/networks/nets/regressor.py +++ b/monai/networks/nets/regressor.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/networks/nets/segresnet.py b/monai/networks/nets/segresnet.py index d7869563cd..c7a085b569 100644 --- a/monai/networks/nets/segresnet.py +++ b/monai/networks/nets/segresnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,9 +13,10 @@ import numpy as np import torch +import torch.nn as nn import torch.nn.functional as F -from monai.networks.blocks.segresnet_block import * +from monai.networks.blocks.segresnet_block import ResBlock, get_conv_layer, get_norm_layer, get_upsample_layer from monai.networks.layers.factories import Act, Dropout from monai.utils import UpsampleMode @@ -65,7 +66,8 @@ def __init__( ): super().__init__() - assert spatial_dims == 2 or spatial_dims == 3, "spatial_dims can only be 2 or 3." + if spatial_dims not in (2, 3): + raise AssertionError("spatial_dims can only be 2 or 3.") self.spatial_dims = spatial_dims self.init_filters = init_filters diff --git a/monai/networks/nets/senet.py b/monai/networks/nets/senet.py index bdf926dfc7..655ff203c7 100644 --- a/monai/networks/nets/senet.py +++ b/monai/networks/nets/senet.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index 55901430d2..f3742d05b5 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,6 +19,8 @@ from monai.networks.layers.simplelayers import SkipConnection from monai.utils import alias, export +__all__ = ["UNet", "Unet", "unet"] + @export("monai.networks.nets") @alias("Unet") @@ -124,17 +126,16 @@ def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_ norm=self.norm, dropout=self.dropout, ) - else: - return Convolution( - self.dimensions, - in_channels, - out_channels, - strides=strides, - kernel_size=self.kernel_size, - act=self.act, - norm=self.norm, - dropout=self.dropout, - ) + return Convolution( + self.dimensions, + in_channels, + out_channels, + strides=strides, + kernel_size=self.kernel_size, + act=self.act, + norm=self.norm, + dropout=self.dropout, + ) def _get_bottom_layer(self, in_channels: int, out_channels: int) -> nn.Module: """ diff --git a/monai/networks/nets/varautoencoder.py b/monai/networks/nets/varautoencoder.py index f586f31995..b68350e8b1 100644 --- a/monai/networks/nets/varautoencoder.py +++ b/monai/networks/nets/varautoencoder.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/networks/nets/vnet.py b/monai/networks/nets/vnet.py index f13b0ede7f..63acb5cafb 100644 --- a/monai/networks/nets/vnet.py +++ b/monai/networks/nets/vnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -212,7 +212,8 @@ def __init__( ): super().__init__() - assert spatial_dims == 2 or spatial_dims == 3, "spatial_dims can only be 2 or 3." + if spatial_dims not in (2, 3): + raise AssertionError("spatial_dims can only be 2 or 3.") self.in_tr = InputTransition(spatial_dims, in_channels, 16, act) self.down_tr32 = DownTransition(spatial_dims, 16, 1, act) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 1bcccd084c..175d3d8b73 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,6 +13,7 @@ """ import warnings +from contextlib import contextmanager from typing import Any, Callable, Optional, Sequence, cast import torch @@ -29,6 +30,8 @@ "normal_init", "icnr_init", "pixelshuffle", + "eval_mode", + "train_mode", ] @@ -42,7 +45,8 @@ def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.f For every value v = labels[b,1,h,w], the value in the result at [b,v,h,w] will be 1 and all others 0. Note that this will include the background label, thus a binary mask should be treated as having 2 classes. """ - assert labels.dim() > 0, "labels should have dim of 1 or more." + if labels.dim() <= 0: + raise AssertionError("labels should have dim of 1 or more.") # if `dim` is bigger, add singleton dim at the end if labels.ndim < dim + 1: @@ -51,7 +55,8 @@ def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.f sh = list(labels.shape) - assert sh[dim] == 1, "labels should have a channel with length equals to one." + if sh[dim] != 1: + raise AssertionError("labels should have a channel with length equals to one.") sh[dim] = num_classes o = torch.zeros(size=sh, dtype=dtype, device=labels.device) @@ -241,3 +246,70 @@ def pixelshuffle(x: torch.Tensor, dimensions: int, scale_factor: int) -> torch.T x = x.reshape(batch_size, org_channels, *([factor] * dim + input_size[2:])) x = x.permute(permute_indices).reshape(output_size) return x + + +@contextmanager +def eval_mode(*nets: nn.Module): + """ + Set network(s) to eval mode and then return to original state at the end. + + Args: + nets: Input network(s) + + Examples + + .. code-block:: python + + t=torch.rand(1,1,16,16) + p=torch.nn.Conv2d(1,1,3) + print(p.training) # True + with eval_mode(p): + print(p.training) # False + print(p(t).sum().backward()) # will correctly raise an exception as gradients are calculated + """ + + # Get original state of network(s) + training = [n for n in nets if n.training] + + try: + # set to eval mode + with torch.no_grad(): + yield [n.eval() for n in nets] + finally: + # Return required networks to training + for n in training: + n.train() + + +@contextmanager +def train_mode(*nets: nn.Module): + """ + Set network(s) to train mode and then return to original state at the end. + + Args: + nets: Input network(s) + + Examples + + .. code-block:: python + + t=torch.rand(1,1,16,16) + p=torch.nn.Conv2d(1,1,3) + p.eval() + print(p.training) # False + with train_mode(p): + print(p.training) # True + print(p(t).sum().backward()) # No exception + """ + + # Get original state of network(s) + eval_list = [n for n in nets if not n.training] + + try: + # set to train mode + with torch.set_grad_enabled(True): + yield [n.train() for n in nets] + finally: + # Return required networks to eval_list + for n in eval_list: + n.eval() diff --git a/monai/optimizers/__init__.py b/monai/optimizers/__init__.py index 120c0ff25f..e53aa8d468 100644 --- a/monai/optimizers/__init__.py +++ b/monai/optimizers/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,5 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .lr_finder import LearningRateFinder from .novograd import Novograd -from .utils import * +from .utils import generate_param_groups diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py new file mode 100644 index 0000000000..6ad4132dd0 --- /dev/null +++ b/monai/optimizers/lr_finder.py @@ -0,0 +1,531 @@ +import warnings +from functools import partial +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union + +import numpy as np +import torch +import torch.nn as nn +from numpy.core.arrayprint import _none_or_positive_arg +from torch.optim import Optimizer +from torch.utils.data import DataLoader + +from monai.networks.utils import eval_mode +from monai.optimizers.lr_scheduler import ExponentialLR, LinearLR +from monai.utils import StateCacher, copy_to_device, optional_import + +if TYPE_CHECKING: + import matplotlib.pyplot as plt + + has_matplotlib = True + import tqdm + + has_tqdm = True +else: + plt, has_matplotlib = optional_import("matplotlib.pyplot") + tqdm, has_tqdm = optional_import("tqdm") + +__all__ = ["LearningRateFinder"] + + +class DataLoaderIter: + def __init__(self, data_loader: DataLoader, image_extractor: Callable, label_extractor: Callable) -> None: + if not isinstance(data_loader, DataLoader): + raise ValueError( + f"Loader has unsupported type: {type(data_loader)}. Expected type was `torch.utils.data.DataLoader`" + ) + self.data_loader = data_loader + self._iterator = iter(data_loader) + self.image_extractor = image_extractor + self.label_extractor = label_extractor + + @property + def dataset(self): + return self.data_loader.dataset + + def inputs_labels_from_batch(self, batch_data): + images = self.image_extractor(batch_data) + labels = self.label_extractor(batch_data) + return images, labels + + def __iter__(self): + return self + + def __next__(self): + batch = next(self._iterator) + return self.inputs_labels_from_batch(batch) + + +class TrainDataLoaderIter(DataLoaderIter): + def __init__( + self, data_loader: DataLoader, image_extractor: Callable, label_extractor: Callable, auto_reset: bool = True + ) -> None: + super().__init__(data_loader, image_extractor, label_extractor) + self.auto_reset = auto_reset + + def __next__(self): + try: + batch = next(self._iterator) + inputs, labels = self.inputs_labels_from_batch(batch) + except StopIteration: + if not self.auto_reset: + raise + self._iterator = iter(self.data_loader) + batch = next(self._iterator) + inputs, labels = self.inputs_labels_from_batch(batch) + + return inputs, labels + + +class ValDataLoaderIter(DataLoaderIter): + """This iterator will reset itself **only** when it is acquired by + the syntax of normal `iterator`. That is, this iterator just works + like a `torch.data.DataLoader`. If you want to restart it, you + should use it like: + + ``` + loader_iter = ValDataLoaderIter(data_loader) + for batch in loader_iter: + ... + + # `loader_iter` should run out of values now, you can restart it by: + # 1. the way we use a `torch.data.DataLoader` + for batch in loader_iter: # __iter__ is called implicitly + ... + + # 2. passing it into `iter()` manually + loader_iter = iter(loader_iter) # __iter__ is called by `iter()` + ``` + """ + + def __init__(self, data_loader: DataLoader, image_extractor: Callable, label_extractor: Callable) -> None: + super().__init__(data_loader, image_extractor, label_extractor) + self.run_limit = len(self.data_loader) + self.run_counter = 0 + + def __iter__(self): + if self.run_counter >= self.run_limit: + self._iterator = iter(self.data_loader) + self.run_counter = 0 + return self + + def __next__(self): + self.run_counter += 1 + return super(ValDataLoaderIter, self).__next__() + + +def default_image_extractor(x: Any) -> torch.Tensor: + """Default callable for getting image from batch data.""" + out: torch.Tensor = x["image"] if isinstance(x, dict) else x[0] + return out + + +def default_label_extractor(x: Any) -> torch.Tensor: + """Default callable for getting label from batch data.""" + out: torch.Tensor = x["label"] if isinstance(x, dict) else x[1] + return out + + +class LearningRateFinder: + """Learning rate range test. + + The learning rate range test increases the learning rate in a pre-training run + between two boundaries in a linear or exponential manner. It provides valuable + information on how well the network can be trained over a range of learning rates + and what is the optimal learning rate. + + Example (fastai approach): + >>> lr_finder = LearningRateFinder(net, optimizer, criterion) + >>> lr_finder.range_test(data_loader, end_lr=100, num_iter=100) + >>> lr_finder.get_steepest_gradient() + >>> lr_finder.plot() # to inspect the loss-learning rate graph + + Example (Leslie Smith's approach): + >>> lr_finder = LearningRateFinder(net, optimizer, criterion) + >>> lr_finder.range_test(train_loader, val_loader=val_loader, end_lr=1, num_iter=100, step_mode="linear") + + Gradient accumulation is supported; example: + >>> train_data = ... # prepared dataset + >>> desired_bs, real_bs = 32, 4 # batch size + >>> accumulation_steps = desired_bs // real_bs # required steps for accumulation + >>> data_loader = torch.utils.data.DataLoader(train_data, batch_size=real_bs, shuffle=True) + >>> acc_lr_finder = LearningRateFinder(net, optimizer, criterion) + >>> acc_lr_finder.range_test(data_loader, end_lr=10, num_iter=100, accumulation_steps=accumulation_steps) + + By default, image will be extracted from data loader with x["image"] and x[0], depending on whether + batch data is a dictionary or not (and similar behaviour for extracting the label). If your data loader + returns something other than this, pass a callable function to extract it, e.g.: + >>> image_extractor = lambda x: x["input"] + >>> label_extractor = lambda x: x[100] + >>> lr_finder = LearningRateFinder(net, optimizer, criterion) + >>> lr_finder.range_test(train_loader, val_loader, image_extractor, label_extractor) + + References: + Modified from: https://github.com/davidtvs/pytorch-lr-finder. + Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 + """ + + def __init__( + self, + model: nn.Module, + optimizer: Optimizer, + criterion: torch.nn.Module, + device: Optional[Union[str, torch.device]] = None, + memory_cache: bool = True, + cache_dir: Optional[str] = None, + amp: bool = False, + verbose: bool = True, + ) -> None: + """Constructor. + + Args: + model: wrapped model. + optimizer: wrapped optimizer. + criterion: wrapped loss function. + device: device on which to test. run a string ("cpu" or "cuda") with an + optional ordinal for the device type (e.g. "cuda:X", where is the ordinal). + Alternatively, can be an object representing the device on which the + computation will take place. Default: None, uses the same device as `model`. + memory_cache: if this flag is set to True, `state_dict` of + model and optimizer will be cached in memory. Otherwise, they will be saved + to files under the `cache_dir`. + cache_dir: path for storing temporary files. If no path is + specified, system-wide temporary directory is used. Notice that this + parameter will be ignored if `memory_cache` is True. + amp: use Automatic Mixed Precision + verbose: verbose output + Returns: + None + """ + # Check if the optimizer is already attached to a scheduler + self.optimizer = optimizer + self._check_for_scheduler() + + self.model = model + self.criterion = criterion + self.history: Dict[str, list] = {"lr": [], "loss": []} + self.memory_cache = memory_cache + self.cache_dir = cache_dir + self.amp = amp + self.verbose = verbose + + # Save the original state of the model and optimizer so they can be restored if + # needed + self.model_device = next(self.model.parameters()).device + self.state_cacher = StateCacher(memory_cache, cache_dir=cache_dir) + self.state_cacher.store("model", self.model.state_dict()) + self.state_cacher.store("optimizer", self.optimizer.state_dict()) + + # If device is None, use the same as the model + self.device = device if device else self.model_device + + def reset(self) -> None: + """Restores the model and optimizer to their initial states.""" + + self.model.load_state_dict(self.state_cacher.retrieve("model")) + self.optimizer.load_state_dict(self.state_cacher.retrieve("optimizer")) + self.model.to(self.model_device) + + def range_test( + self, + train_loader: DataLoader, + val_loader: Optional[DataLoader] = None, + image_extractor: Callable = default_image_extractor, + label_extractor: Callable = default_label_extractor, + start_lr: Optional[float] = None, + end_lr: int = 10, + num_iter: int = 100, + step_mode: str = "exp", + smooth_f: float = 0.05, + diverge_th: int = 5, + accumulation_steps: int = 1, + non_blocking_transfer: bool = True, + auto_reset: bool = True, + ) -> None: + """Performs the learning rate range test. + + Args: + train_loader: training set data loader. + val_loader: validation data loader (if desired). + image_extractor: callable function to get the image from a batch of data. + Default: `x["image"] if isinstance(x, dict) else x[0]`. + label_extractor: callable function to get the label from a batch of data. + Default: `x["label"] if isinstance(x, dict) else x[1]`. + start_lr : the starting learning rate for the range test. + The default is the optimizer's learning rate. + end_lr: the maximum learning rate to test. The test may stop earlier than + this if the result starts diverging. + num_iter: the max number of iterations for test. + step_mode: schedule for increasing learning rate: (`linear` or `exp`). + smooth_f: the loss smoothing factor within the `[0, 1[` interval. Disabled + if set to `0`, otherwise loss is smoothed using exponential smoothing. + diverge_th: test is stopped when loss surpasses threshold: + `diverge_th * best_loss`. + accumulation_steps: steps for gradient accumulation. If set to `1`, + gradients are not accumulated. + non_blocking_transfer: when `True`, moves data to device asynchronously if + possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. + auto_reset: if `True`, returns model and optimizer to original states at end + of test. + Returns: + None + """ + + # Reset test results + self.history = {"lr": [], "loss": []} + best_loss = -float("inf") + + # Move the model to the proper device + self.model.to(self.device) + + # Check if the optimizer is already attached to a scheduler + self._check_for_scheduler() + + # Set the starting learning rate + if start_lr: + self._set_learning_rate(start_lr) + + # Check number of iterations + if num_iter <= 1: + raise ValueError("`num_iter` must be larger than 1") + + # Initialize the proper learning rate policy + lr_schedule: Union[ExponentialLR, LinearLR] + if step_mode.lower() == "exp": + lr_schedule = ExponentialLR(self.optimizer, end_lr, num_iter) + elif step_mode.lower() == "linear": + lr_schedule = LinearLR(self.optimizer, end_lr, num_iter) + else: + raise ValueError(f"expected one of (exp, linear), got {step_mode}") + + if smooth_f < 0 or smooth_f >= 1: + raise ValueError("smooth_f is outside the range [0, 1[") + + # Create an iterator to get data batch by batch + train_iter = TrainDataLoaderIter(train_loader, image_extractor, label_extractor) + if val_loader: + val_iter = ValDataLoaderIter(val_loader, image_extractor, label_extractor) + + trange: Union[partial[tqdm.trange], Type[range]] + if self.verbose and has_tqdm: + trange = partial(tqdm.trange, desc="Computing optimal learning rate") + tprint = tqdm.tqdm.write + else: + trange = range + tprint = print + + for iteration in trange(num_iter): + if self.verbose and not has_tqdm: + print(f"Computing optimal learning rate, iteration {iteration + 1}/{num_iter}") + + # Train on batch and retrieve loss + loss = self._train_batch( + train_iter, + accumulation_steps, + non_blocking_transfer=non_blocking_transfer, + ) + if val_loader: + loss = self._validate(val_iter, non_blocking_transfer=non_blocking_transfer) + + # Update the learning rate + self.history["lr"].append(lr_schedule.get_lr()[0]) + lr_schedule.step() + + # Track the best loss and smooth it if smooth_f is specified + if iteration == 0: + best_loss = loss + else: + if smooth_f > 0: + loss = smooth_f * loss + (1 - smooth_f) * self.history["loss"][-1] + if loss < best_loss: + best_loss = loss + + # Check if the loss has diverged; if it has, stop the test + self.history["loss"].append(loss) + if loss > diverge_th * best_loss: + if self.verbose: + tprint("Stopping early, the loss has diverged") + break + + if auto_reset: + if self.verbose: + print("Resetting model and optimizer") + self.reset() + + def _set_learning_rate(self, new_lrs: Union[float, list]) -> None: + """Set learning rate(s) for optimizer.""" + if not isinstance(new_lrs, list): + new_lrs = [new_lrs] * len(self.optimizer.param_groups) + if len(new_lrs) != len(self.optimizer.param_groups): + raise ValueError( + "Length of `new_lrs` is not equal to the number of parameter groups " + "in the given optimizer" + ) + + for param_group, new_lr in zip(self.optimizer.param_groups, new_lrs): + param_group["lr"] = new_lr + + def _check_for_scheduler(self) -> _none_or_positive_arg: + """Check optimizer doesn't already have scheduler.""" + for param_group in self.optimizer.param_groups: + if "initial_lr" in param_group: + raise RuntimeError("Optimizer already has a scheduler attached to it") + + def _train_batch(self, train_iter, accumulation_steps: int, non_blocking_transfer: bool = True) -> float: + self.model.train() + total_loss = 0 + + self.optimizer.zero_grad() + for i in range(accumulation_steps): + inputs, labels = next(train_iter) + inputs, labels = copy_to_device([inputs, labels], device=self.device, non_blocking=non_blocking_transfer) + + # Forward pass + outputs = self.model(inputs) + loss = self.criterion(outputs, labels) + + # Loss should be averaged in each step + loss /= accumulation_steps + + # Backward pass + if self.amp and hasattr(self.optimizer, "_amp_stash"): + # For minor performance optimization, see also: + # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations + delay_unscale = ((i + 1) % accumulation_steps) != 0 + + with torch.cuda.amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss: # type: ignore + scaled_loss.backward() + else: + loss.backward() + + total_loss += loss.item() + + self.optimizer.step() + + return total_loss + + def _validate(self, val_iter: ValDataLoaderIter, non_blocking_transfer: bool = True) -> float: + # Set model to evaluation mode and disable gradient computation + running_loss = 0 + with eval_mode(self.model): + for inputs, labels in val_iter: + # Copy data to the correct device + inputs, labels = copy_to_device( + [inputs, labels], device=self.device, non_blocking=non_blocking_transfer + ) + + # Forward pass and loss computation + outputs = self.model(inputs) + loss = self.criterion(outputs, labels) + running_loss += loss.item() * len(labels) + + return running_loss / len(val_iter.dataset) + + def get_lrs_and_losses( + self, + skip_start: int = 0, + skip_end: int = 0, + ) -> Tuple[list, list]: + """Get learning rates and their corresponding losses + + Args: + skip_start: number of batches to trim from the start. + skip_end: number of batches to trim from the end. + """ + if skip_start < 0: + raise ValueError("skip_start cannot be negative") + if skip_end < 0: + raise ValueError("skip_end cannot be negative") + + lrs = self.history["lr"] + losses = self.history["loss"] + end_idx = len(lrs) - skip_end - 1 + lrs = lrs[skip_start:end_idx] + losses = losses[skip_start:end_idx] + + return lrs, losses + + def get_steepest_gradient( + self, + skip_start: int = 0, + skip_end: int = 0, + ) -> Union[Tuple[float, float], Tuple[None, None]]: + """Get learning rate which has steepest gradient and its corresponding loss + + Args: + skip_start: number of batches to trim from the start. + skip_end: number of batches to trim from the end. + + Returns: + Learning rate which has steepest gradient and its corresponding loss + """ + lrs, losses = self.get_lrs_and_losses(skip_start, skip_end) + + try: + min_grad_idx = np.gradient(np.array(losses)).argmin() + return lrs[min_grad_idx], losses[min_grad_idx] + except ValueError: + print("Failed to compute the gradients, there might not be enough points.") + return None, None + + def plot( + self, + skip_start: int = 0, + skip_end: int = 0, + log_lr: bool = True, + ax=None, + steepest_lr: bool = True, + ): + """Plots the learning rate range test. + + Args: + skip_start: number of batches to trim from the start. + skip_end: number of batches to trim from the start. + log_lr: True to plot the learning rate in a logarithmic + scale; otherwise, plotted in a linear scale. + ax: the plot is created in the specified matplotlib axes object and the + figure is not be shown. If `None`, then the figure and axes object are + created in this method and the figure is shown. + steepest_lr: plot the learning rate which had the steepest gradient. + + Returns: + The `matplotlib.axes.Axes` object that contains the plot. Returns `None` if + `matplotlib` is not installed. + """ + if not has_matplotlib: + warnings.warn("Matplotlib is missing, can't plot result") + return None + + lrs, losses = self.get_lrs_and_losses(skip_start, skip_end) + + # Create the figure and axes object if axes was not already given + fig = None + if ax is None: + fig, ax = plt.subplots() + + # Plot loss as a function of the learning rate + ax.plot(lrs, losses) + + # Plot the LR with steepest gradient + if steepest_lr: + lr_at_steepest_grad, loss_at_steepest_grad = self.get_steepest_gradient(skip_start, skip_end) + if lr_at_steepest_grad is not None: + ax.scatter( + lr_at_steepest_grad, + loss_at_steepest_grad, + s=75, + marker="o", + color="red", + zorder=3, + label="steepest gradient", + ) + ax.legend() + + if log_lr: + ax.set_xscale("log") + ax.set_xlabel("Learning rate") + ax.set_ylabel("Loss") + + # Show only if the figure was created internally + if fig is not None: + plt.show() + + return ax diff --git a/monai/optimizers/lr_scheduler.py b/monai/optimizers/lr_scheduler.py new file mode 100644 index 0000000000..aa9bf2a89b --- /dev/null +++ b/monai/optimizers/lr_scheduler.py @@ -0,0 +1,43 @@ +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler + +__all__ = ["LinearLR", "ExponentialLR"] + + +class _LRSchedulerMONAI(_LRScheduler): + """Base class for increasing the learning rate between two boundaries over a number + of iterations""" + + def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1) -> None: + """ + Args: + optimizer: wrapped optimizer. + end_lr: the final learning rate. + num_iter: the number of iterations over which the test occurs. + last_epoch: the index of last epoch. + Returns: + None + """ + self.end_lr = end_lr + self.num_iter = num_iter + super(_LRSchedulerMONAI, self).__init__(optimizer, last_epoch) + + +class LinearLR(_LRSchedulerMONAI): + """Linearly increases the learning rate between two boundaries over a number of + iterations. + """ + + def get_lr(self): + r = self.last_epoch / (self.num_iter - 1) + return [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs] + + +class ExponentialLR(_LRSchedulerMONAI): + """Exponentially increases the learning rate between two boundaries over a number of + iterations. + """ + + def get_lr(self): + r = self.last_epoch / (self.num_iter - 1) + return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] diff --git a/monai/optimizers/novograd.py b/monai/optimizers/novograd.py index dd5004afcc..62e42cc9ab 100644 --- a/monai/optimizers/novograd.py +++ b/monai/optimizers/novograd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -44,15 +44,15 @@ def __init__( grad_averaging: bool = False, amsgrad: bool = False, ): - if not 0.0 <= lr: + if 0.0 > lr: raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: + if 0.0 > eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - if not 0.0 <= weight_decay: + if 0.0 > weight_decay: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) defaults = dict( lr=lr, diff --git a/monai/optimizers/utils.py b/monai/optimizers/utils.py index 4cafa45749..9c4bfcf6ee 100644 --- a/monai/optimizers/utils.py +++ b/monai/optimizers/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -75,8 +75,8 @@ def _filter(): return _filter - params = list() - _layers = list() + params = [] + _layers = [] for func, ty, lr in zip(layer_matches, match_types, lr_values): if ty.lower() == "select": layer_params = _get_select(func) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 77114828f6..9eaedd6b15 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,18 +9,337 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .adaptors import * -from .compose import * -from .croppad.array import * -from .croppad.dictionary import * -from .intensity.array import * -from .intensity.dictionary import * -from .io.array import * -from .io.dictionary import * -from .post.array import * -from .post.dictionary import * -from .spatial.array import * -from .spatial.dictionary import * -from .utility.array import * -from .utility.dictionary import * -from .utils import * +from .adaptors import FunctionSignature, adaptor, apply_alias, to_kwargs +from .compose import Compose, MapTransform, Randomizable, Transform +from .croppad.array import ( + BorderPad, + BoundingRect, + CenterSpatialCrop, + CropForeground, + DivisiblePad, + RandCropByPosNegLabel, + RandSpatialCrop, + RandSpatialCropSamples, + RandWeightedCrop, + ResizeWithPadOrCrop, + SpatialCrop, + SpatialPad, +) +from .croppad.dictionary import ( + BorderPadd, + BorderPadD, + BorderPadDict, + BoundingRectd, + BoundingRectD, + BoundingRectDict, + CenterSpatialCropd, + CenterSpatialCropD, + CenterSpatialCropDict, + CropForegroundd, + CropForegroundD, + CropForegroundDict, + DivisiblePadd, + DivisiblePadD, + DivisiblePadDict, + NumpyPadModeSequence, + RandCropByPosNegLabeld, + RandCropByPosNegLabelD, + RandCropByPosNegLabelDict, + RandSpatialCropd, + RandSpatialCropD, + RandSpatialCropDict, + RandSpatialCropSamplesd, + RandSpatialCropSamplesD, + RandSpatialCropSamplesDict, + RandWeightedCropd, + RandWeightedCropD, + RandWeightedCropDict, + ResizeWithPadOrCropd, + ResizeWithPadOrCropD, + ResizeWithPadOrCropDict, + SpatialCropd, + SpatialCropD, + SpatialCropDict, + SpatialPadd, + SpatialPadD, + SpatialPadDict, +) +from .intensity.array import ( + AdjustContrast, + DetectEnvelope, + GaussianSharpen, + GaussianSmooth, + MaskIntensity, + NormalizeIntensity, + RandAdjustContrast, + RandGaussianNoise, + RandGaussianSharpen, + RandGaussianSmooth, + RandHistogramShift, + RandScaleIntensity, + RandShiftIntensity, + SavitzkyGolaySmooth, + ScaleIntensity, + ScaleIntensityRange, + ScaleIntensityRangePercentiles, + ShiftIntensity, + ThresholdIntensity, +) +from .intensity.dictionary import ( + AdjustContrastd, + AdjustContrastD, + AdjustContrastDict, + GaussianSharpend, + GaussianSharpenD, + GaussianSharpenDict, + GaussianSmoothd, + GaussianSmoothD, + GaussianSmoothDict, + MaskIntensityd, + MaskIntensityD, + MaskIntensityDict, + NormalizeIntensityd, + NormalizeIntensityD, + NormalizeIntensityDict, + RandAdjustContrastd, + RandAdjustContrastD, + RandAdjustContrastDict, + RandGaussianNoised, + RandGaussianNoiseD, + RandGaussianNoiseDict, + RandGaussianSharpend, + RandGaussianSharpenD, + RandGaussianSharpenDict, + RandGaussianSmoothd, + RandGaussianSmoothD, + RandGaussianSmoothDict, + RandHistogramShiftd, + RandHistogramShiftD, + RandHistogramShiftDict, + RandScaleIntensityd, + RandScaleIntensityD, + RandScaleIntensityDict, + RandShiftIntensityd, + RandShiftIntensityD, + RandShiftIntensityDict, + ScaleIntensityd, + ScaleIntensityD, + ScaleIntensityDict, + ScaleIntensityRanged, + ScaleIntensityRangeD, + ScaleIntensityRangeDict, + ScaleIntensityRangePercentilesd, + ScaleIntensityRangePercentilesD, + ScaleIntensityRangePercentilesDict, + ShiftIntensityd, + ShiftIntensityD, + ShiftIntensityDict, + ThresholdIntensityd, + ThresholdIntensityD, + ThresholdIntensityDict, +) +from .io.array import LoadImage +from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict +from .post.array import ( + Activations, + AsDiscrete, + KeepLargestConnectedComponent, + LabelToContour, + MeanEnsemble, + VoteEnsemble, +) +from .post.dictionary import ( + Activationsd, + ActivationsD, + ActivationsDict, + AsDiscreted, + AsDiscreteD, + AsDiscreteDict, + Ensembled, + KeepLargestConnectedComponentd, + KeepLargestConnectedComponentD, + KeepLargestConnectedComponentDict, + LabelToContourd, + LabelToContourD, + LabelToContourDict, + MeanEnsembled, + MeanEnsembleD, + MeanEnsembleDict, + VoteEnsembled, + VoteEnsembleD, + VoteEnsembleDict, +) +from .spatial.array import ( + Affine, + AffineGrid, + Flip, + Orientation, + Rand2DElastic, + Rand3DElastic, + RandAffine, + RandAffineGrid, + RandDeformGrid, + RandFlip, + RandRotate, + RandRotate90, + RandZoom, + Resample, + Resize, + Rotate, + Rotate90, + Spacing, + Zoom, +) +from .spatial.dictionary import ( + Flipd, + FlipD, + FlipDict, + Orientationd, + OrientationD, + OrientationDict, + Rand2DElasticd, + Rand2DElasticD, + Rand2DElasticDict, + Rand3DElasticd, + Rand3DElasticD, + Rand3DElasticDict, + RandAffined, + RandAffineD, + RandAffineDict, + RandFlipd, + RandFlipD, + RandFlipDict, + RandRotate90d, + RandRotate90D, + RandRotate90Dict, + RandRotated, + RandRotateD, + RandRotateDict, + RandZoomd, + RandZoomD, + RandZoomDict, + Resized, + ResizeD, + ResizeDict, + Rotate90d, + Rotate90D, + Rotate90Dict, + Rotated, + RotateD, + RotateDict, + Spacingd, + SpacingD, + SpacingDict, + Zoomd, + ZoomD, + ZoomDict, +) +from .utility.array import ( + AddChannel, + AddExtremePointsChannel, + AsChannelFirst, + AsChannelLast, + CastToType, + ConvertToMultiChannelBasedOnBratsClasses, + DataStats, + FgBgToIndices, + Identity, + LabelToMask, + Lambda, + RepeatChannel, + SimulateDelay, + SplitChannel, + SqueezeDim, + ToNumpy, + TorchVision, + ToTensor, + Transpose, +) +from .utility.dictionary import ( + AddChanneld, + AddChannelD, + AddChannelDict, + AddExtremePointsChanneld, + AddExtremePointsChannelD, + AddExtremePointsChannelDict, + AsChannelFirstd, + AsChannelFirstD, + AsChannelFirstDict, + AsChannelLastd, + AsChannelLastD, + AsChannelLastDict, + CastToTyped, + CastToTypeD, + CastToTypeDict, + ConcatItemsd, + ConcatItemsD, + ConcatItemsDict, + ConvertToMultiChannelBasedOnBratsClassesd, + ConvertToMultiChannelBasedOnBratsClassesD, + ConvertToMultiChannelBasedOnBratsClassesDict, + CopyItemsd, + CopyItemsD, + CopyItemsDict, + DataStatsd, + DataStatsD, + DataStatsDict, + DeleteItemsd, + DeleteItemsD, + DeleteItemsDict, + FgBgToIndicesd, + FgBgToIndicesD, + FgBgToIndicesDict, + Identityd, + IdentityD, + IdentityDict, + LabelToMaskd, + LabelToMaskD, + LabelToMaskDict, + Lambdad, + LambdaD, + LambdaDict, + RepeatChanneld, + RepeatChannelD, + RepeatChannelDict, + SelectItemsd, + SimulateDelayd, + SimulateDelayD, + SimulateDelayDict, + SplitChanneld, + SplitChannelD, + SplitChannelDict, + SqueezeDimd, + SqueezeDimD, + SqueezeDimDict, + ToNumpyd, + TorchVisiond, + ToTensord, + ToTensorD, + ToTensorDict, +) +from .utils import ( + apply_transform, + copypaste_arrays, + create_control_grid, + create_grid, + create_rotate, + create_scale, + create_shear, + create_translate, + extreme_points_to_image, + generate_pos_neg_label_crop_centers, + generate_spatial_bounding_box, + get_extreme_points, + get_largest_connected_component_mask, + img_bounds, + in_bounds, + is_empty, + map_binary_to_indices, + rand_choice, + rescale_array, + rescale_array_int_max, + rescale_instance_array, + resize_center, + weighted_patch_samples, + zero_margins, +) diff --git a/monai/transforms/adaptors.py b/monai/transforms/adaptors.py index 3aa7f70339..434d1f1c05 100644 --- a/monai/transforms/adaptors.py +++ b/monai/transforms/adaptors.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -126,6 +126,8 @@ def __call__(self, img, seg): from monai.utils import export as _monai_export +__all__ = ["adaptor", "apply_alias", "to_kwargs", "FunctionSignature"] + @_monai_export("monai.transforms") def adaptor(function, outputs, inputs=None): @@ -194,7 +196,7 @@ def _inner(ditems): if len(ret) != len(outputs): raise ValueError("'outputs' must have the same length as the number of elements that were returned") - ret = {k: v for k, v in zip(op, ret)} + ret = dict(zip(op, ret)) else: must_be_types("outputs", op, (str, list, tuple)) if isinstance(op, (list, tuple)): diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 20e72f1df0..3e23377b36 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -55,7 +55,7 @@ def __call__(self, data: Any): - ``data`` is a Numpy ndarray, PyTorch Tensor or string - the data shape can be: - #. string data without shape, `LoadNifti` and `LoadPNG` transforms expect file paths + #. string data without shape, `LoadImage` transform expects file paths #. most of the pre-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``, except that `AddChannel` expects (spatial_dim_1[, spatial_dim_2, ...]) and `AsChannelFirst` expects (spatial_dim_1[, spatial_dim_2, ...], num_channels) @@ -194,7 +194,7 @@ class Compose(Randomizable, Transform): set of functions must be called as if it were a sequence. Example: images and labels - Images typically require some kind of normalisation that labels do not. + Images typically require some kind of normalization that labels do not. Both are then typically augmented through the use of random rotations, flips, and deformations. Compose can be used with a series of transforms that take a dictionary @@ -282,7 +282,7 @@ def __call__(self, data): - ``data[key]`` is a Numpy ndarray, PyTorch Tensor or string, where ``key`` is an element of ``self.keys``, the data shape can be: - #. string data without shape, `LoadNiftid` and `LoadPNGd` transforms expect file paths + #. string data without shape, `LoadImaged` transform expects file paths #. most of the pre-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``, except that `AddChanneld` expects (spatial_dim_1[, spatial_dim_2, ...]) and `AsChannelFirstd` expects (spatial_dim_1[, spatial_dim_2, ...], num_channels) diff --git a/monai/transforms/croppad/__init__.py b/monai/transforms/croppad/__init__.py index d0044e3563..14ae193634 100644 --- a/monai/transforms/croppad/__init__.py +++ b/monai/transforms/croppad/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 4c69a61b15..e59eb89ac7 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -28,6 +28,21 @@ ) from monai.utils import Method, NumpyPadMode, ensure_tuple, fall_back_tuple +__all__ = [ + "SpatialPad", + "BorderPad", + "DivisiblePad", + "SpatialCrop", + "CenterSpatialCrop", + "RandSpatialCrop", + "RandSpatialCropSamples", + "CropForeground", + "RandWeightedCrop", + "RandCropByPosNegLabel", + "ResizeWithPadOrCrop", + "BoundingRect", +] + class SpatialPad(Transform): """ @@ -59,13 +74,12 @@ def __init__( def _determine_data_pad_width(self, data_shape: Sequence[int]) -> List[Tuple[int, int]]: self.spatial_size = fall_back_tuple(self.spatial_size, data_shape) if self.method == Method.SYMMETRIC: - pad_width = list() + pad_width = [] for i in range(len(self.spatial_size)): width = max(self.spatial_size[i] - data_shape[i], 0) pad_width.append((width // 2, width - (width // 2))) return pad_width - else: - return [(0, max(self.spatial_size[i] - data_shape[i], 0)) for i in range(len(self.spatial_size))] + return [(0, max(self.spatial_size[i] - data_shape[i], 0)) for i in range(len(self.spatial_size))] def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None) -> np.ndarray: """ @@ -82,9 +96,8 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N if not np.asarray(all_pad_width).any(): # all zeros, skip padding return img - else: - img = np.pad(img, all_pad_width, mode=self.mode.value if mode is None else NumpyPadMode(mode).value) - return img + img = np.pad(img, all_pad_width, mode=self.mode.value if mode is None else NumpyPadMode(mode).value) + return img class BorderPad(Transform): @@ -299,12 +312,12 @@ def __call__(self, img: np.ndarray) -> np.ndarray: slicing doesn't apply to the channel dim. """ self.randomize(img.shape[1:]) - assert self._size is not None + if self._size is None: + raise AssertionError if self.random_center: return img[self._slices] - else: - cropper = CenterSpatialCrop(self._size) - return cropper(img) + cropper = CenterSpatialCrop(self._size) + return cropper(img) class RandSpatialCropSamples(Randomizable, Transform): @@ -586,7 +599,7 @@ def __call__( else: fg_indices, bg_indices = map_binary_to_indices(label, image, self.image_threshold) self.randomize(label, fg_indices, bg_indices, image) - results: List[np.ndarray] = list() + results: List[np.ndarray] = [] if self.centers is not None: for center in self.centers: cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) @@ -665,7 +678,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: """ See also: :py:class:`monai.transforms.utils.generate_spatial_bounding_box`. """ - bbox = list() + bbox = [] for channel in range(img.shape[0]): start_, end_ = generate_spatial_bounding_box(img, select_fn=self.select_fn, channel_indices=channel) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 8e927eb605..8bf33dd632 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -39,6 +39,46 @@ ) from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple +__all__ = [ + "NumpyPadModeSequence", + "SpatialPadd", + "BorderPadd", + "DivisiblePadd", + "SpatialCropd", + "CenterSpatialCropd", + "RandSpatialCropd", + "RandSpatialCropSamplesd", + "CropForegroundd", + "RandWeightedCropd", + "RandCropByPosNegLabeld", + "ResizeWithPadOrCropd", + "BoundingRectd", + "SpatialPadD", + "SpatialPadDict", + "BorderPadD", + "BorderPadDict", + "DivisiblePadD", + "DivisiblePadDict", + "SpatialCropD", + "SpatialCropDict", + "CenterSpatialCropD", + "CenterSpatialCropDict", + "RandSpatialCropD", + "RandSpatialCropDict", + "RandSpatialCropSamplesD", + "RandSpatialCropSamplesDict", + "CropForegroundD", + "CropForegroundDict", + "RandWeightedCropD", + "RandWeightedCropDict", + "RandCropByPosNegLabelD", + "RandCropByPosNegLabelDict", + "ResizeWithPadOrCropD", + "ResizeWithPadOrCropDict", + "BoundingRectD", + "BoundingRectDict", +] + NumpyPadModeSequence = Union[Sequence[Union[NumpyPadMode, str]], NumpyPadMode, str] @@ -261,7 +301,8 @@ def randomize(self, img_size: Sequence[int]) -> None: def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) self.randomize(d[self.keys[0]].shape[1:]) # image shape from the first data key - assert self._size is not None + if self._size is None: + raise AssertionError for key in self.keys: if self.random_center: d[key] = d[key][self._slices] @@ -422,7 +463,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n self.randomize(d[self.w_key]) _spatial_size = fall_back_tuple(self.spatial_size, d[self.w_key].shape[1:]) - results: List[Dict[Hashable, np.ndarray]] = [dict() for _ in range(self.num_samples)] + results: List[Dict[Hashable, np.ndarray]] = [{} for _ in range(self.num_samples)] for key in data.keys(): if key in self.keys: img = d[key] @@ -533,9 +574,11 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n bg_indices = d.get(self.bg_indices_key, None) if self.bg_indices_key is not None else None self.randomize(label, fg_indices, bg_indices, image) - assert isinstance(self.spatial_size, tuple) - assert self.centers is not None - results: List[Dict[Hashable, np.ndarray]] = [dict() for _ in range(self.num_samples)] + if not isinstance(self.spatial_size, tuple): + raise AssertionError + if self.centers is None: + raise AssertionError + results: List[Dict[Hashable, np.ndarray]] = [{} for _ in range(self.num_samples)] for key in data.keys(): if key in self.keys: img = d[key] diff --git a/monai/transforms/intensity/__init__.py b/monai/transforms/intensity/__init__.py index d0044e3563..14ae193634 100644 --- a/monai/transforms/intensity/__init__.py +++ b/monai/transforms/intensity/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index ac2d1e46fd..205b719246 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,11 +20,33 @@ import numpy as np import torch -from monai.networks.layers import GaussianFilter, HilbertTransform +from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter from monai.transforms.compose import Randomizable, Transform from monai.transforms.utils import rescale_array from monai.utils import PT_BEFORE_1_7, InvalidPyTorchVersionError, dtype_torch_to_numpy, ensure_tuple_size +__all__ = [ + "RandGaussianNoise", + "ShiftIntensity", + "RandShiftIntensity", + "ScaleIntensity", + "RandScaleIntensity", + "NormalizeIntensity", + "ThresholdIntensity", + "ScaleIntensityRange", + "AdjustContrast", + "RandAdjustContrast", + "ScaleIntensityRangePercentiles", + "MaskIntensity", + "DetectEnvelope", + "SavitzkyGolaySmooth", + "GaussianSmooth", + "RandGaussianSmooth", + "GaussianSharpen", + "RandGaussianSharpen", + "RandHistogramShift", +] + class RandGaussianNoise(Randomizable, Transform): """ @@ -52,7 +74,8 @@ def __call__(self, img: Union[torch.Tensor, np.ndarray]) -> Union[torch.Tensor, Apply the transform to `img`. """ self.randomize(img.shape) - assert self._noise is not None + if self._noise is None: + raise AssertionError if not self._do_transform: return img dtype = dtype_torch_to_numpy(img.dtype) if isinstance(img, torch.Tensor) else img.dtype @@ -92,7 +115,8 @@ def __init__(self, offsets: Union[Tuple[float, float], float], prob: float = 0.1 if isinstance(offsets, (int, float)): self.offsets = (min(-offsets, offsets), max(-offsets, offsets)) else: - assert len(offsets) == 2, "offsets should be a number or pair of numbers." + if len(offsets) != 2: + raise AssertionError("offsets should be a number or pair of numbers.") self.offsets = (min(offsets), max(offsets)) self.prob = prob @@ -142,10 +166,9 @@ def __call__(self, img: np.ndarray) -> np.ndarray: """ if self.minv is not None and self.maxv is not None: return rescale_array(img, self.minv, self.maxv, img.dtype) - elif self.factor is not None: + if self.factor is not None: return (img * (1 + self.factor)).astype(img.dtype) - else: - raise ValueError("Incompatible values: minv=None or maxv=None and factor=None.") + raise ValueError("Incompatible values: minv=None or maxv=None and factor=None.") class RandScaleIntensity(Randomizable, Transform): @@ -165,7 +188,8 @@ def __init__(self, factors: Union[Tuple[float, float], float], prob: float = 0.1 if isinstance(factors, (int, float)): self.factors = (min(-factors, factors), max(-factors, factors)) else: - assert len(factors) == 2, "factors should be a number or pair of numbers." + if len(factors) != 2: + raise AssertionError("factors should be a number or pair of numbers.") self.factors = (min(factors), max(factors)) self.prob = prob @@ -270,7 +294,8 @@ class ThresholdIntensity(Transform): """ def __init__(self, threshold: float, above: bool = True, cval: float = 0.0) -> None: - assert isinstance(threshold, (int, float)), "threshold must be a float or int number." + if not isinstance(threshold, (int, float)): + raise AssertionError("threshold must be a float or int number.") self.threshold = threshold self.above = above self.cval = cval @@ -329,7 +354,8 @@ class AdjustContrast(Transform): """ def __init__(self, gamma: float) -> None: - assert isinstance(gamma, (int, float)), "gamma must be a float or int number." + if not isinstance(gamma, (int, float)): + raise AssertionError("gamma must be a float or int number.") self.gamma = gamma def __call__(self, img: np.ndarray) -> np.ndarray: @@ -358,10 +384,14 @@ def __init__(self, prob: float = 0.1, gamma: Union[Sequence[float], float] = (0. self.prob = prob if isinstance(gamma, (int, float)): - assert gamma > 0.5, "if gamma is single number, must greater than 0.5 and value is picked from (0.5, gamma)" + if gamma <= 0.5: + raise AssertionError( + "if gamma is single number, must greater than 0.5 and value is picked from (0.5, gamma)" + ) self.gamma = (0.5, gamma) else: - assert len(gamma) == 2, "gamma should be a number or pair of numbers." + if len(gamma) != 2: + raise AssertionError("gamma should be a number or pair of numbers.") self.gamma = (min(gamma), max(gamma)) self._do_transform = False @@ -376,7 +406,8 @@ def __call__(self, img: np.ndarray) -> np.ndarray: Apply the transform to `img`. """ self.randomize() - assert self.gamma_value is not None + if self.gamma_value is None: + raise AssertionError if not self._do_transform: return img adjuster = AdjustContrast(self.gamma_value) @@ -441,8 +472,10 @@ class ScaleIntensityRangePercentiles(Transform): def __init__( self, lower: float, upper: float, b_min: float, b_max: float, clip: bool = False, relative: bool = False ) -> None: - assert 0.0 <= lower <= 100.0, "Percentiles must be in the range [0, 100]" - assert 0.0 <= upper <= 100.0, "Percentiles must be in the range [0, 100]" + if lower < 0.0 or lower > 100.0: + raise AssertionError("Percentiles must be in the range [0, 100]") + if upper < 0.0 or upper > 100.0: + raise AssertionError("Percentiles must be in the range [0, 100]") self.lower = lower self.upper = upper self.b_min = b_min @@ -512,6 +545,44 @@ def __call__(self, img: np.ndarray, mask_data: Optional[np.ndarray] = None) -> n return img * mask_data_ +class SavitzkyGolaySmooth(Transform): + """ + Smooth the input data along the given axis using a Savitzky-Golay filter. + + Args: + window_length: Length of the filter window, must be a positive odd integer. + order: Order of the polynomial to fit to each window, must be less than ``window_length``. + axis: Optional axis along which to apply the filter kernel. Default 1 (first spatial dimension). + mode: Optional padding mode, passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` + or ``'circular'``. Default: ``'zeros'``. See ``torch.nn.Conv1d()`` for more information. + """ + + def __init__(self, window_length: int, order: int, axis: int = 1, mode: str = "zeros"): + + if axis < 0: + raise ValueError("axis must be zero or positive.") + + self.window_length = window_length + self.order = order + self.axis = axis + self.mode = mode + + def __call__(self, img: np.ndarray) -> np.ndarray: + """ + Args: + img: numpy.ndarray containing input data. Must be real and in shape [channels, spatial1, spatial2, ...]. + + Returns: + np.ndarray containing smoothed result. + + """ + # add one to transform axis because a batch axis will be added at dimension 0 + savgol_filter = SavitzkyGolayFilter(self.window_length, self.order, self.axis + 1, self.mode) + # convert to Tensor and add Batch axis expected by HilbertTransform + input_data = torch.as_tensor(np.ascontiguousarray(img)).unsqueeze(0) + return savgol_filter(input_data).squeeze(0).numpy() + + class DetectEnvelope(Transform): """ Find the envelope of the input data along the requested axis using a Hilbert transform. @@ -748,11 +819,14 @@ class RandHistogramShift(Randomizable, Transform): def __init__(self, num_control_points: Union[Tuple[int, int], int] = 10, prob: float = 0.1) -> None: if isinstance(num_control_points, int): - assert num_control_points > 2, "num_control_points should be greater than or equal to 3" + if num_control_points <= 2: + raise AssertionError("num_control_points should be greater than or equal to 3") self.num_control_points = (num_control_points, num_control_points) else: - assert len(num_control_points) == 2, "num_control points should be a number or a pair of numbers" - assert min(num_control_points) > 2, "num_control_points should be greater than or equal to 3" + if len(num_control_points) != 2: + raise AssertionError("num_control points should be a number or a pair of numbers") + if min(num_control_points) <= 2: + raise AssertionError("num_control_points should be greater than or equal to 3") self.num_control_points = (min(num_control_points), max(num_control_points)) self.prob = prob self._do_transform = False diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 64f641ecd1..18e2250084 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -37,6 +37,60 @@ ) from monai.utils import dtype_torch_to_numpy, ensure_tuple_size +__all__ = [ + "RandGaussianNoised", + "ShiftIntensityd", + "RandShiftIntensityd", + "ScaleIntensityd", + "RandScaleIntensityd", + "NormalizeIntensityd", + "ThresholdIntensityd", + "ScaleIntensityRanged", + "AdjustContrastd", + "RandAdjustContrastd", + "ScaleIntensityRangePercentilesd", + "MaskIntensityd", + "GaussianSmoothd", + "RandGaussianSmoothd", + "GaussianSharpend", + "RandGaussianSharpend", + "RandHistogramShiftd", + "RandGaussianNoiseD", + "RandGaussianNoiseDict", + "ShiftIntensityD", + "ShiftIntensityDict", + "RandShiftIntensityD", + "RandShiftIntensityDict", + "ScaleIntensityD", + "ScaleIntensityDict", + "RandScaleIntensityD", + "RandScaleIntensityDict", + "NormalizeIntensityD", + "NormalizeIntensityDict", + "ThresholdIntensityD", + "ThresholdIntensityDict", + "ScaleIntensityRangeD", + "ScaleIntensityRangeDict", + "AdjustContrastD", + "AdjustContrastDict", + "RandAdjustContrastD", + "RandAdjustContrastDict", + "ScaleIntensityRangePercentilesD", + "ScaleIntensityRangePercentilesDict", + "MaskIntensityD", + "MaskIntensityDict", + "GaussianSmoothD", + "GaussianSmoothDict", + "RandGaussianSmoothD", + "RandGaussianSmoothDict", + "GaussianSharpenD", + "GaussianSharpenDict", + "RandGaussianSharpenD", + "RandGaussianSharpenDict", + "RandHistogramShiftD", + "RandHistogramShiftDict", +] + class RandGaussianNoised(Randomizable, MapTransform): """ @@ -70,7 +124,8 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda image_shape = d[self.keys[0]].shape # image shape from the first data key self.randomize(image_shape) - assert self._noise is not None + if self._noise is None: + raise AssertionError if not self._do_transform: return d for key in self.keys: @@ -121,7 +176,8 @@ def __init__(self, keys: KeysCollection, offsets: Union[Tuple[float, float], flo if isinstance(offsets, (int, float)): self.offsets = (min(-offsets, offsets), max(-offsets, offsets)) else: - assert len(offsets) == 2, "offsets should be a number or pair of numbers." + if len(offsets) != 2: + raise AssertionError("offsets should be a number or pair of numbers.") self.offsets = (min(offsets), max(offsets)) self.prob = prob @@ -192,7 +248,8 @@ def __init__(self, keys: KeysCollection, factors: Union[Tuple[float, float], flo if isinstance(factors, (int, float)): self.factors = (min(-factors, factors), max(-factors, factors)) else: - assert len(factors) == 2, "factors should be a number or pair of numbers." + if len(factors) != 2: + raise AssertionError("factors should be a number or pair of numbers.") self.factors = (min(factors), max(factors)) self.prob = prob @@ -345,10 +402,14 @@ def __init__( self.prob: float = prob if isinstance(gamma, (int, float)): - assert gamma > 0.5, "if gamma is single number, must greater than 0.5 and value is picked from (0.5, gamma)" + if gamma <= 0.5: + raise AssertionError( + "if gamma is single number, must greater than 0.5 and value is picked from (0.5, gamma)" + ) self.gamma = (0.5, gamma) else: - assert len(gamma) == 2, "gamma should be a number or pair of numbers." + if len(gamma) != 2: + raise AssertionError("gamma should be a number or pair of numbers.") self.gamma = (min(gamma), max(gamma)) self._do_transform = False @@ -361,7 +422,8 @@ def randomize(self, data: Optional[Any] = None) -> None: def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) self.randomize() - assert self.gamma_value is not None + if self.gamma_value is None: + raise AssertionError if not self._do_transform: return d adjuster = AdjustContrast(self.gamma_value) @@ -416,17 +478,26 @@ class MaskIntensityd(MapTransform): of input image. if multiple channels, the channel number must match input data. mask_data will be converted to `bool` values by `mask_data > 0` before applying transform to input image. + if None, will extract the mask data from input data based on `mask_key`. + mask_key: the key to extract mask data from input dictionary, only works + when `mask_data` is None. """ - def __init__(self, keys: KeysCollection, mask_data: np.ndarray) -> None: + def __init__( + self, + keys: KeysCollection, + mask_data: Optional[np.ndarray] = None, + mask_key: Optional[str] = None, + ) -> None: super().__init__(keys) self.converter = MaskIntensity(mask_data) + self.mask_key = mask_key if mask_data is None else None def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.keys: - d[key] = self.converter(d[key]) + d[key] = self.converter(d[key], d[self.mask_key]) if self.mask_key is not None else self.converter(d[key]) return d @@ -635,11 +706,14 @@ def __init__( ) -> None: super().__init__(keys) if isinstance(num_control_points, int): - assert num_control_points > 2, "num_control_points should be greater than or equal to 3" + if num_control_points <= 2: + raise AssertionError("num_control_points should be greater than or equal to 3") self.num_control_points = (num_control_points, num_control_points) else: - assert len(num_control_points) == 2, "num_control points should be a number or a pair of numbers" - assert min(num_control_points) > 2, "num_control_points should be greater than or equal to 3" + if len(num_control_points) != 2: + raise AssertionError("num_control points should be a number or a pair of numbers") + if min(num_control_points) <= 2: + raise AssertionError("num_control_points should be greater than or equal to 3") self.num_control_points = (min(num_control_points), max(num_control_points)) self.prob = prob self._do_transform = False diff --git a/monai/transforms/io/__init__.py b/monai/transforms/io/__init__.py index d0044e3563..14ae193634 100644 --- a/monai/transforms/io/__init__.py +++ b/monai/transforms/io/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index fd44555fa7..3b359cc460 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,22 +13,19 @@ https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ -import warnings -from pathlib import Path -from typing import Dict, List, Optional, Sequence, Union +from typing import List, Optional, Sequence, Union import numpy as np -from torch.utils.data._utils.collate import np_str_obj_array_pattern -from monai.config import KeysCollection from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader -from monai.data.utils import correct_nifti_header_if_necessary from monai.transforms.compose import Transform from monai.utils import ensure_tuple, optional_import nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") +__all__ = ["LoadImage"] + class LoadImage(Transform): """ @@ -130,214 +127,3 @@ def __call__( return img_array meta_data["filename_or_obj"] = ensure_tuple(filename)[0] return img_array, meta_data - - -class LoadNifti(Transform): - """ - Load Nifti format file or files from provided path. If loading a list of - files, stack them together and add a new dimension as first dimension, and - use the meta data of the first image to represent the stacked result. Note - that the affine transform of all the images should be same if ``image_only=False``. - """ - - def __init__( - self, as_closest_canonical: bool = False, image_only: bool = False, dtype: Optional[np.dtype] = np.float32 - ) -> None: - """ - Args: - as_closest_canonical: if True, load the image as closest to canonical axis format. - image_only: if True return only the image volume, otherwise return image data array and header dict. - dtype: if not None convert the loaded image to this data type. - - Note: - The transform returns image data array if `image_only` is True, - or a tuple of two elements containing the data array, and the Nifti - header in a dict format otherwise. - if a dictionary header is returned: - - - header['affine'] stores the affine of the image. - - header['original_affine'] will be additionally created to store the original affine. - """ - warnings.warn("LoadNifti will be deprecated in v0.5, please use LoadImage instead.", DeprecationWarning) - self.as_closest_canonical = as_closest_canonical - self.image_only = image_only - self.dtype = dtype - - def __call__(self, filename: Union[Sequence[Union[Path, str]], Path, str]): - """ - Args: - filename: path file or file-like object or a list of files. - """ - filename = ensure_tuple(filename) - img_array = list() - compatible_meta: Dict = dict() - for name in filename: - img = nib.load(name) - img = correct_nifti_header_if_necessary(img) - header = dict(img.header) - header["filename_or_obj"] = name - header["affine"] = img.affine - header["original_affine"] = img.affine.copy() - header["as_closest_canonical"] = self.as_closest_canonical - ndim = img.header["dim"][0] - spatial_rank = min(ndim, 3) - header["spatial_shape"] = img.header["dim"][1 : spatial_rank + 1] - - if self.as_closest_canonical: - img = nib.as_closest_canonical(img) - header["affine"] = img.affine - - img_array.append(np.array(img.get_fdata(dtype=self.dtype))) - img.uncache() - - if self.image_only: - continue - - if not compatible_meta: - for meta_key in header: - meta_datum = header[meta_key] - if ( - isinstance(meta_datum, np.ndarray) - and np_str_obj_array_pattern.search(meta_datum.dtype.str) is not None - ): - continue - compatible_meta[meta_key] = meta_datum - else: - assert np.allclose( - header["affine"], compatible_meta["affine"] - ), "affine data of all images should be same." - - img_array = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0] - if self.image_only: - return img_array - return img_array, compatible_meta - - -class LoadPNG(Transform): - """ - Load common 2D image format (PNG, JPG, etc. using PIL) file or files from provided path. - If loading a list of files, stack them together and add a new dimension as first dimension, - and use the meta data of the first image to represent the stacked result. - It's based on the Image module in PIL library: - https://pillow.readthedocs.io/en/stable/reference/Image.html - """ - - def __init__(self, image_only: bool = False, dtype: Optional[np.dtype] = np.float32) -> None: - """ - Args: - image_only: if True return only the image volume, otherwise return image data array and metadata. - dtype: if not None convert the loaded image to this data type. - """ - warnings.warn("LoadPNG will be deprecated in v0.5, please use LoadImage instead.", DeprecationWarning) - self.image_only = image_only - self.dtype = dtype - - def __call__(self, filename: Union[Sequence[Union[Path, str]], Path, str]): - """ - Args: - filename: path file or file-like object or a list of files. - """ - filename = ensure_tuple(filename) - img_array = list() - compatible_meta = None - for name in filename: - img = Image.open(name) - data = np.asarray(img) - if self.dtype: - data = data.astype(self.dtype) - img_array.append(data) - - if self.image_only: - continue - - meta = dict() - meta["filename_or_obj"] = name - meta["spatial_shape"] = data.shape[:2] - meta["format"] = img.format - meta["mode"] = img.mode - meta["width"] = img.width - meta["height"] = img.height - if not compatible_meta: - compatible_meta = meta - else: - assert np.allclose( - meta["spatial_shape"], compatible_meta["spatial_shape"] - ), "all the images in the list should have same spatial shape." - - img_array = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0] - return img_array if self.image_only else (img_array, compatible_meta) - - -class LoadNumpy(Transform): - """ - Load arrays or pickled objects from .npy, .npz or pickled files, file or files are from provided path. - A typical usage is to load the `mask` data for classification task. - If loading a list of files or loading npz file, stack results together and add a new dimension as first dimension, - and use the meta data of the first file to represent the stacked result. - It can load part of the npz file with specified `npz_keys`. - It's based on the Numpy load/read API: - https://numpy.org/doc/stable/reference/generated/numpy.load.html - - """ - - def __init__( - self, data_only: bool = False, dtype: Optional[np.dtype] = np.float32, npz_keys: Optional[KeysCollection] = None - ) -> None: - """ - Args: - data_only: if True return only the data array, otherwise return data array and metadata. - dtype: if not None convert the loaded data to this data type. - npz_keys: if loading npz file, only load the specified keys, if None, load all the items. - stack the loaded items together to construct a new first dimension. - - """ - warnings.warn("LoadNumpy will be deprecated in v0.5, please use LoadImage instead.", DeprecationWarning) - self.data_only = data_only - self.dtype = dtype - if npz_keys is not None: - npz_keys = ensure_tuple(npz_keys) - self.npz_keys = npz_keys - - def __call__(self, filename: Union[Sequence[Union[Path, str]], Path, str]): - """ - Args: - filename: path file or file-like object or a list of files. - - Raises: - ValueError: When ``filename`` is a sequence and contains a "npz" file extension. - - """ - if isinstance(filename, (tuple, list)): - for name in filename: - if name.endswith(".npz"): - raise ValueError("Cannot load a sequence of npz files.") - filename = ensure_tuple(filename) - data_array: List = list() - compatible_meta = None - - def _save_data_meta(data_array, name, data, compatible_meta): - data_array.append(data if self.dtype is None else data.astype(self.dtype)) - if not self.data_only: - meta = dict() - meta["filename_or_obj"] = name - meta["spatial_shape"] = data.shape - if not compatible_meta: - compatible_meta = meta - else: - assert np.allclose( - meta["spatial_shape"], compatible_meta["spatial_shape"] - ), "all the data in the list should have same shape." - return compatible_meta - - for name in filename: - data = np.load(name, allow_pickle=True) - if name.endswith(".npz"): - # load expected items from NPZ file - npz_keys = [f"arr_{i}" for i in range(len(data))] if self.npz_keys is None else self.npz_keys - for k in npz_keys: - compatible_meta = _save_data_meta(data_array, name, data[k], compatible_meta) - else: - compatible_meta = _save_data_meta(data_array, name, data, compatible_meta) - - data_array = np.stack(data_array, axis=0) if len(data_array) > 1 else data_array[0] - return data_array if self.data_only else (data_array, compatible_meta) diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index ff8c439d3b..62ac4c8562 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,14 +15,20 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ -from typing import Callable, Optional, Union +from typing import Optional, Union import numpy as np from monai.config import KeysCollection from monai.data.image_reader import ImageReader from monai.transforms.compose import MapTransform -from monai.transforms.io.array import LoadImage, LoadNifti, LoadNumpy, LoadPNG +from monai.transforms.io.array import LoadImage + +__all__ = [ + "LoadImaged", + "LoadImageD", + "LoadImageDict", +] class LoadImaged(MapTransform): @@ -100,161 +106,4 @@ def __call__(self, data, reader: Optional[ImageReader] = None): return d -class LoadDatad(MapTransform): - """ - Base class for dictionary-based wrapper of IO loader transforms. - It must load image and metadata together. If loading a list of files in one key, - stack them together and add a new dimension as the first dimension, and use the - meta data of the first image to represent the stacked result. Note that the affine - transform of all the stacked images should be same. The output metadata field will - be created as ``key_{meta_key_postfix}``. - """ - - def __init__( - self, - keys: KeysCollection, - loader: Callable, - meta_key_postfix: str = "meta_dict", - overwriting: bool = False, - ) -> None: - """ - Args: - keys: keys of the corresponding items to be transformed. - See also: :py:class:`monai.transforms.compose.MapTransform` - loader: callable function to load data from expected source. - typically, it's array level transform, for example: `LoadNifti`, - `LoadPNG` and `LoadNumpy`, etc. - meta_key_postfix: use `key_{postfix}` to store the metadata of the loaded data, - default is `meta_dict`. The meta data is a dictionary object. - For example, load Nifti file for `image`, store the metadata into `image_meta_dict`. - overwriting: whether allow to overwrite existing meta data of same key. - default is False, which will raise exception if encountering existing key. - - Raises: - TypeError: When ``loader`` is not ``callable``. - TypeError: When ``meta_key_postfix`` is not a ``str``. - - """ - super().__init__(keys) - if not callable(loader): - raise TypeError(f"loader must be callable but is {type(loader).__name__}.") - self.loader = loader - if not isinstance(meta_key_postfix, str): - raise TypeError(f"meta_key_postfix must be a str but is {type(meta_key_postfix).__name__}.") - self.meta_key_postfix = meta_key_postfix - self.overwriting = overwriting - - def __call__(self, data): - """ - Raises: - KeyError: When not ``self.overwriting`` and key already exists in ``data``. - - """ - d = dict(data) - for key in self.keys: - data = self.loader(d[key]) - assert isinstance(data, (tuple, list)), "loader must return a tuple or list." - d[key] = data[0] - assert isinstance(data[1], dict), "metadata must be a dict." - key_to_add = f"{key}_{self.meta_key_postfix}" - if key_to_add in d and not self.overwriting: - raise KeyError(f"Meta data with key {key_to_add} already exists and overwriting=False.") - d[key_to_add] = data[1] - return d - - -class LoadNiftid(LoadDatad): - """ - Dictionary-based wrapper of :py:class:`monai.transforms.LoadNifti`, - must load image and metadata together. If loading a list of files in one key, - stack them together and add a new dimension as the first dimension, and use the - meta data of the first image to represent the stacked result. Note that the affine - transform of all the stacked images should be same. The output metadata field will - be created as ``key_{meta_key_postfix}``. - """ - - def __init__( - self, - keys: KeysCollection, - as_closest_canonical: bool = False, - dtype: Optional[np.dtype] = np.float32, - meta_key_postfix: str = "meta_dict", - overwriting: bool = False, - ) -> None: - """ - Args: - keys: keys of the corresponding items to be transformed. - See also: :py:class:`monai.transforms.compose.MapTransform` - as_closest_canonical: if True, load the image as closest to canonical axis format. - dtype: if not None convert the loaded image data to this data type. - meta_key_postfix: use `key_{postfix}` to store the metadata of the nifti image, - default is `meta_dict`. The meta data is a dictionary object. - For example, load nifti file for `image`, store the metadata into `image_meta_dict`. - overwriting: whether allow to overwrite existing meta data of same key. - default is False, which will raise exception if encountering existing key. - """ - loader = LoadNifti(as_closest_canonical, False, dtype) - super().__init__(keys, loader, meta_key_postfix, overwriting) - - -class LoadPNGd(LoadDatad): - """ - Dictionary-based wrapper of :py:class:`monai.transforms.LoadPNG`. - """ - - def __init__( - self, - keys: KeysCollection, - dtype: Optional[np.dtype] = np.float32, - meta_key_postfix: str = "meta_dict", - overwriting: bool = False, - ) -> None: - """ - Args: - keys: keys of the corresponding items to be transformed. - See also: :py:class:`monai.transforms.compose.MapTransform` - dtype: if not None convert the loaded image data to this data type. - meta_key_postfix: use `key_{postfix}` to store the metadata of the PNG image, - default is `meta_dict`. The meta data is a dictionary object. - For example, load PNG file for `image`, store the metadata into `image_meta_dict`. - overwriting: whether allow to overwrite existing meta data of same key. - default is False, which will raise exception if encountering existing key. - """ - loader = LoadPNG(False, dtype) - super().__init__(keys, loader, meta_key_postfix, overwriting) - - -class LoadNumpyd(LoadDatad): - """ - Dictionary-based wrapper of :py:class:`monai.transforms.LoadNumpy`. - """ - - def __init__( - self, - keys: KeysCollection, - dtype: Optional[np.dtype] = np.float32, - npz_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = "meta_dict", - overwriting: bool = False, - ) -> None: - """ - Args: - keys: keys of the corresponding items to be transformed. - See also: :py:class:`monai.transforms.compose.MapTransform` - dtype: if not None convert the loaded data to this data type. - npz_keys: if loading npz file, only load the specified keys, if None, load all the items. - stack the loaded items together to construct a new first dimension. - meta_key_postfix: use `key_{postfix}` to store the metadata of the Numpy data, - default is `meta_dict`. The meta data is a dictionary object. - For example, load Numpy file for `mask`, store the metadata into `mask_meta_dict`. - overwriting: whether allow to overwrite existing meta data of same key. - default is False, which will raise exception if encountering existing key. - """ - loader = LoadNumpy(data_only=False, dtype=dtype, npz_keys=npz_keys) - super().__init__(keys, loader, meta_key_postfix, overwriting) - - LoadImageD = LoadImageDict = LoadImaged -LoadNiftiD = LoadNiftiDict = LoadNiftid -LoadPNGD = LoadPNGDict = LoadPNGd -LoadNumpyD = LoadNumpyDict = LoadNumpyd diff --git a/monai/transforms/post/__init__.py b/monai/transforms/post/__init__.py index d0044e3563..14ae193634 100644 --- a/monai/transforms/post/__init__.py +++ b/monai/transforms/post/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 292aba799d..0c60b0cc89 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -25,6 +25,15 @@ from monai.transforms.utils import get_largest_connected_component_mask from monai.utils import ensure_tuple +__all__ = [ + "Activations", + "AsDiscrete", + "KeepLargestConnectedComponent", + "LabelToContour", + "MeanEnsemble", + "VoteEnsemble", +] + class Activations(Transform): """ @@ -154,7 +163,8 @@ def __call__( if to_onehot or self.to_onehot: _nclasses = self.n_classes if n_classes is None else n_classes - assert isinstance(_nclasses, int), "One of self.n_classes or n_classes must be an integer" + if not isinstance(_nclasses, int): + raise AssertionError("One of self.n_classes or n_classes must be an integer") img = one_hot(img, _nclasses) if threshold_values or self.threshold_values: @@ -405,6 +415,5 @@ def __call__(self, img: Union[Sequence[torch.Tensor], torch.Tensor]) -> torch.Te if self.num_classes is not None: # if not One-Hot, use "argmax" to vote the most common class return torch.argmax(img_, dim=1, keepdim=has_ch_dim) - else: - # for One-Hot data, round the float number to 0 or 1 - return torch.round(img_) + # for One-Hot data, round the float number to 0 or 1 + return torch.round(img_) diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 0bf01aa541..60cda11a91 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -32,6 +32,28 @@ ) from monai.utils import ensure_tuple_rep +__all__ = [ + "Activationsd", + "AsDiscreted", + "KeepLargestConnectedComponentd", + "LabelToContourd", + "Ensembled", + "MeanEnsembled", + "VoteEnsembled", + "ActivationsD", + "ActivationsDict", + "AsDiscreteD", + "AsDiscreteDict", + "KeepLargestConnectedComponentD", + "KeepLargestConnectedComponentDict", + "LabelToContourD", + "LabelToContourDict", + "MeanEnsembleD", + "MeanEnsembleDict", + "VoteEnsembleD", + "VoteEnsembleDict", +] + class Activationsd(MapTransform): """ diff --git a/monai/transforms/spatial/__init__.py b/monai/transforms/spatial/__init__.py index d0044e3563..14ae193634 100644 --- a/monai/transforms/spatial/__init__.py +++ b/monai/transforms/spatial/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 15e22fa8f6..3e1ded4e94 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -46,6 +46,28 @@ nib, _ = optional_import("nibabel") +__all__ = [ + "Spacing", + "Orientation", + "Flip", + "Resize", + "Rotate", + "Zoom", + "Rotate90", + "RandRotate90", + "RandRotate", + "RandFlip", + "RandZoom", + "AffineGrid", + "RandAffineGrid", + "RandDeformGrid", + "Resample", + "Affine", + "RandAffine", + "Rand2DElastic", + "Rand3DElastic", +] + class Spacing(Transform): """ @@ -242,7 +264,8 @@ def __call__( if self.as_closest_canonical: spatial_ornt = src else: - assert self.axcodes is not None + if self.axcodes is None: + raise AssertionError dst = nib.orientations.axcodes2ornt(self.axcodes[:sr], labels=self.labels) if len(dst) < sr: raise ValueError( @@ -277,7 +300,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), """ - flipped = list() + flipped = [] for channel in img: flipped.append(np.flip(channel, self.spatial_axis)) return np.stack(flipped).astype(img.dtype) @@ -555,14 +578,17 @@ def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None: Default: (0, 1), this is the first two axis in spatial dimensions. """ self.k = k - self.spatial_axes = spatial_axes + spatial_axes_ = ensure_tuple(spatial_axes) + if len(spatial_axes_) != 2: + raise ValueError("spatial_axes must be 2 int numbers to indicate the axes to rotate 90 degrees.") + self.spatial_axes = spatial_axes_ def __call__(self, img: np.ndarray) -> np.ndarray: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), """ - rotated = list() + rotated = [] for channel in img: rotated.append(np.rot90(channel, self.k, self.spatial_axes)) return np.stack(rotated).astype(img.dtype) @@ -782,7 +808,8 @@ def __init__( ) -> None: self.min_zoom = ensure_tuple(min_zoom) self.max_zoom = ensure_tuple(max_zoom) - assert len(self.min_zoom) == len(self.max_zoom), "min_zoom and max_zoom must have same length." + if len(self.min_zoom) != len(self.max_zoom): + raise AssertionError("min_zoom and max_zoom must have same length.") self.prob = prob self.mode: InterpolateMode = InterpolateMode(mode) self.padding_mode: NumpyPadMode = NumpyPadMode(padding_mode) @@ -1104,7 +1131,8 @@ def __call__( if not torch.is_tensor(img): img = torch.as_tensor(np.ascontiguousarray(img)) - assert grid is not None, "Error, grid argument must be supplied as an ndarray or tensor " + if grid is None: + raise AssertionError("Error, grid argument must be supplied as an ndarray or tensor ") grid = torch.tensor(grid) if not torch.is_tensor(grid) else grid.detach().clone() if self.device: img = img.to(self.device) @@ -1599,7 +1627,8 @@ def __call__( self.randomize(grid_size=sp_size) grid = create_grid(spatial_size=sp_size) if self.do_transform: - assert self.rand_offset is not None + if self.rand_offset is None: + raise AssertionError grid = torch.as_tensor(np.ascontiguousarray(grid), device=self.device) gaussian = GaussianFilter(3, self.sigma, 3.0).to(device=self.device) offset = torch.as_tensor(self.rand_offset, device=self.device).unsqueeze(0) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 5344928807..615a327d90 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -47,6 +47,51 @@ fall_back_tuple, ) +__all__ = [ + "Spacingd", + "Orientationd", + "Rotate90d", + "RandRotate90d", + "Resized", + "RandAffined", + "Rand2DElasticd", + "Rand3DElasticd", + "Flipd", + "RandFlipd", + "Rotated", + "RandRotated", + "Zoomd", + "RandZoomd", + "SpacingD", + "SpacingDict", + "OrientationD", + "OrientationDict", + "Rotate90D", + "Rotate90Dict", + "RandRotate90D", + "RandRotate90Dict", + "ResizeD", + "ResizeDict", + "RandAffineD", + "RandAffineDict", + "Rand2DElasticD", + "Rand2DElasticDict", + "Rand3DElasticD", + "Rand3DElasticDict", + "FlipD", + "FlipDict", + "RandFlipD", + "RandFlipDict", + "RotateD", + "RotateDict", + "RandRotateD", + "RandRotateDict", + "ZoomD", + "ZoomDict", + "RandZoomD", + "RandZoomDict", +] + GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str] GridSamplePadModeSequence = Union[Sequence[Union[GridSamplePadMode, str]], GridSamplePadMode, str] InterpolateModeSequence = Union[Sequence[Union[InterpolateMode, str]], InterpolateMode, str] @@ -967,7 +1012,8 @@ def __init__( super().__init__(keys) self.min_zoom = ensure_tuple(min_zoom) self.max_zoom = ensure_tuple(max_zoom) - assert len(self.min_zoom) == len(self.max_zoom), "min_zoom and max_zoom must have same length." + if len(self.min_zoom) != len(self.max_zoom): + raise AssertionError("min_zoom and max_zoom must have same length.") self.prob = prob self.mode = ensure_tuple_rep(mode, len(self.keys)) diff --git a/monai/transforms/utility/__init__.py b/monai/transforms/utility/__init__.py index d0044e3563..14ae193634 100644 --- a/monai/transforms/utility/__init__.py +++ b/monai/transforms/utility/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 8daad86dd2..5476e800f4 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,7 +22,29 @@ from monai.transforms.compose import Randomizable, Transform from monai.transforms.utils import extreme_points_to_image, get_extreme_points, map_binary_to_indices -from monai.utils import ensure_tuple +from monai.utils import ensure_tuple, min_version, optional_import + +__all__ = [ + "Identity", + "AsChannelFirst", + "AsChannelLast", + "AddChannel", + "RepeatChannel", + "SplitChannel", + "CastToType", + "ToTensor", + "ToNumpy", + "Transpose", + "SqueezeDim", + "DataStats", + "SimulateDelay", + "Lambda", + "LabelToMask", + "FgBgToIndices", + "ConvertToMultiChannelBasedOnBratsClasses", + "AddExtremePointsChannel", + "TorchVision", +] # Generic type which can represent either a numpy.ndarray or a torch.Tensor # Unlike Union can create a dependence between parameter(s) / return(s) @@ -61,7 +83,8 @@ class AsChannelFirst(Transform): """ def __init__(self, channel_dim: int = -1) -> None: - assert isinstance(channel_dim, int) and channel_dim >= -1, "invalid channel dimension." + if not (isinstance(channel_dim, int) and channel_dim >= -1): + raise AssertionError("invalid channel dimension.") self.channel_dim = channel_dim def __call__(self, img: np.ndarray) -> np.ndarray: @@ -87,7 +110,8 @@ class AsChannelLast(Transform): """ def __init__(self, channel_dim: int = 0) -> None: - assert isinstance(channel_dim, int) and channel_dim >= -1, "invalid channel dimension." + if not (isinstance(channel_dim, int) and channel_dim >= -1): + raise AssertionError("invalid channel dimension.") self.channel_dim = channel_dim def __call__(self, img: np.ndarray) -> np.ndarray: @@ -129,7 +153,8 @@ class RepeatChannel(Transform): """ def __init__(self, repeats: int) -> None: - assert repeats > 0, "repeats count must be greater than 0." + if repeats <= 0: + raise AssertionError("repeats count must be greater than 0.") self.repeats = repeats def __call__(self, img: np.ndarray) -> np.ndarray: @@ -169,7 +194,7 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> List[Union[np.ndarra if n_classes <= 1: raise RuntimeError("input image does not contain multiple channels.") - outputs = list() + outputs = [] slices = [slice(None)] * len(img.shape) for i in range(n_classes): slices[channel_dim] = slice(i, i + 1) @@ -206,10 +231,9 @@ def __call__( """ if isinstance(img, np.ndarray): return img.astype(self.dtype if dtype is None else dtype) - elif torch.is_tensor(img): + if torch.is_tensor(img): return torch.as_tensor(img, dtype=self.dtype if dtype is None else dtype) - else: - raise TypeError(f"img must be one of (numpy.ndarray, torch.Tensor) but is {type(img).__name__}.") + raise TypeError(f"img must be one of (numpy.ndarray, torch.Tensor) but is {type(img).__name__}.") class ToTensor(Transform): @@ -314,7 +338,8 @@ def __init__( TypeError: When ``additional_info`` is not an ``Optional[Callable]``. """ - assert isinstance(prefix, str), "prefix must be a string." + if not isinstance(prefix, str): + raise AssertionError("prefix must be a string.") self.prefix = prefix self.data_shape = data_shape self.value_range = value_range @@ -440,8 +465,7 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor], func: Optional[Callable return func(img) if self.func is not None: return self.func(img) - else: - raise ValueError("Incompatible values: func=None and self.func=None.") + raise ValueError("Incompatible values: func=None and self.func=None.") class LabelToMask(Transform): @@ -537,6 +561,27 @@ def __call__( return fg_indices, bg_indices +class ConvertToMultiChannelBasedOnBratsClasses(Transform): + """ + Convert labels to multi channels based on brats18 classes: + label 1 is the necrotic and non-enhancing tumor core + label 2 is the the peritumoral edema + label 4 is the GD-enhancing tumor + The possible classes are TC (Tumor core), WT (Whole tumor) + and ET (Enhancing tumor). + """ + + def __call__(self, img: np.ndarray) -> np.ndarray: + result = [] + # merge labels 1 (tumor non-enh) and 4 (tumor enh) to TC + result.append(np.logical_or(img == 1, img == 4)) + # merge labels 1 (tumor non-enh) and 4 (tumor enh) and 2 (large edema) to WT + result.append(np.logical_or(np.logical_or(img == 1, img == 4), img == 2)) + # label 4 is ET + result.append(img == 4) + return np.stack(result, axis=0).astype(np.float32) + + class AddExtremePointsChannel(Transform, Randomizable): """ Add extreme points of label to the image as a new channel. This transform generates extreme @@ -597,3 +642,32 @@ def __call__( ) return np.concatenate([img, points_image], axis=0) + + +class TorchVision: + """ + This is a wrapper transform for PyTorch TorchVision transform based on the specified transform name and args. + As most of the TorchVision transforms only work for PIL image and PyTorch Tensor, this transform expects input + data to be PyTorch Tensor, users can easily call `ToTensor` transform to convert a Numpy array to Tensor. + + """ + + def __init__(self, name: str, *args, **kwargs) -> None: + """ + Args: + name: The transform name in TorchVision package. + args: parameters for the TorchVision transform. + kwargs: parameters for the TorchVision transform. + + """ + super().__init__() + transform, _ = optional_import("torchvision.transforms", "0.8.0", min_version, name=name) + self.trans = transform(*args, **kwargs) + + def __call__(self, img: torch.Tensor): + """ + Args: + img: PyTorch Tensor data for the TorchVision transform. + + """ + return self.trans(img) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 28d7452e77..ef89dbe32d 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,13 +23,13 @@ import torch from monai.config import KeysCollection -from monai.transforms import extreme_points_to_image, get_extreme_points from monai.transforms.compose import MapTransform, Randomizable from monai.transforms.utility.array import ( AddChannel, AsChannelFirst, AsChannelLast, CastToType, + ConvertToMultiChannelBasedOnBratsClasses, DataStats, FgBgToIndices, Identity, @@ -40,10 +40,77 @@ SplitChannel, SqueezeDim, ToNumpy, + TorchVision, ToTensor, ) +from monai.transforms.utils import extreme_points_to_image, get_extreme_points from monai.utils import ensure_tuple, ensure_tuple_rep +__all__ = [ + "Identityd", + "AsChannelFirstd", + "AsChannelLastd", + "AddChanneld", + "RepeatChanneld", + "SplitChanneld", + "CastToTyped", + "ToTensord", + "ToNumpyd", + "DeleteItemsd", + "SelectItemsd", + "SqueezeDimd", + "DataStatsd", + "SimulateDelayd", + "CopyItemsd", + "ConcatItemsd", + "Lambdad", + "LabelToMaskd", + "FgBgToIndicesd", + "ConvertToMultiChannelBasedOnBratsClassesd", + "AddExtremePointsChanneld", + "IdentityD", + "IdentityDict", + "AsChannelFirstD", + "AsChannelFirstDict", + "AsChannelLastD", + "AsChannelLastDict", + "AddChannelD", + "AddChannelDict", + "RepeatChannelD", + "RepeatChannelDict", + "SplitChannelD", + "SplitChannelDict", + "CastToTypeD", + "CastToTypeDict", + "ToTensorD", + "ToTensorDict", + "DeleteItemsD", + "DeleteItemsDict", + "SqueezeDimD", + "SqueezeDimDict", + "DataStatsD", + "DataStatsDict", + "SimulateDelayD", + "SimulateDelayDict", + "CopyItemsD", + "CopyItemsDict", + "ConcatItemsD", + "ConcatItemsDict", + "LambdaD", + "LambdaDict", + "LabelToMaskD", + "LabelToMaskDict", + "FgBgToIndicesD", + "FgBgToIndicesDict", + "ConvertToMultiChannelBasedOnBratsClassesD", + "ConvertToMultiChannelBasedOnBratsClassesDict", + "AddExtremePointsChannelD", + "AddExtremePointsChannelDict", + "TorchVisiond", + "TorchVisionD", + "TorchVisionDict", +] + class Identityd(MapTransform): """ @@ -194,7 +261,8 @@ def __call__( for key in self.keys: rets = self.splitter(d[key]) postfixes: Sequence = list(range(len(rets))) if self.output_postfixes is None else self.output_postfixes - assert len(postfixes) == len(rets), "count of split results must match output_postfixes." + if len(postfixes) != len(rets): + raise AssertionError("count of split results must match output_postfixes.") for i, r in enumerate(rets): split_key = f"{key}_{postfixes[i]}" if split_key in d: @@ -284,14 +352,6 @@ class DeleteItemsd(MapTransform): It will remove the key-values and copy the others to construct a new dictionary. """ - def __init__(self, keys: KeysCollection) -> None: - """ - Args: - keys: keys of the corresponding items to be transformed. - See also: :py:class:`monai.transforms.compose.MapTransform` - """ - super().__init__(keys) - def __call__(self, data): return {key: val for key, val in data.items() if key not in self.keys} @@ -302,14 +362,6 @@ class SelectItemsd(MapTransform): It will copy the selected key-values and construct and new dictionary. """ - def __init__(self, keys): - """ - Args: - keys: keys of the corresponding items to be transformed. - See also: :py:class:`monai.transforms.compose.MapTransform` - """ - super().__init__(keys) - def __call__(self, data): result = {key: val for key, val in data.items() if key in self.keys} return result @@ -511,7 +563,7 @@ def __call__(self, data): """ d = dict(data) - output = list() + output = [] data_type = None for key in self.keys: if data_type is None: @@ -547,18 +599,27 @@ class Lambdad(MapTransform): See also: :py:class:`monai.transforms.compose.MapTransform` func: Lambda/function to be applied. It also can be a sequence of Callable, each element corresponds to a key in ``keys``. + overwrite: whether to overwrite the original data in the input dictionary with lamdbda function output. + default to True. it also can be a sequence of bool, each element corresponds to a key in ``keys``. """ - def __init__(self, keys: KeysCollection, func: Union[Sequence[Callable], Callable]) -> None: + def __init__( + self, + keys: KeysCollection, + func: Union[Sequence[Callable], Callable], + overwrite: Union[Sequence[bool], bool] = True, + ) -> None: super().__init__(keys) self.func = ensure_tuple_rep(func, len(self.keys)) - self.lambd = Lambda() + self.overwrite = ensure_tuple_rep(overwrite, len(self.keys)) + self._lambd = Lambda() def __call__(self, data): d = dict(data) for idx, key in enumerate(self.keys): - d[key] = self.lambd(d[key], func=self.func[idx]) - + ret = self._lambd(d[key], func=self.func[idx]) + if self.overwrite[idx]: + d[key] = ret return d @@ -639,6 +700,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform): """ + Dictionary-based wrapper of :py:class:`monai.transforms.ConvertToMultiChannelBasedOnBratsClasses`. Convert labels to multi channels based on brats18 classes: label 1 is the necrotic and non-enhancing tumor core label 2 is the the peritumoral edema @@ -647,17 +709,14 @@ class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform): and ET (Enhancing tumor). """ + def __init__(self, keys: KeysCollection): + super().__init__(keys) + self.converter = ConvertToMultiChannelBasedOnBratsClasses() + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.keys: - result = list() - # merge labels 1 (tumor non-enh) and 4 (tumor enh) to TC - result.append(np.logical_or(d[key] == 1, d[key] == 4)) - # merge labels 1 (tumor non-enh) and 4 (tumor enh) and 2 (large edema) to WT - result.append(np.logical_or(np.logical_or(d[key] == 1, d[key] == 4), d[key] == 2)) - # label 4 is ET - result.append(d[key] == 4) - d[key] = np.stack(result, axis=0).astype(np.float32) + d[key] = self.converter(d[key]) return d @@ -724,6 +783,33 @@ def __call__(self, data): return d +class TorchVisiond(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.TorchVision`. + As most of the TorchVision transforms only work for PIL image and PyTorch Tensor, this transform expects input + data to be dict of PyTorch Tensors, users can easily call `ToTensord` transform to convert Numpy to Tensor. + """ + + def __init__(self, keys: KeysCollection, name: str, *args, **kwargs) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + name: The transform name in TorchVision package. + args: parameters for the TorchVision transform. + kwargs: parameters for the TorchVision transform. + + """ + super().__init__(keys) + self.trans = TorchVision(name, *args, **kwargs) + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.keys: + d[key] = self.trans(d[key]) + return d + + IdentityD = IdentityDict = Identityd AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd AsChannelLastD = AsChannelLastDict = AsChannelLastd @@ -745,3 +831,4 @@ def __call__(self, data): ConvertToMultiChannelBasedOnBratsClassesDict ) = ConvertToMultiChannelBasedOnBratsClassesd AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld +TorchVisionD = TorchVisionDict = TorchVisiond diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 3b552f543c..23c6bd100a 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,6 +23,33 @@ measure, _ = optional_import("skimage.measure", "0.14.2", min_version) +__all__ = [ + "rand_choice", + "img_bounds", + "in_bounds", + "is_empty", + "zero_margins", + "rescale_array", + "rescale_instance_array", + "rescale_array_int_max", + "copypaste_arrays", + "resize_center", + "map_binary_to_indices", + "weighted_patch_samples", + "generate_pos_neg_label_crop_centers", + "apply_transform", + "create_grid", + "create_control_grid", + "create_rotate", + "create_shear", + "create_scale", + "create_translate", + "generate_spatial_bounding_box", + "get_largest_connected_component_mask", + "get_extreme_points", + "extreme_points_to_image", +] + def rand_choice(prob: float = 0.5) -> bool: """ @@ -317,15 +344,16 @@ def _correct_centers( return center_ori centers = [] + fg_indices, bg_indices = np.asarray(fg_indices), np.asarray(bg_indices) + if fg_indices.size == 0 and bg_indices.size == 0: + raise ValueError("No sampling location available.") - if not len(fg_indices) or not len(bg_indices): - if not len(fg_indices) and not len(bg_indices): - raise ValueError("No sampling location available.") + if fg_indices.size == 0 or bg_indices.size == 0: warnings.warn( f"N foreground {len(fg_indices)}, N background {len(bg_indices)}," "unable to generate class balanced samples." ) - pos_ratio = 0 if not len(fg_indices) else 1 + pos_ratio = 0 if fg_indices.size == 0 else 1 for _ in range(num_samples): indices_to_use = fg_indices if rand_state.rand() < pos_ratio else bg_indices @@ -424,7 +452,7 @@ def create_rotate(spatial_dims: int, radians: Union[Sequence[float], float]) -> return np.array([[cos_, -sin_, 0.0], [sin_, cos_, 0.0], [0.0, 0.0, 1.0]]) raise ValueError("radians must be non empty.") - elif spatial_dims == 3: + if spatial_dims == 3: affine = None if len(radians) >= 1: sin_, cos_ = np.sin(radians[0]), np.cos(radians[0]) @@ -463,7 +491,7 @@ def create_shear(spatial_dims: int, coefs: Union[Sequence[float], float]) -> np. if spatial_dims == 2: coefs = ensure_tuple_size(coefs, dim=2, pad_val=0.0) return np.array([[1, coefs[0], 0.0], [coefs[1], 1.0, 0.0], [0.0, 0.0, 1.0]]) - elif spatial_dims == 3: + if spatial_dims == 3: coefs = ensure_tuple_size(coefs, dim=6, pad_val=0.0) return np.array( [ diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index d2d3e41d67..e5567f9f16 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,9 +10,59 @@ # limitations under the License. # have to explicitly bring these in here to resolve circular import issues -from .aliases import * -from .decorators import * -from .enums import * -from .misc import * -from .module import * -from .profiling import * +from .aliases import alias, resolve_name +from .decorators import MethodReplacer, RestartGenerator +from .enums import ( + Activation, + Average, + BlendMode, + ChannelMatching, + GridSampleMode, + GridSamplePadMode, + InterpolateMode, + LossReduction, + Method, + MetricReduction, + Normalization, + NumpyPadMode, + PytorchPadMode, + SkipMode, + UpsampleMode, + Weight, +) +from .misc import ( + MAX_SEED, + copy_to_device, + dtype_numpy_to_torch, + dtype_torch_to_numpy, + ensure_tuple, + ensure_tuple_rep, + ensure_tuple_size, + fall_back_tuple, + first, + get_seed, + is_scalar, + is_scalar_tensor, + issequenceiterable, + list_to_dict, + progress_bar, + set_determinism, + star_zip_with, + zip_with, +) +from .module import ( + PT_BEFORE_1_7, + InvalidPyTorchVersionError, + OptionalImportError, + exact_version, + export, + get_full_type_name, + get_package_version, + get_torch_version_tuple, + has_option, + load_submodules, + min_version, + optional_import, +) +from .profiling import PerfContext, torch_profiler_full, torch_profiler_time_cpu_gpu, torch_profiler_time_end_to_end +from .state_cacher import StateCacher diff --git a/monai/utils/aliases.py b/monai/utils/aliases.py index 224cfe5c4b..e8192897b8 100644 --- a/monai/utils/aliases.py +++ b/monai/utils/aliases.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,6 +21,8 @@ alias_lock = threading.RLock() GlobalAliases = {} +__all__ = ["alias", "resolve_name"] + def alias(*names): """ @@ -58,7 +60,8 @@ def resolve_name(name): with alias_lock: obj = GlobalAliases.get(name, None) - assert name not in GlobalAliases or obj is not None + if name in GlobalAliases and obj is None: + raise AssertionError # attempt to resolve a qualified name if obj is None and "." in name: @@ -89,8 +92,7 @@ def resolve_name(name): modnames = [m.__name__ for m in foundmods] msg = f"Multiple modules ({modnames!r}) with declaration name {name!r} found, resolution is ambiguous." raise ValueError(msg) - else: - mods = list(foundmods) + mods = list(foundmods) obj = getattr(mods[0], name) diff --git a/monai/utils/decorators.py b/monai/utils/decorators.py index 35a594d077..1931d703c9 100644 --- a/monai/utils/decorators.py +++ b/monai/utils/decorators.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,6 +11,8 @@ from functools import wraps +__all__ = ["RestartGenerator", "MethodReplacer"] + class RestartGenerator: """ @@ -25,7 +27,7 @@ def __iter__(self): return self.create_gen() -class MethodReplacer(object): +class MethodReplacer: """ Base class for method decorators which can be used to replace methods pass to replace_method() with wrapped versions. """ diff --git a/monai/utils/enums.py b/monai/utils/enums.py index dbebbe364f..d1d2d3bcce 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,6 +11,25 @@ from enum import Enum +__all__ = [ + "NumpyPadMode", + "GridSampleMode", + "InterpolateMode", + "UpsampleMode", + "BlendMode", + "PytorchPadMode", + "GridSamplePadMode", + "Average", + "MetricReduction", + "LossReduction", + "Weight", + "Normalization", + "Activation", + "ChannelMatching", + "SkipMode", + "Method", +] + class NumpyPadMode(Enum): """ @@ -144,7 +163,7 @@ class Weight(Enum): UNIFORM = "uniform" -class Normalisation(Enum): +class Normalization(Enum): """ See also: - :py:class:`monai.networks.nets.ConvNormActi` diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 020884bbcc..2b31392a46 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,15 +10,39 @@ # limitations under the License. import collections.abc +import inspect import itertools import random +import types +import warnings from ast import literal_eval from distutils.util import strtobool -from typing import Any, Callable, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Optional, Sequence, Tuple, Union, cast import numpy as np import torch +__all__ = [ + "zip_with", + "star_zip_with", + "first", + "issequenceiterable", + "ensure_tuple", + "ensure_tuple_size", + "ensure_tuple_rep", + "fall_back_tuple", + "is_scalar_tensor", + "is_scalar", + "progress_bar", + "get_seed", + "set_determinism", + "list_to_dict", + "dtype_torch_to_numpy", + "dtype_numpy_to_torch", + "MAX_SEED", + "copy_to_device", +] + _seed = None _flag_deterministic = torch.backends.cudnn.deterministic _flag_cudnn_benchmark = torch.backends.cudnn.benchmark @@ -100,7 +124,7 @@ def ensure_tuple_rep(tup: Any, dim: int) -> Tuple[Any, ...]: """ if not issequenceiterable(tup): return (tup,) * dim - elif len(tup) == dim: + if len(tup) == dim: return tuple(tup) raise ValueError(f"Sequence must have length {dim}, got {len(tup)}.") @@ -245,7 +269,7 @@ def _parse_var(s): value = items[1].strip(" \n\r\t'") return key, value - d = dict() + d = {} if items: for item in items: key, value = _parse_var(item) @@ -286,3 +310,40 @@ def dtype_torch_to_numpy(dtype): def dtype_numpy_to_torch(dtype): """Convert a numpy dtype to its torch equivalent.""" return _np_to_torch_dtype[dtype] + + +def copy_to_device( + obj: Any, + device: Optional[Union[str, torch.device]], + non_blocking: bool = True, + verbose: bool = False, +) -> Any: + """ + Copy object or tuple/list/dictionary of objects to ``device``. + + Args: + obj: object or tuple/list/dictionary of objects to move to ``device``. + device: move ``obj`` to this device. Can be a string (e.g., ``cpu``, ``cuda``, + ``cuda:0``, etc.) or of type ``torch.device``. + non_blocking_transfer: when `True`, moves data to device asynchronously if + possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. + verbose: when `True`, will print a warning for any elements of incompatible type + not copied to ``device``. + Returns: + Same as input, copied to ``device`` where possible. Original input will be + unchanged. + """ + + if hasattr(obj, "to"): + return obj.to(device, non_blocking=non_blocking) + elif isinstance(obj, tuple): + return tuple(copy_to_device(o, device, non_blocking) for o in obj) + elif isinstance(obj, list): + return [copy_to_device(o, device, non_blocking) for o in obj] + elif isinstance(obj, dict): + return {k: copy_to_device(o, device, non_blocking) for k, o in obj.items()} + elif verbose: + fn_name = cast(types.FrameType, inspect.currentframe()).f_code.co_name + warnings.warn(f"{fn_name} called with incompatible type: " + f"{type(obj)}. Data will be returned unchanged.") + + return obj diff --git a/monai/utils/module.py b/monai/utils/module.py index dfd5fb7d7b..0e11a6531d 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -73,7 +73,7 @@ def load_submodules(basemod, load_all: bool = True, exclude_pattern: str = "(.*[ if (is_pkg or load_all) and name not in sys.modules and match(exclude_pattern, name) is None: try: mod = import_module(name) - importer.find_module(name).load_module(name) + importer.find_module(name).load_module(name) # type: ignore submodules.append(mod) except OptionalImportError: pass # could not import the optional deps., they are ignored @@ -85,8 +85,7 @@ def get_full_type_name(typeobj): module = typeobj.__module__ if module is None or module == str.__class__.__module__: return typeobj.__name__ # Avoid reporting __builtin__ - else: - return module + "." + typeobj.__name__ + return module + "." + typeobj.__name__ def min_version(the_module, min_version_str: str = "") -> bool: @@ -189,7 +188,8 @@ def optional_import( the_module = import_module(module) if not allow_namespace_pkg: is_namespace = getattr(the_module, "__file__", None) is None and hasattr(the_module, "__path__") - assert not is_namespace + if is_namespace: + raise AssertionError if name: # user specified to load class/function/... from the module the_module = getattr(the_module, name) except Exception as import_exception: # any exceptions during import diff --git a/monai/utils/profiling.py b/monai/utils/profiling.py index bcdc0357c4..695653e897 100644 --- a/monai/utils/profiling.py +++ b/monai/utils/profiling.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/utils/state_cacher.py b/monai/utils/state_cacher.py new file mode 100644 index 0000000000..66e9080724 --- /dev/null +++ b/monai/utils/state_cacher.py @@ -0,0 +1,92 @@ +import copy +import os +import tempfile +from typing import Dict, Optional + +import torch + +__all__ = ["StateCacher"] + + +class StateCacher: + """Class to cache and retrieve the state of an object. + + Objects can either be stored in memory or on disk. If stored on disk, they can be + stored in a given directory, or alternatively a temporary location will be used. + + If necessary/possible, restored objects will be returned to their original device. + + Example: + + >>> state_cacher = StateCacher(memory_cache, cache_dir=cache_dir) + >>> state_cacher.store("model", model.state_dict()) + >>> model.load_state_dict(state_cacher.retrieve("model")) + """ + + def __init__( + self, + in_memory: bool, + cache_dir: Optional[str] = None, + allow_overwrite: bool = True, + ) -> None: + """Constructor. + + Args: + in_memory: boolean to determine if the object will be cached in memory or on + disk. + cache_dir: directory for data to be cached if `in_memory==False`. Defaults + to using a temporary directory. Any created files will be deleted during + the `StateCacher`'s destructor. + allow_overwrite: allow the cache to be overwritten. If set to `False`, an + error will be thrown if a matching already exists in the list of cached + objects. + """ + self.in_memory = in_memory + self.cache_dir = cache_dir + self.allow_overwrite = allow_overwrite + + if self.cache_dir is None: + self.cache_dir = tempfile.gettempdir() + else: + if not os.path.isdir(self.cache_dir): + raise ValueError("Given `cache_dir` is not a valid directory.") + + self.cached: Dict[str, str] = {} + + def store(self, key, data_obj): + """Store a given object with the given key name.""" + if key in self.cached and not self.allow_overwrite: + raise RuntimeError("Cached key already exists and overwriting is disabled.") + if self.in_memory: + self.cached.update({key: {"obj": copy.deepcopy(data_obj)}}) + else: + fn = os.path.join(self.cache_dir, f"state_{key}_{id(self)}.pt") + self.cached.update({key: {"obj": fn}}) + torch.save(data_obj, fn) + # store object's device if relevant + if hasattr(data_obj, "device"): + self.cached[key]["device"] = data_obj.device + + def retrieve(self, key): + """Retrieve the object stored under a given key name.""" + if key not in self.cached: + raise KeyError(f"Target {key} was not cached.") + + if self.in_memory: + return self.cached[key]["obj"] + + fn = self.cached[key]["obj"] # pytype: disable=attribute-error + if not os.path.exists(fn): # pytype: disable=wrong-arg-types + raise RuntimeError(f"Failed to load state in {fn}. File doesn't exist anymore.") + data_obj = torch.load(fn, map_location=lambda storage, location: storage) + # copy back to device if necessary + if "device" in self.cached[key]: + data_obj = data_obj.to(self.cached[key]["device"]) + return data_obj + + def __del__(self): + """If necessary, delete any cached files existing in `cache_dir`.""" + if not self.in_memory: + for k in self.cached: + if os.path.exists(self.cached[k]["obj"]): + os.remove(self.cached[k]["obj"]) diff --git a/monai/visualize/__init__.py b/monai/visualize/__init__.py index 2fbd1dcf66..9ad61fa3f2 100644 --- a/monai/visualize/__init__.py +++ b/monai/visualize/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,5 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .class_activation_maps import * -from .img2tensorboard import * +from .class_activation_maps import CAM, GradCAM, GradCAMpp, ModelWithHooks, default_normalizer +from .img2tensorboard import ( + add_animated_gif, + add_animated_gif_no_channels, + make_animated_gif_summary, + plot_2d_or_3d_image, +) +from .occlusion_sensitivity import OcclusionSensitivity +from .visualizer import default_upsampler diff --git a/monai/visualize/class_activation_maps.py b/monai/visualize/class_activation_maps.py index 6fd29d1c96..a917bcf800 100644 --- a/monai/visualize/class_activation_maps.py +++ b/monai/visualize/class_activation_maps.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,12 +14,28 @@ import numpy as np import torch +import torch.nn as nn import torch.nn.functional as F +from monai.networks.utils import eval_mode, train_mode from monai.transforms import ScaleIntensity -from monai.utils import InterpolateMode, ensure_tuple +from monai.utils import ensure_tuple +from monai.visualize.visualizer import default_upsampler -__all__ = ["ModelWithHooks", "default_upsampler", "default_normalizer", "CAM", "GradCAM", "GradCAMpp"] +__all__ = ["CAM", "GradCAM", "GradCAMpp", "ModelWithHooks", "default_normalizer"] + + +def default_normalizer(x) -> np.ndarray: + """ + A linear intensity scaling by mapping the (min, max) to (1, 0). + + N.B.: This will flip magnitudes (i.e., smallest will become biggest and vice versa). + """ + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + scaler = ScaleIntensity(minv=1.0, maxv=0.0) + x = [scaler(x) for x in x] + return np.stack(x, axis=0) class ModelWithHooks: @@ -101,46 +117,78 @@ def class_score(self, logits, class_idx=None): return logits[:, class_idx].squeeze(), class_idx def __call__(self, x, class_idx=None, retain_graph=False): - logits = self.model(x) - acti, grad = None, None - if self.register_forward: - acti = tuple(self.activations[layer] for layer in self.target_layers) - if self.register_backward: - score, class_idx = self.class_score(logits, class_idx) - self.model.zero_grad() - self.score, self.class_idx = score, class_idx - score.sum().backward(retain_graph=retain_graph) - grad = tuple(self.gradients[layer] for layer in self.target_layers) + # Use train_mode if grad is required, else eval_mode + mode = train_mode if self.register_backward else eval_mode + with mode(self.model): + logits = self.model(x) + acti, grad = None, None + if self.register_forward: + acti = tuple(self.activations[layer] for layer in self.target_layers) + if self.register_backward: + score, class_idx = self.class_score(logits, class_idx) + self.model.zero_grad() + self.score, self.class_idx = score, class_idx + score.sum().backward(retain_graph=retain_graph) + grad = tuple(self.gradients[layer] for layer in self.target_layers) return logits, acti, grad + def get_wrapped_net(self): + return self.model + -def default_upsampler(spatial_size) -> Callable[[torch.Tensor], torch.Tensor]: +class CAMBase: """ - A linear interpolation method for upsampling the feature map. - The output of this function is a callable `func`, - such that `func(activation_map)` returns an upsampled tensor. + Base class for CAM methods. """ - def up(acti_map): - linear_mode = [InterpolateMode.LINEAR, InterpolateMode.BILINEAR, InterpolateMode.TRILINEAR] - interp_mode = linear_mode[len(spatial_size) - 1] - return F.interpolate(acti_map, size=spatial_size, mode=str(interp_mode.value), align_corners=False) + def __init__( + self, + nn_module: nn.Module, + target_layers: str, + upsampler: Callable = default_upsampler, + postprocessing: Callable = default_normalizer, + register_backward: bool = True, + ) -> None: + # Convert to model with hooks if necessary + if not isinstance(nn_module, ModelWithHooks): + self.nn_module = ModelWithHooks( + nn_module, target_layers, register_forward=True, register_backward=register_backward + ) + else: + self.nn_module = nn_module + + self.upsampler = upsampler + self.postprocessing = postprocessing + + def feature_map_size(self, input_size, device="cpu", layer_idx=-1): + """ + Computes the actual feature map size given `nn_module` and the target_layer name. + Args: + input_size: shape of the input tensor + device: the device used to initialise the input tensor + layer_idx: index of the target layer if there are multiple target layers. Defaults to -1. + Returns: + shape of the actual feature map. + """ + return self.compute_map(torch.zeros(*input_size, device=device), layer_idx=layer_idx).shape - return up + def compute_map(self, x, class_idx=None, layer_idx=-1): + raise NotImplementedError() + def _upsample_and_post_process(self, acti_map, x): + # upsampling and postprocessing + if self.upsampler: + img_spatial = x.shape[2:] + acti_map = self.upsampler(img_spatial)(acti_map) + if self.postprocessing: + acti_map = self.postprocessing(acti_map) + return acti_map -def default_normalizer(acti_map) -> np.ndarray: - """ - A linear intensity scaling by mapping the (min, max) to (1, 0). - """ - if isinstance(acti_map, torch.Tensor): - acti_map = acti_map.detach().cpu().numpy() - scaler = ScaleIntensity(minv=1.0, maxv=0.0) - acti_map = [scaler(x) for x in acti_map] - return np.stack(acti_map, axis=0) + def __call__(self): + raise NotImplementedError() -class CAM: +class CAM(CAMBase): """ Compute class activation map from the last fully-connected layers before the spatial pooling. @@ -172,83 +220,66 @@ class CAM: def __init__( self, - nn_module, + nn_module: nn.Module, target_layers: str, fc_layers: Union[str, Callable] = "fc", - upsampler=default_upsampler, + upsampler: Callable = default_upsampler, postprocessing: Callable = default_normalizer, - ): + ) -> None: """ - Args: - nn_module: the model to be visualised + nn_module: the model to be visualized target_layers: name of the model layer to generate the feature map. fc_layers: a string or a callable used to get fully-connected weights to compute activation map from the target_layers (without pooling). and evaluate it at every spatial location. - upsampler: an upsampling method to upsample the feature map. - postprocessing: a callable that applies on the upsampled feature map. + upsampler: An upsampling method to upsample the output image. Default is + N dimensional linear (bilinear, trilinear, etc.) depending on num spatial + dimensions of input. + postprocessing: a callable that applies on the upsampled output image. + Default is normalizing between min=1 and max=0 (i.e., largest input will become 0 and + smallest input will become 1). """ - if not isinstance(nn_module, ModelWithHooks): - self.net = ModelWithHooks(nn_module, target_layers, register_forward=True) - else: - self.net = nn_module - self.upsampler = upsampler - self.postprocessing = postprocessing + super().__init__( + nn_module=nn_module, + target_layers=target_layers, + upsampler=upsampler, + postprocessing=postprocessing, + register_backward=False, + ) self.fc_layers = fc_layers def compute_map(self, x, class_idx=None, layer_idx=-1): """ Compute the actual feature map with input tensor `x`. """ - logits, acti, _ = self.net(x) + logits, acti, _ = self.nn_module(x) acti = acti[layer_idx] if class_idx is None: class_idx = logits.max(1)[-1] b, c, *spatial = acti.shape acti = torch.split(acti.reshape(b, c, -1), 1, dim=2) # make the spatial dims 1D - fc_layers = self.net.get_layer(self.fc_layers) + fc_layers = self.nn_module.get_layer(self.fc_layers) output = torch.stack([fc_layers(a[..., 0]) for a in acti], dim=2) output = torch.stack([output[i, b : b + 1] for i, b in enumerate(class_idx)], dim=0) return output.reshape(b, 1, *spatial) # resume the spatial dims on the selected class - def feature_map_size(self, input_size, device="cpu", layer_idx=-1): - """ - Computes the actual feature map size given `nn_module` and the target_layer name. - - Args: - input_size: shape of the input tensor - device: the device used to initialise the input tensor - layer_idx: index of the target layer if there are multiple target layers. Defaults to -1. - - Returns: - shape of the actual feature map. - """ - return self.compute_map(torch.zeros(*input_size, device=device), layer_idx=layer_idx).shape - def __call__(self, x, class_idx=None, layer_idx=-1): """ Compute the activation map with upsampling and postprocessing. Args: x: input tensor, shape must be compatible with `nn_module`. - class_idx: index of the class to be visualised. Default to argmax(logits) + class_idx: index of the class to be visualized. Default to argmax(logits) layer_idx: index of the target layer if there are multiple target layers. Defaults to -1. Returns: activation maps """ acti_map = self.compute_map(x, class_idx, layer_idx) - - # upsampling and postprocessing - if self.upsampler: - img_spatial = x.shape[2:] - acti_map = self.upsampler(img_spatial)(acti_map) - if self.postprocessing: - acti_map = self.postprocessing(acti_map) - return acti_map + return self._upsample_and_post_process(acti_map, x) -class GradCAM: +class GradCAM(CAMBase): """ Computes Gradient-weighted Class Activation Mapping (Grad-CAM). This implementation is based on: @@ -282,54 +313,24 @@ class GradCAM: """ - def __init__(self, nn_module, target_layers: str, upsampler=default_upsampler, postprocessing=default_normalizer): - """ - - Args: - nn_module: the model to be used to generate the visualisations. - target_layers: name of the model layer to generate the feature map. - upsampler: an upsampling method to upsample the feature map. - postprocessing: a callable that applies on the upsampled feature map. - """ - if not isinstance(nn_module, ModelWithHooks): - self.net = ModelWithHooks(nn_module, target_layers, register_forward=True, register_backward=True) - else: - self.net = nn_module - self.upsampler = upsampler - self.postprocessing = postprocessing - def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1): """ Compute the actual feature map with input tensor `x`. """ - logits, acti, grad = self.net(x, class_idx=class_idx, retain_graph=retain_graph) + _, acti, grad = self.nn_module(x, class_idx=class_idx, retain_graph=retain_graph) acti, grad = acti[layer_idx], grad[layer_idx] b, c, *spatial = grad.shape weights = grad.view(b, c, -1).mean(2).view(b, c, *[1] * len(spatial)) acti_map = (weights * acti).sum(1, keepdim=True) return F.relu(acti_map) - def feature_map_size(self, input_size, device="cpu", layer_idx=-1): - """ - Computes the actual feature map size given `nn_module` and the target_layer name. - - Args: - input_size: shape of the input tensor - device: the device used to initialise the input tensor - layer_idx: index of the target layer if there are multiple target layers. Defaults to -1. - - Returns: - shape of the actual feature map. - """ - return self.compute_map(torch.zeros(*input_size, device=device), layer_idx=layer_idx).shape - def __call__(self, x, class_idx=None, layer_idx=-1, retain_graph=False): """ Compute the activation map with upsampling and postprocessing. Args: x: input tensor, shape must be compatible with `nn_module`. - class_idx: index of the class to be visualised. Default to argmax(logits) + class_idx: index of the class to be visualized. Default to argmax(logits) layer_idx: index of the target layer if there are multiple target layers. Defaults to -1. retain_graph: whether to retain_graph for torch module backward call. @@ -337,14 +338,7 @@ def __call__(self, x, class_idx=None, layer_idx=-1, retain_graph=False): activation maps """ acti_map = self.compute_map(x, class_idx=class_idx, retain_graph=retain_graph, layer_idx=layer_idx) - - # upsampling and postprocessing - if self.upsampler: - img_spatial = x.shape[2:] - acti_map = self.upsampler(img_spatial)(acti_map) - if self.postprocessing: - acti_map = self.postprocessing(acti_map) - return acti_map + return self._upsample_and_post_process(acti_map, x) class GradCAMpp(GradCAM): @@ -365,14 +359,14 @@ def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1): """ Compute the actual feature map with input tensor `x`. """ - logits, acti, grad = self.net(x, class_idx=class_idx, retain_graph=retain_graph) + _, acti, grad = self.nn_module(x, class_idx=class_idx, retain_graph=retain_graph) acti, grad = acti[layer_idx], grad[layer_idx] b, c, *spatial = grad.shape alpha_nr = grad.pow(2) alpha_dr = alpha_nr.mul(2) + acti.mul(grad.pow(3)).view(b, c, -1).sum(-1).view(b, c, *[1] * len(spatial)) alpha_dr = torch.where(alpha_dr != 0.0, alpha_dr, torch.ones_like(alpha_dr)) alpha = alpha_nr.div(alpha_dr + 1e-7) - relu_grad = F.relu(self.net.score.exp() * grad) + relu_grad = F.relu(self.nn_module.score.exp() * grad) weights = (alpha * relu_grad).view(b, c, -1).sum(-1).view(b, c, *[1] * len(spatial)) acti_map = (weights * acti).sum(1, keepdim=True) return F.relu(acti_map) diff --git a/monai/visualize/img2tensorboard.py b/monai/visualize/img2tensorboard.py index c11bfcfc99..8f6eca5482 100644 --- a/monai/visualize/img2tensorboard.py +++ b/monai/visualize/img2tensorboard.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -40,7 +40,8 @@ def _image3_animated_gif(tag: str, image: Union[np.ndarray, torch.Tensor], scale scale_factor: amount to multiply values by. if the image data is between 0 and 1, using 255 for this value will scale it to displayable range """ - assert len(image.shape) == 3, "3D image tensors expected to be in `HWD` format, len(image.shape) != 3" + if len(image.shape) != 3: + raise AssertionError("3D image tensors expected to be in `HWD` format, len(image.shape) != 3") ims = [(np.asarray((image[:, :, i])) * scale_factor).astype(np.uint8) for i in range(image.shape[2])] ims = [GifImage.fromarray(im) for im in ims] diff --git a/monai/visualize/occlusion_sensitivity.py b/monai/visualize/occlusion_sensitivity.py new file mode 100644 index 0000000000..5863614965 --- /dev/null +++ b/monai/visualize/occlusion_sensitivity.py @@ -0,0 +1,318 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Sequence +from functools import partial +from typing import Callable, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn + +from monai.networks.utils import eval_mode +from monai.visualize.visualizer import default_upsampler + +try: + from tqdm import trange + + trange = partial(trange, desc="Computing occlusion sensitivity") +except (ImportError, AttributeError): + trange = range + +# For stride two (for example), +# if input array is: |0|1|2|3|4|5|6|7| +# downsampled output is: | 0 | 1 | 2 | 3 | +# So the upsampling should do it by the corners of the image, not their centres +default_upsampler = partial(default_upsampler, align_corners=True) + + +def _check_input_image(image): + """Check that the input image is as expected.""" + # Only accept batch size of 1 + if image.shape[0] > 1: + raise RuntimeError("Expected batch size of 1.") + + +def _check_input_bounding_box(b_box, im_shape): + """Check that the bounding box (if supplied) is as expected.""" + # If no bounding box has been supplied, set min and max to None + if b_box is None: + b_box_min = b_box_max = None + + # Bounding box has been supplied + else: + # Should be twice as many elements in `b_box` as `im_shape` + if len(b_box) != 2 * len(im_shape): + raise ValueError("Bounding box should contain upper and lower for all dimensions (except batch number)") + + # If any min's or max's are -ve, set them to 0 and im_shape-1, respectively. + b_box_min = np.array(b_box[::2]) + b_box_max = np.array(b_box[1::2]) + b_box_min[b_box_min < 0] = 0 + b_box_max[b_box_max < 0] = im_shape[b_box_max < 0] - 1 + # Check all max's are < im_shape + if np.any(b_box_max >= im_shape): + raise ValueError("Max bounding box should be < image size for all values") + # Check all min's are <= max's + if np.any(b_box_min > b_box_max): + raise ValueError("Min bounding box should be <= max for all values") + + return b_box_min, b_box_max + + +def _append_to_sensitivity_ims(model, batch_images, sensitivity_ims): + """Infer given images. Append to previous evaluations. Store each class separately.""" + batch_images = torch.cat(batch_images, dim=0) + scores = model(batch_images).detach() + for i in range(scores.shape[1]): + sensitivity_ims[i] = torch.cat((sensitivity_ims[i], scores[:, i])) + return sensitivity_ims + + +def _get_as_np_array(val, numel): + # If not a sequence, then convert scalar to numpy array + if not isinstance(val, Sequence): + out = np.full(numel, val, dtype=np.int32) + out[0] = 1 # mask_size and stride always 1 in channel dimension + else: + # Convert to numpy array and check dimensions match + out = np.array(val, dtype=np.int32) + # Add stride of 1 to the channel direction (since user input was only for spatial dimensions) + out = np.insert(out, 0, 1) + if out.size != numel: + raise ValueError( + "If supplying stride/mask_size as sequence, number of elements should match number of spatial dimensions." + ) + return out + + +class OcclusionSensitivity: + """ + This class computes the occlusion sensitivity for a model's prediction of a given image. By occlusion sensitivity, + we mean how the probability of a given prediction changes as the occluded section of an image changes. This can be + useful to understand why a network is making certain decisions. + + As important parts of the image are occluded, the probability of classifying the image correctly will decrease. + Hence, more negative values imply the corresponding occluded volume was more important in the decision process. + + Two ``torch.Tensor`` will be returned by the ``__call__`` method: an occlusion map and an image of the most probable + class. Both images will be cropped if a bounding box used, but voxel sizes will always match the input. + + The occlusion map shows the inference probabilities when the corresponding part of the image is occluded. Hence, + more -ve values imply that region was important in the decision process. The map will have shape ``BCHW(D)N``, + where ``N`` is the number of classes to be inferred by the network. Hence, the occlusion for class ``i`` can + be seen with ``map[...,i]``. + + The most probable class is an image of the probable class when the corresponding part of the image is occluded + (equivalent to ``occ_map.argmax(dim=-1)``). + + See: R. R. Selvaraju et al. Grad-CAM: Visual Explanations from Deep Networks via + Gradient-based Localization. https://doi.org/10.1109/ICCV.2017.74. + + Examples: + + .. code-block:: python + + # densenet 2d + from monai.networks.nets import densenet121 + from monai.visualize import OcclusionSensitivity + + model_2d = densenet121(spatial_dims=2, in_channels=1, out_channels=3) + occ_sens = OcclusionSensitivity(nn_module=model_2d) + occ_map, most_probable_class = occ_sens(x=torch.rand((1, 1, 48, 64)), class_idx=None, b_box=[-1, -1, 2, 40, 1, 62]) + + # densenet 3d + from monai.networks.nets import DenseNet + from monai.visualize import OcclusionSensitivity + + model_3d = DenseNet(spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,)) + occ_sens = OcclusionSensitivity(nn_module=model_3d, n_batch=10, stride=2) + occ_map, most_probable_class = occ_sens(torch.rand(1, 1, 6, 6, 6), class_idx=1, b_box=[-1, -1, 2, 3, -1, -1, -1, -1]) + + See Also: + + - :py:class:`monai.visualize.occlusion_sensitivity.OcclusionSensitivity.` + """ + + def __init__( + self, + nn_module: nn.Module, + pad_val: Optional[float] = None, + mask_size: Union[int, Sequence] = 15, + n_batch: int = 128, + stride: Union[int, Sequence] = 1, + upsampler: Optional[Callable] = default_upsampler, + verbose: bool = True, + ) -> None: + """Occlusion sensitivitiy constructor. + + Args: + nn_module: Classification model to use for inference + pad_val: When occluding part of the image, which values should we put + in the image? If ``None`` is used, then the average of the image will be used. + mask_size: Size of box to be occluded, centred on the central voxel. To ensure that the occluded area + is correctly centred, ``mask_size`` and ``stride`` should both be odd or even. + n_batch: Number of images in a batch for inference. + stride: Stride in spatial directions for performing occlusions. Can be single + value or sequence (for varying stride in the different directions). + Should be >= 1. Striding in the channel direction will always be 1. + upsampler: An upsampling method to upsample the output image. Default is + N-dimensional linear (bilinear, trilinear, etc.) depending on num spatial + dimensions of input. + verbose: Use ``tdqm.trange`` output (if available). + """ + + self.nn_module = nn_module + self.upsampler = upsampler + self.pad_val = pad_val + self.mask_size = mask_size + self.n_batch = n_batch + self.stride = stride + self.verbose = verbose + + def _compute_occlusion_sensitivity(self, x, b_box): + + # Get bounding box + im_shape = np.array(x.shape[1:]) + b_box_min, b_box_max = _check_input_bounding_box(b_box, im_shape) + + # Get the number of prediction classes + num_classes = self.nn_module(x).numel() + + #  If pad val not supplied, get the mean of the image + pad_val = x.mean() if self.pad_val is None else self.pad_val + + # List containing a batch of images to be inferred + batch_images = [] + + # List of sensitivity images, one for each inferred class + sensitivity_ims = num_classes * [torch.empty(0, dtype=torch.float32, device=x.device)] + + # If no bounding box supplied, output shape is same as input shape. + # If bounding box is present, shape is max - min + 1 + output_im_shape = im_shape if b_box is None else b_box_max - b_box_min + 1 + + # Get the stride and mask_size as numpy arrays + self.stride = _get_as_np_array(self.stride, len(im_shape)) + self.mask_size = _get_as_np_array(self.mask_size, len(im_shape)) + + # For each dimension, ... + for o, s in zip(output_im_shape, self.stride): + # if the size is > 1, then check that the stride is a factor of the output image shape + if o > 1 and o % s != 0: + raise ValueError( + "Stride should be a factor of the image shape. Im shape " + + f"(taking bounding box into account): {output_im_shape}, stride: {self.stride}" + ) + + # to ensure the occluded area is nicely centred if stride is even, ensure that so is the mask_size + if np.any(self.mask_size % 2 != self.stride % 2): + raise ValueError( + "Stride and mask size should both be odd or even (element-wise). " + + f"``stride={self.stride}``, ``mask_size={self.mask_size}``" + ) + + downsampled_im_shape = (output_im_shape / self.stride).astype(np.int32) + downsampled_im_shape[downsampled_im_shape == 0] = 1 # make sure dimension sizes are >= 1 + num_required_predictions = np.prod(downsampled_im_shape) + + # Get bottom left and top right corners of occluded region + lower_corner = (self.stride - self.mask_size) // 2 + upper_corner = (self.stride + self.mask_size) // 2 + + # Loop 1D over image + verbose_range = trange if self.verbose else range + for i in verbose_range(num_required_predictions): + # Get corresponding ND index + idx = np.unravel_index(i, downsampled_im_shape) + # Multiply by stride + idx *= self.stride + # If a bounding box is being used, we need to add on + # the min to shift to start of region of interest + if b_box_min is not None: + idx += b_box_min + + # Get min and max index of box to occlude (and make sure it's in bounds) + min_idx = np.maximum(idx + lower_corner, 0) + max_idx = np.minimum(idx + upper_corner, im_shape) + + # Clone and replace target area with `pad_val` + occlu_im = x.detach().clone() + occlu_im[(...,) + tuple(slice(i, j) for i, j in zip(min_idx, max_idx))] = pad_val + + # Add to list + batch_images.append(occlu_im) + + # Once the batch is complete (or on last iteration) + if len(batch_images) == self.n_batch or i == num_required_predictions - 1: + # Do the predictions and append to sensitivity maps + sensitivity_ims = _append_to_sensitivity_ims(self.nn_module, batch_images, sensitivity_ims) + # Clear lists + batch_images = [] + + # Reshape to match downsampled image, and unsqueeze to add batch dimension back in + for i in range(num_classes): + sensitivity_ims[i] = sensitivity_ims[i].reshape(tuple(downsampled_im_shape)).unsqueeze(0) + + return sensitivity_ims, output_im_shape + + def __call__( # type: ignore + self, + x: torch.Tensor, + b_box: Optional[Sequence] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: Image to use for inference. Should be a tensor consisting of 1 batch. + b_box: Bounding box on which to perform the analysis. The output image will be limited to this size. + There should be a minimum and maximum for all dimensions except batch: ``[min1, max1, min2, max2,...]``. + * By default, the whole image will be used. Decreasing the size will speed the analysis up, which might + be useful for larger images. + * Min and max are inclusive, so ``[0, 63, ...]`` will have size ``(64, ...)``. + * Use -ve to use ``min=0`` and ``max=im.shape[x]-1`` for xth dimension. + + Returns: + * Occlusion map: + * Shows the inference probabilities when the corresponding part of the image is occluded. + Hence, more -ve values imply that region was important in the decision process. + * The map will have shape ``BCHW(D)N``, where N is the number of classes to be inferred by the + network. Hence, the occlusion for class ``i`` can be seen with ``map[...,i]``. + * Most probable class: + * The most probable class when the corresponding part of the image is occluded (``argmax(dim=-1)``). + Both images will be cropped if a bounding box used, but voxel sizes will always match the input. + """ + + with eval_mode(self.nn_module): + + # Check input arguments + _check_input_image(x) + + # Generate sensitivity images + sensitivity_ims_list, output_im_shape = self._compute_occlusion_sensitivity(x, b_box) + + # Loop over image for each classification + for i in range(len(sensitivity_ims_list)): + + # upsample + if self.upsampler is not None: + if len(sensitivity_ims_list[i].shape) != len(x.shape): + raise AssertionError + if np.any(sensitivity_ims_list[i].shape != x.shape): + img_spatial = tuple(output_im_shape[1:]) + sensitivity_ims_list[i] = self.upsampler(img_spatial)(sensitivity_ims_list[i]) + + # Convert list of tensors to tensor + sensitivity_ims = torch.stack(sensitivity_ims_list, dim=-1) + + # The most probable class is the max in the classification dimension (last) + most_probable_class = sensitivity_ims.argmax(dim=-1) + + return sensitivity_ims, most_probable_class diff --git a/monai/visualize/visualizer.py b/monai/visualize/visualizer.py new file mode 100644 index 0000000000..bbb01f5c5e --- /dev/null +++ b/monai/visualize/visualizer.py @@ -0,0 +1,36 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable + +import torch +import torch.nn.functional as F + +from monai.utils import InterpolateMode + +__all__ = ["default_upsampler"] + + +def default_upsampler(spatial_size, align_corners=False) -> Callable[[torch.Tensor], torch.Tensor]: + """ + A linear interpolation method for upsampling the feature map. + The output of this function is a callable `func`, + such that `func(x)` returns an upsampled tensor. + """ + + def up(x): + + linear_mode = [InterpolateMode.LINEAR, InterpolateMode.BILINEAR, InterpolateMode.TRILINEAR] + interp_mode = linear_mode[len(spatial_size) - 1] + return F.interpolate(x, size=spatial_size, mode=str(interp_mode.value), align_corners=align_corners) + + return up diff --git a/runtests.sh b/runtests.sh index 33d5c73a90..76692e731b 100755 --- a/runtests.sh +++ b/runtests.sh @@ -1,6 +1,6 @@ #! /bin/bash -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/setup.cfg b/setup.cfg index 78cf8db6ca..aff62045e1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,7 +1,7 @@ [metadata] name = monai author = MONAI Consortium -author_email = monai.miccai2019@gmail.com +author_email = monai.contact@gmail.com url = https://monai.io/ description = AI Toolkit for Healthcare Imaging long_description = file:README.md @@ -61,7 +61,7 @@ max-line-length = 120 # C408 ignored because we like the dict keyword argument syntax # E501 is not flexible enough, we're using B950 instead ignore = - E203,E305,E402,E501,E721,E741,F403,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303, + E203,E305,E402,E501,E721,E741,F821,F841,F999,W503,W504,C408,E302,W291,E303, # N812 lowercase 'torch.nn.functional' imported as non lowercase 'F' N812 per-file-ignores = __init__.py: F401 diff --git a/setup.py b/setup.py index b9a1caa5ba..9b20df845a 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -38,7 +38,8 @@ BUILD_CUDA = (torch.cuda.is_available() and (CUDA_HOME is not None)) or FORCE_CUDA _pt_version = pkg_resources.parse_version(torch.__version__).release # type: ignore[attr-defined] - assert _pt_version is not None and len(_pt_version) >= 3, "unknown torch version" + if _pt_version is None or len(_pt_version) < 3: + raise AssertionError("unknown torch version") TORCH_VERSION = int(_pt_version[0]) * 10000 + int(_pt_version[1]) * 100 + int(_pt_version[2]) except (ImportError, TypeError, AssertionError, AttributeError) as e: TORCH_VERSION = 0 @@ -62,9 +63,9 @@ def torch_parallel_backend(): backend = match.group("backend") if backend == "OpenMP": return "AT_PARALLEL_OPENMP" - elif backend == "native thread pool": + if backend == "native thread pool": return "AT_PARALLEL_NATIVE" - elif backend == "native thread pool and TBB": + if backend == "native thread pool and TBB": return "AT_PARALLEL_NATIVE_TBB" except (NameError, AttributeError): # no torch or no binaries warnings.warn("Could not determine torch parallel_info.") diff --git a/tests/__init__.py b/tests/__init__.py index dbe9fe114c..5093d1f72d 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/clang_format_utils.py b/tests/clang_format_utils.py index 71db38fde8..41902eb272 100644 --- a/tests/clang_format_utils.py +++ b/tests/clang_format_utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -34,8 +34,8 @@ # This dictionary maps each platform to a relative path to a file containing its reference hash. # github/pytorch/pytorch/tree/63d62d3e44a0a4ec09d94f30381d49b78cc5b095/tools/clang_format_hash PLATFORM_TO_HASH = { - "Darwin": "020c7f38f14665c2ed82f3e8976c9074c2cfac0a", - "Linux": "d1365110da598d148d8143a7f2ccfd8bac7df499", + "Darwin": "b24cc8972344c4e01afbbae78d6a414f7638ff6f", + "Linux": "9073602de1c4e1748f2feea5a0782417b20e3043", } # Directory and file paths for the clang-format binary. diff --git a/tests/min_tests.py b/tests/min_tests.py index ccfc789992..665ead6cc6 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -67,13 +67,9 @@ def run_testsuit(): "test_lmdbdataset", "test_load_image", "test_load_imaged", - "test_load_nifti", - "test_load_niftid", - "test_load_png", - "test_load_pngd", "test_load_spacing_orientation", "test_mednistdataset", - "test_nifti_dataset", + "test_image_dataset", "test_nifti_header_revise", "test_nifti_rw", "test_nifti_saver", @@ -101,7 +97,12 @@ def run_testsuit(): "test_zoom", "test_zoom_affine", "test_zoomd", - "test_compute_occlusion_sensitivity", + "test_occlusion_sensitivity", + "test_torchvision", + "test_torchvisiond", + "test_handler_metrics_saver", + "test_handler_metrics_saver_dist", + "test_evenly_divisible_all_gather_dist", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" @@ -134,4 +135,4 @@ def run_testsuit(): # testing all modules test_runner = unittest.TextTestRunner(stream=sys.stdout, verbosity=2) result = test_runner.run(run_testsuit()) - exit(int(not result.wasSuccessful())) + sys.exit(int(not result.wasSuccessful())) diff --git a/tests/runner.py b/tests/runner.py index 4c249535bf..b5d1de5fc1 100644 --- a/tests/runner.py +++ b/tests/runner.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,7 +18,7 @@ from monai.utils import PerfContext -results: dict = dict() +results: dict = {} class TimeLoggingTestResult(unittest.TextTestResult): @@ -26,7 +26,7 @@ class TimeLoggingTestResult(unittest.TextTestResult): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.timed_tests = dict() + self.timed_tests = {} def startTest(self, test): # noqa: N802 """Start timer, print test name, do normal test.""" diff --git a/tests/test_activations.py b/tests/test_activations.py index 1bcc73e15d..1614642d6d 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -32,7 +32,7 @@ ] TEST_CASE_3 = [ - {"sigmoid": False, "softmax": False, "other": lambda x: torch.tanh(x)}, + {"sigmoid": False, "softmax": False, "other": torch.tanh}, torch.tensor([[[[0.0, 1.0], [2.0, 3.0]]]]), torch.tensor([[[[0.0000, 0.7616], [0.9640, 0.9951]]]]), (1, 1, 2, 2), diff --git a/tests/test_activationsd.py b/tests/test_activationsd.py index 9285ee2b1c..f186c17716 100644 --- a/tests/test_activationsd.py +++ b/tests/test_activationsd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -27,7 +27,7 @@ ] TEST_CASE_2 = [ - {"keys": ["pred", "label"], "sigmoid": False, "softmax": False, "other": [lambda x: torch.tanh(x), None]}, + {"keys": ["pred", "label"], "sigmoid": False, "softmax": False, "other": [torch.tanh, None]}, {"pred": torch.tensor([[[[0.0, 1.0], [2.0, 3.0]]]]), "label": torch.tensor([[[[0.0, 1.0], [2.0, 3.0]]]])}, { "pred": torch.tensor([[[[0.0000, 0.7616], [0.9640, 0.9951]]]]), @@ -37,7 +37,7 @@ ] TEST_CASE_3 = [ - {"keys": "pred", "sigmoid": False, "softmax": False, "other": lambda x: torch.tanh(x)}, + {"keys": "pred", "sigmoid": False, "softmax": False, "other": torch.tanh}, {"pred": torch.tensor([[[[0.0, 1.0], [2.0, 3.0]]]])}, {"pred": torch.tensor([[[[0.0000, 0.7616], [0.9640, 0.9951]]]])}, (1, 1, 2, 2), diff --git a/tests/test_adaptors.py b/tests/test_adaptors.py index 68fe6b687f..9bcd01feb7 100644 --- a/tests/test_adaptors.py +++ b/tests/test_adaptors.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_add_channeld.py b/tests/test_add_channeld.py index 101d9635cb..ca4af37271 100644 --- a/tests/test_add_channeld.py +++ b/tests/test_add_channeld.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_add_extreme_points_channel.py b/tests/test_add_extreme_points_channel.py index f4f3fa6d02..ecf2c83d3c 100644 --- a/tests/test_add_extreme_points_channel.py +++ b/tests/test_add_extreme_points_channel.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_add_extreme_points_channeld.py b/tests/test_add_extreme_points_channeld.py index 4fee176b20..e33bb0838c 100644 --- a/tests/test_add_extreme_points_channeld.py +++ b/tests/test_add_extreme_points_channeld.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_adjust_contrast.py b/tests/test_adjust_contrast.py index b84f379153..8e78698360 100644 --- a/tests/test_adjust_contrast.py +++ b/tests/test_adjust_contrast.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_adjust_contrastd.py b/tests/test_adjust_contrastd.py index 0b6d59b71d..65647607e4 100644 --- a/tests/test_adjust_contrastd.py +++ b/tests/test_adjust_contrastd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_adn.py b/tests/test_adn.py index 71ac286b03..2130ebc005 100644 --- a/tests/test_adn.py +++ b/tests/test_adn.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_affine.py b/tests/test_affine.py index 755e0cf0c0..fbda818437 100644 --- a/tests/test_affine.py +++ b/tests/test_affine.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_affine_grid.py b/tests/test_affine_grid.py index 8a24501f22..c7caae29b4 100644 --- a/tests/test_affine_grid.py +++ b/tests/test_affine_grid.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_affine_transform.py b/tests/test_affine_transform.py index ded37dce18..c3dc9cc6ef 100644 --- a/tests/test_affine_transform.py +++ b/tests/test_affine_transform.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_ahnet.py b/tests/test_ahnet.py index 78d2cebac3..777e2637a7 100644 --- a/tests/test_ahnet.py +++ b/tests/test_ahnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.blocks import FCN, MCFCN from monai.networks.nets import AHNet from tests.utils import skip_if_quick, test_pretrained_networks, test_script_save @@ -127,8 +128,7 @@ class TestFCN(unittest.TestCase): @skip_if_quick def test_fcn_shape(self, input_param, input_shape, expected_shape): net = FCN(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) @@ -138,8 +138,7 @@ class TestFCNWithPretrain(unittest.TestCase): @skip_if_quick def test_fcn_shape(self, input_param, input_shape, expected_shape): net = test_pretrained_networks(FCN, input_param, device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) @@ -148,8 +147,7 @@ class TestMCFCN(unittest.TestCase): @parameterized.expand([TEST_CASE_MCFCN_1, TEST_CASE_MCFCN_2, TEST_CASE_MCFCN_3]) def test_mcfcn_shape(self, input_param, input_shape, expected_shape): net = MCFCN(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) @@ -158,8 +156,7 @@ class TestMCFCNWithPretrain(unittest.TestCase): @parameterized.expand([TEST_CASE_MCFCN_WITH_PRETRAIN_1, TEST_CASE_MCFCN_WITH_PRETRAIN_2]) def test_mcfcn_shape(self, input_param, input_shape, expected_shape): net = test_pretrained_networks(MCFCN, input_param, device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) @@ -174,8 +171,7 @@ class TestAHNET(unittest.TestCase): ) def test_ahnet_shape_2d(self, input_param, input_shape, expected_shape): net = AHNet(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) @@ -189,16 +185,20 @@ def test_ahnet_shape_2d(self, input_param, input_shape, expected_shape): @skip_if_quick def test_ahnet_shape_3d(self, input_param, input_shape, expected_shape): net = AHNet(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) @skip_if_quick def test_script(self): + # test 2D network net = AHNet(spatial_dims=2, out_channels=2) test_data = torch.randn(1, 1, 128, 64) test_script_save(net, test_data) + # test 3D network + net = AHNet(spatial_dims=3, out_channels=2, psp_block_num=0, upsample_mode="nearest") + test_data = torch.randn(1, 1, 32, 32, 64) + test_script_save(net, test_data) class TestAHNETWithPretrain(unittest.TestCase): @@ -213,8 +213,7 @@ def test_ahnet_shape(self, input_param, input_shape, expected_shape, fcn_input_p net = AHNet(**input_param).to(device) net2d = FCN(**fcn_input_param).to(device) net.copy_from(net2d) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) @@ -230,7 +229,7 @@ def test_initialize_pretrained(self): progress=True, ).to(device) input_data = torch.randn(2, 2, 32, 32, 64).to(device) - with torch.no_grad(): + with eval_mode(net): result = net.forward(input_data) self.assertEqual(result.shape, (2, 3, 32, 32, 64)) diff --git a/tests/test_arraydataset.py b/tests/test_arraydataset.py index d5112d4200..f6459cc88c 100644 --- a/tests/test_arraydataset.py +++ b/tests/test_arraydataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_as_channel_first.py b/tests/test_as_channel_first.py index 6fa2df30fe..e7d9866ae1 100644 --- a/tests/test_as_channel_first.py +++ b/tests/test_as_channel_first.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_as_channel_firstd.py b/tests/test_as_channel_firstd.py index 584a0021ed..e70c2e1b47 100644 --- a/tests/test_as_channel_firstd.py +++ b/tests/test_as_channel_firstd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_as_channel_last.py b/tests/test_as_channel_last.py index 35fafa29d7..6ec6c8d6e6 100644 --- a/tests/test_as_channel_last.py +++ b/tests/test_as_channel_last.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_as_channel_lastd.py b/tests/test_as_channel_lastd.py index 198d72cfac..2ef4dd4da1 100644 --- a/tests/test_as_channel_lastd.py +++ b/tests/test_as_channel_lastd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_as_discrete.py b/tests/test_as_discrete.py index 46cef21e46..7e3b586cc9 100644 --- a/tests/test_as_discrete.py +++ b/tests/test_as_discrete.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_as_discreted.py b/tests/test_as_discreted.py index f98cb05e8d..0b4c483ac6 100644 --- a/tests/test_as_discreted.py +++ b/tests/test_as_discreted.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_autoencoder.py b/tests/test_autoencoder.py index 86b31e0361..36c04bb94f 100644 --- a/tests/test_autoencoder.py +++ b/tests/test_autoencoder.py @@ -1,8 +1,20 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.layers import Act from monai.networks.nets import AutoEncoder from tests.utils import test_script_save @@ -75,8 +87,7 @@ class TestAutoEncoder(unittest.TestCase): @parameterized.expand(CASES) def test_shape(self, input_param, input_shape, expected_shape): net = AutoEncoder(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_basic_unet.py b/tests/test_basic_unet.py index c2494dc2d3..e09e368f7b 100644 --- a/tests/test_basic_unet.py +++ b/tests/test_basic_unet.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.nets import BasicUNet from tests.utils import test_script_save @@ -95,8 +96,7 @@ def test_shape(self, input_param, input_shape, expected_shape): device = "cuda" if torch.cuda.is_available() else "cpu" print(input_param) net = BasicUNet(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_bending_energy.py b/tests/test_bending_energy.py new file mode 100644 index 0000000000..f2b9a41cae --- /dev/null +++ b/tests/test_bending_energy.py @@ -0,0 +1,79 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses.deform import BendingEnergyLoss + +TEST_CASES = [ + [ + {}, + {"pred": torch.ones((1, 3, 5, 5, 5))}, + 0.0, + ], + [ + {}, + {"pred": torch.arange(0, 5)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, + 0.0, + ], + [ + {}, + {"pred": torch.arange(0, 5)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, + 4.0, + ], + [ + {}, + {"pred": torch.arange(0, 5)[None, None, None, :].expand(1, 3, 5, 5) ** 2}, + 4.0, + ], + [ + {}, + {"pred": torch.arange(0, 5)[None, None, :].expand(1, 3, 5) ** 2}, + 4.0, + ], +] + + +class TestBendingEnergy(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_data, expected_val): + result = BendingEnergyLoss(**input_param).forward(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5) + + def test_ill_shape(self): + loss = BendingEnergyLoss() + # not in 3-d, 4-d, 5-d + with self.assertRaisesRegex(ValueError, ""): + loss.forward(torch.ones((1, 3))) + with self.assertRaisesRegex(ValueError, ""): + loss.forward(torch.ones((1, 3, 5, 5, 5, 5))) + # spatial_dim < 5 + with self.assertRaisesRegex(ValueError, ""): + loss.forward(torch.ones((1, 3, 4, 5, 5))) + with self.assertRaisesRegex(ValueError, ""): + loss.forward(torch.ones((1, 3, 5, 4, 5))) + with self.assertRaisesRegex(ValueError, ""): + loss.forward(torch.ones((1, 3, 5, 5, 4))) + + def test_ill_opts(self): + pred = torch.rand(1, 3, 5, 5, 5) + with self.assertRaisesRegex(ValueError, ""): + BendingEnergyLoss(reduction="unknown")(pred) + with self.assertRaisesRegex(ValueError, ""): + BendingEnergyLoss(reduction=None)(pred) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_bilateral_approx_cpu.py b/tests/test_bilateral_approx_cpu.py new file mode 100644 index 0000000000..71cf53519c --- /dev/null +++ b/tests/test_bilateral_approx_cpu.py @@ -0,0 +1,381 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.layers.filtering import BilateralFilter +from tests.utils import skip_if_no_cpp_extention + +TEST_CASES = [ + [ + # Case Descirption + "1 dimension, 1 channel, low spatial sigma, low color sigma", + # Spatial and Color Sigmas + (1, 0.2), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [1.000000, 0.000000, 0.000000, 0.000000, 1.000000] + ], + # Batch 1 + [ + # Channel 0 + [0.000000, 0.000000, 1.000000, 0.000000, 0.000000] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 1 channel, low spatial sigma, high color sigma", + # Spatial and Color Sigmas + (1, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.631360, 0.099349, 0.070177, 0.164534, 0.649869] + ], + # Batch 1 + [ + # Channel 0 + [0.052271, 0.173599, 0.481337, 0.183721, 0.045619] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 1 channel, high spatial sigma, low color sigma", + # Spatial and Color Sigmas + (4, 0.2), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [1.000000, 0.000000, 0.000000, 0.000000, 1.000000] + ], + # Batch 1 + [ + # Channel 0 + [0.000000, 0.000000, 1.000000, 0.000000, 0.000000] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 1 channel, high spatial sigma, high color sigma", + # Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.497667, 0.268683, 0.265026, 0.261467, 0.495981] + ], + # Batch 1 + [ + # Channel 0 + [0.145959, 0.142282, 0.315710, 0.135609, 0.132572] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 4 channel, low spatial sigma, high color sigma", + # Spatial and Color Sigmas + (1, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 0], + # Channel 1 + [1, 0, 1, 0, 0], + # Channel 2 + [0, 0, 1, 0, 1], + # Channel 3 + [0, 0, 0, 0, 1], + ] + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.960843, 0.073540, 0.027689, 0.002676, 0.000000], + # Channel 1 + [0.960843, 0.073540, 0.951248, 0.003033, 0.000750], + # Channel 2 + [0.000000, 0.000000, 0.923559, 0.000357, 0.981324], + # Channel 3 + [0.000000, 0.000000, 0.000000, 0.000000, 0.980574], + ] + ], + ], + [ + # Case Descirption + "2 dimension, 1 channel, high spatial sigma, high color sigma", + # Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]] + ], + # Batch 1 + [ + # Channel 0 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + [0.213684, 0.094356, 0.092973, 0.091650, 0.216281], + [0.094085, 0.092654, 0.091395, 0.090186, 0.089302], + [0.092436, 0.091150, 0.090008, 0.088896, 0.088897], + [0.090849, 0.089717, 0.088759, 0.087751, 0.088501], + [0.211458, 0.088334, 0.087495, 0.087049, 0.212173], + ] + ], + # Batch 1 + [ + # Channel 0 + [ + [0.033341, 0.031314, 0.029367, 0.027494, 0.025692], + [0.031869, 0.030632, 0.028820, 0.027074, 0.025454], + [0.030455, 0.029628, 0.084257, 0.026704, 0.025372], + [0.029095, 0.028391, 0.027790, 0.026375, 0.025292], + [0.027786, 0.027197, 0.026692, 0.026181, 0.025213], + ] + ], + ], + ], + [ + # Case Descirption + "2 dimension, 4 channel, high spatial sigma, high color sigma", + # Spatial and Color Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 1]], + # Channel 1 + [[1, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 1]], + # Channel 2 + [[0, 0, 1, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 1, 0, 0]], + # Channel 3 + [[0, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 0]], + ] + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + [0.244373, 0.014488, 0.036589, 0.014226, 0.024329], + [0.014108, 0.014228, 0.014096, 0.013961, 0.013823], + [0.013574, 0.013757, 0.013836, 0.013699, 0.013558], + [0.013008, 0.013211, 0.013404, 0.013438, 0.013295], + [0.025179, 0.012634, 0.034555, 0.013050, 0.237582], + ], + # Channel 1 + [ + [0.271496, 0.015547, 0.439432, 0.015700, 0.089579], + [0.015252, 0.015702, 0.015779, 0.015859, 0.015940], + [0.015020, 0.015556, 0.015935, 0.016015, 0.016098], + [0.014774, 0.015331, 0.015860, 0.016171, 0.016255], + [0.107384, 0.015094, 0.462471, 0.016166, 0.263480], + ], + # Channel 2 + [ + [0.027123, 0.003527, 0.467273, 0.004912, 0.645776], + [0.003810, 0.004908, 0.005605, 0.006319, 0.007050], + [0.004816, 0.005991, 0.006989, 0.007716, 0.008459], + [0.005880, 0.007060, 0.008179, 0.009101, 0.009858], + [0.633398, 0.008191, 0.496893, 0.010376, 0.025898], + ], + # Channel 3 + [ + [0.000000, 0.002468, 0.064430, 0.003437, 0.580526], + [0.002666, 0.003434, 0.003922, 0.004422, 0.004933], + [0.003370, 0.004192, 0.004890, 0.005399, 0.005919], + [0.004115, 0.004940, 0.005723, 0.006368, 0.006898], + [0.551194, 0.005731, 0.068977, 0.007260, 0.000000], + ], + ] + ], + ], + [ + # Case Descirption + "3 dimension, 1 channel, high spatial sigma, high color sigma", + # Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], + # Frame 1 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 2 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 3 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 4 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], + ] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [ + [0.086801, 0.036670, 0.035971, 0.035304, 0.088456], + [0.036639, 0.035652, 0.035009, 0.034394, 0.033803], + [0.035899, 0.034897, 0.034136, 0.033566, 0.033129], + [0.035180, 0.034238, 0.033413, 0.032811, 0.032577], + [0.088290, 0.033597, 0.032821, 0.032134, 0.088786], + ], + # Frame 1 + [ + [0.036286, 0.035269, 0.034632, 0.034021, 0.033435], + [0.035398, 0.034485, 0.033922, 0.033381, 0.033177], + [0.034688, 0.033822, 0.033169, 0.032664, 0.032780], + [0.034024, 0.033234, 0.032533, 0.032005, 0.032388], + [0.033564, 0.032797, 0.032118, 0.031525, 0.032105], + ], + # Frame 2 + [ + [0.035225, 0.034169, 0.033404, 0.032843, 0.032766], + [0.034383, 0.033487, 0.032908, 0.032415, 0.032650], + [0.033691, 0.032921, 0.032353, 0.031900, 0.032384], + [0.033080, 0.032390, 0.031786, 0.031432, 0.032008], + [0.033099, 0.032373, 0.031737, 0.031479, 0.032054], + ], + # Frame 3 + [ + [0.034216, 0.033231, 0.032337, 0.031758, 0.032101], + [0.033456, 0.032669, 0.031913, 0.031455, 0.032034], + [0.032788, 0.032140, 0.031618, 0.031413, 0.031977], + [0.032221, 0.031650, 0.031145, 0.031130, 0.031652], + [0.032642, 0.031968, 0.031378, 0.031433, 0.032003], + ], + # Frame 4 + [ + [0.086207, 0.032335, 0.031499, 0.030832, 0.087498], + [0.032570, 0.031884, 0.031155, 0.030858, 0.031401], + [0.031967, 0.031417, 0.030876, 0.030881, 0.031388], + [0.031602, 0.031103, 0.030696, 0.030960, 0.031455], + [0.090599, 0.031546, 0.031127, 0.031386, 0.083483], + ], + ] + ] + ], + ], +] + + +@skip_if_no_cpp_extention +class BilateralFilterTestCaseCpuApprox(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_cpu_approx(self, test_case_description, sigmas, input, expected): + + # Params to determine the implementation to test + device = torch.device("cpu") + fast_approx = True + + # Create input tensor and apply filter + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device) + output = BilateralFilter.apply(input_tensor, *sigmas, fast_approx).cpu().numpy() + + # Ensure result are as expected + np.testing.assert_allclose(output, expected, atol=1e-5) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_bilateral_approx_cuda.py b/tests/test_bilateral_approx_cuda.py new file mode 100644 index 0000000000..d0515d60e5 --- /dev/null +++ b/tests/test_bilateral_approx_cuda.py @@ -0,0 +1,386 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.layers.filtering import BilateralFilter +from tests.utils import skip_if_no_cpp_extention, skip_if_no_cuda + +TEST_CASES = [ + [ + # Case Descirption + "1 dimension, 1 channel, low spatial sigma, low color sigma", + # Spatial and Color Sigmas + (1, 0.2), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [1.000000, 0.000000, 0.000000, 0.000000, 1.000000] + ], + # Batch 1 + [ + # Channel 0 + [0.000000, 0.000000, 1.000000, 0.000000, 0.000000] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 1 channel, low spatial sigma, high color sigma", + # Spatial and Color Sigmas + (1, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.880626, 0.306148, 0.158734, 0.164534, 0.754386] + ], + # Batch 1 + [ + # Channel 0 + [0.019010, 0.104507, 0.605634, 0.183721, 0.045619] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 1 channel, high spatial sigma, low color sigma", + # Spatial and Color Sigmas + (4, 0.2), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [1.000000, 0.000000, 0.000000, 0.000000, 1.000000] + ], + # Batch 1 + [ + # Channel 0 + [0.000000, 0.000000, 1.000000, 0.000000, 0.000000] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 1 channel, high spatial sigma, high color sigma", + # Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.497667, 0.268683, 0.265026, 0.261467, 0.495981] + ], + # Batch 1 + [ + # Channel 0 + [0.149889, 0.148226, 0.367978, 0.144023, 0.141317] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 4 channel, low spatial sigma, high color sigma", + # Spatial and Color Sigmas + (1, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 0], + # Channel 1 + [1, 0, 1, 0, 0], + # Channel 2 + [0, 0, 1, 0, 1], + # Channel 3 + [0, 0, 0, 0, 1], + ] + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.988107, 0.061340, 0.001565, 0.000011, 0.000000], + # Channel 1 + [0.988107, 0.061340, 0.998000, 0.000016, 0.000123], + # Channel 2 + [0.000000, 0.000000, 0.996435, 0.000006, 0.999236], + # Channel 3 + [0.000000, 0.000000, 0.000000, 0.000000, 0.999113], + ] + ], + ], + [ + # Case Descirption + "2 dimension, 1 channel, high spatial sigma, high color sigma", + # Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]] + ], + # Batch 1 + [ + # Channel 0 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + [0.211469, 0.094356, 0.092973, 0.091650, 0.211894], + [0.093755, 0.091753, 0.090524, 0.089343, 0.088384], + [0.091803, 0.089783, 0.088409, 0.087346, 0.086927], + [0.089938, 0.088126, 0.086613, 0.085601, 0.085535], + [0.208359, 0.086535, 0.085179, 0.084210, 0.205858], + ] + ], + # Batch 1 + [ + # Channel 0 + [ + [0.032760, 0.030146, 0.027442, 0.024643, 0.021744], + [0.030955, 0.029416, 0.026574, 0.023629, 0.020841], + [0.028915, 0.027834, 0.115442, 0.022515, 0.020442], + [0.026589, 0.025447, 0.024319, 0.021286, 0.019964], + [0.023913, 0.022704, 0.021510, 0.020388, 0.019379], + ] + ], + ], + ], + [ + # Case Descirption + "2 dimension, 4 channel, high spatial sigma, high color sigma", + # Spatial and Color Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 1]], + # Channel 1 + [[1, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 1]], + # Channel 2 + [[0, 0, 1, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 1, 0, 0]], + # Channel 3 + [[0, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 0]], + ] + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + [0.557349, 0.011031, 0.001800, 0.011265, 0.000631], + [0.009824, 0.010361, 0.010429, 0.010506, 0.010595], + [0.008709, 0.009252, 0.009688, 0.009714, 0.009744], + [0.007589, 0.008042, 0.008576, 0.008887, 0.008852], + [0.000420, 0.006827, 0.001048, 0.007763, 0.190722], + ], + # Channel 1 + [ + [0.614072, 0.011045, 0.925766, 0.011287, 0.007548], + [0.009838, 0.010382, 0.010454, 0.010536, 0.010630], + [0.008727, 0.009277, 0.009720, 0.009751, 0.009787], + [0.007611, 0.008071, 0.008613, 0.008932, 0.008904], + [0.027088, 0.006859, 0.950749, 0.007815, 0.230270], + ], + # Channel 2 + [ + [0.056723, 0.000150, 0.973790, 0.000233, 0.990814], + [0.000151, 0.000214, 0.000257, 0.000307, 0.000364], + [0.000186, 0.000257, 0.000328, 0.000384, 0.000449], + [0.000221, 0.000295, 0.000382, 0.000465, 0.000538], + [0.993884, 0.000333, 0.984743, 0.000532, 0.039548], + ], + # Channel 3 + [ + [0.000000, 0.000136, 0.049824, 0.000210, 0.983897], + [0.000136, 0.000193, 0.000232, 0.000277, 0.000329], + [0.000168, 0.000232, 0.000297, 0.000347, 0.000405], + [0.000200, 0.000266, 0.000345, 0.000420, 0.000485], + [0.967217, 0.000301, 0.035041, 0.000481, 0.000000], + ], + ] + ], + ], + [ + # Case Descirption + "3 dimension, 1 channel, high spatial sigma, high color sigma", + # Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], + # Frame 1 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 2 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 3 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 4 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], + ] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [ + [0.085451, 0.037820, 0.036880, 0.035978, 0.084296], + [0.037939, 0.036953, 0.036155, 0.035385, 0.034640], + [0.037167, 0.036302, 0.035603, 0.034931, 0.034465], + [0.036469, 0.035724, 0.035137, 0.034572, 0.034480], + [0.088942, 0.035193, 0.034682, 0.034266, 0.090568], + ], + # Frame 1 + [ + [0.037125, 0.035944, 0.035103, 0.033429, 0.033498], + [0.033380, 0.032653, 0.033748, 0.033073, 0.032549], + [0.034834, 0.034001, 0.033500, 0.032902, 0.032560], + [0.033972, 0.033554, 0.033220, 0.032765, 0.032570], + [0.033590, 0.033222, 0.032927, 0.032689, 0.032629], + ], + # Frame 2 + [ + [0.035635, 0.034468, 0.033551, 0.032818, 0.032302], + [0.034523, 0.032830, 0.032146, 0.031536, 0.031149], + [0.033612, 0.032011, 0.031664, 0.031128, 0.030839], + [0.032801, 0.031668, 0.031529, 0.031198, 0.030978], + [0.032337, 0.031550, 0.031419, 0.031383, 0.031211], + ], + # Frame 3 + [ + [0.034300, 0.033236, 0.032239, 0.031517, 0.031133], + [0.033357, 0.031842, 0.031035, 0.030471, 0.030126], + [0.032563, 0.031094, 0.030156, 0.029703, 0.029324], + [0.031850, 0.030505, 0.030027, 0.029802, 0.029461], + [0.031555, 0.030121, 0.029943, 0.030000, 0.029700], + ], + # Frame 4 + [ + [0.083156, 0.032122, 0.031204, 0.030380, 0.080582], + [0.032296, 0.030936, 0.030170, 0.029557, 0.029124], + [0.031617, 0.030293, 0.029377, 0.028886, 0.028431], + [0.031084, 0.029859, 0.028839, 0.028439, 0.027973], + [0.164616, 0.029457, 0.028484, 0.028532, 0.211082], + ], + ] + ] + ], + ], +] + + +@skip_if_no_cuda +@skip_if_no_cpp_extention +class BilateralFilterTestCaseCudaApprox(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_cuda_approx(self, test_case_description, sigmas, input, expected): + + # Skip this test + if not torch.cuda.is_available(): + return + + # Params to determine the implementation to test + device = torch.device("cuda") + fast_approx = True + + # Create input tensor and apply filter + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device) + output = BilateralFilter.apply(input_tensor, *sigmas, fast_approx).cpu().numpy() + + # Ensure result are as expected + np.testing.assert_allclose(output, expected, atol=1e-2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_bilateral_precise.py b/tests/test_bilateral_precise.py new file mode 100644 index 0000000000..b02f3f04df --- /dev/null +++ b/tests/test_bilateral_precise.py @@ -0,0 +1,403 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.layers.filtering import BilateralFilter +from tests.utils import skip_if_no_cpp_extention, skip_if_no_cuda + +TEST_CASES = [ + [ + # Case Descirption + "1 dimension, 1 channel, low spatial sigma, low color sigma", + # Spatial and Color Sigmas + (1, 0.2), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.999998, 0.000002, 0.000000, 0.000002, 0.999998] + ], + # Batch 1 + [ + # Channel 0 + [0.000000, 0.000001, 0.999995, 0.000001, 0.000000] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 1 channel, low spatial sigma, high color sigma", + # Spatial and Color Sigmas + (1, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.813183, 0.186817, 0.061890, 0.186817, 0.813183] + ], + # Batch 1 + [ + # Channel 0 + [0.030148, 0.148418, 0.555452, 0.148418, 0.030148] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 1 channel, high spatial sigma, low color sigma", + # Spatial and Color Sigmas + (4, 0.2), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.999999, 0.000009, 0.000009, 0.000009, 0.999999] + ], + # Batch 1 + [ + # Channel 0 + [0.000000, 0.000000, 0.999967, 0.000000, 0.000000] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 1 channel, high spatial sigma, high color sigma", + # Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.839145, 0.572834, 0.562460, 0.572834, 0.839145] + ], + # Batch 1 + [ + # Channel 0 + [0.049925, 0.055062, 0.171732, 0.055062, 0.049925] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 4 channel, low spatial sigma, high color sigma", + # Spatial and Color Sigmas + (1, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 0], + # Channel 1 + [1, 0, 1, 0, 0], + # Channel 2 + [0, 0, 1, 0, 1], + # Channel 3 + [0, 0, 0, 0, 1], + ] + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.889742, 0.141296, 0.027504, 0.000000, 0.000000], + # Channel 1 + [0.909856, 0.256817, 0.725970, 0.115520, 0.020114], + # Channel 2 + [0.020114, 0.115520, 0.725970, 0.256817, 0.909856], + # Channel 3 + [0.000000, 0.000000, 0.027504, 0.141296, 0.889742], + ] + ], + ], + [ + # Case Descirption + "2 dimension, 1 channel, high spatial sigma, high color sigma", + # Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]] + ], + # Batch 1 + [ + # Channel 0 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + [0.688943, 0.374599, 0.368574, 0.374599, 0.688943], + [0.374599, 0.358248, 0.352546, 0.358248, 0.374599], + [0.368574, 0.352546, 0.346955, 0.352546, 0.368574], + [0.374599, 0.358248, 0.352546, 0.358248, 0.374599], + [0.688943, 0.374599, 0.368574, 0.374599, 0.688943], + ] + ], + # Batch 1 + [ + # Channel 0 + [ + [0.004266, 0.004687, 0.004836, 0.004687, 0.004266], + [0.004687, 0.005150, 0.005314, 0.005150, 0.004687], + [0.004836, 0.005314, 0.018598, 0.005314, 0.004836], + [0.004687, 0.005150, 0.005314, 0.005150, 0.004687], + [0.004266, 0.004687, 0.004836, 0.004687, 0.004266], + ] + ], + ], + ], + [ + # Case Descirption + "2 dimension, 4 channel, high spatial sigma, high color sigma", + # Spatial and Color Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 1]], + # Channel 1 + [[1, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 1]], + # Channel 2 + [[0, 0, 1, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 1, 0, 0]], + # Channel 3 + [[0, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 0]], + ] + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + [0.692549, 0.149979, 0.220063, 0.115840, 0.035799], + [0.148403, 0.133935, 0.123253, 0.116828, 0.114623], + [0.128773, 0.122804, 0.120731, 0.122804, 0.128773], + [0.114623, 0.116828, 0.123253, 0.133935, 0.148403], + [0.035799, 0.115840, 0.220063, 0.149979, 0.692549], + ], + # Channel 1 + [ + [0.731597, 0.186319, 0.436069, 0.152181, 0.074847], + [0.180049, 0.168217, 0.158453, 0.151110, 0.146269], + [0.159760, 0.156381, 0.155211, 0.156381, 0.159760], + [0.146269, 0.151110, 0.158453, 0.168217, 0.180049], + [0.074847, 0.152181, 0.436068, 0.186319, 0.731597], + ], + # Channel 2 + [ + [0.074847, 0.152181, 0.436068, 0.186319, 0.731597], + [0.146269, 0.151110, 0.158453, 0.168217, 0.180049], + [0.159760, 0.156381, 0.155211, 0.156381, 0.159760], + [0.180049, 0.168217, 0.158453, 0.151110, 0.146269], + [0.731597, 0.186319, 0.436069, 0.152181, 0.074847], + ], + # Channel 3 + [ + [0.035799, 0.115840, 0.220063, 0.149979, 0.692549], + [0.114623, 0.116828, 0.123253, 0.133935, 0.148403], + [0.128773, 0.122804, 0.120731, 0.122804, 0.128773], + [0.148403, 0.133935, 0.123253, 0.116828, 0.114623], + [0.692549, 0.149979, 0.220063, 0.115840, 0.035799], + ], + ] + ], + ], + [ + # Case Descirption + "3 dimension, 1 channel, high spatial sigma, high color sigma", + # Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], + # Frame 1 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 2 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 3 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 4 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], + ] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [ + [0.554430, 0.254995, 0.251207, 0.254996, 0.554430], + [0.254996, 0.244691, 0.241082, 0.244692, 0.254996], + [0.251207, 0.241082, 0.237534, 0.241082, 0.251207], + [0.254996, 0.244691, 0.241082, 0.244692, 0.254996], + [0.554430, 0.254995, 0.251207, 0.254996, 0.554430], + ], + # Frame 1 + [ + [0.254996, 0.244691, 0.241082, 0.244692, 0.254996], + [0.244692, 0.234873, 0.231432, 0.234873, 0.244692], + [0.241082, 0.231431, 0.228049, 0.231432, 0.241082], + [0.244692, 0.234873, 0.231432, 0.234873, 0.244692], + [0.254996, 0.244691, 0.241082, 0.244692, 0.254996], + ], + # Frame 2 + [ + [0.251207, 0.241081, 0.237534, 0.241082, 0.251207], + [0.241082, 0.231431, 0.228049, 0.231432, 0.241082], + [0.237534, 0.228048, 0.224724, 0.228049, 0.237534], + [0.241082, 0.231431, 0.228049, 0.231432, 0.241082], + [0.251207, 0.241081, 0.237534, 0.241082, 0.251207], + ], + # Frame 3 + [ + [0.254996, 0.244691, 0.241082, 0.244692, 0.254996], + [0.244692, 0.234873, 0.231432, 0.234873, 0.244692], + [0.241082, 0.231431, 0.228049, 0.231432, 0.241082], + [0.244692, 0.234873, 0.231432, 0.234873, 0.244692], + [0.254996, 0.244691, 0.241082, 0.244692, 0.254996], + ], + # Frame 4 + [ + [0.554430, 0.254995, 0.251207, 0.254996, 0.554430], + [0.254996, 0.244691, 0.241082, 0.244692, 0.254996], + [0.251207, 0.241082, 0.237534, 0.241082, 0.251207], + [0.254996, 0.244691, 0.241082, 0.244692, 0.254996], + [0.554430, 0.254995, 0.251207, 0.254996, 0.554430], + ], + ] + ] + ], + ], +] + + +@skip_if_no_cpp_extention +class BilateralFilterTestCaseCpuPrecised(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_cpu_precised(self, test_case_description, sigmas, input, expected): + + # Params to determine the implementation to test + device = torch.device("cpu") + fast_approx = False + + # Create input tensor and apply filter + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device) + output = BilateralFilter.apply(input_tensor, *sigmas, fast_approx).cpu().numpy() + + # Ensure result are as expected + np.testing.assert_allclose(output, expected, atol=1e-5) + + +@skip_if_no_cuda +@skip_if_no_cpp_extention +class BilateralFilterTestCaseCudaPrecised(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_cuda_precised(self, test_case_description, sigmas, input, expected): + + # Skip this test + if not torch.cuda.is_available(): + return + + # Params to determine the implementation to test + device = torch.device("cuda") + fast_approx = False + + # Create input tensor and apply filter + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device) + output = BilateralFilter.apply(input_tensor, *sigmas, fast_approx).cpu().numpy() + + # Ensure result are as expected + np.testing.assert_allclose(output, expected, atol=1e-5) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_border_pad.py b/tests/test_border_pad.py index 665656fcb3..14d93aae4e 100644 --- a/tests/test_border_pad.py +++ b/tests/test_border_pad.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_border_padd.py b/tests/test_border_padd.py index 511e280fe2..b48629fc98 100644 --- a/tests/test_border_padd.py +++ b/tests/test_border_padd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_bounding_rect.py b/tests/test_bounding_rect.py index 69476479a3..bcd89fabc9 100644 --- a/tests/test_bounding_rect.py +++ b/tests/test_bounding_rect.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_bounding_rectd.py b/tests/test_bounding_rectd.py index c33a3c371d..3019fe994a 100644 --- a/tests/test_bounding_rectd.py +++ b/tests/test_bounding_rectd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_cachedataset.py b/tests/test_cachedataset.py index 8e9350255c..2b8931704a 100644 --- a/tests/test_cachedataset.py +++ b/tests/test_cachedataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_cachedataset_parallel.py b/tests/test_cachedataset_parallel.py index 0f8453b041..0be3ba085b 100644 --- a/tests/test_cachedataset_parallel.py +++ b/tests/test_cachedataset_parallel.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_cachentransdataset.py b/tests/test_cachentransdataset.py index c9617d64db..492db8b16f 100644 --- a/tests/test_cachentransdataset.py +++ b/tests/test_cachentransdataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_cast_to_type.py b/tests/test_cast_to_type.py index 20a6e6c461..5e81b41650 100644 --- a/tests/test_cast_to_type.py +++ b/tests/test_cast_to_type.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_cast_to_typed.py b/tests/test_cast_to_typed.py index 3a38496c27..be495564fb 100644 --- a/tests/test_cast_to_typed.py +++ b/tests/test_cast_to_typed.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_center_spatial_crop.py b/tests/test_center_spatial_crop.py index d710e62432..c03ec24e18 100644 --- a/tests/test_center_spatial_crop.py +++ b/tests/test_center_spatial_crop.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_center_spatial_cropd.py b/tests/test_center_spatial_cropd.py index 5220162dcf..349253ab56 100644 --- a/tests/test_center_spatial_cropd.py +++ b/tests/test_center_spatial_cropd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_channel_pad.py b/tests/test_channel_pad.py index 00d0eab65a..ebc731c321 100644 --- a/tests/test_channel_pad.py +++ b/tests/test_channel_pad.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.layers import ChannelPad TEST_CASES_3D = [] @@ -34,8 +35,7 @@ class TestChannelPad(unittest.TestCase): @parameterized.expand(TEST_CASES_3D) def test_shape(self, input_param, input_shape, expected_shape): net = ChannelPad(**input_param) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(list(result.shape), list(expected_shape)) diff --git a/tests/test_check_hash.py b/tests/test_check_hash.py index df3f0a0174..0126b3c1a3 100644 --- a/tests/test_check_hash.py +++ b/tests/test_check_hash.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_compose.py b/tests/test_compose.py index 6c85835d51..3585b3453c 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_compute_confusion_matrix.py b/tests/test_compute_confusion_matrix.py index 6c322dba9b..56ca5371ab 100644 --- a/tests/test_compute_confusion_matrix.py +++ b/tests/test_compute_confusion_matrix.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -220,7 +220,7 @@ def test_value(self, input_data, expected_value): @parameterized.expand(TEST_CASES_COMPUTE_SAMPLE) def test_compute_sample(self, input_data, expected_value): params = input_data.copy() - vals = dict() + vals = {} vals["y_pred"] = params.pop("y_pred") vals["y"] = params.pop("y") metric = ConfusionMatrixMetric(**params) @@ -230,7 +230,7 @@ def test_compute_sample(self, input_data, expected_value): @parameterized.expand(TEST_CASES_COMPUTE_SAMPLE_MULTI_METRICS) def test_compute_sample_multiple_metrics(self, input_data, expected_values): params = input_data.copy() - vals = dict() + vals = {} vals["y_pred"] = params.pop("y_pred") vals["y"] = params.pop("y") metric = ConfusionMatrixMetric(**params) @@ -243,7 +243,7 @@ def test_compute_sample_multiple_metrics(self, input_data, expected_values): @parameterized.expand(TEST_CASES_COMPUTE_SAMPLE_NAN) def test_compute_sample_with_nan(self, input_data, expected_value, expected_not_nans): params = input_data.copy() - vals = dict() + vals = {} vals["y_pred"] = params.pop("y_pred") vals["y"] = params.pop("y") metric = ConfusionMatrixMetric(**params) @@ -254,7 +254,7 @@ def test_compute_sample_with_nan(self, input_data, expected_value, expected_not_ @parameterized.expand([TEST_CASES_CLF]) def test_clf_with_nan(self, input_data, expected_value): params = input_data.copy() - vals = dict() + vals = {} vals["y_pred"] = params.pop("y_pred") vals["y"] = params.pop("y") metric = ConfusionMatrixMetric(**params) diff --git a/tests/test_compute_meandice.py b/tests/test_compute_meandice.py index 4e4e02622c..64f38dcdb8 100644 --- a/tests/test_compute_meandice.py +++ b/tests/test_compute_meandice.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -184,7 +184,7 @@ def test_nans(self, input_data, expected_value): def test_value_class(self, input_data, expected_value): # same test as for compute_meandice - vals = dict() + vals = {} vals["y_pred"] = input_data.pop("y_pred") vals["y"] = input_data.pop("y") dice_metric = DiceMetric(**input_data, reduction="none") diff --git a/tests/test_compute_occlusion_sensitivity.py b/tests/test_compute_occlusion_sensitivity.py deleted file mode 100644 index 9f30162c47..0000000000 --- a/tests/test_compute_occlusion_sensitivity.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import torch -from parameterized import parameterized - -from monai.metrics import compute_occlusion_sensitivity -from monai.networks.nets import DenseNet, densenet121 - -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -model_2d = densenet121(spatial_dims=2, in_channels=1, out_channels=3).to(device) -model_3d = DenseNet( - spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,) -).to(device) -model_2d.eval() -model_3d.eval() - -# 2D w/ bounding box -TEST_CASE_0 = [ - { - "model": model_2d, - "image": torch.rand(1, 1, 48, 64).to(device), - "label": torch.tensor([[0]], dtype=torch.int64).to(device), - "b_box": [-1, -1, 2, 40, 1, 62], - }, - (39, 62), -] -# 3D w/ bounding box -TEST_CASE_1 = [ - { - "model": model_3d, - "image": torch.rand(1, 1, 6, 6, 6).to(device), - "label": 0, - "b_box": [-1, -1, 2, 3, -1, -1, -1, -1], - "n_batch": 10, - "stride": 2, - }, - (2, 6, 6), -] - - -class TestComputeOcclusionSensitivity(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) - def test_shape(self, input_data, expected_shape): - result = compute_occlusion_sensitivity(**input_data) - self.assertTupleEqual(result.shape, expected_shape) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_compute_roc_auc.py b/tests/test_compute_roc_auc.py index 8ff31e92ec..612bd375ac 100644 --- a/tests/test_compute_roc_auc.py +++ b/tests/test_compute_roc_auc.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_concat_itemsd.py b/tests/test_concat_itemsd.py index 36a1dbe5e2..520833fc88 100644 --- a/tests/test_concat_itemsd.py +++ b/tests/test_concat_itemsd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_convert_to_multi_channel.py b/tests/test_convert_to_multi_channel.py new file mode 100644 index 0000000000..ea27371ac7 --- /dev/null +++ b/tests/test_convert_to_multi_channel.py @@ -0,0 +1,33 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import ConvertToMultiChannelBasedOnBratsClasses + +TEST_CASE = [ + np.array([[0, 1, 2], [1, 2, 4], [0, 1, 4]]), + np.array([[[0, 1, 0], [1, 0, 1], [0, 1, 1]], [[0, 1, 1], [1, 1, 1], [0, 1, 1]], [[0, 0, 0], [0, 0, 1], [0, 0, 1]]]), +] + + +class TestConvertToMultiChannel(unittest.TestCase): + @parameterized.expand([TEST_CASE]) + def test_type_shape(self, data, expected_result): + result = ConvertToMultiChannelBasedOnBratsClasses()(data) + np.testing.assert_equal(result, expected_result) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_convert_to_multi_channeld.py b/tests/test_convert_to_multi_channeld.py index 2de3ee7394..945e07e1cd 100644 --- a/tests/test_convert_to_multi_channeld.py +++ b/tests/test_convert_to_multi_channeld.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_convolutions.py b/tests/test_convolutions.py index bb6ea45e62..97c01dd659 100644 --- a/tests/test_convolutions.py +++ b/tests/test_convolutions.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_copy_itemsd.py b/tests/test_copy_itemsd.py index 436cb5430b..e3133ae4f8 100644 --- a/tests/test_copy_itemsd.py +++ b/tests/test_copy_itemsd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,6 +15,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.transforms import CopyItemsd from monai.utils import ensure_tuple @@ -61,8 +62,7 @@ def test_array_values(self): def test_graph_tensor_values(self): device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu:0") net = torch.nn.PReLU().to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): pred = net(torch.tensor([[0.0, 1.0], [1.0, 2.0]], device=device)) input_data = {"pred": pred, "seg": torch.tensor([[0.0, 1.0], [1.0, 2.0]], device=device)} result = CopyItemsd(keys="pred", times=1, names="pred_1")(input_data) diff --git a/tests/test_create_grid_and_affine.py b/tests/test_create_grid_and_affine.py index 930558042d..0c0e52e04a 100644 --- a/tests/test_create_grid_and_affine.py +++ b/tests/test_create_grid_and_affine.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_crop_foreground.py b/tests/test_crop_foreground.py index dc87cca0a8..f50c7f11ff 100644 --- a/tests/test_crop_foreground.py +++ b/tests/test_crop_foreground.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_crop_foregroundd.py b/tests/test_crop_foregroundd.py index d6f7a33251..f4283514de 100644 --- a/tests/test_crop_foregroundd.py +++ b/tests/test_crop_foregroundd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_cross_validation.py b/tests/test_cross_validation.py index 21d5b7edf7..33e10a6a40 100644 --- a/tests/test_cross_validation.py +++ b/tests/test_cross_validation.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_csv_saver.py b/tests/test_csv_saver.py index d1ff1975ed..6dd0159322 100644 --- a/tests/test_csv_saver.py +++ b/tests/test_csv_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_data_stats.py b/tests/test_data_stats.py index 419957e22f..e7334eb52c 100644 --- a/tests/test_data_stats.py +++ b/tests/test_data_stats.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -78,7 +78,7 @@ "data_shape": True, "value_range": True, "data_value": True, - "additional_info": lambda x: np.mean(x), + "additional_info": np.mean, "logger_handler": None, }, np.array([[0, 1], [1, 2]]), @@ -124,7 +124,7 @@ def test_file(self, input_data, expected_print): "data_shape": True, "value_range": True, "data_value": True, - "additional_info": lambda x: np.mean(x), + "additional_info": np.mean, "logger_handler": handler, } transform = DataStats(**input_param) diff --git a/tests/test_data_statsd.py b/tests/test_data_statsd.py index 487d529952..a5fae3d66d 100644 --- a/tests/test_data_statsd.py +++ b/tests/test_data_statsd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -79,7 +79,7 @@ "data_shape": True, "value_range": True, "data_value": True, - "additional_info": lambda x: np.mean(x), + "additional_info": np.mean, }, {"img": np.array([[0, 1], [1, 2]])}, "test data statistics:\nShape: (2, 2)\nValue range: (0, 2)\nValue: [[0 1]\n [1 2]]\nAdditional info: 1.0", @@ -108,7 +108,7 @@ "data_shape": True, "value_range": (True, False), "data_value": (False, True), - "additional_info": (lambda x: np.mean(x), None), + "additional_info": (np.mean, None), }, {"img": np.array([[0, 1], [1, 2]]), "affine": np.eye(2, 2)}, "affine statistics:\nShape: (2, 2)\nValue: [[1. 0.]\n [0. 1.]]", @@ -138,7 +138,7 @@ def test_file(self, input_data, expected_print): "data_shape": True, "value_range": True, "data_value": True, - "additional_info": lambda x: np.mean(x), + "additional_info": np.mean, "logger_handler": handler, } transform = DataStatsd(**input_param) diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index aed6c70c80..072a4a01c0 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_dataset.py b/tests/test_dataset.py index b03a9a9552..2e92b15977 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_decathlondataset.py b/tests/test_decathlondataset.py index e7b9678f4d..15dbceb8ad 100644 --- a/tests/test_decathlondataset.py +++ b/tests/test_decathlondataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_delete_itemsd.py b/tests/test_delete_itemsd.py index c3dbbcbf1c..7426e39ff0 100644 --- a/tests/test_delete_itemsd.py +++ b/tests/test_delete_itemsd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,7 +23,7 @@ class TestDeleteItemsd(unittest.TestCase): @parameterized.expand([TEST_CASE_1]) def test_memory(self, input_param, expected_key_size): - input_data = dict() + input_data = {} for i in range(50): input_data[str(i)] = [time.time()] * 100000 result = DeleteItemsd(**input_param)(input_data) diff --git a/tests/test_densenet.py b/tests/test_densenet.py index 183c5443b2..876689314a 100644 --- a/tests/test_densenet.py +++ b/tests/test_densenet.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.nets import densenet121, densenet169, densenet201, densenet264 from tests.utils import skip_if_quick, test_pretrained_networks, test_script_save @@ -66,8 +67,7 @@ class TestPretrainedDENSENET(unittest.TestCase): @skip_if_quick def test_121_3d_shape_pretrain(self, model, input_param, input_shape, expected_shape): net = test_pretrained_networks(model, input_param, device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) @@ -76,8 +76,7 @@ class TestDENSENET(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_densenet_shape(self, model, input_param, input_shape, expected_shape): net = model(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_detect_envelope.py b/tests/test_detect_envelope.py index cbd281f6e8..08c699c84f 100644 --- a/tests/test_detect_envelope.py +++ b/tests/test_detect_envelope.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_dice_ce_loss.py b/tests/test_dice_ce_loss.py new file mode 100644 index 0000000000..443d9a9baf --- /dev/null +++ b/tests/test_dice_ce_loss.py @@ -0,0 +1,69 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses import DiceCELoss + +TEST_CASES = [ + [ # shape: (2, 2, 3), (2, 1, 3) + {"to_onehot_y": True}, + { + "input": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]), + "target": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]), + }, + 0.3133, # the result equals to -1 + np.log(1 + np.exp(1)) + ], + [ # shape: (2, 2, 3), (2, 2, 3), one-hot target + {"to_onehot_y": False}, + { + "input": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]), + "target": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]), + }, + 0.3133, + ], + [ # shape: (2, 2, 3), (2, 1, 3) + {"include_background": False, "to_onehot_y": True, "ce_weight": torch.tensor([1.0, 1.0])}, + { + "input": torch.tensor([[[100.0, 100.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]), + "target": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]), + }, + 0.2088, + ], + [ # shape: (2, 2, 3), (2, 1, 3), do not include class 0 + {"include_background": False, "to_onehot_y": True, "ce_weight": torch.tensor([0.0, 1.0])}, + { + "input": torch.tensor([[[100.0, 100.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]), + "target": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]), + }, + 0.3133, + ], +] + + +class TestDiceCELoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_result(self, input_param, input_data, expected_val): + result = DiceCELoss(**input_param)(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) + + def test_ill_shape(self): + loss = DiceCELoss() + with self.assertRaisesRegex(ValueError, ""): + loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_dice_loss.py b/tests/test_dice_loss.py index e2361df7a6..aa4a7cbc34 100644 --- a/tests/test_dice_loss.py +++ b/tests/test_dice_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_discriminator.py b/tests/test_discriminator.py index 2123737d05..52b9a10dd5 100644 --- a/tests/test_discriminator.py +++ b/tests/test_discriminator.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.nets import Discriminator from tests.utils import test_script_save @@ -42,8 +43,7 @@ class TestDiscriminator(unittest.TestCase): @parameterized.expand(CASES) def test_shape(self, input_param, input_data, expected_shape): net = Discriminator(**input_param) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(input_data) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_distributed_sampler.py b/tests/test_distributed_sampler.py index 8c182dd9e6..d0054885eb 100644 --- a/tests/test_distributed_sampler.py +++ b/tests/test_distributed_sampler.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,39 +9,38 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest + import numpy as np -import torch import torch.distributed as dist from monai.data import DistributedSampler +from tests.utils import DistCall, DistTestCase -def test(expected, **kwargs): - dist.init_process_group(backend="nccl", init_method="env://") - - torch.cuda.set_device(dist.get_rank()) - data = [1, 2, 3, 4, 5] - sampler = DistributedSampler(dataset=data, **kwargs) - samples = np.array([data[i] for i in list(sampler)]) - if dist.get_rank() == 0: - np.testing.assert_allclose(samples, np.array(expected[0])) - - if dist.get_rank() == 1: - np.testing.assert_allclose(samples, np.array(expected[1])) - - dist.destroy_process_group() +class DistributedSamplerTest(DistTestCase): + @DistCall(nnodes=1, nproc_per_node=2) + def test_even(self): + data = [1, 2, 3, 4, 5] + sampler = DistributedSampler(dataset=data, shuffle=False) + samples = np.array([data[i] for i in list(sampler)]) + if dist.get_rank() == 0: + np.testing.assert_allclose(samples, np.array([1, 3, 5])) + if dist.get_rank() == 1: + np.testing.assert_allclose(samples, np.array([2, 4, 1])) -def main(): - test(shuffle=False, expected=[[1, 3, 5], [2, 4, 1]]) - test(shuffle=False, even_divisible=False, expected=[[1, 3, 5], [2, 4]]) + @DistCall(nnodes=1, nproc_per_node=2) + def test_uneven(self): + data = [1, 2, 3, 4, 5] + sampler = DistributedSampler(dataset=data, shuffle=False, even_divisible=False) + samples = np.array([data[i] for i in list(sampler)]) + if dist.get_rank() == 0: + np.testing.assert_allclose(samples, np.array([1, 3, 5])) + if dist.get_rank() == 1: + np.testing.assert_allclose(samples, np.array([2, 4])) -# suppose to execute on 2 rank processes -# python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_PER_NODE -# --nnodes=NUM_NODES --node_rank=INDEX_CURRENT_NODE -# --master_addr="localhost" --master_port=1234 -# test_distributed_sampler.py if __name__ == "__main__": - main() + unittest.main() diff --git a/tests/test_divisible_pad.py b/tests/test_divisible_pad.py index c054a5eba4..27965b51d9 100644 --- a/tests/test_divisible_pad.py +++ b/tests/test_divisible_pad.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_divisible_padd.py b/tests/test_divisible_padd.py index 8195f3a6cf..d894a9f42e 100644 --- a/tests/test_divisible_padd.py +++ b/tests/test_divisible_padd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_download_and_extract.py b/tests/test_download_and_extract.py index b3db084d49..66bf19b442 100644 --- a/tests/test_download_and_extract.py +++ b/tests/test_download_and_extract.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_downsample_block.py b/tests/test_downsample_block.py index c2da0f9a43..f4ae30198f 100644 --- a/tests/test_downsample_block.py +++ b/tests/test_downsample_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.blocks import MaxAvgPool TEST_CASES = [ @@ -41,8 +42,7 @@ class TestMaxAvgPool(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_shape, expected_shape): net = MaxAvgPool(**input_param) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index 6b89c8c4fd..d72c1fc48a 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,6 +15,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.nets import DynUNet from tests.utils import test_script_save @@ -42,7 +43,6 @@ "strides": strides, "upsample_kernel_size": strides[1:], "norm_name": "batch", - "deep_supervision": False, "res_block": res_block, }, (1, in_channels, in_size, in_size), @@ -65,7 +65,6 @@ "strides": ((1, 2, 1), 2, 2, 1), "upsample_kernel_size": (2, 2, 1), "norm_name": "instance", - "deep_supervision": False, "res_block": res_block, }, (1, in_channels, in_size, in_size, in_size), @@ -87,7 +86,6 @@ "strides": strides, "upsample_kernel_size": strides[1:], "norm_name": "group", - "deep_supervision": True, "deep_supr_num": deep_supr_num, "res_block": res_block, }, @@ -107,10 +105,9 @@ class TestDynUNet(unittest.TestCase): @parameterized.expand(TEST_CASE_DYNUNET_2D + TEST_CASE_DYNUNET_3D) def test_shape(self, input_param, input_shape, expected_shape): net = DynUNet(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape).to(device)) - self.assertEqual(result[0].shape, expected_shape) + self.assertEqual(result.shape, expected_shape) def test_script(self): input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0] @@ -124,7 +121,7 @@ class TestDynUNetDeepSupervision(unittest.TestCase): def test_shape(self, input_param, input_shape, expected_shape): net = DynUNet(**input_param).to(device) with torch.no_grad(): - results = net(torch.randn(input_shape).to(device)) + results = [net(torch.randn(input_shape).to(device))] + net.get_feature_maps() self.assertEqual(len(results), len(expected_shape)) for idx in range(len(results)): result, sub_expected_shape = results[idx], expected_shape[idx] diff --git a/tests/test_dynunet_block.py b/tests/test_dynunet_block.py index 3a0bfa5e7e..c156b7b423 100644 --- a/tests/test_dynunet_block.py +++ b/tests/test_dynunet_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, UnetUpBlock, get_padding from tests.utils import test_script_save @@ -70,8 +71,7 @@ class TestResBasicBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_RES_BASIC_BLOCK) def test_shape(self, input_param, input_shape, expected_shape): for net in [UnetResBlock(**input_param), UnetBasicBlock(**input_param)]: - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) @@ -94,8 +94,7 @@ class TestUpBlock(unittest.TestCase): @parameterized.expand(TEST_UP_BLOCK) def test_shape(self, input_param, input_shape, expected_shape, skip_shape): net = UnetUpBlock(**input_param) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape), torch.randn(skip_shape)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_ensemble_evaluator.py b/tests/test_ensemble_evaluator.py index b7d84ca6f2..fdb9695476 100644 --- a/tests/test_ensemble_evaluator.py +++ b/tests/test_ensemble_evaluator.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_eval_mode.py b/tests/test_eval_mode.py new file mode 100644 index 0000000000..45c551c209 --- /dev/null +++ b/tests/test_eval_mode.py @@ -0,0 +1,31 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from monai.networks.utils import eval_mode + + +class TestEvalMode(unittest.TestCase): + def test_eval_mode(self): + t = torch.rand(1, 1, 4, 4) + p = torch.nn.Conv2d(1, 1, 3) + self.assertTrue(p.training) # True + with eval_mode(p): + self.assertFalse(p.training) # False + with self.assertRaises(RuntimeError): + p(t).sum().backward() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_evenly_divisible_all_gather_dist.py b/tests/test_evenly_divisible_all_gather_dist.py new file mode 100644 index 0000000000..70dcd7ca6a --- /dev/null +++ b/tests/test_evenly_divisible_all_gather_dist.py @@ -0,0 +1,42 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +import torch.distributed as dist + +from monai.handlers.utils import evenly_divisible_all_gather +from tests.utils import DistCall, DistTestCase + + +class DistributedEvenlyDivisibleAllGather(DistTestCase): + @DistCall(nnodes=1, nproc_per_node=2) + def test_data(self): + self._run() + + def _run(self): + if dist.get_rank() == 0: + data1 = torch.tensor([[1, 2], [3, 4]]) + data2 = torch.tensor([[1.0, 2.0]]) + + if dist.get_rank() == 1: + data1 = torch.tensor([[5, 6]]) + data2 = torch.tensor([[3.0, 4.0], [5.0, 6.0]]) + + result1 = evenly_divisible_all_gather(data=data1) + torch.testing.assert_allclose(result1, torch.tensor([[1, 2], [3, 4], [5, 6]])) + result2 = evenly_divisible_all_gather(data=data2) + torch.testing.assert_allclose(result2, torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_fg_bg_to_indices.py b/tests/test_fg_bg_to_indices.py index 4401818294..98626c7028 100644 --- a/tests/test_fg_bg_to_indices.py +++ b/tests/test_fg_bg_to_indices.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_fg_bg_to_indicesd.py b/tests/test_fg_bg_to_indicesd.py index e1f255815d..ce6ca30f1b 100644 --- a/tests/test_fg_bg_to_indicesd.py +++ b/tests/test_fg_bg_to_indicesd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_file_basename.py b/tests/test_file_basename.py index 98c4f6cf8a..21039d3d15 100644 --- a/tests/test_file_basename.py +++ b/tests/test_file_basename.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_flip.py b/tests/test_flip.py index 1f17c36e7a..7a2af02585 100644 --- a/tests/test_flip.py +++ b/tests/test_flip.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -32,7 +32,7 @@ def test_invalid_inputs(self, _, spatial_axis, raises): @parameterized.expand(VALID_CASES) def test_correct_results(self, _, spatial_axis): flip = Flip(spatial_axis=spatial_axis) - expected = list() + expected = [] for channel in self.imt[0]: expected.append(np.flip(channel, spatial_axis)) expected = np.stack(expected) diff --git a/tests/test_flipd.py b/tests/test_flipd.py index ec81b78fcc..b8996dee42 100644 --- a/tests/test_flipd.py +++ b/tests/test_flipd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -32,7 +32,7 @@ def test_invalid_cases(self, _, spatial_axis, raises): @parameterized.expand(VALID_CASES) def test_correct_results(self, _, spatial_axis): flip = Flipd(keys="img", spatial_axis=spatial_axis) - expected = list() + expected = [] for channel in self.imt[0]: expected.append(np.flip(channel, spatial_axis)) expected = np.stack(expected) diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py index 98807bc282..d06e2b4c36 100644 --- a/tests/test_focal_loss.py +++ b/tests/test_focal_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_fullyconnectednet.py b/tests/test_fullyconnectednet.py index 1819c4cdb9..ec91a99c3e 100644 --- a/tests/test_fullyconnectednet.py +++ b/tests/test_fullyconnectednet.py @@ -1,8 +1,20 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.nets import FullyConnectedNet, VarFullyConnectedNet device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -45,8 +57,7 @@ def test_fc_shape(self, dropout): @parameterized.expand(VFC_CASES) def test_vfc_shape(self, input_param, input_shape, expected_shape): net = VarFullyConnectedNet(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device))[0] self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_gaussian.py b/tests/test_gaussian.py index f1a5afcd19..e2659abb0c 100644 --- a/tests/test_gaussian.py +++ b/tests/test_gaussian.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_gaussian_filter.py b/tests/test_gaussian_filter.py index b273bd3d23..e056c961c9 100644 --- a/tests/test_gaussian_filter.py +++ b/tests/test_gaussian_filter.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_gaussian_sharpen.py b/tests/test_gaussian_sharpen.py index 2ff5781afc..9d078e65e5 100644 --- a/tests/test_gaussian_sharpen.py +++ b/tests/test_gaussian_sharpen.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_gaussian_sharpend.py b/tests/test_gaussian_sharpend.py index d304535d3a..c795b11762 100644 --- a/tests/test_gaussian_sharpend.py +++ b/tests/test_gaussian_sharpend.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_gaussian_smooth.py b/tests/test_gaussian_smooth.py index 4ad2061b64..e51977fbee 100644 --- a/tests/test_gaussian_smooth.py +++ b/tests/test_gaussian_smooth.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_gaussian_smoothd.py b/tests/test_gaussian_smoothd.py index b14c09faba..3d7eb6195e 100644 --- a/tests/test_gaussian_smoothd.py +++ b/tests/test_gaussian_smoothd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py index 3ddc68ae4e..e88253ccba 100644 --- a/tests/test_generalized_dice_loss.py +++ b/tests/test_generalized_dice_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_generalized_wasserstein_dice_loss.py b/tests/test_generalized_wasserstein_dice_loss.py index ce35b251f1..6865b53027 100644 --- a/tests/test_generalized_wasserstein_dice_loss.py +++ b/tests/test_generalized_wasserstein_dice_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_generate_param_groups.py b/tests/test_generate_param_groups.py index 2130234013..8ccb8b7977 100644 --- a/tests/test_generate_param_groups.py +++ b/tests/test_generate_param_groups.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_generate_pos_neg_label_crop_centers.py b/tests/test_generate_pos_neg_label_crop_centers.py index 138be0d282..40181aa9ea 100644 --- a/tests/test_generate_pos_neg_label_crop_centers.py +++ b/tests/test_generate_pos_neg_label_crop_centers.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_generate_spatial_bounding_box.py b/tests/test_generate_spatial_bounding_box.py index 338b6fe5d4..32a45d8d1c 100644 --- a/tests/test_generate_spatial_bounding_box.py +++ b/tests/test_generate_spatial_bounding_box.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_generator.py b/tests/test_generator.py index 46b469b111..b5d846febc 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.nets import Generator from tests.utils import test_script_save @@ -42,8 +43,7 @@ class TestGenerator(unittest.TestCase): @parameterized.expand(CASES) def test_shape(self, input_param, input_data, expected_shape): net = Generator(**input_param) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(input_data) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_get_extreme_points.py b/tests/test_get_extreme_points.py index dd38af573e..a334c12415 100644 --- a/tests/test_get_extreme_points.py +++ b/tests/test_get_extreme_points.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_global_mutual_information_loss.py b/tests/test_global_mutual_information_loss.py new file mode 100644 index 0000000000..252a70e85e --- /dev/null +++ b/tests/test_global_mutual_information_loss.py @@ -0,0 +1,100 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses.image_dissimilarity import GlobalMutualInformationLoss + +TEST_CASES = [ + [ + {}, + { + "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3).div(3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3).div(3), + }, + -1.0986018, + ], + [ + {}, + { + "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3).div(3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3).div(3) + ** 2, + }, + -1.083999, + ], + [ + {}, + { + "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3).div(3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3).div(3) ** 2, + }, + -1.083999, + ], + [ + {}, + { + "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3).div(3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3).div(3) ** 2, + }, + -1.083999, + ], + [ + {}, + { + "pred": torch.arange(0, 3, dtype=torch.float)[None, :].div(3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :].div(3) ** 2, + }, + -1.083999, + ], + [ + {}, + { + "pred": torch.arange(0, 3, dtype=torch.float).div(3), + "target": torch.arange(0, 3, dtype=torch.float).div(3) ** 2, + }, + -1.1920927e-07, + ], +] + + +class TestGlobalMutualInformationLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_data, expected_val): + result = GlobalMutualInformationLoss(**input_param).forward(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4) + + def test_ill_shape(self): + loss = GlobalMutualInformationLoss() + with self.assertRaisesRegex(ValueError, ""): + loss.forward(torch.ones((1, 2), dtype=torch.float), torch.ones((1, 3), dtype=torch.float)) + with self.assertRaisesRegex(ValueError, ""): + loss.forward(torch.ones((1, 3, 3), dtype=torch.float), torch.ones((1, 3), dtype=torch.float)) + + def test_ill_opts(self): + pred = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) + target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) + with self.assertRaisesRegex(ValueError, ""): + GlobalMutualInformationLoss(num_bins=0)(pred, target) + with self.assertRaisesRegex(ValueError, ""): + GlobalMutualInformationLoss(num_bins=-1)(pred, target) + with self.assertRaisesRegex(ValueError, ""): + GlobalMutualInformationLoss(reduction="unknown")(pred, target) + with self.assertRaisesRegex(ValueError, ""): + GlobalMutualInformationLoss(reduction=None)(pred, target) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_handler_checkpoint_loader.py b/tests/test_handler_checkpoint_loader.py index 438e73bf3a..d299b65e9b 100644 --- a/tests/test_handler_checkpoint_loader.py +++ b/tests/test_handler_checkpoint_loader.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_handler_checkpoint_saver.py b/tests/test_handler_checkpoint_saver.py index 2df36d9720..0ae8be1e73 100644 --- a/tests/test_handler_checkpoint_saver.py +++ b/tests/test_handler_checkpoint_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_handler_classification_saver.py b/tests/test_handler_classification_saver.py index 3b05092adc..20a9f1c95b 100644 --- a/tests/test_handler_classification_saver.py +++ b/tests/test_handler_classification_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_handler_confusion_matrix.py b/tests/test_handler_confusion_matrix.py index c07dd52998..0524676763 100644 --- a/tests/test_handler_confusion_matrix.py +++ b/tests/test_handler_confusion_matrix.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,15 +13,15 @@ from typing import Any, Dict import torch +from ignite.engine import Engine from parameterized import parameterized from monai.handlers import ConfusionMatrix -TEST_CASE_1 = [{"include_background": True, "metric_name": "f1", "compute_sample": False}, 0.75] -TEST_CASE_2 = [{"include_background": False, "metric_name": "ppv", "compute_sample": False}, 1.0] +TEST_CASE_1 = [{"include_background": True, "save_details": False, "metric_name": "f1"}, 0.75] +TEST_CASE_2 = [{"include_background": False, "save_details": False, "metric_name": "ppv"}, 1.0] -TEST_CASE_SEG_1 = [{"include_background": True, "metric_name": "tpr", "compute_sample": True}, 0.8333] -TEST_CASE_SEG_2 = [{"include_background": True, "metric_name": "tpr", "compute_sample": False}, 0.7] +TEST_CASE_SEG_1 = [{"include_background": True, "metric_name": "tpr"}, 0.7] data_1: Dict[Any, Any] = { "y_pred": torch.tensor( @@ -70,10 +70,16 @@ def test_compute(self, input_params, expected_avg): avg_metric = metric.compute() self.assertAlmostEqual(avg_metric, expected_avg, places=4) - @parameterized.expand([TEST_CASE_SEG_1, TEST_CASE_SEG_2]) + @parameterized.expand([TEST_CASE_SEG_1]) def test_compute_seg(self, input_params, expected_avg): metric = ConfusionMatrix(**input_params) + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + metric.attach(engine, "confusion_matrix") + y_pred = data_1["y_pred"] y = data_1["y"] metric.update([y_pred, y]) @@ -83,8 +89,6 @@ def test_compute_seg(self, input_params, expected_avg): metric.update([y_pred, y]) avg_metric = metric.compute() - if input_params["compute_sample"] is False: - avg_metric = avg_metric.item() self.assertAlmostEqual(avg_metric, expected_avg, places=4) @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) diff --git a/tests/test_handler_confusion_matrix_dist.py b/tests/test_handler_confusion_matrix_dist.py index b7718e15d2..40245bce2e 100644 --- a/tests/test_handler_confusion_matrix_dist.py +++ b/tests/test_handler_confusion_matrix_dist.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,24 +15,26 @@ import numpy as np import torch import torch.distributed as dist +from ignite.engine import Engine from monai.handlers import ConfusionMatrix from tests.utils import DistCall, DistTestCase class DistributedConfusionMatrix(DistTestCase): - @DistCall(nnodes=1, nproc_per_node=2) - def test_compute_sample(self): - self._compute(True) - @DistCall(nnodes=1, nproc_per_node=2) def test_compute(self): - self._compute(False) + self._compute() - def _compute(self, compute_sample=True): + def _compute(self): device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu" - metric = ConfusionMatrix(include_background=True, metric_name="tpr", compute_sample=compute_sample) + metric = ConfusionMatrix(include_background=True, metric_name="tpr") + + def _val_func(engine, batch): + pass + engine = Engine(_val_func) + metric.attach(engine, "confusion_matrix") if dist.get_rank() == 0: y_pred = torch.tensor( [ @@ -62,11 +64,7 @@ def _compute(self, compute_sample=True): metric.update([y_pred, y]) avg_metric = metric.compute() - if compute_sample is False: - avg_metric = avg_metric.item() - np.testing.assert_allclose(avg_metric, 0.7, rtol=1e-04, atol=1e-04) - else: - np.testing.assert_allclose(avg_metric, 0.8333, rtol=1e-04, atol=1e-04) + np.testing.assert_allclose(avg_metric, 0.7, rtol=1e-04, atol=1e-04) if __name__ == "__main__": diff --git a/tests/test_handler_hausdorff_distance.py b/tests/test_handler_hausdorff_distance.py index 67322718b1..c0d2e723ca 100644 --- a/tests/test_handler_hausdorff_distance.py +++ b/tests/test_handler_hausdorff_distance.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import numpy as np import torch +from ignite.engine import Engine from monai.handlers import HausdorffDistance @@ -62,6 +63,13 @@ class TestHandlerHausdorffDistance(unittest.TestCase): def test_compute(self): hd_metric = HausdorffDistance(include_background=True) + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + hd_metric.attach(engine, "hausdorff_distance") + y_pred, y = TEST_SAMPLE_1 hd_metric.update([y_pred, y]) self.assertEqual(hd_metric.compute(), 10) @@ -71,10 +79,9 @@ def test_compute(self): y_pred, y = TEST_SAMPLE_3 hd_metric.update([y_pred, y]) self.assertEqual(hd_metric.compute(), float("inf")) - self.assertEqual(hd_metric._num_examples, 3) y_pred, y = TEST_SAMPLE_4 hd_metric.update([y_pred, y]) - self.assertEqual(hd_metric._num_examples, 3) + self.assertEqual(hd_metric.compute(), float("inf")) def test_shape_mismatch(self): hd_metric = HausdorffDistance(include_background=True) diff --git a/tests/test_handler_lr_scheduler.py b/tests/test_handler_lr_scheduler.py index ffc85e8cd1..82a62dce21 100644 --- a/tests/test_handler_lr_scheduler.py +++ b/tests/test_handler_lr_scheduler.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_handler_mean_dice.py b/tests/test_handler_mean_dice.py index bcd1db6cc9..d15b549d86 100644 --- a/tests/test_handler_mean_dice.py +++ b/tests/test_handler_mean_dice.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,20 +12,28 @@ import unittest import torch +from ignite.engine import Engine from parameterized import parameterized from monai.handlers import MeanDice -TEST_CASE_1 = [{"include_background": True}, 0.75] -TEST_CASE_2 = [{"include_background": False}, 0.66666] +TEST_CASE_1 = [{"include_background": True}, 0.75, (4, 2)] +TEST_CASE_2 = [{"include_background": False}, 0.66666, (4, 1)] class TestHandlerMeanDice(unittest.TestCase): # TODO test multi node averaged dice @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_compute(self, input_params, expected_avg): + def test_compute(self, input_params, expected_avg, details_shape): dice_metric = MeanDice(**input_params) + # set up engine + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + dice_metric.attach(engine=engine, name="mean_dice") y_pred = torch.Tensor([[[0], [1]], [[1], [0]]]) y = torch.Tensor([[[0], [1]], [[0], [1]]]) @@ -37,9 +45,10 @@ def test_compute(self, input_params, expected_avg): avg_dice = dice_metric.compute() self.assertAlmostEqual(avg_dice, expected_avg, places=4) + self.assertTupleEqual(tuple(engine.state.metric_details["mean_dice"].shape), details_shape) @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_shape_mismatch(self, input_params, _expected): + def test_shape_mismatch(self, input_params, _expected_avg, _details_shape): dice_metric = MeanDice(**input_params) with self.assertRaises((AssertionError, ValueError)): y_pred = torch.Tensor([[0, 1], [1, 0]]) diff --git a/tests/test_handler_metrics_saver.py b/tests/test_handler_metrics_saver.py new file mode 100644 index 0000000000..58a6f10d33 --- /dev/null +++ b/tests/test_handler_metrics_saver.py @@ -0,0 +1,84 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +import os +import tempfile +import unittest + +import torch +from ignite.engine import Engine, Events + +from monai.handlers import MetricsSaver + + +class TestHandlerMetricsSaver(unittest.TestCase): + def test_content(self): + with tempfile.TemporaryDirectory() as tempdir: + metrics_saver = MetricsSaver( + save_dir=tempdir, + metrics=["metric1", "metric2"], + metric_details=["metric3", "metric4"], + batch_transform=lambda x: x["image_meta_dict"], + summary_ops=["mean", "median", "max", "90percent"], + ) + # set up engine + data = [ + {"image_meta_dict": {"filename_or_obj": ["filepath1"]}}, + {"image_meta_dict": {"filename_or_obj": ["filepath2"]}}, + ] + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + + @engine.on(Events.EPOCH_COMPLETED) + def _save_metrics(engine): + engine.state.metrics = {"metric1": 1, "metric2": 2} + engine.state.metric_details = { + "metric3": torch.tensor([[1, 2], [2, 3]]), + "metric4": torch.tensor([[5, 6], [7, 8]]), + } + + metrics_saver.attach(engine) + engine.run(data, max_epochs=1) + + # check the metrics.csv and content + self.assertTrue(os.path.exists(os.path.join(tempdir, "metrics.csv"))) + with open(os.path.join(tempdir, "metrics.csv")) as f: + f_csv = csv.reader(f) + for i, row in enumerate(f_csv): + self.assertEqual(row, [f"metric{i + 1}\t{i + 1}"]) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric3_raw.csv"))) + # check the metric_raw.csv and content + with open(os.path.join(tempdir, "metric3_raw.csv")) as f: + f_csv = csv.reader(f) + for i, row in enumerate(f_csv): + if i > 0: + self.assertEqual(row, [f"filepath{i}\t{float(i)}\t{float(i + 1)}\t{i + 0.5}"]) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric3_summary.csv"))) + # check the metric_summary.csv and content + with open(os.path.join(tempdir, "metric3_summary.csv")) as f: + f_csv = csv.reader(f) + for i, row in enumerate(f_csv): + if i == 1: + self.assertEqual(row, ["class0\t1.5000\t1.5000\t2.0000\t1.1000"]) + elif i == 2: + self.assertEqual(row, ["class1\t2.5000\t2.5000\t3.0000\t2.1000"]) + elif i == 3: + self.assertEqual(row, ["mean\t2.0000\t2.0000\t2.5000\t1.6000"]) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric4_raw.csv"))) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric4_summary.csv"))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_handler_metrics_saver_dist.py b/tests/test_handler_metrics_saver_dist.py new file mode 100644 index 0000000000..1b17d0adb4 --- /dev/null +++ b/tests/test_handler_metrics_saver_dist.py @@ -0,0 +1,106 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import csv +import os +import tempfile +import unittest + +import torch +import torch.distributed as dist +from ignite.engine import Engine, Events + +from monai.handlers import MetricsSaver +from tests.utils import DistCall, DistTestCase, SkipIfBeforePyTorchVersion + + +@SkipIfBeforePyTorchVersion((1, 7)) +class DistributedMetricsSaver(DistTestCase): + @DistCall(nnodes=1, nproc_per_node=2) + def test_content(self): + self._run() + + def _run(self): + with tempfile.TemporaryDirectory() as tempdir: + metrics_saver = MetricsSaver( + save_dir=tempdir, + metrics=["metric1", "metric2"], + metric_details=["metric3", "metric4"], + batch_transform=lambda x: x["image_meta_dict"], + summary_ops="*", + ) + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + + if dist.get_rank() == 0: + data = [{"image_meta_dict": {"filename_or_obj": ["filepath1"]}}] + + @engine.on(Events.EPOCH_COMPLETED) + def _save_metrics0(engine): + engine.state.metrics = {"metric1": 1, "metric2": 2} + engine.state.metric_details = { + "metric3": torch.tensor([[1, 2]]), + "metric4": torch.tensor([[5, 6]]), + } + + if dist.get_rank() == 1: + # different ranks have different data length + data = [ + {"image_meta_dict": {"filename_or_obj": ["filepath2"]}}, + {"image_meta_dict": {"filename_or_obj": ["filepath3"]}}, + ] + + @engine.on(Events.EPOCH_COMPLETED) + def _save_metrics1(engine): + engine.state.metrics = {"metric1": 1, "metric2": 2} + engine.state.metric_details = { + "metric3": torch.tensor([[2, 3], [3, 4]]), + "metric4": torch.tensor([[6, 7], [7, 8]]), + } + + metrics_saver.attach(engine) + engine.run(data, max_epochs=1) + + if dist.get_rank() == 0: + # check the metrics.csv and content + self.assertTrue(os.path.exists(os.path.join(tempdir, "metrics.csv"))) + with open(os.path.join(tempdir, "metrics.csv")) as f: + f_csv = csv.reader(f) + for i, row in enumerate(f_csv): + self.assertEqual(row, [f"metric{i + 1}\t{i + 1}"]) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric3_raw.csv"))) + # check the metric_raw.csv and content + with open(os.path.join(tempdir, "metric3_raw.csv")) as f: + f_csv = csv.reader(f) + for i, row in enumerate(f_csv): + if i > 0: + self.assertEqual(row, [f"filepath{i}\t{float(i)}\t{float(i + 1)}\t{i + 0.5}"]) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric3_summary.csv"))) + # check the metric_summary.csv and content + with open(os.path.join(tempdir, "metric3_summary.csv")) as f: + f_csv = csv.reader(f) + for i, row in enumerate(f_csv): + if i == 1: + self.assertEqual(row, ["class0\t1.0000\t1.0000\t1.0000\t1.0000\t1.0000\t0.0000"]) + elif i == 2: + self.assertEqual(row, ["class1\t2.0000\t2.0000\t2.0000\t2.0000\t2.0000\t0.0000"]) + elif i == 3: + self.assertEqual(row, ["mean\t1.5000\t1.5000\t1.5000\t1.5000\t1.5000\t0.0000"]) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric4_raw.csv"))) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric4_summary.csv"))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_handler_rocauc.py b/tests/test_handler_rocauc.py index 8b5895dc4c..05f6eebce6 100644 --- a/tests/test_handler_rocauc.py +++ b/tests/test_handler_rocauc.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_handler_rocauc_dist.py b/tests/test_handler_rocauc_dist.py index 7ff45185a6..825b172064 100644 --- a/tests/test_handler_rocauc_dist.py +++ b/tests/test_handler_rocauc_dist.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_handler_segmentation_saver.py b/tests/test_handler_segmentation_saver.py index 96ca1c27c9..1a2bbb7fbd 100644 --- a/tests/test_handler_segmentation_saver.py +++ b/tests/test_handler_segmentation_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_handler_smartcache.py b/tests/test_handler_smartcache.py index 97571f2d40..95f8e70fa4 100644 --- a/tests/test_handler_smartcache.py +++ b/tests/test_handler_smartcache.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py index dab5a0ea14..d1602f802a 100644 --- a/tests/test_handler_stats.py +++ b/tests/test_handler_stats.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_handler_surface_distance.py b/tests/test_handler_surface_distance.py index 02898769f6..fbd86edb03 100644 --- a/tests/test_handler_surface_distance.py +++ b/tests/test_handler_surface_distance.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import numpy as np import torch +from ignite.engine import Engine from monai.handlers import SurfaceDistance @@ -62,6 +63,13 @@ class TestHandlerSurfaceDistance(unittest.TestCase): def test_compute(self): sur_metric = SurfaceDistance(include_background=True) + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + sur_metric.attach(engine, "surface_distance") + y_pred, y = TEST_SAMPLE_1 sur_metric.update([y_pred, y]) self.assertAlmostEqual(sur_metric.compute(), 4.17133, places=4) @@ -71,10 +79,9 @@ def test_compute(self): y_pred, y = TEST_SAMPLE_3 sur_metric.update([y_pred, y]) self.assertAlmostEqual(sur_metric.compute(), float("inf")) - self.assertAlmostEqual(sur_metric._num_examples, 3) y_pred, y = TEST_SAMPLE_4 sur_metric.update([y_pred, y]) - self.assertAlmostEqual(sur_metric._num_examples, 3) + self.assertAlmostEqual(sur_metric.compute(), float("inf")) def test_shape_mismatch(self): sur_metric = SurfaceDistance(include_background=True) diff --git a/tests/test_handler_tb_image.py b/tests/test_handler_tb_image.py index e1cda3be65..ed3ba8a32d 100644 --- a/tests/test_handler_tb_image.py +++ b/tests/test_handler_tb_image.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_handler_tb_stats.py b/tests/test_handler_tb_stats.py index ab356e74b4..2d7d18d1f6 100644 --- a/tests/test_handler_tb_stats.py +++ b/tests/test_handler_tb_stats.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_handler_validation.py b/tests/test_handler_validation.py index 822e17fdf8..11a51c7213 100644 --- a/tests/test_handler_validation.py +++ b/tests/test_handler_validation.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_hashing.py b/tests/test_hashing.py index 19442acd98..ca317a72e8 100644 --- a/tests/test_hashing.py +++ b/tests/test_hashing.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_hausdorff_distance.py b/tests/test_hausdorff_distance.py index 96c52cbb68..465900c12a 100644 --- a/tests/test_hausdorff_distance.py +++ b/tests/test_hausdorff_distance.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_header_correct.py b/tests/test_header_correct.py index e9b2a8821b..4a8927fa80 100644 --- a/tests/test_header_correct.py +++ b/tests/test_header_correct.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_highresnet.py b/tests/test_highresnet.py index 10f4f41fea..83248ad85f 100644 --- a/tests/test_highresnet.py +++ b/tests/test_highresnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.nets import HighResNet from tests.utils import DistTestCase, TimedCall, test_script_save @@ -48,8 +49,7 @@ class TestHighResNet(DistTestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_shape(self, input_param, input_shape, expected_shape): net = HighResNet(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_hilbert_transform.py b/tests/test_hilbert_transform.py index 1099468102..82454c34d0 100644 --- a/tests/test_hilbert_transform.py +++ b/tests/test_hilbert_transform.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -51,7 +51,7 @@ def create_expected_numpy_output(input_datum, **kwargs): # CPU TEST DATA -cpu_input_data = dict() +cpu_input_data = {} cpu_input_data["1D"] = torch.as_tensor(hann_windowed_sine, device=cpu).unsqueeze(0).unsqueeze(0) cpu_input_data["2D"] = ( torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=cpu).unsqueeze(0).unsqueeze(0) @@ -110,7 +110,7 @@ def create_expected_numpy_output(input_datum, **kwargs): if torch.cuda.is_available(): gpu = torch.device("cuda") - gpu_input_data = dict() + gpu_input_data = {} gpu_input_data["1D"] = torch.as_tensor(hann_windowed_sine, device=gpu).unsqueeze(0).unsqueeze(0) gpu_input_data["2D"] = ( torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=gpu).unsqueeze(0).unsqueeze(0) diff --git a/tests/test_identity.py b/tests/test_identity.py index 7a4b2de291..2dff2bb13d 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_identityd.py b/tests/test_identityd.py index 481cdd45c4..8796f28da8 100644 --- a/tests/test_identityd.py +++ b/tests/test_identityd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,7 +18,7 @@ class TestIdentityd(NumpyImageTestCase2D): def test_identityd(self): img = self.imt - data = dict() + data = {} data["img"] = img identity = Identityd(keys=data.keys()) self.assertEqual(data, identity(data)) diff --git a/tests/test_nifti_dataset.py b/tests/test_image_dataset.py similarity index 87% rename from tests/test_nifti_dataset.py rename to tests/test_image_dataset.py index 801e625453..d79a7d884c 100644 --- a/tests/test_nifti_dataset.py +++ b/tests/test_image_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,7 +16,7 @@ import nibabel as nib import numpy as np -from monai.data import NiftiDataset +from monai.data import ImageDataset from monai.transforms import Randomizable FILENAMES = ["test1.nii.gz", "test2.nii", "test3.nii.gz"] @@ -35,7 +35,7 @@ def __call__(self, data): return data + self._a -class TestNiftiDataset(unittest.TestCase): +class TestImageDataset(unittest.TestCase): def test_dataset(self): with tempfile.TemporaryDirectory() as tempdir: full_names, ref_data = [], [] @@ -47,46 +47,46 @@ def test_dataset(self): nib.save(nib.Nifti1Image(test_image, np.eye(4)), save_path) # default loading no meta - dataset = NiftiDataset(full_names) + dataset = ImageDataset(full_names) for d, ref in zip(dataset, ref_data): np.testing.assert_allclose(d, ref, atol=1e-3) # loading no meta, int - dataset = NiftiDataset(full_names, dtype=np.float16) + dataset = ImageDataset(full_names, dtype=np.float16) for d, _ in zip(dataset, ref_data): self.assertEqual(d.dtype, np.float16) # loading with meta, no transform - dataset = NiftiDataset(full_names, image_only=False) + dataset = ImageDataset(full_names, image_only=False) for d_tuple, ref in zip(dataset, ref_data): d, meta = d_tuple np.testing.assert_allclose(d, ref, atol=1e-3) np.testing.assert_allclose(meta["original_affine"], np.eye(4)) # loading image/label, no meta - dataset = NiftiDataset(full_names, seg_files=full_names, image_only=True) + dataset = ImageDataset(full_names, seg_files=full_names, image_only=True) for d_tuple, ref in zip(dataset, ref_data): img, seg = d_tuple np.testing.assert_allclose(img, ref, atol=1e-3) np.testing.assert_allclose(seg, ref, atol=1e-3) # loading image/label, no meta - dataset = NiftiDataset(full_names, transform=lambda x: x + 1, image_only=True) + dataset = ImageDataset(full_names, transform=lambda x: x + 1, image_only=True) for d, ref in zip(dataset, ref_data): np.testing.assert_allclose(d, ref + 1, atol=1e-3) # set seg transform, but no seg_files with self.assertRaises(RuntimeError): - dataset = NiftiDataset(full_names, seg_transform=lambda x: x + 1, image_only=True) + dataset = ImageDataset(full_names, seg_transform=lambda x: x + 1, image_only=True) _ = dataset[0] # set seg transform, but no seg_files with self.assertRaises(RuntimeError): - dataset = NiftiDataset(full_names, seg_transform=lambda x: x + 1, image_only=True) + dataset = ImageDataset(full_names, seg_transform=lambda x: x + 1, image_only=True) _ = dataset[0] # loading image/label, with meta - dataset = NiftiDataset( + dataset = ImageDataset( full_names, transform=lambda x: x + 1, seg_files=full_names, @@ -100,7 +100,7 @@ def test_dataset(self): np.testing.assert_allclose(meta["original_affine"], np.eye(4), atol=1e-3) # loading image/label, with meta - dataset = NiftiDataset( + dataset = ImageDataset( full_names, transform=lambda x: x + 1, seg_files=full_names, labels=[1, 2, 3], image_only=False ) for idx, (d_tuple, ref) in enumerate(zip(dataset, ref_data)): @@ -111,7 +111,7 @@ def test_dataset(self): np.testing.assert_allclose(meta["original_affine"], np.eye(4), atol=1e-3) # loading image/label, with sync. transform - dataset = NiftiDataset( + dataset = ImageDataset( full_names, transform=RandTest(), seg_files=full_names, seg_transform=RandTest(), image_only=False ) for d_tuple, ref in zip(dataset, ref_data): diff --git a/tests/test_img2tensorboard.py b/tests/test_img2tensorboard.py index 99761b4d11..bd0369868e 100644 --- a/tests/test_img2tensorboard.py +++ b/tests/test_img2tensorboard.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_init_reader.py b/tests/test_init_reader.py index 87777d83a3..d6737c26ca 100644 --- a/tests/test_init_reader.py +++ b/tests/test_init_reader.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_integration_classification_2d.py b/tests/test_integration_classification_2d.py index aa0fd57f76..4be59cba41 100644 --- a/tests/test_integration_classification_2d.py +++ b/tests/test_integration_classification_2d.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,6 +21,7 @@ import monai from monai.apps import download_and_extract from monai.metrics import compute_roc_auc +from monai.networks import eval_mode from monai.networks.nets import densenet121 from monai.transforms import AddChannel, Compose, LoadImage, RandFlip, RandRotate, RandZoom, ScaleIntensity, ToTensor from monai.utils import set_determinism @@ -79,8 +80,8 @@ def run_training_test(root_dir, train_x, train_y, val_x, val_y, device="cuda:0", # start training validation best_metric = -1 best_metric_epoch = -1 - epoch_loss_values = list() - metric_values = list() + epoch_loss_values = [] + metric_values = [] model_filename = os.path.join(root_dir, "best_metric_model.pth") for epoch in range(epoch_num): print("-" * 10) @@ -102,8 +103,7 @@ def run_training_test(root_dir, train_x, train_y, val_x, val_y, device="cuda:0", print(f"epoch {epoch + 1} average loss:{epoch_loss:0.4f}") if (epoch + 1) % val_interval == 0: - model.eval() - with torch.no_grad(): + with eval_mode(model): y_pred = torch.tensor([], dtype=torch.float32, device=device) y = torch.tensor([], dtype=torch.long, device=device) for val_data in val_loader: @@ -137,10 +137,9 @@ def run_inference_test(root_dir, test_x, test_y, device="cuda:0", num_workers=10 model_filename = os.path.join(root_dir, "best_metric_model.pth") model.load_state_dict(torch.load(model_filename)) - model.eval() - y_true = list() - y_pred = list() - with torch.no_grad(): + y_true = [] + y_pred = [] + with eval_mode(model): for test_data in val_loader: test_images, test_labels = test_data[0].to(device), test_data[1].to(device) pred = model(test_images).argmax(dim=1) diff --git a/tests/test_integration_determinism.py b/tests/test_integration_determinism.py index dbabc96da1..4947610484 100644 --- a/tests/test_integration_determinism.py +++ b/tests/test_integration_determinism.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_integration_segmentation_3d.py b/tests/test_integration_segmentation_3d.py index 9de7dcf362..af97236eda 100644 --- a/tests/test_integration_segmentation_3d.py +++ b/tests/test_integration_segmentation_3d.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -24,6 +24,7 @@ from monai.data import NiftiSaver, create_test_image_3d from monai.inferers import sliding_window_inference from monai.metrics import DiceMetric +from monai.networks import eval_mode from monai.networks.nets import UNet from monai.transforms import ( Activations, @@ -111,8 +112,8 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0): # start a typical PyTorch training val_interval = 2 best_metric, best_metric_epoch = -1, -1 - epoch_loss_values = list() - metric_values = list() + epoch_loss_values = [] + metric_values = [] writer = SummaryWriter(log_dir=os.path.join(root_dir, "runs")) model_filename = os.path.join(root_dir, "best_metric_model.pth") for epoch in range(6): @@ -138,8 +139,7 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0): print(f"epoch {epoch +1} average loss:{epoch_loss:0.4f}") if (epoch + 1) % val_interval == 0: - model.eval() - with torch.no_grad(): + with eval_mode(model): metric_sum = 0.0 metric_count = 0 val_images = None @@ -207,8 +207,7 @@ def run_inference_test(root_dir, device="cuda:0"): model_filename = os.path.join(root_dir, "best_metric_model.pth") model.load_state_dict(torch.load(model_filename)) - model.eval() - with torch.no_grad(): + with eval_mode(model): metric_sum = 0.0 metric_count = 0 # resampling with align_corners=True or dtype=float64 will generate diff --git a/tests/test_integration_sliding_window.py b/tests/test_integration_sliding_window.py index 74c6a82350..c4d020276e 100644 --- a/tests/test_integration_sliding_window.py +++ b/tests/test_integration_sliding_window.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,10 +19,10 @@ from ignite.engine import Engine from torch.utils.data import DataLoader -from monai.data import NiftiDataset, create_test_image_3d +from monai.data import ImageDataset, create_test_image_3d from monai.handlers import SegmentationSaver from monai.inferers import sliding_window_inference -from monai.networks import predict_segmentation +from monai.networks import eval_mode, predict_segmentation from monai.networks.nets import UNet from monai.transforms import AddChannel from monai.utils import set_determinism @@ -30,7 +30,7 @@ def run_test(batch_size, img_name, seg_name, output_dir, device="cuda:0"): - ds = NiftiDataset([img_name], [seg_name], transform=AddChannel(), seg_transform=AddChannel(), image_only=False) + ds = ImageDataset([img_name], [seg_name], transform=AddChannel(), seg_transform=AddChannel(), image_only=False) loader = DataLoader(ds, batch_size=1, pin_memory=torch.cuda.is_available()) net = UNet( @@ -40,9 +40,8 @@ def run_test(batch_size, img_name, seg_name, output_dir, device="cuda:0"): sw_batch_size = batch_size def _sliding_window_processor(_engine, batch): - net.eval() img, seg, meta_data = batch - with torch.no_grad(): + with eval_mode(net): seg_probs = sliding_window_inference(img.to(device), roi_size, sw_batch_size, net, device=device) return predict_segmentation(seg_probs) diff --git a/tests/test_integration_stn.py b/tests/test_integration_stn.py index c8759e5f42..c1fcfe7a89 100644 --- a/tests/test_integration_stn.py +++ b/tests/test_integration_stn.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_integration_unet_2d.py b/tests/test_integration_unet_2d.py index 435fb3446f..a46a174dc9 100644 --- a/tests/test_integration_unet_2d.py +++ b/tests/test_integration_unet_2d.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index 8e96947ccb..aa4ccbb76d 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -25,7 +25,7 @@ import monai from monai.data import create_test_image_3d -from monai.engines import SupervisedEvaluator, SupervisedTrainer +from monai.engines import IterationEvents, SupervisedEvaluator, SupervisedTrainer from monai.handlers import ( CheckpointLoader, CheckpointSaver, @@ -113,6 +113,14 @@ def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4): KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) + + class _TestEvalIterEvents: + def attach(self, engine): + engine.add_event_handler(IterationEvents.FORWARD_COMPLETED, self._forward_completed) + + def _forward_completed(self, engine): + pass + val_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir=root_dir, output_transform=lambda x: None), @@ -120,6 +128,7 @@ def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4): log_dir=root_dir, batch_transform=lambda x: (x["image"], x["label"]), output_transform=lambda x: x["pred"] ), CheckpointSaver(save_dir=root_dir, save_dict={"net": net}, save_key_metric=True), + _TestEvalIterEvents(), ] evaluator = SupervisedEvaluator( @@ -143,12 +152,33 @@ def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4): KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) + + class _TestTrainIterEvents: + def attach(self, engine): + engine.add_event_handler(IterationEvents.FORWARD_COMPLETED, self._forward_completed) + engine.add_event_handler(IterationEvents.LOSS_COMPLETED, self._loss_completed) + engine.add_event_handler(IterationEvents.BACKWARD_COMPLETED, self._backward_completed) + engine.add_event_handler(IterationEvents.OPTIMIZER_COMPLETED, self._optimizer_completed) + + def _forward_completed(self, engine): + pass + + def _loss_completed(self, engine): + pass + + def _backward_completed(self, engine): + pass + + def _optimizer_completed(self, engine): + pass + train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=2, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(log_dir=root_dir, tag_name="train_loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir=root_dir, save_dict={"net": net, "opt": opt}, save_interval=2, epoch_level=True), + _TestTrainIterEvents(), ] trainer = SupervisedTrainer( diff --git a/tests/test_integration_workflows_gan.py b/tests/test_integration_workflows_gan.py index a4133b788f..73a9e69370 100644 --- a/tests/test_integration_workflows_gan.py +++ b/tests/test_integration_workflows_gan.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_is_supported_format.py b/tests/test_is_supported_format.py index a19752086d..c0af8f4395 100644 --- a/tests/test_is_supported_format.py +++ b/tests/test_is_supported_format.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 56f118b8a2..7b16eaf594 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -32,7 +32,7 @@ class TestIterableDataset(unittest.TestCase): def test_shape(self): expected_shape = (128, 128, 128) test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) - test_data = list() + test_data = [] with tempfile.TemporaryDirectory() as tempdir: for i in range(6): nib.save(test_image, os.path.join(tempdir, f"test_image{str(i)}.nii.gz")) diff --git a/tests/test_keep_largest_connected_component.py b/tests/test_keep_largest_connected_component.py index 5e6dd716b1..773ca4ad0b 100644 --- a/tests/test_keep_largest_connected_component.py +++ b/tests/test_keep_largest_connected_component.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_keep_largest_connected_componentd.py b/tests/test_keep_largest_connected_componentd.py index 3d3f749426..7298b91e4f 100644 --- a/tests/test_keep_largest_connected_componentd.py +++ b/tests/test_keep_largest_connected_componentd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_label_to_contour.py b/tests/test_label_to_contour.py index d8e27d8136..b118b91999 100644 --- a/tests/test_label_to_contour.py +++ b/tests/test_label_to_contour.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_label_to_contourd.py b/tests/test_label_to_contourd.py index f6cc51e7e8..aa4dffe03e 100644 --- a/tests/test_label_to_contourd.py +++ b/tests/test_label_to_contourd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_label_to_mask.py b/tests/test_label_to_mask.py index 9d5372bd4f..2a84c7bea6 100644 --- a/tests/test_label_to_mask.py +++ b/tests/test_label_to_mask.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_label_to_maskd.py b/tests/test_label_to_maskd.py index e2cc9206ed..f046390c19 100644 --- a/tests/test_label_to_maskd.py +++ b/tests/test_label_to_maskd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_lambda.py b/tests/test_lambda.py index 96ebd88705..e71eb3e5b0 100644 --- a/tests/test_lambda.py +++ b/tests/test_lambda.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_lambdad.py b/tests/test_lambdad.py index fbebe081fe..ca28af778b 100644 --- a/tests/test_lambdad.py +++ b/tests/test_lambdad.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,27 +20,26 @@ class TestLambdad(NumpyImageTestCase2D): def test_lambdad_identity(self): img = self.imt - data = dict() - data["img"] = img + data = {"img": img, "prop": 1.0} - def identity_func(x): - return x + def noise_func(x): + return x + 1.0 - lambd = Lambdad(keys=data.keys(), func=identity_func) - expected = data - expected["img"] = identity_func(data["img"]) - self.assertTrue(np.allclose(expected["img"], lambd(data)["img"])) + expected = {"img": noise_func(data["img"]), "prop": 1.0} + ret = Lambdad(keys=["img", "prop"], func=noise_func, overwrite=[True, False])(data) + self.assertTrue(np.allclose(expected["img"], ret["img"])) + self.assertTrue(np.allclose(expected["prop"], ret["prop"])) def test_lambdad_slicing(self): img = self.imt - data = dict() + data = {} data["img"] = img def slice_func(x): return x[:, :, :6, ::-2] lambd = Lambdad(keys=data.keys(), func=slice_func) - expected = dict() + expected = {} expected["img"] = slice_func(data["img"]) self.assertTrue(np.allclose(expected["img"], lambd(data)["img"])) diff --git a/tests/test_list_data_collate.py b/tests/test_list_data_collate.py index 7253a43f4d..eebac69fcf 100644 --- a/tests/test_list_data_collate.py +++ b/tests/test_list_data_collate.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_list_to_dict.py b/tests/test_list_to_dict.py index 56f55a87cd..2f026f3e29 100644 --- a/tests/test_list_to_dict.py +++ b/tests/test_list_to_dict.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_lltm.py b/tests/test_lltm.py index 41a9ea55fd..f1311379bc 100644 --- a/tests/test_lltm.py +++ b/tests/test_lltm.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_lmdbdataset.py b/tests/test_lmdbdataset.py index e4d79ad4bd..90a4b4a0b4 100644 --- a/tests/test_lmdbdataset.py +++ b/tests/test_lmdbdataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_load_decathlon_datalist.py b/tests/test_load_decathlon_datalist.py index a64aba6830..90b9d3ab03 100644 --- a/tests/test_load_decathlon_datalist.py +++ b/tests/test_load_decathlon_datalist.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_load_image.py b/tests/test_load_image.py index 272a528f81..b7743f86ad 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -124,7 +124,7 @@ def test_itk_dicom_series_reader(self, input_param, filenames, expected_shape): [0.0, 0.0, 0.0, 1.0], ] ), - ), + ) self.assertTupleEqual(result.shape, expected_shape) self.assertTupleEqual(tuple(header["spatial_shape"]), expected_shape) diff --git a/tests/test_load_imaged.py b/tests/test_load_imaged.py index ef733cac2f..978c3b6551 100644 --- a/tests/test_load_imaged.py +++ b/tests/test_load_imaged.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -32,7 +32,7 @@ class TestLoadImaged(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_shape(self, input_param, expected_shape): test_image = nib.Nifti1Image(np.random.rand(128, 128, 128), np.eye(4)) - test_data = dict() + test_data = {} with tempfile.TemporaryDirectory() as tempdir: for key in KEYS: nib.save(test_image, os.path.join(tempdir, key + ".nii.gz")) diff --git a/tests/test_load_nifti.py b/tests/test_load_nifti.py deleted file mode 100644 index de3cac996e..0000000000 --- a/tests/test_load_nifti.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import tempfile -import unittest - -import nibabel as nib -import numpy as np -from parameterized import parameterized - -from monai.transforms import LoadNifti - -TEST_CASE_1 = [{"as_closest_canonical": False, "image_only": True}, ["test_image.nii.gz"], (128, 128, 128)] - -TEST_CASE_2 = [{"as_closest_canonical": False, "image_only": False}, ["test_image.nii.gz"], (128, 128, 128)] - -TEST_CASE_3 = [ - {"as_closest_canonical": False, "image_only": True}, - ["test_image1.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], - (3, 128, 128, 128), -] - -TEST_CASE_4 = [ - {"as_closest_canonical": False, "image_only": False}, - ["test_image1.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], - (3, 128, 128, 128), -] - -TEST_CASE_5 = [{"as_closest_canonical": True, "image_only": False}, ["test_image.nii.gz"], (128, 128, 128)] - - -class TestLoadNifti(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) - def test_shape(self, input_param, filenames, expected_shape): - test_image = np.random.randint(0, 2, size=[128, 128, 128]) - with tempfile.TemporaryDirectory() as tempdir: - for i, name in enumerate(filenames): - filenames[i] = os.path.join(tempdir, name) - nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) - result = LoadNifti(**input_param)(filenames) - - if isinstance(result, tuple): - result, header = result - self.assertTrue("affine" in header) - np.testing.assert_allclose(header["affine"], np.eye(4)) - if input_param["as_closest_canonical"]: - np.testing.assert_allclose(header["original_affine"], np.eye(4)) - self.assertTupleEqual(result.shape, expected_shape) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_load_niftid.py b/tests/test_load_niftid.py deleted file mode 100644 index 54d816bead..0000000000 --- a/tests/test_load_niftid.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import tempfile -import unittest - -import nibabel as nib -import numpy as np -from parameterized import parameterized - -from monai.transforms import LoadImaged - -KEYS = ["image", "label", "extra"] - -TEST_CASE_1 = [{"keys": KEYS, "as_closest_canonical": False}, (128, 128, 128)] - - -class TestLoadNiftid(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) - def test_shape(self, input_param, expected_shape): - test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) - test_data = dict() - with tempfile.TemporaryDirectory() as tempdir: - for key in KEYS: - nib.save(test_image, os.path.join(tempdir, key + ".nii.gz")) - test_data.update({key: os.path.join(tempdir, key + ".nii.gz")}) - result = LoadImaged(**input_param)(test_data) - - for key in KEYS: - self.assertTupleEqual(result[key].shape, expected_shape) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_load_numpy.py b/tests/test_load_numpy.py deleted file mode 100644 index d65087531b..0000000000 --- a/tests/test_load_numpy.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import tempfile -import unittest - -import numpy as np - -from monai.transforms import LoadNumpy - - -class TestLoadNumpy(unittest.TestCase): - def test_npy(self): - test_data = np.random.randint(0, 256, size=[3, 4, 4]) - with tempfile.TemporaryDirectory() as tempdir: - filepath = os.path.join(tempdir, "test_data.npy") - np.save(filepath, test_data) - - result = LoadNumpy()(filepath) - self.assertTupleEqual(result[1]["spatial_shape"], test_data.shape) - self.assertTupleEqual(result[0].shape, test_data.shape) - np.testing.assert_allclose(result[0], test_data) - - def test_npz1(self): - test_data1 = np.random.randint(0, 256, size=[3, 4, 4]) - with tempfile.TemporaryDirectory() as tempdir: - filepath = os.path.join(tempdir, "test_data.npy") - np.save(filepath, test_data1) - - result = LoadNumpy()(filepath) - self.assertTupleEqual(result[1]["spatial_shape"], test_data1.shape) - self.assertTupleEqual(result[0].shape, test_data1.shape) - np.testing.assert_allclose(result[0], test_data1) - - def test_npz2(self): - test_data1 = np.random.randint(0, 256, size=[3, 4, 4]) - test_data2 = np.random.randint(0, 256, size=[3, 4, 4]) - with tempfile.TemporaryDirectory() as tempdir: - filepath = os.path.join(tempdir, "test_data.npz") - np.savez(filepath, test_data1, test_data2) - - result = LoadNumpy()(filepath) - self.assertTupleEqual(result[1]["spatial_shape"], test_data1.shape) - self.assertTupleEqual(result[0].shape, (2, 3, 4, 4)) - np.testing.assert_allclose(result[0], np.stack([test_data1, test_data2])) - - def test_npz3(self): - test_data1 = np.random.randint(0, 256, size=[3, 4, 4]) - test_data2 = np.random.randint(0, 256, size=[3, 4, 4]) - with tempfile.TemporaryDirectory() as tempdir: - filepath = os.path.join(tempdir, "test_data.npz") - np.savez(filepath, test1=test_data1, test2=test_data2) - - result = LoadNumpy(npz_keys=["test1", "test2"])(filepath) - self.assertTupleEqual(result[1]["spatial_shape"], test_data1.shape) - self.assertTupleEqual(result[0].shape, (2, 3, 4, 4)) - np.testing.assert_allclose(result[0], np.stack([test_data1, test_data2])) - - def test_npy_pickle(self): - test_data = {"test": np.random.randint(0, 256, size=[3, 4, 4])} - with tempfile.TemporaryDirectory() as tempdir: - filepath = os.path.join(tempdir, "test_data.npy") - np.save(filepath, test_data, allow_pickle=True) - - result = LoadNumpy(data_only=True, dtype=None)(filepath).item() - self.assertTupleEqual(result["test"].shape, test_data["test"].shape) - np.testing.assert_allclose(result["test"], test_data["test"]) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_load_numpyd.py b/tests/test_load_numpyd.py deleted file mode 100644 index 9abe0b0daf..0000000000 --- a/tests/test_load_numpyd.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import tempfile -import unittest - -import numpy as np - -from monai.transforms import LoadNumpyd - - -class TestLoadNumpyd(unittest.TestCase): - def test_npy(self): - test_data = np.random.randint(0, 256, size=[3, 4, 4]) - with tempfile.TemporaryDirectory() as tempdir: - filepath = os.path.join(tempdir, "test_data.npy") - np.save(filepath, test_data) - - result = LoadNumpyd(keys="mask")({"mask": filepath}) - self.assertTupleEqual(result["mask_meta_dict"]["spatial_shape"], test_data.shape) - self.assertTupleEqual(result["mask"].shape, test_data.shape) - np.testing.assert_allclose(result["mask"], test_data) - - def test_npz1(self): - test_data1 = np.random.randint(0, 256, size=[3, 4, 4]) - with tempfile.TemporaryDirectory() as tempdir: - filepath = os.path.join(tempdir, "test_data.npy") - np.save(filepath, test_data1) - - result = LoadNumpyd(keys="mask")({"mask": filepath}) - self.assertTupleEqual(result["mask_meta_dict"]["spatial_shape"], test_data1.shape) - self.assertTupleEqual(result["mask"].shape, test_data1.shape) - np.testing.assert_allclose(result["mask"], test_data1) - - def test_npz2(self): - test_data1 = np.random.randint(0, 256, size=[3, 4, 4]) - test_data2 = np.random.randint(0, 256, size=[3, 4, 4]) - with tempfile.TemporaryDirectory() as tempdir: - filepath = os.path.join(tempdir, "test_data.npz") - np.savez(filepath, test_data1, test_data2) - - result = LoadNumpyd(keys="mask")({"mask": filepath}) - self.assertTupleEqual(result["mask_meta_dict"]["spatial_shape"], test_data1.shape) - self.assertTupleEqual(result["mask"].shape, (2, 3, 4, 4)) - np.testing.assert_allclose(result["mask"], np.stack([test_data1, test_data2])) - - def test_npz3(self): - test_data1 = np.random.randint(0, 256, size=[3, 4, 4]) - test_data2 = np.random.randint(0, 256, size=[3, 4, 4]) - with tempfile.TemporaryDirectory() as tempdir: - filepath = os.path.join(tempdir, "test_data.npz") - np.savez(filepath, test1=test_data1, test2=test_data2) - - result = LoadNumpyd(keys="mask", npz_keys=["test1", "test2"])({"mask": filepath}) - self.assertTupleEqual(result["mask_meta_dict"]["spatial_shape"], test_data1.shape) - self.assertTupleEqual(result["mask"].shape, (2, 3, 4, 4)) - np.testing.assert_allclose(result["mask"], np.stack([test_data1, test_data2])) - - def test_npy_pickle(self): - test_data = {"test": np.random.randint(0, 256, size=[3, 4, 4])} - with tempfile.TemporaryDirectory() as tempdir: - filepath = os.path.join(tempdir, "test_data.npy") - np.save(filepath, test_data, allow_pickle=True) - - result = LoadNumpyd(keys="mask", dtype=None)({"mask": filepath})["mask"].item() - self.assertTupleEqual(result["test"].shape, test_data["test"].shape) - np.testing.assert_allclose(result["test"], test_data["test"]) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_load_png.py b/tests/test_load_png.py deleted file mode 100644 index 929ee1536d..0000000000 --- a/tests/test_load_png.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import tempfile -import unittest - -import numpy as np -from parameterized import parameterized -from PIL import Image - -from monai.transforms import LoadPNG - -TEST_CASE_1 = [(128, 128), ["test_image.png"], (128, 128), (128, 128)] - -TEST_CASE_2 = [(128, 128, 3), ["test_image.png"], (128, 128, 3), (128, 128)] - -TEST_CASE_3 = [(128, 128), ["test_image1.png", "test_image2.png", "test_image3.png"], (3, 128, 128), (128, 128)] - - -class TestLoadPNG(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, data_shape, filenames, expected_shape, meta_shape): - test_image = np.random.randint(0, 256, size=data_shape) - with tempfile.TemporaryDirectory() as tempdir: - for i, name in enumerate(filenames): - filenames[i] = os.path.join(tempdir, name) - Image.fromarray(test_image.astype("uint8")).save(filenames[i]) - result = LoadPNG()(filenames) - self.assertTupleEqual(result[1]["spatial_shape"], meta_shape) - self.assertTupleEqual(result[0].shape, expected_shape) - if result[0].shape == test_image.shape: - np.testing.assert_allclose(result[0], test_image) - else: - np.testing.assert_allclose(result[0], np.tile(test_image, [result[0].shape[0], 1, 1])) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_load_pngd.py b/tests/test_load_pngd.py deleted file mode 100644 index 6be3197d8f..0000000000 --- a/tests/test_load_pngd.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import tempfile -import unittest - -import numpy as np -from parameterized import parameterized -from PIL import Image - -from monai.transforms import LoadPNGd - -KEYS = ["image", "label", "extra"] - -TEST_CASE_1 = [{"keys": KEYS}, (128, 128, 3)] - - -class TestLoadPNGd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) - def test_shape(self, input_param, expected_shape): - test_image = np.random.randint(0, 256, size=[128, 128, 3]) - with tempfile.TemporaryDirectory() as tempdir: - test_data = dict() - for key in KEYS: - Image.fromarray(test_image.astype("uint8")).save(os.path.join(tempdir, key + ".png")) - test_data.update({key: os.path.join(tempdir, key + ".png")}) - result = LoadPNGd(**input_param)(test_data) - for key in KEYS: - self.assertTupleEqual(result[key].shape, expected_shape) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_load_spacing_orientation.py b/tests/test_load_spacing_orientation.py index 5ae9cc4326..48aac7ec56 100644 --- a/tests/test_load_spacing_orientation.py +++ b/tests/test_load_spacing_orientation.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_local_normalized_cross_correlation_loss.py b/tests/test_local_normalized_cross_correlation_loss.py new file mode 100644 index 0000000000..cf8566a559 --- /dev/null +++ b/tests/test_local_normalized_cross_correlation_loss.py @@ -0,0 +1,114 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses.image_dissimilarity import LocalNormalizedCrossCorrelationLoss + +TEST_CASES = [ + [ + {"in_channels": 1, "ndim": 1, "kernel_type": "rectangular", "reduction": "sum"}, + { + "pred": torch.arange(0, 3).reshape(1, 1, -1).to(torch.float), + "target": torch.arange(0, 3).reshape(1, 1, -1).to(torch.float), + }, + -1.0 * 3, + ], + [ + {"in_channels": 1, "ndim": 1, "kernel_type": "rectangular"}, + { + "pred": torch.arange(0, 3).reshape(1, 1, -1).to(torch.float), + "target": torch.arange(0, 3).reshape(1, 1, -1).to(torch.float), + }, + -1.0, + ], + [ + {"in_channels": 1, "ndim": 2, "kernel_type": "rectangular"}, + { + "pred": torch.arange(0, 3).reshape(1, 1, -1, 1).expand(1, 1, 3, 3).to(torch.float), + "target": torch.arange(0, 3).reshape(1, 1, -1, 1).expand(1, 1, 3, 3).to(torch.float), + }, + -1.0, + ], + [ + {"in_channels": 1, "ndim": 3, "kernel_type": "rectangular"}, + { + "pred": torch.arange(0, 3).reshape(1, 1, -1, 1, 1).expand(1, 1, 3, 3, 3).to(torch.float), + "target": torch.arange(0, 3).reshape(1, 1, -1, 1, 1).expand(1, 1, 3, 3, 3).to(torch.float), + }, + -1.0, + ], + [ + {"in_channels": 3, "ndim": 3, "kernel_type": "rectangular"}, + { + "pred": torch.arange(0, 3).reshape(1, 1, -1, 1, 1).expand(1, 3, 3, 3, 3).to(torch.float), + "target": torch.arange(0, 3).reshape(1, 1, -1, 1, 1).expand(1, 3, 3, 3, 3).to(torch.float) ** 2, + }, + -0.95801723, + ], + [ + {"in_channels": 3, "ndim": 3, "kernel_type": "triangular", "kernel_size": 5}, + { + "pred": torch.arange(0, 5).reshape(1, 1, -1, 1, 1).expand(1, 3, 5, 5, 5).to(torch.float), + "target": torch.arange(0, 5).reshape(1, 1, -1, 1, 1).expand(1, 3, 5, 5, 5).to(torch.float) ** 2, + }, + -0.918672, + ], + [ + {"in_channels": 3, "ndim": 3, "kernel_type": "gaussian"}, + { + "pred": torch.arange(0, 3).reshape(1, 1, -1, 1, 1).expand(1, 3, 3, 3, 3).to(torch.float), + "target": torch.arange(0, 3).reshape(1, 1, -1, 1, 1).expand(1, 3, 3, 3, 3).to(torch.float) ** 2, + }, + -0.95406944, + ], +] + + +class TestLocalNormalizedCrossCorrelationLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_data, expected_val): + result = LocalNormalizedCrossCorrelationLoss(**input_param).forward(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5) + + def test_ill_shape(self): + loss = LocalNormalizedCrossCorrelationLoss(in_channels=3, ndim=3) + # in_channel unmatch + with self.assertRaisesRegex(ValueError, ""): + loss.forward(torch.ones((1, 2, 3, 3, 3), dtype=torch.float), torch.ones((1, 2, 3, 3, 3), dtype=torch.float)) + # ndim unmatch + with self.assertRaisesRegex(ValueError, ""): + loss.forward(torch.ones((1, 3, 3, 3), dtype=torch.float), torch.ones((1, 3, 3, 3), dtype=torch.float)) + # pred, target shape unmatch + with self.assertRaisesRegex(ValueError, ""): + loss.forward(torch.ones((1, 3, 3, 3, 3), dtype=torch.float), torch.ones((1, 3, 4, 4, 4), dtype=torch.float)) + + def test_ill_opts(self): + pred = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) + target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) + with self.assertRaisesRegex(ValueError, ""): + LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_type="unknown")(pred, target) + with self.assertRaisesRegex(ValueError, ""): + LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_type=None)(pred, target) + with self.assertRaisesRegex(ValueError, ""): + LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_size=4)(pred, target) + with self.assertRaisesRegex(ValueError, ""): + LocalNormalizedCrossCorrelationLoss(in_channels=3, reduction="unknown")(pred, target) + with self.assertRaisesRegex(ValueError, ""): + LocalNormalizedCrossCorrelationLoss(in_channels=3, reduction=None)(pred, target) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_localnet.py b/tests/test_localnet.py new file mode 100644 index 0000000000..97a10d0c83 --- /dev/null +++ b/tests/test_localnet.py @@ -0,0 +1,76 @@ +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.localnet import LocalNet +from tests.utils import test_script_save + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +TEST_CASE_LOCALNET_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 2, + "out_channels": 2, + "num_channel_initial": 16, + "extract_levels": [0, 1, 2], + "out_activation": act, + }, + (1, 2, 16, 16), + (1, 2, 16, 16), + ] + for act in ["sigmoid", None] +] + +TEST_CASE_LOCALNET_3D = [] +for in_channels in [2, 3]: + for out_channels in [1, 3]: + for num_channel_initial in [4, 16, 32]: + for extract_levels in [[0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]]: + for out_activation in ["sigmoid", None]: + for out_initializer in ["kaiming_uniform", "zeros"]: + TEST_CASE_LOCALNET_3D.append( + [ + { + "spatial_dims": 3, + "in_channels": in_channels, + "out_channels": out_channels, + "num_channel_initial": num_channel_initial, + "extract_levels": extract_levels, + "out_activation": out_activation, + "out_initializer": out_initializer, + }, + (1, in_channels, 16, 16, 16), + (1, out_channels, 16, 16, 16), + ] + ) + + +class TestLocalNet(unittest.TestCase): + @parameterized.expand(TEST_CASE_LOCALNET_2D + TEST_CASE_LOCALNET_3D) + def test_shape(self, input_param, input_shape, expected_shape): + net = LocalNet(**input_param).to(device) + with eval_mode(net): + result = net(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_shape(self): + with self.assertRaisesRegex(ValueError, ""): + input_param, _, _ = TEST_CASE_LOCALNET_2D[0] + input_shape = (1, input_param["in_channels"], 17, 17) + net = LocalNet(**input_param).to(device) + net.forward(torch.randn(input_shape).to(device)) + + def test_script(self): + input_param, input_shape, _ = TEST_CASE_LOCALNET_2D[0] + net = LocalNet(**input_param) + test_data = torch.randn(input_shape) + test_script_save(net, test_data) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_localnet_block.py b/tests/test_localnet_block.py new file mode 100644 index 0000000000..e6171aeae9 --- /dev/null +++ b/tests/test_localnet_block.py @@ -0,0 +1,95 @@ +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.blocks.localnet_block import ( + LocalNetDownSampleBlock, + LocalNetFeatureExtractorBlock, + LocalNetUpSampleBlock, +) + +TEST_CASE_DOWN_SAMPLE = [ + [{"spatial_dims": spatial_dims, "in_channels": 2, "out_channels": 4, "kernel_size": 3}] for spatial_dims in [2, 3] +] + +TEST_CASE_UP_SAMPLE = [[{"spatial_dims": spatial_dims, "in_channels": 4, "out_channels": 2}] for spatial_dims in [2, 3]] + +TEST_CASE_EXTRACT = [ + [{"spatial_dims": spatial_dims, "in_channels": 2, "out_channels": 3, "act": act, "initializer": initializer}] + for spatial_dims, act, initializer in zip([2, 3], ["sigmoid", None], ["kaiming_uniform", "zeros"]) +] + +in_size = 4 + + +class TestLocalNetDownSampleBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_DOWN_SAMPLE) + def test_shape(self, input_param): + net = LocalNetDownSampleBlock(**input_param) + input_shape = (1, input_param["in_channels"], *([in_size] * input_param["spatial_dims"])) + expect_mid_shape = (1, input_param["out_channels"], *([in_size] * input_param["spatial_dims"])) + expect_x_shape = (1, input_param["out_channels"], *([in_size / 2] * input_param["spatial_dims"])) + with eval_mode(net): + x, mid = net(torch.randn(input_shape)) + self.assertEqual(x.shape, expect_x_shape) + self.assertEqual(mid.shape, expect_mid_shape) + + def test_ill_arg(self): + # even kernel_size + with self.assertRaises(NotImplementedError): + LocalNetDownSampleBlock(spatial_dims=2, in_channels=2, out_channels=4, kernel_size=4) + + @parameterized.expand(TEST_CASE_DOWN_SAMPLE) + def test_ill_shape(self, input_param): + net = LocalNetDownSampleBlock(**input_param) + input_shape = (1, input_param["in_channels"], *([5] * input_param["spatial_dims"])) + with self.assertRaises(ValueError): + with eval_mode(net): + net(torch.randn(input_shape)) + + +class TestLocalNetUpSampleBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_UP_SAMPLE) + def test_shape(self, input_param): + net = LocalNetUpSampleBlock(**input_param) + input_shape = (1, input_param["in_channels"], *([in_size] * input_param["spatial_dims"])) + mid_shape = (1, input_param["out_channels"], *([in_size * 2] * input_param["spatial_dims"])) + expected_shape = mid_shape + with eval_mode(net): + result = net(torch.randn(input_shape), torch.randn(mid_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + # channel unmatch + with self.assertRaises(ValueError): + LocalNetUpSampleBlock(spatial_dims=2, in_channels=2, out_channels=2) + + @parameterized.expand(TEST_CASE_UP_SAMPLE) + def test_ill_shape(self, input_param): + net = LocalNetUpSampleBlock(**input_param) + input_shape = (1, input_param["in_channels"], *([in_size] * input_param["spatial_dims"])) + mid_shape = (1, input_param["out_channels"], *([in_size] * input_param["spatial_dims"])) + with self.assertRaises(ValueError): + with eval_mode(net): + net(torch.randn(input_shape), torch.randn(mid_shape)) + + +class TestExtractBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_EXTRACT) + def test_shape(self, input_param): + net = LocalNetFeatureExtractorBlock(**input_param) + input_shape = (1, input_param["in_channels"], *([in_size] * input_param["spatial_dims"])) + expected_shape = (1, input_param["out_channels"], *([in_size] * input_param["spatial_dims"])) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(ValueError): + LocalNetFeatureExtractorBlock(spatial_dims=2, in_channels=2, out_channels=2, initializer="none") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_lr_finder.py b/tests/test_lr_finder.py new file mode 100644 index 0000000000..9ee9c8a4d0 --- /dev/null +++ b/tests/test_lr_finder.py @@ -0,0 +1,81 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +import sys +import unittest + +import torch +from torch.utils.data import DataLoader + +from monai.apps import MedNISTDataset +from monai.networks.nets import DenseNet +from monai.optimizers import LearningRateFinder +from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord +from monai.utils import optional_import, set_determinism + +PILImage, has_pil = optional_import("PIL.Image") + +RAND_SEED = 42 +random.seed(RAND_SEED) +set_determinism(seed=RAND_SEED) + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +@unittest.skipUnless(sys.platform == "linux", "requires linux") +@unittest.skipUnless(has_pil, "requires PIL") +class TestLRFinder(unittest.TestCase): + def setUp(self): + + self.root_dir = os.environ.get("MONAI_DATA_DIRECTORY") + if not self.root_dir: + self.root_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") + + self.transforms = Compose( + [ + LoadImaged(keys="image"), + AddChanneld(keys="image"), + ScaleIntensityd(keys="image"), + ToTensord(keys="image"), + ] + ) + + def test_lr_finder(self): + # 0.001 gives 54 examples + train_ds = MedNISTDataset( + root_dir=self.root_dir, + transform=self.transforms, + section="validation", + val_frac=0.001, + download=True, + num_workers=10, + ) + train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=10) + num_classes = train_ds.get_num_classes() + + model = DenseNet( + spatial_dims=2, in_channels=1, out_channels=num_classes, init_features=2, growth_rate=2, block_config=(2,) + ) + loss_function = torch.nn.CrossEntropyLoss() + learning_rate = 1e-5 + optimizer = torch.optim.Adam(model.parameters(), learning_rate) + + lr_finder = LearningRateFinder(model, optimizer, loss_function, device=device) + lr_finder.range_test(train_loader, val_loader=train_loader, end_lr=10, num_iter=5) + print(lr_finder.get_steepest_gradient(0, 0)[0]) + lr_finder.plot(0, 0) # to inspect the loss-learning rate graph + lr_finder.reset() # to reset the model and optimizer to their initial state + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_map_binary_to_indices.py b/tests/test_map_binary_to_indices.py index 394aa7efec..1fafa6f446 100644 --- a/tests/test_map_binary_to_indices.py +++ b/tests/test_map_binary_to_indices.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_map_transform.py b/tests/test_map_transform.py index 22118202eb..803e699a7d 100644 --- a/tests/test_map_transform.py +++ b/tests/test_map_transform.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,7 +17,7 @@ TEST_CASES = [["item", ("item",)], [None, (None,)], [["item1", "item2"], ("item1", "item2")]] -TEST_ILL_CASES = [[ValueError, list()], [ValueError, tuple()], [TypeError, [list()]]] +TEST_ILL_CASES = [[ValueError, []], [ValueError, ()], [TypeError, [[]]]] class MapTest(MapTransform): diff --git a/tests/test_mask_intensity.py b/tests/test_mask_intensity.py index df9d77aadd..3131abe8bf 100644 --- a/tests/test_mask_intensity.py +++ b/tests/test_mask_intensity.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_mask_intensityd.py b/tests/test_mask_intensityd.py index f7527795b6..0d08952db2 100644 --- a/tests/test_mask_intensityd.py +++ b/tests/test_mask_intensityd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -34,9 +34,18 @@ np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 4, 0], [0, 5, 0], [0, 6, 0]]]), ] +TEST_CASE_4 = [ + {"keys": "img", "mask_key": "mask"}, + { + "img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + "mask": np.array([[[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 1, 0], [0, 1, 0], [0, 1, 0]]]), + }, + np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 4, 0], [0, 5, 0], [0, 6, 0]]]), +] + class TestMaskIntensityd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_value(self, argments, image, expected_data): result = MaskIntensityd(**argments)(image) np.testing.assert_allclose(result["img"], expected_data) diff --git a/tests/test_masked_dice_loss.py b/tests/test_masked_dice_loss.py index 3ea3151bd7..b8d69bc8f9 100644 --- a/tests/test_masked_dice_loss.py +++ b/tests/test_masked_dice_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_mean_ensemble.py b/tests/test_mean_ensemble.py index 7eb4b5a13c..32a6856263 100644 --- a/tests/test_mean_ensemble.py +++ b/tests/test_mean_ensemble.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_mean_ensembled.py b/tests/test_mean_ensembled.py index b26449fc85..c7549e5aa4 100644 --- a/tests/test_mean_ensembled.py +++ b/tests/test_mean_ensembled.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_mednistdataset.py b/tests/test_mednistdataset.py index 9b9f8a75b1..0887734a7c 100644 --- a/tests/test_mednistdataset.py +++ b/tests/test_mednistdataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -52,6 +52,7 @@ def _test_dataset(dataset): # testing from data = MedNISTDataset(root_dir=testing_dir, transform=transform, section="test", download=False) + data.get_num_classes() _test_dataset(data) data = MedNISTDataset(root_dir=testing_dir, section="test", download=False) self.assertTupleEqual(data[0]["image"].shape, (64, 64)) diff --git a/tests/test_nifti_header_revise.py b/tests/test_nifti_header_revise.py index 5998217614..8d9a1d4f3a 100644 --- a/tests/test_nifti_header_revise.py +++ b/tests/test_nifti_header_revise.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_nifti_rw.py b/tests/test_nifti_rw.py index 9a2f8ac75c..7bfa10c6c5 100644 --- a/tests/test_nifti_rw.py +++ b/tests/test_nifti_rw.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,7 +18,7 @@ from parameterized import parameterized from monai.data import write_nifti -from monai.transforms import LoadNifti, Orientation, Spacing +from monai.transforms import LoadImage, Orientation, Spacing from tests.utils import make_nifti_image TEST_IMAGE = np.arange(24).reshape((2, 4, 3)) @@ -27,11 +27,16 @@ ) TEST_CASES = [ - [TEST_IMAGE, TEST_AFFINE, dict(as_closest_canonical=True, image_only=False), np.arange(24).reshape((2, 4, 3))], [ TEST_IMAGE, TEST_AFFINE, - dict(as_closest_canonical=True, image_only=True), + dict(reader="NibabelReader", image_only=False, as_closest_canonical=True), + np.arange(24).reshape((2, 4, 3)), + ], + [ + TEST_IMAGE, + TEST_AFFINE, + dict(reader="NibabelReader", image_only=True, as_closest_canonical=True), np.array( [ [[12.0, 15.0, 18.0, 21.0], [13.0, 16.0, 19.0, 22.0], [14.0, 17.0, 20.0, 23.0]], @@ -39,9 +44,24 @@ ] ), ], - [TEST_IMAGE, TEST_AFFINE, dict(as_closest_canonical=False, image_only=True), np.arange(24).reshape((2, 4, 3))], - [TEST_IMAGE, TEST_AFFINE, dict(as_closest_canonical=False, image_only=False), np.arange(24).reshape((2, 4, 3))], - [TEST_IMAGE, None, dict(as_closest_canonical=False, image_only=False), np.arange(24).reshape((2, 4, 3))], + [ + TEST_IMAGE, + TEST_AFFINE, + dict(reader="NibabelReader", image_only=True, as_closest_canonical=False), + np.arange(24).reshape((2, 4, 3)), + ], + [ + TEST_IMAGE, + TEST_AFFINE, + dict(reader="NibabelReader", image_only=False, as_closest_canonical=False), + np.arange(24).reshape((2, 4, 3)), + ], + [ + TEST_IMAGE, + None, + dict(reader="NibabelReader", image_only=False, as_closest_canonical=False), + np.arange(24).reshape((2, 4, 3)), + ], ] @@ -51,7 +71,7 @@ def test_orientation(self, array, affine, reader_param, expected): test_image = make_nifti_image(array, affine) # read test cases - loader = LoadNifti(**reader_param) + loader = LoadImage(**reader_param) load_result = loader(test_image) if isinstance(load_result, tuple): data_array, header = load_result @@ -79,7 +99,7 @@ def test_orientation(self, array, affine, reader_param, expected): def test_consistency(self): np.set_printoptions(suppress=True, precision=3) test_image = make_nifti_image(np.arange(64).reshape(1, 8, 8), np.diag([1.5, 1.5, 1.5, 1])) - data, header = LoadNifti(as_closest_canonical=False)(test_image) + data, header = LoadImage(reader="NibabelReader", as_closest_canonical=False)(test_image) data, original_affine, new_affine = Spacing([0.8, 0.8, 0.8])(data[None], header["affine"], mode="nearest") data, _, new_affine = Orientation("ILP")(data, new_affine) if os.path.exists(test_image): diff --git a/tests/test_nifti_saver.py b/tests/test_nifti_saver.py index ef07007fad..2e2bfd4254 100644 --- a/tests/test_nifti_saver.py +++ b/tests/test_nifti_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_normalize_intensity.py b/tests/test_normalize_intensity.py index 06768f77b7..ecf162e12f 100644 --- a/tests/test_normalize_intensity.py +++ b/tests/test_normalize_intensity.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -61,7 +61,7 @@ def test_default(self): normalized = normalizer(self.imt) self.assertTrue(normalized.dtype == np.float32) expected = (self.imt - np.mean(self.imt)) / np.std(self.imt) - np.testing.assert_allclose(normalized, expected, rtol=1e-6) + np.testing.assert_allclose(normalized, expected, rtol=1e-5) @parameterized.expand(TEST_CASES) def test_nonzero(self, input_param, input_data, expected_data): diff --git a/tests/test_normalize_intensityd.py b/tests/test_normalize_intensityd.py index cc2241ac5d..a3a1eb518c 100644 --- a/tests/test_normalize_intensityd.py +++ b/tests/test_normalize_intensityd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_numpy_reader.py b/tests/test_numpy_reader.py index 6d7589a368..a57a036905 100644 --- a/tests/test_numpy_reader.py +++ b/tests/test_numpy_reader.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_occlusion_sensitivity.py b/tests/test_occlusion_sensitivity.py new file mode 100644 index 0000000000..47a13d01e1 --- /dev/null +++ b/tests/test_occlusion_sensitivity.py @@ -0,0 +1,96 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.nets import DenseNet, densenet121 +from monai.visualize import OcclusionSensitivity + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +out_channels_2d = 4 +out_channels_3d = 3 +model_2d = densenet121(spatial_dims=2, in_channels=1, out_channels=out_channels_2d).to(device) +model_3d = DenseNet( + spatial_dims=3, in_channels=1, out_channels=out_channels_3d, init_features=2, growth_rate=2, block_config=(6,) +).to(device) +model_2d.eval() +model_3d.eval() + +# 2D w/ bounding box +TEST_CASE_0 = [ + { + "nn_module": model_2d, + }, + { + "x": torch.rand(1, 1, 48, 64).to(device), + "b_box": [-1, -1, 2, 40, 1, 62], + }, + (1, 1, 39, 62, out_channels_2d), + (1, 1, 39, 62), +] +# 3D w/ bounding box and stride +TEST_CASE_1 = [ + {"nn_module": model_3d, "n_batch": 10, "stride": (2, 1, 2), "mask_size": (16, 15, 14)}, + { + "x": torch.rand(1, 1, 6, 6, 6).to(device), + "b_box": [-1, -1, 2, 3, -1, -1, -1, -1], + }, + (1, 1, 2, 6, 6, out_channels_3d), + (1, 1, 2, 6, 6), +] + +TEST_CASE_FAIL_0 = [ # 2D should fail, since 3 stride values given + { + "nn_module": model_2d, + "n_batch": 10, + "stride": (2, 2, 2), + }, + { + "x": torch.rand(1, 1, 48, 64).to(device), + "b_box": [-1, -1, 2, 3, -1, -1], + }, +] + +TEST_CASE_FAIL_1 = [ # 2D should fail, since stride is not a factor of image size + { + "nn_module": model_2d, + "stride": 3, + }, + { + "x": torch.rand(1, 1, 48, 64).to(device), + }, +] + + +class TestComputeOcclusionSensitivity(unittest.TestCase): + @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) + def test_shape(self, init_data, call_data, map_expected_shape, most_prob_expected_shape): + occ_sens = OcclusionSensitivity(**init_data) + m, most_prob = occ_sens(**call_data) + self.assertTupleEqual(m.shape, map_expected_shape) + self.assertTupleEqual(most_prob.shape, most_prob_expected_shape) + # most probable class should be of type int, and should have min>=0, max loss_val, "loss did not decrease") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_repeat_channel.py b/tests/test_repeat_channel.py index a89b1375f2..643ebc64de 100644 --- a/tests/test_repeat_channel.py +++ b/tests/test_repeat_channel.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_repeat_channeld.py b/tests/test_repeat_channeld.py index 73d446f00e..7bd58bd1fe 100644 --- a/tests/test_repeat_channeld.py +++ b/tests/test_repeat_channeld.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_resampler.py b/tests/test_resampler.py index 5f3383aeaa..a4536967fa 100644 --- a/tests/test_resampler.py +++ b/tests/test_resampler.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_resize.py b/tests/test_resize.py index a4cb1a8d85..22a68bcf85 100644 --- a/tests/test_resize.py +++ b/tests/test_resize.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -39,7 +39,7 @@ def test_correct_results(self, spatial_size, mode): _order = 1 if spatial_size == (32, -1): spatial_size = (32, 64) - expected = list() + expected = [] for channel in self.imt[0]: expected.append( skimage.transform.resize( diff --git a/tests/test_resize_with_pad_or_crop.py b/tests/test_resize_with_pad_or_crop.py index d6a9dadfca..53fb0d3002 100644 --- a/tests/test_resize_with_pad_or_crop.py +++ b/tests/test_resize_with_pad_or_crop.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_resize_with_pad_or_cropd.py b/tests/test_resize_with_pad_or_cropd.py index 008b1e2d17..8cbb31b5a6 100644 --- a/tests/test_resize_with_pad_or_cropd.py +++ b/tests/test_resize_with_pad_or_cropd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_resized.py b/tests/test_resized.py index 81a37a65a4..d89c866af3 100644 --- a/tests/test_resized.py +++ b/tests/test_resized.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -37,7 +37,7 @@ def test_correct_results(self, spatial_size, mode): _order = 1 if spatial_size == (32, -1): spatial_size = (32, 64) - expected = list() + expected = [] for channel in self.imt[0]: expected.append( skimage.transform.resize( diff --git a/tests/test_rotate.py b/tests/test_rotate.py index 837f11fe7e..6e43ab90e7 100644 --- a/tests/test_rotate.py +++ b/tests/test_rotate.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -56,7 +56,7 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne else: _mode = "constant" - expected = list() + expected = [] for channel in self.imt[0]: expected.append( scipy.ndimage.rotate( @@ -88,7 +88,7 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne else: _mode = "constant" - expected = list() + expected = [] for channel in self.imt[0]: expected.append( scipy.ndimage.rotate( diff --git a/tests/test_rotate90.py b/tests/test_rotate90.py index e4eafcc88c..a8b4e3f57c 100644 --- a/tests/test_rotate90.py +++ b/tests/test_rotate90.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,7 +21,7 @@ class TestRotate90(NumpyImageTestCase2D): def test_rotate90_default(self): rotate = Rotate90() rotated = rotate(self.imt[0]) - expected = list() + expected = [] for channel in self.imt[0]: expected.append(np.rot90(channel, 1, (0, 1))) expected = np.stack(expected) @@ -30,7 +30,7 @@ def test_rotate90_default(self): def test_k(self): rotate = Rotate90(k=2) rotated = rotate(self.imt[0]) - expected = list() + expected = [] for channel in self.imt[0]: expected.append(np.rot90(channel, 2, (0, 1))) expected = np.stack(expected) @@ -39,7 +39,7 @@ def test_k(self): def test_spatial_axes(self): rotate = Rotate90(spatial_axes=(0, 1)) rotated = rotate(self.imt[0]) - expected = list() + expected = [] for channel in self.imt[0]: expected.append(np.rot90(channel, 1, (0, 1))) expected = np.stack(expected) @@ -48,7 +48,7 @@ def test_spatial_axes(self): def test_prob_k_spatial_axes(self): rotate = Rotate90(k=2, spatial_axes=(0, 1)) rotated = rotate(self.imt[0]) - expected = list() + expected = [] for channel in self.imt[0]: expected.append(np.rot90(channel, 2, (0, 1))) expected = np.stack(expected) diff --git a/tests/test_rotate90d.py b/tests/test_rotate90d.py index 1f85d2fc0a..3d71ead82a 100644 --- a/tests/test_rotate90d.py +++ b/tests/test_rotate90d.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,7 +22,7 @@ def test_rotate90_default(self): key = "test" rotate = Rotate90d(keys=key) rotated = rotate({key: self.imt[0]}) - expected = list() + expected = [] for channel in self.imt[0]: expected.append(np.rot90(channel, 1, (0, 1))) expected = np.stack(expected) @@ -32,7 +32,7 @@ def test_k(self): key = None rotate = Rotate90d(keys=key, k=2) rotated = rotate({key: self.imt[0]}) - expected = list() + expected = [] for channel in self.imt[0]: expected.append(np.rot90(channel, 2, (0, 1))) expected = np.stack(expected) @@ -42,7 +42,7 @@ def test_spatial_axes(self): key = "test" rotate = Rotate90d(keys=key, spatial_axes=(0, 1)) rotated = rotate({key: self.imt[0]}) - expected = list() + expected = [] for channel in self.imt[0]: expected.append(np.rot90(channel, 1, (0, 1))) expected = np.stack(expected) @@ -52,7 +52,7 @@ def test_prob_k_spatial_axes(self): key = "test" rotate = Rotate90d(keys=key, k=2, spatial_axes=(0, 1)) rotated = rotate({key: self.imt[0]}) - expected = list() + expected = [] for channel in self.imt[0]: expected.append(np.rot90(channel, 2, (0, 1))) expected = np.stack(expected) diff --git a/tests/test_rotated.py b/tests/test_rotated.py index 14b1d6d8bd..3353ae9fba 100644 --- a/tests/test_rotated.py +++ b/tests/test_rotated.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_savitzky_golay_filter.py b/tests/test_savitzky_golay_filter.py new file mode 100644 index 0000000000..9163204810 --- /dev/null +++ b/tests/test_savitzky_golay_filter.py @@ -0,0 +1,152 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.layers import SavitzkyGolayFilter +from tests.utils import skip_if_no_cuda + +# Zero-padding trivial tests + +TEST_CASE_SINGLE_VALUE = [ + {"window_length": 3, "order": 1}, + torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value + torch.Tensor([1 / 3]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1 + # output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed) + 1e-15, # absolute tolerance +] + +TEST_CASE_1D = [ + {"window_length": 3, "order": 1}, + torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Input data + torch.Tensor([2 / 3, 1.0, 2 / 3]) + .unsqueeze(0) + .unsqueeze(0), # Expected output: zero padded, so linear interpolation + # over length-3 windows will result in output of [2/3, 1, 2/3]. + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_2 = [ + {"window_length": 3, "order": 1}, # along default axis (2, first spatial dim) + torch.ones((3, 2)).unsqueeze(0).unsqueeze(0), + torch.Tensor([[2 / 3, 2 / 3], [1.0, 1.0], [2 / 3, 2 / 3]]).unsqueeze(0).unsqueeze(0), + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_3 = [ + {"window_length": 3, "order": 1, "axis": 3}, # along axis 3 (second spatial dim) + torch.ones((2, 3)).unsqueeze(0).unsqueeze(0), + torch.Tensor([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]).unsqueeze(0).unsqueeze(0), + 1e-15, # absolute tolerance +] + +# Replicated-padding trivial tests + +TEST_CASE_SINGLE_VALUE_REP = [ + {"window_length": 3, "order": 1, "mode": "replicate"}, + torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value + torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1 + # output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed) + 1e-15, # absolute tolerance +] + +TEST_CASE_1D_REP = [ + {"window_length": 3, "order": 1, "mode": "replicate"}, + torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Input data + torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Expected output: zero padded, so linear interpolation + # over length-3 windows will result in output of [2/3, 1, 2/3]. + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_2_REP = [ + {"window_length": 3, "order": 1, "mode": "replicate"}, # along default axis (2, first spatial dim) + torch.ones((3, 2)).unsqueeze(0).unsqueeze(0), + torch.Tensor([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]).unsqueeze(0).unsqueeze(0), + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_3_REP = [ + {"window_length": 3, "order": 1, "axis": 3, "mode": "replicate"}, # along axis 3 (second spatial dim) + torch.ones((2, 3)).unsqueeze(0).unsqueeze(0), + torch.Tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]).unsqueeze(0).unsqueeze(0), + 1e-15, # absolute tolerance +] + +# Sine smoothing + +TEST_CASE_SINE_SMOOTH = [ + {"window_length": 3, "order": 1}, + # Sine wave with period equal to savgol window length (windowed to reduce edge effects). + torch.as_tensor(np.sin(2 * np.pi * 1 / 3 * np.arange(100)) * np.hanning(100)).unsqueeze(0).unsqueeze(0), + # Should be smoothed out to zeros + torch.zeros(100).unsqueeze(0).unsqueeze(0), + # tolerance chosen by examining output of SciPy.signal.savgol_filter when provided the above input + 2e-2, # absolute tolerance +] + + +class TestSavitzkyGolayCPU(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_SINGLE_VALUE, + TEST_CASE_1D, + TEST_CASE_2D_AXIS_2, + TEST_CASE_2D_AXIS_3, + TEST_CASE_SINE_SMOOTH, + ] + ) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolayFilter(**arguments)(image) + np.testing.assert_allclose(result, expected_data, atol=atol) + + +class TestSavitzkyGolayCPUREP(unittest.TestCase): + @parameterized.expand( + [TEST_CASE_SINGLE_VALUE_REP, TEST_CASE_1D_REP, TEST_CASE_2D_AXIS_2_REP, TEST_CASE_2D_AXIS_3_REP] + ) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolayFilter(**arguments)(image) + np.testing.assert_allclose(result, expected_data, atol=atol) + + +@skip_if_no_cuda +class TestSavitzkyGolayGPU(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_SINGLE_VALUE, + TEST_CASE_1D, + TEST_CASE_2D_AXIS_2, + TEST_CASE_2D_AXIS_3, + TEST_CASE_SINE_SMOOTH, + ] + ) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolayFilter(**arguments)(image.to(device="cuda")) + np.testing.assert_allclose(result.cpu(), expected_data, atol=atol) + + +@skip_if_no_cuda +class TestSavitzkyGolayGPUREP(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_SINGLE_VALUE_REP, + TEST_CASE_1D_REP, + TEST_CASE_2D_AXIS_2_REP, + TEST_CASE_2D_AXIS_3_REP, + ] + ) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolayFilter(**arguments)(image.to(device="cuda")) + np.testing.assert_allclose(result.cpu(), expected_data, atol=atol) diff --git a/tests/test_savitzky_golay_smooth.py b/tests/test_savitzky_golay_smooth.py new file mode 100644 index 0000000000..63dcce1b05 --- /dev/null +++ b/tests/test_savitzky_golay_smooth.py @@ -0,0 +1,70 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import SavitzkyGolaySmooth + +# Zero-padding trivial tests + +TEST_CASE_SINGLE_VALUE = [ + {"window_length": 3, "order": 1}, + np.expand_dims(np.array([1.0]), 0), # Input data: Single value + np.expand_dims(np.array([1 / 3]), 0), # Expected output: With a window length of 3 and polyorder 1 + # output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed) + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_2 = [ + {"window_length": 3, "order": 1, "axis": 2}, # along axis 2 (second spatial dim) + np.expand_dims(np.ones((2, 3)), 0), + np.expand_dims(np.array([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]), 0), + 1e-15, # absolute tolerance +] + +# Replicated-padding trivial tests + +TEST_CASE_SINGLE_VALUE_REP = [ + {"window_length": 3, "order": 1, "mode": "replicate"}, + np.expand_dims(np.array([1.0]), 0), # Input data: Single value + np.expand_dims(np.array([1.0]), 0), # Expected output: With a window length of 3 and polyorder 1 + # output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed) + 1e-15, # absolute tolerance +] + +# Sine smoothing + +TEST_CASE_SINE_SMOOTH = [ + {"window_length": 3, "order": 1}, + # Sine wave with period equal to savgol window length (windowed to reduce edge effects). + np.expand_dims(np.sin(2 * np.pi * 1 / 3 * np.arange(100)) * np.hanning(100), 0), + # Should be smoothed out to zeros + np.expand_dims(np.zeros(100), 0), + # tolerance chosen by examining output of SciPy.signal.savgol_filter() when provided the above input + 2e-2, # absolute tolerance +] + + +class TestSavitzkyGolaySmooth(unittest.TestCase): + @parameterized.expand([TEST_CASE_SINGLE_VALUE, TEST_CASE_2D_AXIS_2, TEST_CASE_SINE_SMOOTH]) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolaySmooth(**arguments)(image) + np.testing.assert_allclose(result, expected_data, atol=atol) + + +class TestSavitzkyGolaySmoothREP(unittest.TestCase): + @parameterized.expand([TEST_CASE_SINGLE_VALUE_REP]) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolaySmooth(**arguments)(image) + np.testing.assert_allclose(result, expected_data, atol=atol) diff --git a/tests/test_scale_intensity.py b/tests/test_scale_intensity.py index 7d9c9ea901..61e89191fd 100644 --- a/tests/test_scale_intensity.py +++ b/tests/test_scale_intensity.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_scale_intensity_range.py b/tests/test_scale_intensity_range.py index d952d18ce9..cba07d9157 100644 --- a/tests/test_scale_intensity_range.py +++ b/tests/test_scale_intensity_range.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_scale_intensity_range_percentiles.py b/tests/test_scale_intensity_range_percentiles.py index ace11fcc8c..8393d7c082 100644 --- a/tests/test_scale_intensity_range_percentiles.py +++ b/tests/test_scale_intensity_range_percentiles.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_scale_intensity_range_percentilesd.py b/tests/test_scale_intensity_range_percentilesd.py index 75e79b7c9b..5057c1e32c 100644 --- a/tests/test_scale_intensity_range_percentilesd.py +++ b/tests/test_scale_intensity_range_percentilesd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,7 +20,7 @@ class TestScaleIntensityRangePercentilesd(NumpyImageTestCase2D): def test_scaling(self): img = self.imt - data = dict() + data = {} data["img"] = img lower = 10 upper = 99 @@ -38,7 +38,7 @@ def test_scaling(self): def test_relative_scaling(self): img = self.imt - data = dict() + data = {} data["img"] = img lower = 10 upper = 99 diff --git a/tests/test_scale_intensity_ranged.py b/tests/test_scale_intensity_ranged.py index c4c12bfacf..a8cac414e8 100644 --- a/tests/test_scale_intensity_ranged.py +++ b/tests/test_scale_intensity_ranged.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_scale_intensityd.py b/tests/test_scale_intensityd.py index 772952aef3..688c99c6af 100644 --- a/tests/test_scale_intensityd.py +++ b/tests/test_scale_intensityd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_se_block.py b/tests/test_se_block.py index ce3ffc89c9..1f515a7fb4 100644 --- a/tests/test_se_block.py +++ b/tests/test_se_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.blocks import SEBlock from monai.networks.layers.factories import Act, Norm from tests.utils import test_script_save @@ -63,8 +64,7 @@ class TestSEBlockLayer(unittest.TestCase): @parameterized.expand(TEST_CASES + TEST_CASES_3D) def test_shape(self, input_param, input_shape, expected_shape): net = SEBlock(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_se_blocks.py b/tests/test_se_blocks.py index 654cd1f1bf..e9aed7d9d9 100644 --- a/tests/test_se_blocks.py +++ b/tests/test_se_blocks.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.blocks import ChannelSELayer, ResidualSELayer from tests.utils import test_script_save @@ -41,8 +42,7 @@ class TestChannelSELayer(unittest.TestCase): @parameterized.expand(TEST_CASES + TEST_CASES_3D) def test_shape(self, input_param, input_shape, expected_shape): net = ChannelSELayer(**input_param) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) @@ -61,8 +61,7 @@ class TestResidualSELayer(unittest.TestCase): @parameterized.expand(TEST_CASES[:1]) def test_shape(self, input_param, input_shape, expected_shape): net = ResidualSELayer(**input_param) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_seg_loss_integration.py b/tests/test_seg_loss_integration.py index c583939415..2103119342 100644 --- a/tests/test_seg_loss_integration.py +++ b/tests/test_seg_loss_integration.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_segresnet.py b/tests/test_segresnet.py index fc1325d94e..a3fae55a1a 100644 --- a/tests/test_segresnet.py +++ b/tests/test_segresnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.nets import SegResNet, SegResNetVAE from monai.utils import UpsampleMode from tests.utils import test_script_save @@ -82,8 +83,7 @@ class TestResNet(unittest.TestCase): @parameterized.expand(TEST_CASE_SEGRESNET + TEST_CASE_SEGRESNET_2) def test_shape(self, input_param, input_shape, expected_shape): net = SegResNet(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) @@ -102,7 +102,7 @@ class TestResNetVAE(unittest.TestCase): @parameterized.expand(TEST_CASE_SEGRESNET_VAE) def test_vae_shape(self, input_param, input_shape, expected_shape): net = SegResNetVAE(**input_param).to(device) - with torch.no_grad(): + with eval_mode(net): result, _ = net(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_segresnet_block.py b/tests/test_segresnet_block.py index 0598362619..2848e2ad04 100644 --- a/tests/test_segresnet_block.py +++ b/tests/test_segresnet_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.blocks.segresnet_block import ResBlock TEST_CASE_RESBLOCK = [] @@ -39,8 +40,7 @@ class TestResBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_RESBLOCK) def test_shape(self, input_param, input_shape, expected_shape): net = ResBlock(**input_param) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_select_cross_validation_folds.py b/tests/test_select_cross_validation_folds.py index 9897971513..6dbd004e71 100644 --- a/tests/test_select_cross_validation_folds.py +++ b/tests/test_select_cross_validation_folds.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_select_itemsd.py b/tests/test_select_itemsd.py index be683ef6a7..bf63864eb0 100644 --- a/tests/test_select_itemsd.py +++ b/tests/test_select_itemsd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,7 +23,7 @@ class TestSelectItemsd(unittest.TestCase): @parameterized.expand([TEST_CASE_1]) def test_memory(self, input_param, expected_key_size): - input_data = dict() + input_data = {} for i in range(50): input_data[str(i)] = [time.time()] * 100000 result = SelectItemsd(**input_param)(input_data) diff --git a/tests/test_senet.py b/tests/test_senet.py index d4fdbe28a7..883d75d62d 100644 --- a/tests/test_senet.py +++ b/tests/test_senet.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.nets import ( se_resnet50, se_resnet101, @@ -42,9 +43,8 @@ class TestSENET(unittest.TestCase): def test_senet_shape(self, net, net_args): input_data = torch.randn(2, 2, 64, 64, 64).to(device) expected_shape = (2, 2) - net = net(**net_args) - net = net.to(device).eval() - with torch.no_grad(): + net = net(**net_args).to(device) + with eval_mode(net): result = net(input_data) self.assertEqual(result.shape, expected_shape) @@ -65,8 +65,8 @@ def test_senet_shape(self, model, input_param): net = test_pretrained_networks(model, input_param, device) input_data = torch.randn(3, 3, 64, 64).to(device) expected_shape = (3, 2) - net = net.to(device).eval() - with torch.no_grad(): + net = net.to(device) + with eval_mode(net): result = net(input_data) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_set_determinism.py b/tests/test_set_determinism.py index 14b908fa35..bc4927007b 100644 --- a/tests/test_set_determinism.py +++ b/tests/test_set_determinism.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_shift_intensity.py b/tests/test_shift_intensity.py index 339acc3f96..b73c18b6a5 100644 --- a/tests/test_shift_intensity.py +++ b/tests/test_shift_intensity.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_shift_intensityd.py b/tests/test_shift_intensityd.py index 63befbffee..752cf4b8d2 100644 --- a/tests/test_shift_intensityd.py +++ b/tests/test_shift_intensityd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_simple_aspp.py b/tests/test_simple_aspp.py index da3ed3ecb2..89ca589c51 100644 --- a/tests/test_simple_aspp.py +++ b/tests/test_simple_aspp.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.blocks import SimpleASPP TEST_CASES = [ @@ -69,8 +70,7 @@ class TestChannelSELayer(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_shape, expected_shape): net = SimpleASPP(**input_param) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_simulatedelay.py b/tests/test_simulatedelay.py index a28d3cf88e..3a4686218e 100644 --- a/tests/test_simulatedelay.py +++ b/tests/test_simulatedelay.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_simulatedelayd.py b/tests/test_simulatedelayd.py index ec4660d3f5..58bd3eb6b8 100644 --- a/tests/test_simulatedelayd.py +++ b/tests/test_simulatedelayd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_skip_connection.py b/tests/test_skip_connection.py index aa6b4a35c3..2118842ed0 100644 --- a/tests/test_skip_connection.py +++ b/tests/test_skip_connection.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.layers import SkipConnection TEST_CASES_3D = [] @@ -35,8 +36,7 @@ class TestSkipConnection(unittest.TestCase): @parameterized.expand(TEST_CASES_3D) def test_shape(self, input_param, input_shape, expected_shape): net = SkipConnection(submodule=torch.nn.Softmax(dim=1), **input_param) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_sliding_window_inference.py b/tests/test_sliding_window_inference.py index ce18dbebfe..a22e5990bf 100644 --- a/tests/test_sliding_window_inference.py +++ b/tests/test_sliding_window_inference.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_smartcachedataset.py b/tests/test_smartcachedataset.py index cc458b281b..3d1a051a83 100644 --- a/tests/test_smartcachedataset.py +++ b/tests/test_smartcachedataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_spacing.py b/tests/test_spacing.py index db6cd8e082..bc491f2f82 100644 --- a/tests/test_spacing.py +++ b/tests/test_spacing.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_spacingd.py b/tests/test_spacingd.py index 5380da7cd1..ec32563543 100644 --- a/tests/test_spacingd.py +++ b/tests/test_spacingd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_spatial_crop.py b/tests/test_spatial_crop.py index e8c4c4d00c..f3c904889f 100644 --- a/tests/test_spatial_crop.py +++ b/tests/test_spatial_crop.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_spatial_cropd.py b/tests/test_spatial_cropd.py index 5e7a92fb97..590dc83281 100644 --- a/tests/test_spatial_cropd.py +++ b/tests/test_spatial_cropd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_spatial_pad.py b/tests/test_spatial_pad.py index 270d580aed..4473a23770 100644 --- a/tests/test_spatial_pad.py +++ b/tests/test_spatial_pad.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_spatial_padd.py b/tests/test_spatial_padd.py index 2f6667d9ff..8400bb82cc 100644 --- a/tests/test_spatial_padd.py +++ b/tests/test_spatial_padd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_split_channel.py b/tests/test_split_channel.py index c9c2d79ca3..8eec3c4e70 100644 --- a/tests/test_split_channel.py +++ b/tests/test_split_channel.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_split_channeld.py b/tests/test_split_channeld.py index 6b6c7ab36d..814ef69922 100644 --- a/tests/test_split_channeld.py +++ b/tests/test_split_channeld.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_squeezedim.py b/tests/test_squeezedim.py index def940d807..01ea489320 100644 --- a/tests/test_squeezedim.py +++ b/tests/test_squeezedim.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_squeezedimd.py b/tests/test_squeezedimd.py index f9a13b9890..dcbd9212c7 100644 --- a/tests/test_squeezedimd.py +++ b/tests/test_squeezedimd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_state_cacher.py b/tests/test_state_cacher.py new file mode 100644 index 0000000000..139e7b8374 --- /dev/null +++ b/tests/test_state_cacher.py @@ -0,0 +1,68 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from os.path import exists, join +from tempfile import gettempdir + +import torch +from parameterized import parameterized + +from monai.utils import StateCacher + +DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" + +TEST_CASE_0 = [ + torch.Tensor([1]).to(DEVICE), + {"in_memory": True}, +] +TEST_CASE_1 = [ + torch.Tensor([1]).to(DEVICE), + {"in_memory": False, "cache_dir": gettempdir()}, +] +TEST_CASE_2 = [ + torch.Tensor([1]).to(DEVICE), + {"in_memory": False, "allow_overwrite": False}, +] + +TEST_CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2] + + +class TestStateCacher(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_state_cacher(self, data_obj, params): + + key = "data_obj" + + state_cacher = StateCacher(**params) + # store it + state_cacher.store(key, data_obj) + # create clone then modify original + data_obj_orig = data_obj.clone() + data_obj += 1 + # Restore and check nothing has changed + data_obj_restored = state_cacher.retrieve(key) + self.assertEqual(data_obj_orig, data_obj_restored) + + # If not allow overwrite, check an attempt would raise exception + if "allow_overwrite" in params and params["allow_overwrite"]: + with self.assertRaises(RuntimeError): + state_cacher.store(key, data_obj) + + # If using a cache dir, check file has been deleted et end + if "cache_dir" in params: + i = id(state_cacher) + del state_cacher + self.assertFalse(exists(join(params["cache_dir"], f"state_{key}_{i}.pt"))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_subpixel_upsample.py b/tests/test_subpixel_upsample.py index 92e12ecf6c..07e110d7a7 100644 --- a/tests/test_subpixel_upsample.py +++ b/tests/test_subpixel_upsample.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,6 +15,7 @@ import torch.nn as nn from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.blocks import SubpixelUpsample from monai.networks.layers.factories import Conv @@ -75,8 +76,7 @@ class TestSUBPIXEL(unittest.TestCase): @parameterized.expand(TEST_CASE_SUBPIXEL) def test_subpixel_shape(self, input_param, input_shape, expected_shape): net = SubpixelUpsample(**input_param) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_surface_distance.py b/tests/test_surface_distance.py index dca3aaec12..db90c87938 100644 --- a/tests/test_surface_distance.py +++ b/tests/test_surface_distance.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_thread_buffer.py b/tests/test_thread_buffer.py index 2688f1ae40..07e5a779ca 100644 --- a/tests/test_thread_buffer.py +++ b/tests/test_thread_buffer.py @@ -1,3 +1,14 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import time import unittest diff --git a/tests/test_threshold_intensity.py b/tests/test_threshold_intensity.py index 6eb471d87c..a6d3895709 100644 --- a/tests/test_threshold_intensity.py +++ b/tests/test_threshold_intensity.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_threshold_intensityd.py b/tests/test_threshold_intensityd.py index 305b950131..efcfcfe604 100644 --- a/tests/test_threshold_intensityd.py +++ b/tests/test_threshold_intensityd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_timedcall.py b/tests/test_timedcall.py index aa9d170d85..e87d160743 100644 --- a/tests/test_timedcall.py +++ b/tests/test_timedcall.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_to_numpy.py b/tests/test_to_numpy.py index ffc14dd7c4..581731d4b5 100644 --- a/tests/test_to_numpy.py +++ b/tests/test_to_numpy.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_to_numpyd.py b/tests/test_to_numpyd.py index 91fcf42e30..48db52183b 100644 --- a/tests/test_to_numpyd.py +++ b/tests/test_to_numpyd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_to_onehot.py b/tests/test_to_onehot.py index 974b8ab0e0..c3e373955d 100644 --- a/tests/test_to_onehot.py +++ b/tests/test_to_onehot.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_torchvision.py b/tests/test_torchvision.py new file mode 100644 index 0000000000..0846b7f6b6 --- /dev/null +++ b/tests/test_torchvision.py @@ -0,0 +1,86 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms import TorchVision +from monai.utils import set_determinism +from tests.utils import SkipIfBeforePyTorchVersion + +TEST_CASE_1 = [ + {"name": "ColorJitter"}, + torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), + torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), +] + +TEST_CASE_2 = [ + {"name": "ColorJitter", "brightness": 0.5, "contrast": 0.5, "saturation": [0.1, 0.8], "hue": 0.5}, + torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), + torch.tensor( + [ + [ + [0.1090, 0.6193], + [0.6193, 0.9164], + ], + [ + [0.1090, 0.6193], + [0.6193, 0.9164], + ], + [ + [0.1090, 0.6193], + [0.6193, 0.9164], + ], + ], + ), +] + +TEST_CASE_3 = [ + {"name": "Pad", "padding": [1, 1, 1, 1]}, + torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), + torch.tensor( + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 2.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 2.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 2.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ), +] + + +@SkipIfBeforePyTorchVersion((1, 7)) +class TestTorchVision(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_value(self, input_param, input_data, expected_value): + set_determinism(seed=0) + result = TorchVision(**input_param)(input_data) + torch.testing.assert_allclose(result, expected_value) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_torchvisiond.py b/tests/test_torchvisiond.py new file mode 100644 index 0000000000..4f42bc95f7 --- /dev/null +++ b/tests/test_torchvisiond.py @@ -0,0 +1,86 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms import TorchVisiond +from monai.utils import set_determinism +from tests.utils import SkipIfBeforePyTorchVersion + +TEST_CASE_1 = [ + {"keys": "img", "name": "ColorJitter"}, + {"img": torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]])}, + torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), +] + +TEST_CASE_2 = [ + {"keys": "img", "name": "ColorJitter", "brightness": 0.5, "contrast": 0.5, "saturation": [0.1, 0.8], "hue": 0.5}, + {"img": torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]])}, + torch.tensor( + [ + [ + [0.1090, 0.6193], + [0.6193, 0.9164], + ], + [ + [0.1090, 0.6193], + [0.6193, 0.9164], + ], + [ + [0.1090, 0.6193], + [0.6193, 0.9164], + ], + ], + ), +] + +TEST_CASE_3 = [ + {"keys": "img", "name": "Pad", "padding": [1, 1, 1, 1]}, + {"img": torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]])}, + torch.tensor( + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 2.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 2.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 2.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ), +] + + +@SkipIfBeforePyTorchVersion((1, 7)) +class TestTorchVisiond(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_value(self, input_param, input_data, expected_value): + set_determinism(seed=0) + result = TorchVisiond(**input_param)(input_data) + torch.testing.assert_allclose(result["img"], expected_value) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_train_mode.py b/tests/test_train_mode.py new file mode 100644 index 0000000000..1acb443041 --- /dev/null +++ b/tests/test_train_mode.py @@ -0,0 +1,31 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from monai.networks.utils import train_mode + + +class TestEvalMode(unittest.TestCase): + def test_eval_mode(self): + t = torch.rand(1, 1, 4, 4) + p = torch.nn.Conv2d(1, 1, 3) + p.eval() + self.assertFalse(p.training) # False + with train_mode(p): + self.assertTrue(p.training) # True + p(t).sum().backward() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_tversky_loss.py b/tests/test_tversky_loss.py index bf4b3f8f0a..a1befa062d 100644 --- a/tests/test_tversky_loss.py +++ b/tests/test_tversky_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_unet.py b/tests/test_unet.py index 5d95e66ba4..49b9df343f 100644 --- a/tests/test_unet.py +++ b/tests/test_unet.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.layers import Act, Norm from monai.networks.nets import UNet from tests.utils import test_script_save @@ -72,7 +73,7 @@ (16, 3, 32, 64, 48), ] -TEST_CASE_4 = [ # 4-channel 3D, batch 16, batch normalisation +TEST_CASE_4 = [ # 4-channel 3D, batch 16, batch normalization { "dimensions": 3, "in_channels": 4, @@ -121,8 +122,7 @@ class TestUNET(unittest.TestCase): @parameterized.expand(CASES) def test_shape(self, input_param, input_shape, expected_shape): net = UNet(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_upsample_block.py b/tests/test_upsample_block.py index aa3a1fb90a..f9d5ea4492 100644 --- a/tests/test_upsample_block.py +++ b/tests/test_upsample_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.blocks import UpSample from monai.utils import UpsampleMode @@ -85,8 +86,7 @@ class TestUpsample(unittest.TestCase): @parameterized.expand(TEST_CASES + TEST_CASES_EQ) def test_shape(self, input_param, input_shape, expected_shape): net = UpSample(**input_param) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_varautoencoder.py b/tests/test_varautoencoder.py index b2bb1c22e9..7a4a546d87 100644 --- a/tests/test_varautoencoder.py +++ b/tests/test_varautoencoder.py @@ -1,8 +1,20 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.layers import Act from monai.networks.nets import VarAutoEncoder from tests.utils import test_script_save @@ -70,8 +82,7 @@ class TestVarAutoEncoder(unittest.TestCase): @parameterized.expand(CASES) def test_shape(self, input_param, input_shape, expected_shape): net = VarAutoEncoder(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device))[0] self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_vis_cam.py b/tests/test_vis_cam.py index e2ec119ec8..d400c27f02 100644 --- a/tests/test_vis_cam.py +++ b/tests/test_vis_cam.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_vis_gradcam.py b/tests/test_vis_gradcam.py index 3fb53b1fda..2a7de0e70c 100644 --- a/tests/test_vis_gradcam.py +++ b/tests/test_vis_gradcam.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_vis_gradcampp.py b/tests/test_vis_gradcampp.py index c6bdef1647..fce68ccde0 100644 --- a/tests/test_vis_gradcampp.py +++ b/tests/test_vis_gradcampp.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_vnet.py b/tests/test_vnet.py index 062e171655..c64b566c42 100644 --- a/tests/test_vnet.py +++ b/tests/test_vnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import torch from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.nets import VNet from tests.utils import test_script_save @@ -64,8 +65,7 @@ class TestVNet(unittest.TestCase): ) def test_vnet_shape(self, input_param, input_shape, expected_shape): net = VNet(**input_param).to(device) - net.eval() - with torch.no_grad(): + with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_vote_ensemble.py b/tests/test_vote_ensemble.py index 43c1bb124e..92039fe103 100644 --- a/tests/test_vote_ensemble.py +++ b/tests/test_vote_ensemble.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_vote_ensembled.py b/tests/test_vote_ensembled.py index 6ed1401bab..f4b93c7887 100644 --- a/tests/test_vote_ensembled.py +++ b/tests/test_vote_ensembled.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_warp.py b/tests/test_warp.py new file mode 100644 index 0000000000..69ae997e38 --- /dev/null +++ b/tests/test_warp.py @@ -0,0 +1,69 @@ +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.blocks.warp import Warp + +LOW_POWER_TEST_CASES = [ + [ + {"spatial_dims": 2, "mode": 0, "padding_mode": "zeros"}, + {"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 2, 2)}, + torch.arange(4).reshape((1, 1, 2, 2)), + ], + [ + {"spatial_dims": 2, "mode": 1, "padding_mode": "zeros"}, + {"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.ones(1, 2, 2, 2)}, + torch.tensor([[[[3, 0], [0, 0]]]]), + ], +] + +HIGH_POWER_TEST_CASES = [ + [ + {"spatial_dims": 3, "mode": 2, "padding_mode": "border"}, + { + "image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), + "ddf": torch.ones(1, 3, 2, 2, 2) * -1, + }, + torch.tensor([[[[[0, 0], [0, 0]], [[0, 0], [0, 0]]]]]), + ], + [ + {"spatial_dims": 3, "mode": 3, "padding_mode": "reflection"}, + {"image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), "ddf": torch.ones(1, 3, 2, 2, 2)}, + torch.tensor([[[[[7, 6], [5, 4]], [[3, 2], [1, 0]]]]]), + ], +] + +TEST_CASES = LOW_POWER_TEST_CASES +# if USE_COMPILED: +# TEST_CASES += HIGH_POWER_TEST_CASES + + +class TestWarp(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_resample(self, input_param, input_data, expected_val): + warp_layer = Warp(**input_param) + result = warp_layer(**input_data) + np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) + + def test_ill_shape(self): + warp_layer = Warp(spatial_dims=2) + with self.assertRaisesRegex(ValueError, ""): + warp_layer( + image=torch.arange(4).reshape((1, 1, 1, 2, 2)).to(dtype=torch.float), ddf=torch.zeros(1, 2, 2, 2) + ) + with self.assertRaisesRegex(ValueError, ""): + warp_layer( + image=torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), ddf=torch.zeros(1, 2, 1, 2, 2) + ) + with self.assertRaisesRegex(ValueError, ""): + warp_layer(image=torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), ddf=torch.zeros(1, 2, 3, 3)) + + def test_ill_opts(self): + with self.assertRaisesRegex(ValueError, ""): + Warp(spatial_dims=4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_write_metrics_reports.py b/tests/test_write_metrics_reports.py new file mode 100644 index 0000000000..72625ddd9a --- /dev/null +++ b/tests/test_write_metrics_reports.py @@ -0,0 +1,64 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +import os +import tempfile +import unittest + +import torch + +from monai.handlers.utils import write_metrics_reports + + +class TestWriteMetricsReports(unittest.TestCase): + def test_content(self): + with tempfile.TemporaryDirectory() as tempdir: + write_metrics_reports( + save_dir=tempdir, + images=["filepath1", "filepath2"], + metrics={"metric1": 1, "metric2": 2}, + metric_details={"metric3": torch.tensor([[1, 2], [2, 3]]), "metric4": torch.tensor([[5, 6], [7, 8]])}, + summary_ops=["mean", "median", "max", "90percent"], + deli="\t", + output_type="csv", + ) + + # check the metrics.csv and content + self.assertTrue(os.path.exists(os.path.join(tempdir, "metrics.csv"))) + with open(os.path.join(tempdir, "metrics.csv")) as f: + f_csv = csv.reader(f) + for i, row in enumerate(f_csv): + self.assertEqual(row, [f"metric{i + 1}\t{i + 1}"]) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric3_raw.csv"))) + # check the metric_raw.csv and content + with open(os.path.join(tempdir, "metric3_raw.csv")) as f: + f_csv = csv.reader(f) + for i, row in enumerate(f_csv): + if i > 0: + self.assertEqual(row, [f"filepath{i}\t{float(i)}\t{float(i + 1)}\t{i + 0.5}"]) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric3_summary.csv"))) + # check the metric_summary.csv and content + with open(os.path.join(tempdir, "metric3_summary.csv")) as f: + f_csv = csv.reader(f) + for i, row in enumerate(f_csv): + if i == 1: + self.assertEqual(row, ["class0\t1.5000\t1.5000\t2.0000\t1.1000"]) + elif i == 2: + self.assertEqual(row, ["class1\t2.5000\t2.5000\t3.0000\t2.1000"]) + elif i == 3: + self.assertEqual(row, ["mean\t2.0000\t2.0000\t2.5000\t1.6000"]) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric4_raw.csv"))) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric4_summary.csv"))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_zipdataset.py b/tests/test_zipdataset.py index fba0a217b0..1bdb6458d3 100644 --- a/tests/test_zipdataset.py +++ b/tests/test_zipdataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -28,8 +28,7 @@ def __len__(self): def __getitem__(self, index): if self.index_only: return index - else: - return 1, 2, index + return 1, 2, index TEST_CASE_1 = [[Dataset_(5), Dataset_(5), Dataset_(5)], None, (0, 0, 0), 5] diff --git a/tests/test_zoom.py b/tests/test_zoom.py index edf4fb988e..dcc401f16c 100644 --- a/tests/test_zoom.py +++ b/tests/test_zoom.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -31,7 +31,7 @@ def test_correct_results(self, zoom, mode): _order = 0 if mode.endswith("linear"): _order = 1 - expected = list() + expected = [] for channel in self.imt[0]: expected.append(zoom_scipy(channel, zoom=zoom, mode="nearest", order=_order, prefilter=False)) expected = np.stack(expected).astype(np.float32) diff --git a/tests/test_zoom_affine.py b/tests/test_zoom_affine.py index f5fb3e871d..49c3c0dcac 100644 --- a/tests/test_zoom_affine.py +++ b/tests/test_zoom_affine.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_zoomd.py b/tests/test_zoomd.py index d9685d259f..b17ecd1bf0 100644 --- a/tests/test_zoomd.py +++ b/tests/test_zoomd.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -37,7 +37,7 @@ def test_correct_results(self, zoom, mode, keep_size): _order = 0 if mode.endswith("linear"): _order = 1 - expected = list() + expected = [] for channel in self.imt[0]: expected.append(zoom_scipy(channel, zoom=zoom, mode="nearest", order=_order, prefilter=False)) expected = np.stack(expected).astype(np.float32) diff --git a/tests/testing_data/integration_answers.py b/tests/testing_data/integration_answers.py index 6a52cb4ed3..5490cfe2e3 100644 --- a/tests/testing_data/integration_answers.py +++ b/tests/testing_data/integration_answers.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/utils.py b/tests/utils.py index 50c159053e..ebc9bff99f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -28,6 +28,7 @@ import torch import torch.distributed as dist +from monai.config.deviceconfig import USE_COMPILED from monai.data import create_test_image_2d, create_test_image_3d from monai.utils import ensure_tuple, optional_import, set_determinism from monai.utils.module import get_torch_version_tuple @@ -56,7 +57,7 @@ def skip_if_quick(obj): return unittest.skipIf(is_quick, "Skipping slow tests")(obj) -class SkipIfNoModule(object): +class SkipIfNoModule: """Decorator to be used if test should be skipped when optional module is not present.""" @@ -68,7 +69,7 @@ def __call__(self, obj): return unittest.skipIf(self.module_missing, f"optional module not present: {self.module_name}")(obj) -class SkipIfModule(object): +class SkipIfModule: """Decorator to be used if test should be skipped when optional module is present.""" @@ -80,11 +81,18 @@ def __call__(self, obj): return unittest.skipIf(self.module_avail, f"Skipping because optional module present: {self.module_name}")(obj) +def skip_if_no_cpp_extention(obj): + """ + Skip the unit tests if the cpp extention isnt available + """ + return unittest.skipUnless(USE_COMPILED, "Skipping cpp extention tests")(obj) + + def skip_if_no_cuda(obj): """ Skip the unit tests if torch.cuda.is_available is False """ - return unittest.skipIf(not torch.cuda.is_available(), "Skipping CUDA-based tests")(obj) + return unittest.skipUnless(torch.cuda.is_available(), "Skipping CUDA-based tests")(obj) def skip_if_windows(obj): @@ -94,7 +102,7 @@ def skip_if_windows(obj): return unittest.skipIf(sys.platform == "win32", "Skipping tests on Windows")(obj) -class SkipIfBeforePyTorchVersion(object): +class SkipIfBeforePyTorchVersion: """Decorator to be used if test should be skipped with PyTorch versions older than that given.""" @@ -111,9 +119,9 @@ def __call__(self, obj): )(obj) -class SkipIfAtLeastPyTorchVersion(object): +class SkipIfAtLeastPyTorchVersion: """Decorator to be used if test should be skipped - with PyTorch versions older than that given.""" + with PyTorch versions newer than that given.""" def __init__(self, pytorch_version_tuple): self.max_version = pytorch_version_tuple @@ -399,8 +407,7 @@ def _wrapper(*args, **kwargs): if isinstance(res, Exception): # other errors from obj if hasattr(res, "traceback"): raise RuntimeError(res.traceback) from res - else: - raise res + raise res if timeout_error: # no force_quit finished raise timeout_error return res