diff --git a/.azure/gpu-integrations.yml b/.azure/gpu-integrations.yml index 55b05cb369c..9944e7b5df9 100644 --- a/.azure/gpu-integrations.yml +++ b/.azure/gpu-integrations.yml @@ -17,13 +17,13 @@ jobs: - job: integrate_GPU strategy: matrix: - "torch | 1.x": - docker-image: "pytorchlightning/torchmetrics:ubuntu22.04-cuda11.8.0-py3.9-torch1.13" - torch-ver: "1.13" + "torch | 2.0": + docker-image: "pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime" + torch-ver: "2.0" requires: "oldest" "torch | 2.x": - docker-image: "pytorch/pytorch:2.3.0-cuda12.1-cudnn8-runtime" - torch-ver: "2.3" + docker-image: "pytorch/pytorch:2.5.0-cuda12.1-cudnn9-runtime" + torch-ver: "2.5" # how long to run the job before automatically cancelling timeoutInMinutes: "40" # how much time to give 'run always even if cancelled tasks' before stopping them diff --git a/.azure/gpu-nuke-cache.yml b/.azure/gpu-nuke-cache.yml new file mode 100644 index 00000000000..f8f758ce8c0 --- /dev/null +++ b/.azure/gpu-nuke-cache.yml @@ -0,0 +1,56 @@ +trigger: + tags: + include: + - "*" +# run every month to sanitatize dev environment +schedules: + - cron: "0 0 1 * *" + displayName: Monthly nuke caches + branches: + include: + - master +# run on PR changing only this file +pr: + branches: + include: + - master + paths: + include: + - .azure/gpu-nuke-cache.yml + +jobs: + - job: nuke_caches + # how long to run the job before automatically cancelling + timeoutInMinutes: "10" + # how much time to give 'run always even if cancelled tasks' before stopping them + cancelTimeoutInMinutes: "2" + + pool: "lit-rtx-3090" + + variables: + # these two caches assume to run repetitively on the same set of machines + # see: https://github.com/microsoft/azure-pipelines-agent/issues/4113#issuecomment-1439241481 + TORCH_HOME: "/var/tmp/torch" + TRANSFORMERS_CACHE: "/var/tmp/hf/transformers" + HF_HOME: "/var/tmp/hf/home" + HF_HUB_CACHE: "/var/tmp/hf/hub" + PIP_CACHE_DIR: "/var/tmp/pip" + CACHED_REFERENCES: "/var/tmp/cached-references.zip" + + container: + image: "ubuntu:22.04" + options: "-v /var/tmp:/var/tmp" + + steps: + - bash: | + set -ex + rm -rf $(TORCH_HOME) + rm -rf $(TRANSFORMERS_CACHE) + rm -rf $(HF_HOME) + rm -rf $(HF_HUB_CACHE) + rm -rf $(PIP_CACHE_DIR) + rm -rf $(CACHED_REFERENCES) + displayName: "delete all caches" + - bash: | + ls -lh /var/tmp + displayName: "show tmp/ folder" diff --git a/.azure/gpu-unittests.yml b/.azure/gpu-unittests.yml index a810b3a48fe..42a5e84d913 100644 --- a/.azure/gpu-unittests.yml +++ b/.azure/gpu-unittests.yml @@ -9,6 +9,13 @@ trigger: - master - release/* - refs/tags/* +# run every month to populate caches +schedules: + - cron: "0 1 1 * *" + displayName: Monthly re-build caches + branches: + include: + - master pr: - master - release/* @@ -17,19 +24,13 @@ jobs: - job: unitest_GPU strategy: matrix: - "PyTorch | 1.10 oldest": + "PyTorch | 2.0 oldest": # Torch does not have build wheels with old Torch versions for newer CUDA - docker-image: "ubuntu20.04-cuda11.3.1-py3.9-torch1.10" - torch-ver: "1.10" - "PyTorch | 1.X LTS": - docker-image: "ubuntu22.04-cuda11.8.0-py3.9-torch1.13" - torch-ver: "1.13" + docker-image: "ubuntu22.04-cuda11.8.0-py3.10-torch2.0" + torch-ver: "2.0" "PyTorch | 2.X stable": - docker-image: "ubuntu22.04-cuda12.1.1-py3.11-torch2.3" - torch-ver: "2.3" - "PyTorch | 2.X future": - docker-image: "ubuntu22.04-cuda12.1.1-py3.11-torch2.4" - torch-ver: "2.4" + docker-image: "ubuntu22.04-cuda12.1.1-py3.11-torch2.5" + torch-ver: "2.5" # how long to run the job before automatically cancelling timeoutInMinutes: "180" # how much time to give 'run always even if cancelled tasks' before stopping them @@ -70,6 +71,11 @@ jobs: CUDA_version_mm="${CUDA_version//'.'/''}" echo "##vso[task.setvariable variable=CUDA_VERSION_MM]$CUDA_version_mm" echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/cu${CUDA_version_mm}/torch_stable.html" + mkdir -p $(TORCH_HOME) + mkdir -p $(TRANSFORMERS_CACHE) + mkdir -p $(HF_HOME) + mkdir -p $(HF_HUB_CACHE) + mkdir -p $(PIP_CACHE_DIR) displayName: "set Env. vars" - bash: | echo "##vso[task.setvariable variable=ALLOW_SKIP_IF_OUT_OF_MEMORY]1" @@ -114,7 +120,7 @@ jobs: - bash: | python .github/assistant.py set-oldest-versions - condition: eq(variables['torch-ver'], '1.10.2') + condition: eq(variables['torch-ver'], '2.0') displayName: "Setting oldest versions" - bash: | @@ -135,6 +141,21 @@ jobs: displayName: "Show caches" - bash: | + python -m pytest torchmetrics --cov=torchmetrics \ + --timeout=240 --durations=50 \ + --reruns 2 --reruns-delay 1 + # --numprocesses=5 --dist=loadfile + env: + DOCTEST_DOWNLOAD_TIMEOUT: "180" + SKIP_SLOW_DOCTEST: "1" + workingDirectory: "src/" + timeoutInMinutes: "40" + displayName: "DocTesting" + + - bash: | + df -h . + ls -lh $(CACHED_REFERENCES) + ls -lh tests/ # Check if the file references exists if [ -f $(CACHED_REFERENCES) ]; then # Create a directory if it doesn't already exist @@ -145,25 +166,12 @@ jobs: else echo "The file '$(CACHED_REFERENCES)' does not exist." fi - du -h --max-depth=1 tests/ timeoutInMinutes: "5" # if pull request, copy the cache to the tests folder to be used in the next steps condition: eq(variables['Build.Reason'], 'PullRequest') continueOnError: "true" displayName: "Copy/Unzip cached refs" - - bash: | - python -m pytest torchmetrics --cov=torchmetrics \ - --timeout=240 --durations=50 \ - --reruns 2 --reruns-delay 1 - # --numprocesses=5 --dist=loadfile - env: - DOCTEST_DOWNLOAD_TIMEOUT: "180" - SKIP_SLOW_DOCTEST: "1" - workingDirectory: "src/" - timeoutInMinutes: "40" - displayName: "DocTesting" - - bash: | wget https://pl-public-data.s3.amazonaws.com/metrics/data.zip unzip -o data.zip @@ -172,6 +180,7 @@ jobs: displayName: "Pull testing data from S3" - bash: | + du -h --max-depth=1 . python -m pytest $(TEST_DIRS) \ -m "not DDP" --numprocesses=5 --dist=loadfile \ --cov=torchmetrics --timeout=240 --durations=100 \ @@ -179,7 +188,7 @@ jobs: workingDirectory: "tests/" # skip for PR if there is nothing to test, note that outside PR there is default 'unittests' condition: and(succeeded(), ne(variables['TEST_DIRS'], '')) - timeoutInMinutes: "90" + timeoutInMinutes: "95" displayName: "UnitTesting common" - bash: | @@ -191,13 +200,14 @@ jobs: workingDirectory: "tests/" # skip for PR if there is nothing to test, note that outside PR there is default 'unittests' condition: and(succeeded(), ne(variables['TEST_DIRS'], '')) - timeoutInMinutes: "90" + timeoutInMinutes: "95" displayName: "UnitTesting DDP" - bash: | + du -h --max-depth=1 tests/ # archive potentially updated cache to the machine filesystem to be reused with next jobs zip -q -r $(CACHED_REFERENCES) tests/_cache-references - du -h --max-depth=1 tests/ + ls -lh $(CACHED_REFERENCES) # set as extra step to not pollute general cache when jobs fails or crashes # so do this update only with successful jobs on master condition: and(succeeded(), ne(variables['Build.Reason'], 'PullRequest')) @@ -212,7 +222,6 @@ jobs: python -m coverage xml python -m codecov --token=$(CODECOV_TOKEN) --name="GPU-coverage" \ --commit=$(Build.SourceVersion) --flags=gpu,unittest --env=linux,azure - ls -l workingDirectory: "tests/" # skip for PR if there is nothing to test, note that outside PR there is default 'unittests' condition: and(succeeded(), ne(variables['TEST_DIRS'], '')) diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index f170b93c2c1..7b7211b9d61 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -16,13 +16,13 @@ We are always looking for help implementing new features or fixing bugs. - Add details on how to reproduce the issue - a minimal test case is always best, colab is also great. Note, that the sample code shall be minimal and if needed with publicly available data. -1. Try to fix it or recommend a solution. We highly recommend to use test-driven approach: +2. Try to fix it or recommend a solution. We highly recommend to use test-driven approach: - Convert your minimal code example to a unit/integration test with assert on expected results. - Start by debugging the issue... You can run just this particular test in your IDE and draft a fix. - Verify that your test case fails on the master branch and only passes with the fix applied. -1. Submit a PR! +3. Submit a PR! _**Note**, even if you do not find the solution, sending a PR with a test covering the issue is a valid contribution and we can help you or finish it with you :\]_ @@ -31,14 +31,14 @@ help you or finish it with you :\]_ 1. Submit a github issue - describe what is the motivation of such feature (adding the use case or an example is helpful). -1. Let's discuss to determine the feature scope. +2. Let's discuss to determine the feature scope. -1. Submit a PR! We recommend test driven approach to adding new features as well: +3. Submit a PR! We recommend test driven approach to adding new features as well: - Write a test for the functionality you want to add. - Write the functional code until the test passes. -1. Add/update the relevant tests! +4. Add/update the relevant tests! - [This PR](https://github.com/Lightning-AI/torchmetrics/pull/98) is a good example for adding a new metric @@ -71,7 +71,7 @@ In case you adding new dependencies, make sure that they are compatible with the ### Coding Style 1. Use f-strings for output formation (except logging when we stay with lazy `logging.info("Hello %s!", name)`. -1. You can use `pre-commit` to make sure your code style is correct. +2. You can use `pre-commit` to make sure your code style is correct. ### Documentation diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 150f7849963..928a3b2dfaf 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -19,8 +19,10 @@ Steps to reproduce the behavior...
Code sample - +```python +# Ideally attach a minimal code sample to reproduce the decried issue. +# Minimal means having the shortest code but still preserving the bug. +```
@@ -30,9 +32,9 @@ Minimal means having the shortest code but still preserving the bug. --> ### Environment -- TorchMetrics version (and how you installed TM, e.g. `conda`, `pip`, build from source): -- Python & PyTorch Version (e.g., 1.0): -- Any other relevant information such as OS (e.g., Linux): +- TorchMetrics version (if build from source, add commit SHA): ??? +- Python & PyTorch Version (e.g., 1.0): ??? +- Any other relevant information such as OS (e.g., Linux): ??? ### Additional context diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index d82ec8f0bea..d431c0abd92 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -3,6 +3,6 @@ contact_links: - name: Ask a Question url: https://github.com/Lightning-AI/torchmetrics/discussions/new about: Ask and answer TorchMetrics related questions - - name: 💬 Slack - url: https://app.slack.com/client/TR9DVT48M/CQXV8BRH9/thread/CQXV8BRH9-1591382895.254600 - about: Chat with our community + - name: 💬 Chat with us + url: https://discord.gg/VptPCZkGNa + about: Live chat with experts, engineers, and users in our Discord community. diff --git a/.github/assistant.py b/.github/assistant.py index 95652888022..ad054d96e2f 100644 --- a/.github/assistant.py +++ b/.github/assistant.py @@ -31,8 +31,8 @@ "3.10": "1.11", "3.11": "1.13", } -_path = lambda *ds: os.path.join(_PATH_ROOT, *ds) -REQUIREMENTS_FILES = (*glob.glob(_path("requirements", "*.txt")), _path("requirements.txt")) +_path_root = lambda *ds: os.path.join(_PATH_ROOT, *ds) +REQUIREMENTS_FILES = (*glob.glob(_path_root("requirements", "*.txt")), _path_root("requirements.txt")) class AssistantCLI: @@ -73,21 +73,35 @@ def set_min_torch_by_python(fpath: str = "requirements/base.txt") -> None: fp.write(requires) @staticmethod - def replace_min_requirements(fpath: str) -> None: - """Replace all `>=` by `==` in given file.""" - logging.info(f"processing: {fpath}") + def _replace_requirement(fpath: str, old_str: str = "", new_str: str = "") -> None: + """Replace all strings given file.""" + logging.info(f"processing '{old_str}' -> '{new_str}': {fpath}") with open(fpath, encoding="utf-8") as fp: req = fp.read() - req = req.replace(">=", "==") + req = req.replace(old_str, new_str) with open(fpath, "w", encoding="utf-8") as fp: fp.write(req) + @staticmethod + def replace_str_requirements(old_str: str, new_str: str, req_files: List[str] = REQUIREMENTS_FILES) -> None: + """Replace a particular string in all requirements files.""" + if isinstance(req_files, str): + req_files = [req_files] + for fpath in req_files: + AssistantCLI._replace_requirement(fpath, old_str=old_str, new_str=new_str) + + @staticmethod + def replace_min_requirements(fpath: str) -> None: + """Replace all `>=` by `==` in given file.""" + AssistantCLI._replace_requirement(fpath, old_str=">=", new_str="==") + @staticmethod def set_oldest_versions(req_files: List[str] = REQUIREMENTS_FILES) -> None: """Set the oldest version for requirements.""" AssistantCLI.set_min_torch_by_python() + if isinstance(req_files, str): + req_files = [req_files] for fpath in req_files: - logging.info(f"processing req: `{fpath}`") AssistantCLI.replace_min_requirements(fpath) @staticmethod diff --git a/.github/workflows/ci-checks.yml b/.github/workflows/ci-checks.yml index 01627e1be4b..3c7c3eb69f1 100644 --- a/.github/workflows/ci-checks.yml +++ b/.github/workflows/ci-checks.yml @@ -13,19 +13,19 @@ concurrency: jobs: check-code: - uses: Lightning-AI/utilities/.github/workflows/check-typing.yml@v0.11.3.post0 + uses: Lightning-AI/utilities/.github/workflows/check-typing.yml@v0.11.7 with: - actions-ref: v0.11.3.post0 + actions-ref: v0.11.7 extra-typing: "typing" check-schema: - uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.11.3.post0 + uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.11.7 check-package: if: github.event.pull_request.draft == false - uses: Lightning-AI/utilities/.github/workflows/check-package.yml@v0.11.3.post0 + uses: Lightning-AI/utilities/.github/workflows/check-package.yml@v0.11.7 with: - actions-ref: v0.11.3.post0 + actions-ref: v0.11.7 artifact-name: dist-packages-${{ github.sha }} import-name: "torchmetrics" testing-matrix: | @@ -35,7 +35,7 @@ jobs: } check-md-links: - uses: Lightning-AI/utilities/.github/workflows/check-md-links.yml@v0.11.3.post0 + uses: Lightning-AI/utilities/.github/workflows/check-md-links.yml@v0.11.7 with: base-branch: master config-file: ".github/markdown-links-config.json" diff --git a/.github/workflows/ci-integrate.yml b/.github/workflows/ci-integrate.yml index a01bd076cb2..9732360f795 100644 --- a/.github/workflows/ci-integrate.yml +++ b/.github/workflows/ci-integrate.yml @@ -53,6 +53,8 @@ jobs: - name: source cashing uses: ./.github/actions/pull-caches + with: + requires: ${{ matrix.requires }} - name: set oldest if/only for integrations if: matrix.requires == 'oldest' run: python .github/assistant.py set-oldest-versions --req_files='["requirements/_integrate.txt"]' diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml index d53b101acc9..7d44adf3aab 100644 --- a/.github/workflows/ci-tests.yml +++ b/.github/workflows/ci-tests.yml @@ -31,38 +31,32 @@ jobs: strategy: fail-fast: false matrix: - os: ["ubuntu-20.04"] - python-version: ["3.9"] + os: ["ubuntu-22.04"] + python-version: ["3.10"] pytorch-version: - - "1.10.2" - - "1.11.0" - - "1.12.1" - - "1.13.1" - "2.0.1" - "2.1.2" - "2.2.2" - - "2.3.0" + - "2.3.1" + - "2.4.1" + - "2.5.0" include: - # cover additional python nad PR combinations - - { os: "ubuntu-22.04", python-version: "3.8", pytorch-version: "1.13.1" } - - { os: "ubuntu-22.04", python-version: "3.10", pytorch-version: "2.0.1" } - - { os: "ubuntu-22.04", python-version: "3.10", pytorch-version: "2.2.2" } - - { os: "ubuntu-22.04", python-version: "3.11", pytorch-version: "2.3.0" } + # cover additional python and PT combinations + - { os: "ubuntu-20.04", python-version: "3.8", pytorch-version: "2.0.1", requires: "oldest" } + - { os: "ubuntu-22.04", python-version: "3.11", pytorch-version: "2.4.1" } + - { os: "ubuntu-22.04", python-version: "3.12", pytorch-version: "2.5.0" } # standard mac machine, not the M1 - - { os: "macOS-13", python-version: "3.8", pytorch-version: "1.13.1" } - { os: "macOS-13", python-version: "3.10", pytorch-version: "2.0.1" } - - { os: "macOS-13", python-version: "3.11", pytorch-version: "2.2.2" } # using the ARM based M1 machine - { os: "macOS-14", python-version: "3.10", pytorch-version: "2.0.1" } - - { os: "macOS-14", python-version: "3.11", pytorch-version: "2.3.0" } + - { os: "macOS-14", python-version: "3.12", pytorch-version: "2.5.0" } # some windows - - { os: "windows-2022", python-version: "3.8", pytorch-version: "1.13.1" } - { os: "windows-2022", python-version: "3.10", pytorch-version: "2.0.1" } - - { os: "windows-2022", python-version: "3.11", pytorch-version: "2.3.0" } + - { os: "windows-2022", python-version: "3.12", pytorch-version: "2.5.0" } # Future released version - - { os: "ubuntu-22.04", python-version: "3.11", pytorch-version: "2.4.0" } - - { os: "macOS-14", python-version: "3.11", pytorch-version: "2.4.0" } - - { os: "windows-2022", python-version: "3.11", pytorch-version: "2.4.0" } + #- { os: "ubuntu-22.04", python-version: "3.11", pytorch-version: "2.5.0" } + #- { os: "macOS-14", python-version: "3.11", pytorch-version: "2.5.0" } + #- { os: "windows-2022", python-version: "3.11", pytorch-version: "2.5.0" } env: FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }} PYPI_CACHE_DIR: "_ci-cache_PyPI" @@ -105,9 +99,9 @@ jobs: pytorch-version: ${{ matrix.pytorch-version }} pypi-dir: ${{ env.PYPI_CACHE_DIR }} - - name: Switch to PT test URL - if: ${{ matrix.pytorch-version == '2.4.0' }} - run: echo 'PIP_EXTRA_INDEX_URL=--extra-index-url https://download.pytorch.org/whl/test/cpu/' >> $GITHUB_ENV + #- name: Switch to PT test URL + # if: ${{ matrix.pytorch-version == '2.X.0' }} + # run: echo 'PIP_EXTRA_INDEX_URL=--extra-index-url https://download.pytorch.org/whl/test/cpu/' >> $GITHUB_ENV - name: Install pkg timeout-minutes: 25 run: | diff --git a/.github/workflows/clear-cache.yml b/.github/workflows/clear-cache.yml index eca3609a901..ecd7c6e3ff3 100644 --- a/.github/workflows/clear-cache.yml +++ b/.github/workflows/clear-cache.yml @@ -1,23 +1,40 @@ -name: Clear cache weekly +name: Clear cache ... on: + pull_request: + paths: + - ".github/workflows/clear-cache.yml" workflow_dispatch: inputs: pattern: description: "pattern for cleaning cache" - default: "pip|conda" + default: "pip-|conda" required: false type: string + age-days: + description: "setting the age of caches in days to be dropped" + required: true + type: number + default: 5 + schedule: + # on Sundays + - cron: "0 0 * * 0" jobs: cron-clear: - if: github.event_name == 'schedule' - uses: Lightning-AI/utilities/.github/workflows/clear-cache.yml@v0.11.3.post0 + if: github.event_name == 'schedule' || github.event_name == 'pull_request' + uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.11.7 with: + scripts-ref: v0.11.7 + dry-run: ${{ github.event_name == 'pull_request' }} pattern: "pip-latest" + age-days: 7 direct-clear: - if: github.event_name == 'workflow_dispatch' - uses: Lightning-AI/utilities/.github/workflows/clear-cache.yml@v0.11.3.post0 + if: github.event_name == 'workflow_dispatch' || github.event_name == 'pull_request' + uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.11.7 with: - pattern: ${{ inputs.pattern }} + scripts-ref: v0.11.7 + dry-run: ${{ github.event_name == 'pull_request' }} + pattern: ${{ inputs.pattern || 'pypi_wheels' }} # setting str in case of PR / debugging + age-days: ${{ fromJSON(inputs.age-days) || 0 }} # setting 0 in case of PR / debugging diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index b3af59d6483..946f64cbc0f 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -66,14 +66,14 @@ jobs: include: # These are the base images for PL release docker images, # so include at least all the combinations in release-dockers.yml. - - { python: "3.9", pytorch: "1.10", cuda: "11.3.1", ubuntu: "20.04" } - #- { python: "3.9", pytorch: "1.11", cuda: "11.8.0", ubuntu: "22.04" } - - { python: "3.9", pytorch: "1.13", cuda: "11.8.0", ubuntu: "22.04" } - - { python: "3.10", pytorch: "2.2", cuda: "12.1.1", ubuntu: "22.04" } - - { python: "3.11", pytorch: "2.2", cuda: "12.1.1", ubuntu: "22.04" } - - { python: "3.11", pytorch: "2.3", cuda: "12.1.1", ubuntu: "22.04" } + - { python: "3.10", pytorch: "2.0.1", cuda: "12.1.1", ubuntu: "22.04" } + - { python: "3.11", pytorch: "2.1.2", cuda: "12.1.1", ubuntu: "22.04" } + - { python: "3.11", pytorch: "2.2.2", cuda: "12.1.1", ubuntu: "22.04" } + - { python: "3.11", pytorch: "2.3.1", cuda: "12.1.1", ubuntu: "22.04" } + - { python: "3.11", pytorch: "2.4.1", cuda: "12.1.1", ubuntu: "22.04" } + - { python: "3.11", pytorch: "2.5.0", cuda: "12.1.1", ubuntu: "22.04" } # the future version - test or RC version - - { python: "3.11", pytorch: "2.4", cuda: "12.1.1", ubuntu: "22.04" } + #- { python: "3.11", pytorch: "2.6", cuda: "12.1.1", ubuntu: "22.04" } steps: - uses: actions/checkout@v4 @@ -84,6 +84,12 @@ jobs: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} + - name: shorten Torch version + run: | + # convert 1.10.2 to 1.10 + pt_version=$(echo ${{ matrix.pytorch }} | cut -d. -f1,2) + echo "PT_VERSION=$pt_version" >> $GITHUB_ENV + - name: Build (and Push) runner uses: docker/build-push-action@v6 with: @@ -94,5 +100,5 @@ jobs: CUDA_VERSION=${{ matrix.cuda }} file: dockers/ubuntu-cuda/Dockerfile push: ${{ env.PUSH_DOCKERHUB }} - tags: "pytorchlightning/torchmetrics:ubuntu${{ matrix.ubuntu }}-cuda${{ matrix.cuda }}-py${{ matrix.python }}-torch${{ matrix.pytorch }}" + tags: "pytorchlightning/torchmetrics:ubuntu${{ matrix.ubuntu }}-cuda${{ matrix.cuda }}-py${{ matrix.python }}-torch${{ env.PT_VERSION }}" timeout-minutes: 55 diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index 5aeef255bfe..dce0f0192a2 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -44,22 +44,14 @@ jobs: - name: source cashing uses: ./.github/actions/pull-caches with: - requires: ${{ matrix.requires }} pytorch-version: ${{ matrix.pytorch-version }} pypi-dir: ${{ env.PYPI_CACHE }} - - name: Install Latex - if: ${{ matrix.target == 'html' }} - # install Texlive, see https://linuxconfig.org/how-to-install-latex-on-ubuntu-20-04-focal-fossa-linux - run: | - sudo apt-get update --fix-missing - sudo apt-get install -y \ - texlive-latex-extra texlive-pictures texlive-fonts-recommended dvipng cm-super - - name: Install package & dependencies run: | make get-sphinx-template - pip install . -U -r requirements/_docs.txt \ + # install with -e so the path to source link comes from this project not from the installed package + pip install -e . -U -r requirements/_docs.txt \ --find-links="${PYPI_CACHE}" --find-links="${TORCH_URL}" - run: pip list - name: Full build for deployment @@ -70,7 +62,10 @@ jobs: run: echo "SPHINX_ENABLE_GALLERY=0" >> $GITHUB_ENV - name: make ${{ matrix.target }} working-directory: ./docs - run: make ${{ matrix.target }} --debug --jobs $(nproc) SPHINXOPTS="-W --keep-going" + run: | + pwd + ls -la + make ${{ matrix.target }} --debug --jobs $(nproc) SPHINXOPTS="-W --keep-going" - name: Upload built docs if: ${{ matrix.target == 'html' && github.event_name != 'pull_request' }} @@ -81,7 +76,7 @@ jobs: retention-days: ${{ steps.keep-artifact.outputs.DAYS }} - name: update cashing - if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' && matrix.target == 'html' }} continue-on-error: true uses: ./.github/actions/push-caches with: diff --git a/.github/workflows/publish-pkg.yml b/.github/workflows/publish-pkg.yml index b46277c7e26..7affbea96c2 100644 --- a/.github/workflows/publish-pkg.yml +++ b/.github/workflows/publish-pkg.yml @@ -67,7 +67,7 @@ jobs: - run: ls -lh dist/ # We do this, since failures on test.pypi aren't that bad - name: Publish to Test PyPI - uses: pypa/gh-action-pypi-publish@v1.9.0 + uses: pypa/gh-action-pypi-publish@v1.10.2 with: user: __token__ password: ${{ secrets.test_pypi_password }} @@ -94,7 +94,7 @@ jobs: path: dist - run: ls -lh dist/ - name: Publish distribution 📦 to PyPI - uses: pypa/gh-action-pypi-publish@v1.9.0 + uses: pypa/gh-action-pypi-publish@v1.10.2 with: user: __token__ password: ${{ secrets.pypi_password }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f8f7097af6c..b36b25e59f4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -54,7 +54,7 @@ repos: exclude: pyproject.toml - repo: https://github.com/PyCQA/docformatter - rev: v1.7.5 + rev: 06907d0267368b49b9180eed423fae5697c1e909 # todo: fix for docformatter after last 1.7.5 hooks: - id: docformatter additional_dependencies: [tomli] @@ -69,6 +69,7 @@ repos: rev: 0.7.17 hooks: - id: mdformat + args: ["--number"] additional_dependencies: - mdformat-gfm - mdformat-black diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b70066b5bc..f31b3968ffb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,52 +8,117 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 --- -## [UnReleased] - 2022-MM-DD +## [UnReleased] - 2024-MM-DD ### Added -- Added a new audio metric `DNSMOS` ([#2525](https://github.com/PyTorchLightning/metrics/pull/2525)) +- Added `NormalizedRootMeanSquaredError` metric to regression subpackage ([#2442](https://github.com/Lightning-AI/torchmetrics/pull/2442)) -- Added `MetricInputTransformer` wrapper ([#2392](https://github.com/Lightning-AI/torchmetrics/pull/2392)) +- Added `LogAUC` metric to classification package ([#2377](https://github.com/Lightning-AI/torchmetrics/pull/2377)) +### Changed -- Added `input_format` argument to segmentation metrics ([#2572](https://github.com/Lightning-AI/torchmetrics/pull/2572)) +- Changed naming and input order arguments in `KLDivergence` ([#2800](https://github.com/Lightning-AI/torchmetrics/pull/2800)) -- Added better error messages for intersection detection metrics for wrong user input ([#2577](https://github.com/Lightning-AI/torchmetrics/pull/2577)) +### Removed +- Changed minimum supported Pytorch version to 2.0 ([#2671](https://github.com/Lightning-AI/torchmetrics/pull/2671)) -- Added multi-output support for MAE metric ([#2605](https://github.com/Lightning-AI/torchmetrics/pull/2605)) +- Removed `num_outputs` in `R2Score` ([#2800](https://github.com/Lightning-AI/torchmetrics/pull/2800)) -- Added `LogAUC` metric to classification package ([#2377](https://github.com/Lightning-AI/torchmetrics/pull/2377)) + +### Fixed + +- + + +--- + +## [1.5.1] - 2024-10-22 + +### Fixed + +- Changing `_modules` dict type in Pytorch 2.5 preventing to fail collections metrics ([#2793](https://github.com/Lightning-AI/torchmetrics/pull/2793)) + + +## [1.5.0] - 2024-10-18 + +### Added + +- Added segmentation metric `HausdorffDistance` ([#2122](https://github.com/Lightning-AI/torchmetrics/pull/2122)) +- Added audio metric `DNSMOS` ([#2525](https://github.com/PyTorchLightning/metrics/pull/2525)) +- Added shape metric `ProcrustesDistance` ([#2723](https://github.com/Lightning-AI/torchmetrics/pull/2723) +- Added `MetricInputTransformer` wrapper ([#2392](https://github.com/Lightning-AI/torchmetrics/pull/2392)) +- Added `input_format` argument to segmentation metrics ([#2572](https://github.com/Lightning-AI/torchmetrics/pull/2572)) +- Added `multi-output` support for MAE metric ([#2605](https://github.com/Lightning-AI/torchmetrics/pull/2605)) +- Added `truncation` argument to `BERTScore` ([#2776](https://github.com/Lightning-AI/torchmetrics/pull/2776)) ### Changed -- Calculate text color of ConfusionMatrix plot based on luminance +- Tracker higher is better integration ([#2649](https://github.com/Lightning-AI/torchmetrics/pull/2649)) +- Updated `InfoLM` class to dynamically set `higher_is_better` ([#2674](https://github.com/Lightning-AI/torchmetrics/pull/2674)) + +### Deprecated +- Deprecated `num_outputs` in `R2Score` ([#2705](https://github.com/Lightning-AI/torchmetrics/pull/2705)) -### Removed +### Fixed -- +- Fixed corner case in `IoU` metric for single empty prediction tensors ([#2780](https://github.com/Lightning-AI/torchmetrics/pull/2780)) +- Fixed `PSNR` calculation for integer type input images ([#2788](https://github.com/Lightning-AI/torchmetrics/pull/2788)) +--- + +## [1.4.3] - 2024-10-10 ### Fixed -- Fixed bug in `MetricCollection` when using compute groups and `compute` is called more than once ([#2571](https://github.com/Lightning-AI/torchmetrics/pull/2571)) +- Fixed for Pearson changes inputs ([#2765](https://github.com/Lightning-AI/torchmetrics/pull/2765)) +- Fixed bug in `PESQ` metric where `NoUtterancesError` prevented calculating on a batch of data ([#2753](https://github.com/Lightning-AI/torchmetrics/pull/2753)) +- Fixed corner case in `MatthewsCorrCoef` ([#2743](https://github.com/Lightning-AI/torchmetrics/pull/2743)) -- Fixed class order of `panoptic_quality(..., return_per_class=True)` output ([#2548](https://github.com/Lightning-AI/torchmetrics/pull/2548)) +## [1.4.2] - 2022-09-12 +### Added -- Fixed `BootstrapWrapper` not being reset correctly ([#2574](https://github.com/Lightning-AI/torchmetrics/pull/2574)) +- Re-adding `Chrf` implementation ([#2701](https://github.com/Lightning-AI/torchmetrics/pull/2701)) +### Fixed + +- Fixed wrong aggregation in `segmentation.MeanIoU` ([#2698](https://github.com/Lightning-AI/torchmetrics/pull/2698)) +- Fixed handling zero division error in binary IoU (Jaccard index) calculation ([#2726](https://github.com/Lightning-AI/torchmetrics/pull/2726)) +- Corrected the padding related calculation errors in SSIM ([#2721](https://github.com/Lightning-AI/torchmetrics/pull/2721)) +- Fixed compatibility of audio domain with new `scipy` ([#2733](https://github.com/Lightning-AI/torchmetrics/pull/2733)) +- Fixed how `prefix`/`postfix` works in `MultitaskWrapper` ([#2722](https://github.com/Lightning-AI/torchmetrics/pull/2722)) +- Fixed flakiness in tests related to `torch.unique` with `dim=None` ([#2650](https://github.com/Lightning-AI/torchmetrics/pull/2650)) -- Fixed integration between `ClasswiseWrapper` and `MetricCollection` with custom `_filter_kwargs` method ([#2575](https://github.com/Lightning-AI/torchmetrics/pull/2575)) +## [1.4.1] - 2024-08-02 +### Changed + +- Calculate text color of `ConfusionMatrix` plot based on luminance ([#2590](https://github.com/Lightning-AI/torchmetrics/pull/2590)) +- Updated `_safe_divide` to allow `Accuracy` to run on the GPU ([#2640](https://github.com/Lightning-AI/torchmetrics/pull/2640)) +- Improved error messages for intersection detection metrics for wrong user input ([#2577](https://github.com/Lightning-AI/torchmetrics/pull/2577)) + +### Removed + +- Dropped `Chrf` implementation due to licensing issues with the upstream package ([#2668](https://github.com/Lightning-AI/torchmetrics/pull/2668)) + +### Fixed + +- Fixed bug in `MetricCollection` when using compute groups and `compute` is called more than once ([#2571](https://github.com/Lightning-AI/torchmetrics/pull/2571)) +- Fixed class order of `panoptic_quality(..., return_per_class=True)` output ([#2548](https://github.com/Lightning-AI/torchmetrics/pull/2548)) +- Fixed `BootstrapWrapper` not being reset correctly ([#2574](https://github.com/Lightning-AI/torchmetrics/pull/2574)) +- Fixed integration between `ClasswiseWrapper` and `MetricCollection` with custom `_filter_kwargs` method ([#2575](https://github.com/Lightning-AI/torchmetrics/pull/2575)) - Fixed BertScore calculation: pred target misalignment ([#2347](https://github.com/Lightning-AI/torchmetrics/pull/2347)) +- Fixed `_cumsum` helper function in multi-gpu ([#2636](https://github.com/Lightning-AI/torchmetrics/pull/2636)) +- Fixed bug in `MeanAveragePrecision.coco_to_tm` ([#2588](https://github.com/Lightning-AI/torchmetrics/pull/2588)) +- Fixed missed f-strings in exceptions/warnings ([#2667](https://github.com/Lightning-AI/torchmetrics/pull/2667)) ## [1.4.0] - 2024-05-03 diff --git a/Makefile b/Makefile index 92b255c2cc3..7ab37aab29d 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,6 @@ -.PHONY: clean test get-sphinx-template docs env data +.PHONY: clean test get-sphinx-template docs live-docs env data +export TOKENIZERS_PARALLELISM=false export FREEZE_REQUIREMENTS=1 # assume you have installed need packages export SPHINX_MOCK_REQUIREMENTS=1 @@ -39,6 +40,10 @@ docs: clean get-sphinx-template # apt-get install -y texlive-latex-extra dvipng texlive-pictures texlive-fonts-recommended cm-super cd docs && make html --debug --jobs $(nproc) SPHINXOPTS="-W --keep-going" +live-docs: get-sphinx-template + pip install -e . --quiet -r requirements/_docs.txt + cd docs && make livehtml --jobs $(nproc) + env: pip install -e . -U -r requirements/_devel.txt diff --git a/dockers/ubuntu-cuda/Dockerfile b/dockers/ubuntu-cuda/Dockerfile index cb9aa1e723e..a93277a2926 100644 --- a/dockers/ubuntu-cuda/Dockerfile +++ b/dockers/ubuntu-cuda/Dockerfile @@ -73,6 +73,7 @@ RUN \ ENV PYTHONPATH="/usr/lib/python${PYTHON_VERSION}/site-packages" COPY requirements/ requirements/ +COPY .github/assistant.py ./assistant.py RUN \ # set particular PyTorch version @@ -85,6 +86,11 @@ RUN \ pip install -q "numpy<1.24" && \ CUDA_VERSION_MM=${CUDA_VERSION%.*} && \ CU_VERSION_MM=${CUDA_VERSION_MM//'.'/''} && \ + # requirements for assistant + pip install -q -U packaging fire && \ + # switch some packages to be GPU related + python assistant.py replace_str_requirements "onnxruntime" "onnxruntime_gpu" --req_files requirements/audio.txt && \ + # install develop environment pip install --no-cache-dir -r requirements/_devel.txt \ --find-links="https://download.pytorch.org/whl/cu${CU_VERSION_MM}/torch_stable.html" \ --extra-index-url="https://download.pytorch.org/whl/test/cu${CU_VERSION_MM}" && \ diff --git a/.readthedocs.yml b/docs/.readthedocs.yaml similarity index 98% rename from .readthedocs.yml rename to docs/.readthedocs.yaml index e797ae385fe..ba81077847d 100644 --- a/.readthedocs.yml +++ b/docs/.readthedocs.yaml @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# .readthedocs.yml +# .readthedocs.yaml # Read the Docs configuration file # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details diff --git a/docs/Makefile b/docs/Makefile index 35e5650808f..bae43786043 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -17,3 +17,6 @@ help: # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +livehtml: + sphinx-autobuild "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/source/_static/runllm.js b/docs/source/_static/runllm.js new file mode 100644 index 00000000000..8054f5a5051 --- /dev/null +++ b/docs/source/_static/runllm.js @@ -0,0 +1,15 @@ +document.addEventListener("DOMContentLoaded", function () { + var script = document.createElement("script"); + script.type = "module"; + script.id = "runllm-widget-script" + + script.src = "https://widget.runllm.com"; + + script.setAttribute("runllm-keyboard-shortcut", "Mod+j"); // cmd-j or ctrl-j to open the widget. + script.setAttribute("runllm-name", "TorchMetrics"); + script.setAttribute("runllm-position", "BOTTOM_RIGHT"); + script.setAttribute("runllm-assistant-id", "244"); + + script.async = true; + document.head.appendChild(script); +}); diff --git a/docs/source/conf.py b/docs/source/conf.py index 9484761ba7a..81f842e7a12 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -220,6 +220,7 @@ def _set_root_image_path(page_path: str) -> None: # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] html_css_files = ["css/custom.css"] +html_js_files = ["runllm.js"] # -- Options for HTMLHelp output --------------------------------------------- @@ -269,6 +270,11 @@ def _set_root_image_path(page_path: str) -> None: ), ] +# MathJax configuration +mathjax3_config = { + "tex": {"packages": {"[+]": ["ams", "newcommand", "configMacros"]}}, +} + # -- Options for Epub output ------------------------------------------------- # Bibliographic Dublin Core info. @@ -358,8 +364,7 @@ def package_list_from_file(file: str) -> list[str]: autodoc_mock_imports = MOCK_PACKAGES -# Resolve function -# This function is used to populate the (source) links in the API +# Resolve function - this function is used to populate the (source) links in the API def linkcode_resolve(domain, info) -> Optional[str]: # noqa: ANN001 return _linkcode_resolve(domain, info=info, github_user="Lightning-AI", github_repo="torchmetrics") @@ -442,4 +447,10 @@ def linkcode_resolve(domain, info) -> Optional[str]: # noqa: ANN001 "https://aclanthology.org/W17-4770", # A wavelet transform method to merge Landsat TM and SPOT panchromatic data "https://www.ingentaconnect.com/content/tandf/tres/1998/00000019/00000004/art00013", + # Improved normalization of time-lapse seismic data using normalized root mean square repeatability data ... + # ... to improve automatic production and seismic history matching in the Nelson field + "https://onlinelibrary.wiley.com/doi/abs/10.1111/1365-2478.12109", + # todo: these links seems to be unstable, referring to .devcontainer + "https://code.visualstudio.com", + "https://code.visualstudio.com/.*", ] diff --git a/docs/source/index.rst b/docs/source/index.rst index 58dac2c6fb1..46670de00e4 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -231,6 +231,14 @@ Or directly from conda segmentation/* +.. toctree:: + :maxdepth: 2 + :name: shape + :caption: Shape + :glob: + + shape/* + .. toctree:: :maxdepth: 2 :name: text diff --git a/docs/source/links.rst b/docs/source/links.rst index e31b3362693..b7a4f63565e 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -91,7 +91,7 @@ .. _CER: https://rechtsprechung-im-ostseeraum.archiv.uni-greifswald.de/word-error-rate-character-error-rate-how-to-evaluate-a-model .. _MER: https://www.isca-speech.org/archive/interspeech_2004/morris04_interspeech.html .. _WIL: https://www.isca-speech.org/archive/interspeech_2004/morris04_interspeech.html -.. _WIP: https://infoscience.epfl.ch/record/82766 +.. _WIP: https://www.isca-archive.org/interspeech_2004/morris04_interspeech.pdf .. _TV: https://en.wikipedia.org/wiki/Total_variation_denoising .. _InfoLM: https://arxiv.org/abs/2112.01589 .. _alpha divergence: https://static.renyi.hu/renyi_cikkek/1961_on_measures_of_entropy_and_information.pdf @@ -171,4 +171,8 @@ .. _FLORES-200: https://arxiv.org/abs/2207.04672 .. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html .. _SCC: https://www.ingentaconnect.com/content/tandf/tres/1998/00000019/00000004/art00013 +.. _Normalized Root Mean Squared Error: https://onlinelibrary.wiley.com/doi/abs/10.1111/1365-2478.12109 .. _Generalized Dice Score: https://arxiv.org/abs/1707.03237 +.. _Hausdorff Distance: https://en.wikipedia.org/wiki/Hausdorff_distance +.. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html +.. _Procrustes Disparity: https://en.wikipedia.org/wiki/Procrustes_analysis diff --git a/docs/source/multimodal/clip_iqa.rst b/docs/source/multimodal/clip_iqa.rst index 074f35a50bf..59734b5a05b 100644 --- a/docs/source/multimodal/clip_iqa.rst +++ b/docs/source/multimodal/clip_iqa.rst @@ -1,7 +1,7 @@ .. customcarditem:: :header: CLIP IQA :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg - :tags: Image + :tags: Multimodal .. include:: ../links.rst diff --git a/docs/source/multimodal/clip_score.rst b/docs/source/multimodal/clip_score.rst index 60166403fd0..12fe0fa2815 100644 --- a/docs/source/multimodal/clip_score.rst +++ b/docs/source/multimodal/clip_score.rst @@ -1,7 +1,7 @@ .. customcarditem:: :header: CLIP Score :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg - :tags: Image + :tags: Multimodal .. include:: ../links.rst diff --git a/docs/source/pages/implement.rst b/docs/source/pages/implement.rst index f7da4aa8ba7..1620ce29cd9 100644 --- a/docs/source/pages/implement.rst +++ b/docs/source/pages/implement.rst @@ -9,8 +9,8 @@ Implementing a Metric ##################### -While we strive to include as many metrics as possible in ``torchmetrics``, we cannot include them all. Therefore, we -have made it easy to implement your own metric and possible contribute it to ``torchmetrics``. This page will guide +While we strive to include as many metrics as possible in ``torchmetrics``, we cannot include them all. We have made it +easy to implement your own metric, and you can contribute it to ``torchmetrics`` if you wish. This page will guide you through the process. If you afterwards are interested in contributing your metric to ``torchmetrics``, please read the `contribution guidelines `_ and see this :ref:`section `. @@ -63,7 +63,7 @@ A few important things to note: * The ``dist_reduce_fx`` argument to ``add_state`` is used to specify how the metric states should be reduced between batches in distributed settings. In this case we use ``"sum"`` to sum the metric states across batches. A couple of - build in options are available: ``"sum"``, ``"mean"``, ``"cat"``, ``"min"`` or ``"max"``, but a custom reduction is + built-in options are available: ``"sum"``, ``"mean"``, ``"cat"``, ``"min"`` or ``"max"``, but a custom reduction is also supported. * In ``update`` we do not return anything but instead update the metric states in-place. @@ -101,7 +101,7 @@ because we need to calculate the rank of the predictions and targets. preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) # some intermediate computation... - r_preds, r_target = _rank_data(preds), _rank_dat(target) + r_preds, r_target = _rank_data(preds), _rank_data(target) preds_diff = r_preds - r_preds.mean(0) target_diff = r_target - r_target.mean(0) cov = (preds_diff * target_diff).mean(0) @@ -118,10 +118,10 @@ A few important things to note for this example: * When working with list states, The ``update(...)`` method should append the batch states to the list. -* In the the ``compute`` method the list states behave a bit differently dependeding on weather you are running in +* In the the ``compute`` method the list states behave a bit differently dependeding on whether you are running in distributed mode or not. In non-distributed mode the list states will be a list of tensors, while in distributed mode the list have already been concatenated into a single tensor. For this reason, we recommend always using the - ``dim_zero_cat`` helper function which will standardize the list states to be a single concatenate tensor regardless + ``dim_zero_cat`` helper function which will standardize the list states to be a single concatenated tensor regardless of the mode. * Calling the ``reset`` method will clear the list state, deleting any values inserted into it. For this reason, care @@ -179,7 +179,7 @@ used, that provides the common plotting functionality for most metrics in torchm return self._plot(val, ax) If the metric returns a more complex output, a custom implementation of the `plot` method is required. For more details -on the plotting API, see the this :ref:`page ` . In addti +on the plotting API, see the this :ref:`page ` . ******************************* Internal implementation details @@ -247,7 +247,7 @@ as long as they serve a general purpose. However, to keep all our metrics consis and tests gets formatted in the following way: 1. Start by reading our `contribution guidelines `_. -2. First implement the functional backend. This takes cares of all the logic that goes into the metric. The code should +2. First implement the functional backend. This takes care of all the logic that goes into the metric. The code should be put into a single file placed under ``src/torchmetrics/functional/"domain"/"new_metric".py`` where ``domain`` is the type of metric (classification, regression, text etc.) and ``new_metric`` is the name of the metric. In this file, there should be the following three functions: @@ -259,7 +259,7 @@ and tests gets formatted in the following way: .. note:: The `functional mean squared error `_ - metric is a great example of this division of logic. + metric is a is a great example of how to divide the logic. 3. In a corresponding file placed in ``src/torchmetrics/"domain"/"new_metric".py`` create the module interface: @@ -283,12 +283,12 @@ and tests gets formatted in the following way: both the functional and module interface. 2. In that file, start by defining a number of test inputs that your metric should be evaluated on. 3. Create a testclass ``class NewMetric(MetricTester)`` that inherits from ``tests.helpers.testers.MetricTester``. - This testclass should essentially implement the ``test_"new_metric"_class`` and ``test_"new_metric"_fn`` methods that + This test class should essentially implement the ``test_"new_metric"_class`` and ``test_"new_metric"_fn`` methods that respectively tests the module interface and the functional interface. 4. The testclass should be parameterized (using ``@pytest.mark.parametrize``) by the different test inputs defined initially. Additionally, the ``test_"new_metric"_class`` method should also be parameterized with an ``ddp`` parameter such that it gets tested in a distributed setting. If your metric has additional parameters, then make sure to also parameterize these - such that different combinations of inputs and parameters gets tested. + so that different combinations of inputs and parameters get tested. 5. (optional) If your metric raises any exception, please add tests that showcase this. .. note:: diff --git a/docs/source/pages/lightning.rst b/docs/source/pages/lightning.rst index 7f1bc70363b..a96196396b8 100644 --- a/docs/source/pages/lightning.rst +++ b/docs/source/pages/lightning.rst @@ -3,6 +3,7 @@ import torch from torch.nn import Module from lightning.pytorch import LightningModule + from lightning.pytorch.utilities import rank_zero_only from torchmetrics import Metric ################################# @@ -14,8 +15,8 @@ framework designed for scaling models without boilerplate. .. note:: - TorchMetrics always offers compatibility with the last 2 major PyTorch Lightning versions, but we recommend to always keep both frameworks - up-to-date for the best experience. + TorchMetrics always offers compatibility with the last 2 major PyTorch Lightning versions, but we recommend always + keeping both frameworks up-to-date for the best experience. While TorchMetrics was built to be used with native PyTorch, using TorchMetrics with Lightning offers additional benefits: @@ -73,7 +74,7 @@ method, Lightning will log the metric based on ``on_step`` and ``on_epoch`` flag ``sync_dist``, ``sync_dist_group`` and ``reduce_fx`` flags from ``self.log(...)`` don't affect the metric logging in any manner. The metric class contains its own distributed synchronization logic. - This however is only true for metrics that inherit the base class ``Metric``, + This, however is only true for metrics that inherit the base class ``Metric``, and thus the functional metric API provides no support for in-built distributed synchronization or reduction functions. @@ -107,7 +108,7 @@ also manually log the output of the metrics. class MyModule(LightningModule): - def __init__(self): + def __init__(self, num_classes): ... self.train_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes) self.valid_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes) @@ -156,6 +157,43 @@ Additionally, we highly recommend that the two ways of logging are not mixed as self.valid_acc.update(logits, y) self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True) +In general if you are logging multiple metrics we highly recommend that you combine them into a single metric object +using the :class:`~torchmetrics.MetricCollection` class and then replacing the ``self.log`` calls with ``self.log_dict``, +assuming that all metrics receive the same input. + +.. testcode:: python + + class MyModule(LightningModule): + + def __init__(self): + ... + self.train_metrics = torchmetrics.MetricCollection( + { + "accuracy": torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes), + "f1": torchmetrics.classification.F1(task="multiclass", num_classes=num_classes), + }, + prefix="train_", + ) + self.valid_metrics = self.train_metrics.clone(prefix="valid_") + + def training_step(self, batch, batch_idx): + x, y = batch + preds = self(x) + ... + batch_value = self.train_metrics(preds, y) + self.log_dict(batch_value) + + def on_train_epoch_end(self): + self.train_metrics.reset() + + def validation_step(self, batch, batch_idx): + logits = self(x) + ... + self.valid_metrics.update(logits, y) + + def on_validation_epoch_end(self, outputs): + self.log_dict(self.valid_metrics.compute()) + self.valid_metrics.reset() *************** Common Pitfalls @@ -163,6 +201,41 @@ Common Pitfalls The following contains a list of pitfalls to be aware of: +* Logging a `MetricCollection` object directly using ``self.log_dict`` is only supported if all metrics in the + collection return a scalar tensor. If any of the metrics in the collection return a non-scalar tensor, + the logging will fail. This can especially happen when either nesting multiple ``MetricCollection`` objects or when + using wrapper metrics such as :class:`~torchmetrics.wrappers.ClasswiseWrapper`, + :class:`~torchmetrics.wrappers.MinMaxMetric` etc. inside a ``MetricCollection`` since all these wrappers return + dicts or lists of tensors. It is still possible to log such nested metrics manually because the ``MetricCollection`` + object will try to flatten everything into a single dict. Example: + +.. testcode:: python + + class MyModule(LightningModule): + + def __init__(self): + super().__init__() + self.train_metrics = MetricCollection( + { + "macro_accuracy": MinMaxMetric(MulticlassAccuracy(num_classes=5, average="macro")), + "weighted_accuracy": MinMaxMetric(MulticlassAccuracy(num_classes=5, average="weighted")), + }, + prefix="train_", + ) + + def training_step(self, batch, batch_idx): + ... + # logging the MetricCollection object directly will fail + self.log_dict(self.train_metrics(preds, target)) + + # manually computing the result and then logging will work + batch_values = self.train_metrics(preds, target) + self.log_dict(batch_values, on_step=True, on_epoch=False) + ... + + def on_train_epoch_end(self): + self.train_metrics.reset() + * Modular metrics contain internal states that should belong to only one DataLoader. In case you are using multiple DataLoaders, it is recommended to initialize a separate modular metric instances for each DataLoader and use them separately. The same holds for using separate metrics for training, validation and testing. @@ -193,8 +266,31 @@ The following contains a list of pitfalls to be aware of: Because the object is logged in the first case, Lightning will reset the metric before calling the second line leading to errors or nonsense results. +* If you decorate a lightning method with the ``rank_zero_only`` decorator with the goal of only calculating a particular + metric on the main process, you need to disable the default behavior of the metric to synchronize the metric values + across all processes. This can be done by setting the ``sync_on_compute`` flag to ``False`` when initializing the + metric. Not doing so can lead to race conditions and processes hanging. + +.. testcode:: python + + class MyModule(LightningModule): + + def __init__(self, num_classes): + ... + self.metric = torchmetrics.image.FrechetInceptionDistance(sync_on_compute=False) + + @rank_zero_only + def validation_step(self, batch, batch_idx): + image, target = batch + generated_image = self(x) + ... + self.metric(image, real=True) + self.metric(generated_image, real=False) + val = self.metric.compute() # this will only be called on the main process + self.log('val_fid', val) + * Calling ``self.log("val", self.metric(preds, target))`` with the intention of logging the metric object. Because - ``self.metric(preds, target)`` corresponds to calling the forward method, this will return a tensor and not the + ``self.metric(preds, target)`` corresponds to calling the ``forward`` method, this will return a tensor and not the metric object. Such logging will be wrong in this case. Instead, it is essential to separate into several lines: .. testcode:: python @@ -207,7 +303,8 @@ The following contains a list of pitfalls to be aware of: self.accuracy(preds, y) # compute metrics self.log('train_acc_step', self.accuracy) # log metric object -* Using :class:`~torchmetrics.wrappers.MetricTracker` wrapper with Lightning is a special case, because the wrapper in itself is not a metric - i.e. it does not inherit from the base :class:`~torchmetrics.Metric` class but instead from :class:`~torch.nn.ModuleList`. Thus, - to log the output of this metric one needs to manually log the returned values (not the object) using ``self.log`` - and for epoch level logging this should be done in the appropriate ``on_{train|validation|test}_epoch_end`` method. +* Using :class:`~torchmetrics.wrappers.MetricTracker` wrapper with Lightning is a special case, because the wrapper in + itself is not a metric i.e. it does not inherit from the base :class:`~torchmetrics.Metric` class but instead from + :class:`~torch.nn.ModuleList`. Thus, to log the output of this metric one needs to manually log the returned values + (not the object) using ``self.log`` and for epoch level logging this should be done in the appropriate + ``on_{train|validation|test}_epoch_end`` method. diff --git a/docs/source/regression/kl_divergence.rst b/docs/source/regression/kl_divergence.rst index 4e65ba6fb01..7f5761ad9a8 100644 --- a/docs/source/regression/kl_divergence.rst +++ b/docs/source/regression/kl_divergence.rst @@ -1,7 +1,7 @@ .. customcarditem:: :header: KL Divergence :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg - :tags: Classification + :tags: Regression .. include:: ../links.rst diff --git a/docs/source/regression/normalized_root_mean_squared_error.rst b/docs/source/regression/normalized_root_mean_squared_error.rst new file mode 100644 index 00000000000..7bbc2f392d5 --- /dev/null +++ b/docs/source/regression/normalized_root_mean_squared_error.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: Normalized Root Mean Squared Error (NRMSE) + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg + :tags: Regression + +.. include:: ../links.rst + +########################################## +Normalized Root Mean Squared Error (NRMSE) +########################################## + +Module Interface +________________ + +.. autoclass:: torchmetrics.NormalizedRootMeanSquaredError + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.normalized_root_mean_squared_error diff --git a/docs/source/segmentation/generalized_dice.rst b/docs/source/segmentation/generalized_dice.rst index 5c48fc670d1..f0abd2f5353 100644 --- a/docs/source/segmentation/generalized_dice.rst +++ b/docs/source/segmentation/generalized_dice.rst @@ -1,7 +1,7 @@ .. customcarditem:: :header: Generalized Dice Score :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg - :tags: Classification + :tags: Segmentation .. include:: ../links.rst diff --git a/docs/source/segmentation/hausdorff_distance.rst b/docs/source/segmentation/hausdorff_distance.rst new file mode 100644 index 00000000000..cfe1d3fdb5b --- /dev/null +++ b/docs/source/segmentation/hausdorff_distance.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: Hausdorff Distance + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/text_classification.svg + :tags: segmentation + +.. include:: ../links.rst + +################## +Hausdorff Distance +################## + +Module Interface +________________ + +.. autoclass:: torchmetrics.segmentation.HausdorffDistance + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.segmentation.hausdorff_distance diff --git a/docs/source/segmentation/mean_iou.rst b/docs/source/segmentation/mean_iou.rst index 7fddd9f316d..9e5544db349 100644 --- a/docs/source/segmentation/mean_iou.rst +++ b/docs/source/segmentation/mean_iou.rst @@ -1,7 +1,7 @@ .. customcarditem:: :header: Mean Intersection over Union (mIoU) :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/object_detection.svg - :tags: segmentation + :tags: Segmentation ################################### Mean Intersection over Union (mIoU) diff --git a/docs/source/shape/procrustes.rst b/docs/source/shape/procrustes.rst new file mode 100644 index 00000000000..e69357c6473 --- /dev/null +++ b/docs/source/shape/procrustes.rst @@ -0,0 +1,22 @@ +.. customcarditem:: + :header: Procrustes Disparity + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg + :tags: shape + +.. include:: ../links.rst + +#################### +Procrustes Disparity +#################### + +Module Interface +________________ + +.. autoclass:: torchmetrics.shape.ProcrustesDisparity + :exclude-members: update, compute + + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.shape.procrustes_disparity diff --git a/examples/audio/pesq.py b/examples/audio/pesq.py new file mode 100644 index 00000000000..6afde2bfdd5 --- /dev/null +++ b/examples/audio/pesq.py @@ -0,0 +1,117 @@ +""" +Evaluating Speech Quality with PESQ metric +============================================== + +This notebook will guide you through calculating the Perceptual Evaluation of Speech Quality (PESQ) score, + a key metric in assessing how effective noise reduction and enhancement techniques are in improving speech quality. + PESQ is widely adopted in industries such as telecommunications, VoIP, and audio processing. + It provides an objective way to measure the perceived quality of speech signals from a human listener's perspective. + +Imagine being on a noisy street, trying to have a phone call. The technology behind the scenes aims + to clean up your voice and make it sound clearer on the other end. But how do engineers measure that improvement? + This is where PESQ comes in. In this notebook, we will simulate a similar scenario, applying a simple noise reduction + technique and using the PESQ score to evaluate how much the speech quality improves. +""" + +# %% +# Import necessary libraries +import matplotlib.pyplot as plt +import numpy as np +import torch +import torchaudio +from torchmetrics.audio import PerceptualEvaluationSpeechQuality + +# %% +# Generate Synthetic Clean and Noisy Audio Signals +# We'll generate a clean sine wave (representing a clean speech signal) and add white noise to simulate the noisy version. + + +def generate_sine_wave(frequency, duration, sample_rate, amplitude: float = 0.5): + """Generate a clean sine wave at a given frequency.""" + t = torch.linspace(0, duration, int(sample_rate * duration)) + return amplitude * torch.sin(2 * np.pi * frequency * t) + + +def add_noise(waveform: torch.Tensor, noise_factor: float = 0.05) -> torch.Tensor: + """Add white noise to a waveform.""" + noise = noise_factor * torch.randn(waveform.size()) + return waveform + noise + + +# Parameters for the synthetic audio +sample_rate = 16000 # 16 kHz typical for speech +duration = 3 # 3 seconds of audio +frequency = 440 # A4 note, can represent a simple speech-like tone + +# Generate the clean sine wave +clean_waveform = generate_sine_wave(frequency, duration, sample_rate) + +# Generate the noisy waveform by adding white noise +noisy_waveform = add_noise(clean_waveform) + + +# %% +# Apply Basic Noise Reduction Technique +# In this step, we apply a simple spectral gating method for noise reduction using torchaudio's +# `spectrogram` method. This is to simulate the enhancement of noisy speech. + + +def reduce_noise(noisy_signal: torch.Tensor, threshold: float = 0.2) -> torch.Tensor: + """Basic noise reduction using spectral gating.""" + # Compute the spectrogram + spec = torchaudio.transforms.Spectrogram()(noisy_signal) + + # Apply threshold-based gating: values below the threshold will be zeroed out + spec_denoised = spec * (spec > threshold) + + # Convert back to the waveform + return torchaudio.transforms.GriffinLim()(spec_denoised) + + +# Apply noise reduction to the noisy waveform +enhanced_waveform = reduce_noise(noisy_waveform) + +# %% +# Initialize the PESQ Metric +# PESQ can be computed in two modes: 'wb' (wideband) or 'nb' (narrowband). +# Here, we are using 'wb' mode for wideband speech quality evaluation. +pesq_metric = PerceptualEvaluationSpeechQuality(fs=sample_rate, mode="wb") + +# %% +# Compute PESQ Scores +# We will calculate the PESQ scores for both the noisy and enhanced versions compared to the clean signal. +# The PESQ scores give us a numerical evaluation of how well the enhanced speech +# compares to the clean speech. Higher scores indicate better quality. + +pesq_noisy = pesq_metric(clean_waveform, noisy_waveform) +pesq_enhanced = pesq_metric(clean_waveform, enhanced_waveform) + +print(f"PESQ Score for Noisy Audio: {pesq_noisy.item():.4f}") +print(f"PESQ Score for Enhanced Audio: {pesq_enhanced.item():.4f}") + +# %% +# Visualize the waveforms +# We can visualize the waveforms of the clean, noisy, and enhanced audio to see the differences. +fig, axs = plt.subplots(3, 1, figsize=(12, 9)) + +# Plot clean waveform +axs[0].plot(clean_waveform.numpy()) +axs[0].set_title("Clean Audio Waveform (Sine Wave)") +axs[0].set_xlabel("Time") +axs[0].set_ylabel("Amplitude") + +# Plot noisy waveform +axs[1].plot(noisy_waveform.numpy(), color="orange") +axs[1].set_title(f"Noisy Audio Waveform (PESQ: {pesq_noisy.item():.4f})") +axs[1].set_xlabel("Time") +axs[1].set_ylabel("Amplitude") + +# Plot enhanced waveform +axs[2].plot(enhanced_waveform.numpy(), color="green") +axs[2].set_title(f"Enhanced Audio Waveform (PESQ: {pesq_enhanced.item():.4f})") +axs[2].set_xlabel("Time") +axs[2].set_ylabel("Amplitude") + +# Adjust layout for better visualization +fig.tight_layout() +plt.show() diff --git a/examples/audio/signal_to_noise_ratio.py b/examples/audio/signal_to_noise_ratio.py index 910f80801fc..c7130a895e4 100644 --- a/examples/audio/signal_to_noise_ratio.py +++ b/examples/audio/signal_to_noise_ratio.py @@ -1,12 +1,13 @@ """Signal-to-Noise Ratio =============================== -The Signal-to-Noise Ratio (SNR) is a metric used to evaluate the quality of a signal by comparing the power of the signal to the power of background noise. In audio processing, SNR can be used to measure the quality of a reconstructed audio signal by comparing it to the original clean signal. +Imagine developing a song recognition application. The software's goal is to recognize a song even when it's played in a noisy environment, similar to Shazam. To achieve this, you want to enhance the audio quality by reducing the noise and evaluating the improvement using the Signal-to-Noise Ratio (SNR). + +In this example, we will demonstrate how to generate a clean signal, add varying levels of noise to simulate the noisy recording, use FFT for noise reduction, and then evaluate the quality of the reconstructed audio using SNR. """ # %% -# Here's a hypothetical Python example demonstrating the usage of the Signal-to-Noise Ratio to evaluate an audio reconstruction task: - +# Import necessary libraries from typing import Tuple import matplotlib.animation as animation @@ -15,22 +16,21 @@ import torch from torchmetrics.audio import SignalNoiseRatio -# Set seed for reproducibility -torch.manual_seed(42) -np.random.seed(42) +# %% +# Generate a clean signal (simulating a high-quality recording) -# %% -# Create a clean signal (sine wave) def generate_clean_signal(length: int = 1000) -> Tuple[np.ndarray, np.ndarray]: """Generate a clean signal (sine wave)""" t = np.linspace(0, 1, length) - signal = np.sin(2 * np.pi * 10 * t) # 10 Hz sine wave + signal = np.sin(2 * np.pi * 10 * t) # 10 Hz sine wave, representing the clean recording return t, signal # %% -# Add Gaussian noise to the signal +# Add Gaussian noise to the signal to simulate the noisy environment + + def add_noise(signal: np.ndarray, noise_level: float = 0.5) -> np.ndarray: """Add Gaussian noise to the signal.""" noise = noise_level * np.random.randn(signal.shape[0]) @@ -38,31 +38,49 @@ def add_noise(signal: np.ndarray, noise_level: float = 0.5) -> np.ndarray: # %% -# Generate and plot clean and noisy signals +# Apply FFT to filter out the noise + + +def fft_denoise(noisy_signal: np.ndarray, threshold: float) -> np.ndarray: + """Denoise the signal using FFT.""" + freq_domain = np.fft.fft(noisy_signal) # Filter frequencies using FFT + magnitude = np.abs(freq_domain) + filtered_freq_domain = freq_domain * (magnitude > threshold) + return np.fft.ifft(filtered_freq_domain).real # Perform inverse FFT to reconstruct the signal + + +# %% +# Generate and plot clean, noisy, and denoised signals to visualize the reconstruction + length = 1000 t, clean_signal = generate_clean_signal(length) noisy_signal = add_noise(clean_signal, noise_level=0.5) +denoised_signal = fft_denoise(noisy_signal, threshold=10) plt.figure(figsize=(12, 4)) -plt.plot(t, noisy_signal, label="Noisy Signal", color="blue", alpha=0.7) -plt.plot(t, clean_signal, label="Clean Signal", color="red", linewidth=3) +plt.plot(t, noisy_signal, label="Noisy environment", color="blue", alpha=0.7) +plt.plot(t, denoised_signal, label="Denoised signal", color="green", alpha=0.7) +plt.plot(t, clean_signal, label="Clean song", color="red", linewidth=3) plt.xlabel("Time") plt.ylabel("Amplitude") -plt.title("Clean Signal vs. Noisy Signal") +plt.title("Clean Song vs. Noisy Environment vs. Denoised Signal") plt.legend() plt.show() - # %% # Convert the signals to PyTorch tensors and calculate the SNR clean_signal_tensor = torch.tensor(clean_signal).float() noisy_signal_tensor = torch.tensor(noisy_signal).float() +denoised_signal_tensor = torch.tensor(denoised_signal).float() snr = SignalNoiseRatio() -score = snr(preds=noisy_signal_tensor, target=clean_signal_tensor) +initial_snr = snr(preds=noisy_signal_tensor, target=clean_signal_tensor) +reconstructed_snr = snr(preds=denoised_signal_tensor, target=clean_signal_tensor) +print(f"Initial SNR: {initial_snr:.2f}") +print(f"Reconstructed SNR: {reconstructed_snr:.2f}") # %% -# To show the effect of different noise levels on the SNR, we can create an animation that iterates over different noise levels and updates the plot accordingly: +# To show the effect of different noise levels on the SNR, we create an animation that iterates over different noise levels and updates the plot accordingly: fig, ax = plt.subplots(figsize=(12, 4)) noise_levels = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] @@ -71,21 +89,26 @@ def update(num: int) -> tuple: """Update the plot for each frame.""" t, clean_signal = generate_clean_signal(length) noisy_signal = add_noise(clean_signal, noise_level=noise_levels[num]) + denoised_signal = fft_denoise(noisy_signal, threshold=10) clean_signal_tensor = torch.tensor(clean_signal).float() noisy_signal_tensor = torch.tensor(noisy_signal).float() - score = snr(preds=noisy_signal_tensor, target=clean_signal_tensor) + denoised_signal_tensor = torch.tensor(denoised_signal).float() + initial_snr = snr(preds=noisy_signal_tensor, target=clean_signal_tensor) + reconstructed_snr = snr(preds=denoised_signal_tensor, target=clean_signal_tensor) ax.clear() - (clean,) = plt.plot(t, noisy_signal, label="Noisy Signal", color="blue", alpha=0.7) - (noisy,) = plt.plot(t, clean_signal, label="Clean Signal", color="red", linewidth=3) + (noisy,) = plt.plot(t, noisy_signal, label="Noisy Environment", color="blue", alpha=0.7) + (denoised,) = plt.plot(t, denoised_signal, label="Denoised Signal", color="green", alpha=0.7) + (clean,) = plt.plot(t, clean_signal, label="Clean Song", color="red", linewidth=3) ax.set_xlabel("Time") ax.set_ylabel("Amplitude") - ax.set_title(f"SNR: {score:.2f} - Noise level: {noise_levels[num]}") - ax.legend() + ax.set_title( + f"Initial SNR: {initial_snr:.2f} - Reconstructed SNR: {reconstructed_snr:.2f} - Noise level: {noise_levels[num]}" + ) + ax.legend(loc="upper right") ax.set_ylim(-3, 3) - return clean, noisy + return noisy, denoised, clean -ani = animation.FuncAnimation(fig, update, frames=len(noise_levels), interval=500) -plt.show() +ani = animation.FuncAnimation(fig, update, frames=len(noise_levels), interval=1000) diff --git a/examples/image/clip_score.py b/examples/image/clip_score.py index e465ed8ce1f..f73c5d68333 100644 --- a/examples/image/clip_score.py +++ b/examples/image/clip_score.py @@ -19,6 +19,7 @@ # %% # Get sample images + images = { "astronaut": astronaut(), "cat": cat(), @@ -27,6 +28,7 @@ # %% # Define a hypothetical captions for the images + captions = [ "A photo of an astronaut.", "A photo of a cat.", @@ -35,6 +37,7 @@ # %% # Define the models for CLIPScore + models = [ "openai/clip-vit-base-patch16", # "openai/clip-vit-base-patch32", @@ -44,6 +47,7 @@ # %% # Collect scores for each image-caption pair + score_results = [] for model in models: clip_score = CLIPScore(model_name_or_path=model) @@ -54,6 +58,7 @@ # %% # Create an animation to display the scores + fig, (ax_img, ax_table) = plt.subplots(1, 2, figsize=(10, 5)) diff --git a/examples/text/bertscore.py b/examples/text/bertscore.py index 1a10dc097cf..09e2fbff418 100644 --- a/examples/text/bertscore.py +++ b/examples/text/bertscore.py @@ -15,26 +15,26 @@ # %% # Define the prompt and target texts -prompt = "Economic recovery is underway with a 3.5% GDP growth and a decrease in unemployment. Experts forecast continued improvement with boosts from consumer spending and government projects." -target_text = "The economy is recovering, with GDP growth at 3.5% and unemployment at a two-year low. Experts expect this trend to continue due to higher consumer spending and government infrastructure investments." +prompt = "Economic recovery is underway with a 3.5% GDP growth and a decrease in unemployment. Experts forecast continued improvement with boosts from consumer spending and government projects. In summary: " +target_summary = "the recession is ending." # %% # Generate a sample text using the GPT-2 model -generated_text = pipe(prompt, max_new_tokens=20, do_sample=False, temperature=0, pad_token_id=tokenizer.eos_token_id)[ - 0 -]["generated_text"][len(prompt) :] +generated_summary = pipe(prompt, max_new_tokens=20, do_sample=False, pad_token_id=tokenizer.eos_token_id)[0][ + "generated_text" +][len(prompt) :].strip() # %% # Calculate the BERTScore of the generated text bertscore = BERTScore(model_name_or_path="roberta-base") -score = bertscore(preds=[generated_text], target=[target_text]) +score = bertscore(preds=[generated_summary], target=[target_summary]) print(f"Prompt: {prompt}") -print(f"Target Text: {target_text}") -print(f"Generated Text: {generated_text}") -print(f"BERTScore: {score['f1']}") +print(f"Target summary: {target_summary}") +print(f"Generated summary: {generated_summary}") +print(f"BERTScore: {score['f1']:.4f}") # %% # In addition, to illustrate BERTScore's robustness to paraphrasing, let's consider two candidate texts that are variations of the reference text. @@ -42,10 +42,11 @@ candidate_good = "it is cold today" candidate_bad = "it is warm outside" +# %% +# Here we see that using the BERTScore we are able to differentiate between the candidate texts based on their similarity to the reference text, whereas the ROUGE scores for the same text pairs are identical. rouge = ROUGEScore() -bertscore = BERTScore(model_name_or_path="roberta-base") -print("ROUGE for candidate_good:", rouge(preds=[candidate_good], target=[reference])["rouge1_fmeasure"]) -print("ROUGE for candidate_bad:", rouge(preds=[candidate_bad], target=[reference])["rouge1_fmeasure"]) -print("BERTScore for candidate_good:", bertscore(preds=[candidate_good], target=[reference])["f1"]) -print("BERTScore for candidate_bad:", bertscore(preds=[candidate_bad], target=[reference])["f1"]) +print(f"ROUGE for candidate_good: {rouge(preds=[candidate_good], target=[reference])['rouge1_fmeasure'].item()}") +print(f"ROUGE for candidate_bad: {rouge(preds=[candidate_bad], target=[reference])['rouge1_fmeasure'].item()}") +print(f"BERTScore for candidate_good: {bertscore(preds=[candidate_good], target=[reference])['f1'].item():.4f}") +print(f"BERTScore for candidate_bad: {bertscore(preds=[candidate_bad], target=[reference])['f1'].item():.4f}") diff --git a/pyproject.toml b/pyproject.toml index c0c0f86e6e8..5a765978081 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,12 +48,14 @@ lint.ignore = [ "D107", # Missing docstring in `__init__` "E731", # Do not assign a lambda expression, use a def "EXE002", # The file is executable but no shebang is present + "ISC001", # potehtional comflict, flagged by Ruff itsel ] lint.per-file-ignores."docs/source/conf.py" = [ "A001", "D103", ] lint.per-file-ignores."examples/*" = [ + "ANN", # any annotaions "D205", # 1 blank line required between summary line and description "D212", # [*] Multi-line docstring summary should start at the first line "D415", # First line should end with a period, question mark, or exclamation point @@ -82,7 +84,6 @@ lint.unfixable = [ lint.mccabe.max-complexity = 10 # Use Google-style docstrings. lint.pydocstyle.convention = "google" -lint.ignore-init-module-imports = true [tool.codespell] #skip = '*.py' diff --git a/requirements/_devel.txt b/requirements/_devel.txt index 596cc138133..6a8ea2b8e7f 100644 --- a/requirements/_devel.txt +++ b/requirements/_devel.txt @@ -20,3 +20,4 @@ -r classification_test.txt -r nominal_test.txt -r segmentation_test.txt +-r regression_test.txt diff --git a/requirements/_docs.txt b/requirements/_docs.txt index e729659d8b9..1fd1d103caf 100644 --- a/requirements/_docs.txt +++ b/requirements/_docs.txt @@ -1,6 +1,6 @@ sphinx ==5.3.0 myst-parser ==1.0.0 -pandoc ==2.3 +pandoc ==2.4 docutils ==0.19 sphinxcontrib-fulltoc >=1.0 sphinxcontrib-mockautodoc @@ -9,10 +9,11 @@ sphinx-autodoc-typehints ==1.23.0 sphinx-paramlinks ==0.6.0 sphinx-togglebutton ==0.3.2 sphinx-copybutton ==0.5.2 -sphinx-gallery ==0.16.0 +sphinx-autobuild ==2024.10.3 +sphinx-gallery ==0.18.0 -lightning >=1.8.0, <2.4.0 -lightning-utilities ==0.11.3.post0 +lightning >=1.8.0, <2.5.0 +lightning-utilities ==0.11.8 pydantic > 1.0.0, < 3.0.0 # integrations @@ -29,4 +30,4 @@ pydantic > 1.0.0, < 3.0.0 # todo: until this has resolution - https://github.com/sphinx-gallery/sphinx-gallery/issues/1290 # Image scikit-image ~=0.22; python_version < "3.10" -scikit-image ~=0.24; python_version >= "3.10" +scikit-image ~=0.24; python_version > "3.9" # we do not use `> =` because of oldest replcement diff --git a/requirements/_integrate.txt b/requirements/_integrate.txt index fb4472a6f92..87ee9a33585 100644 --- a/requirements/_integrate.txt +++ b/requirements/_integrate.txt @@ -1,4 +1,4 @@ # contentiously validated integration with these expected ranges # ToDo: investigate and add validation with 2.0+ on GPU -pytorch-lightning >=1.9.0, <2.4.0 +pytorch-lightning >=1.9.0, <2.6.0 diff --git a/requirements/_tests.txt b/requirements/_tests.txt index c4b0cb4cfa1..889b468659f 100644 --- a/requirements/_tests.txt +++ b/requirements/_tests.txt @@ -1,8 +1,10 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment +codecov ==2.1.13 coverage ==7.6.* -pytest ==8.2.2 +codecov ==2.1.13 +pytest ==8.3.* pytest-cov ==5.0.0 pytest-doctestplus ==1.2.1 pytest-rerunfailures ==14.0 @@ -10,11 +12,11 @@ pytest-timeout ==2.3.1 pytest-xdist ==3.6.1 phmdoctest ==1.4.0 -psutil <6.1.0 -pyGithub ==2.3.0 -fire <=0.6.0 +psutil ==6.* +pyGithub >2.0.0, <2.5.0 +fire ==0.7.* -cloudpickle >1.3, <=3.0.0 -scikit-learn >=1.1.1, <1.3.0; python_version < "3.9" -scikit-learn >=1.4.0, <1.6.0; python_version >= "3.9" -cachier ==3.0.0 +cloudpickle >1.3, <=3.1.0 +scikit-learn ==1.2.*; python_version < "3.9" +scikit-learn ==1.5.*; python_version > "3.8" # we do not use `> =` because of oldest replcement +cachier ==3.0.1 diff --git a/requirements/audio.txt b/requirements/audio.txt index ce717cab35a..670e3fed77d 100644 --- a/requirements/audio.txt +++ b/requirements/audio.txt @@ -3,10 +3,10 @@ # this need to be the same as used inside speechmetrics pesq >=0.0.4, <0.0.5 -pystoi >=0.3.0, <0.5.0 -torchaudio >=0.10.0, <2.5.0 +numpy <2.0 # strict, for compatibility reasons +pystoi >=0.4.0, <0.5.0 +torchaudio >=2.0.1, <2.6.0 gammatone >=1.0.0, <1.1.0 -librosa >=0.9.0, <0.11.0 -onnxruntime-gpu >=1.12.0, <1.19; sys_platform != 'darwin' -onnxruntime >=1.12.0, <1.19; sys_platform == 'darwin' # installing onnxruntime-gpu failed on macos -requests >=2.19.0, <2.32.0 +librosa >=0.10.0, <0.11.0 +onnxruntime >=1.12.0, <1.20 # installing onnxruntime_gpu-gpu failed on macos +requests >=2.19.0, <2.33.0 diff --git a/requirements/base.txt b/requirements/base.txt index 2ce12be4af2..669e36d9f6d 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -3,6 +3,6 @@ numpy >1.20.0 packaging >17.1 -torch >=1.10.0, <2.5.0 +torch >=2.0.0, <2.6.0 typing-extensions; python_version < '3.9' lightning-utilities >=0.8.0, <0.12.0 diff --git a/requirements/classification_test.txt b/requirements/classification_test.txt index 3939de21cf7..45f0b86c925 100644 --- a/requirements/classification_test.txt +++ b/requirements/classification_test.txt @@ -1,8 +1,8 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -pandas >=1.4.0, <=2.2.2 -netcal >1.0.0, <=1.3.5 # calibration_error -numpy <2.1.0 +pandas >1.4.0, <=2.2.3 +netcal >1.0.0, <1.4.0 # calibration_error +numpy <2.2.0 fairlearn # group_fairness PyTDC # locauc diff --git a/requirements/detection.txt b/requirements/detection.txt index cc65884fa16..6d23ed0c60c 100644 --- a/requirements/detection.txt +++ b/requirements/detection.txt @@ -1,5 +1,5 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -torchvision >=0.8, <0.20.0 +torchvision >=0.15.1, <0.21.0 pycocotools >2.0.0, <2.1.0 diff --git a/requirements/detection_test.txt b/requirements/detection_test.txt index 6515620c715..f3c576f3da3 100644 --- a/requirements/detection_test.txt +++ b/requirements/detection_test.txt @@ -1,4 +1,4 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -faster-coco-eval >=1.3.3 +faster-coco-eval >=1.6.3, <1.7.0 diff --git a/requirements/image.txt b/requirements/image.txt index 4f058b0e410..586633d12ca 100644 --- a/requirements/image.txt +++ b/requirements/image.txt @@ -1,6 +1,6 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -scipy >1.0.0, <1.14.0 -torchvision >=0.8, <0.20.0 +scipy >1.0.0, <1.15.0 +torchvision >=0.15.1, <0.21.0 torch-fidelity <=0.4.0 # bumping to allow install version from master, now used in testing diff --git a/requirements/image_test.txt b/requirements/image_test.txt index 692f5953431..956a370fa0e 100644 --- a/requirements/image_test.txt +++ b/requirements/image_test.txt @@ -5,6 +5,6 @@ scikit-image >=0.19.0, <0.25.0 kornia >=0.6.7, <0.8.0 pytorch-msssim ==1.0.0 sewar >=0.4.4, <=0.4.6 -numpy <2.1.0 +numpy <2.2.0 torch-fidelity @ git+https://github.com/toshas/torch-fidelity@master lpips <=0.1.4 diff --git a/requirements/integrate.txt b/requirements/integrate.txt new file mode 100644 index 00000000000..e69de29bb2d diff --git a/requirements/multimodal.txt b/requirements/multimodal.txt index 37c0d169462..1a034aa3a1f 100644 --- a/requirements/multimodal.txt +++ b/requirements/multimodal.txt @@ -1,5 +1,5 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -transformers >=4.42.3, <4.43.0 +transformers >=4.42.3, <4.46.0 piq <=0.8.0 diff --git a/requirements/nominal_test.txt b/requirements/nominal_test.txt index 0cbbde2c23b..70beddada96 100644 --- a/requirements/nominal_test.txt +++ b/requirements/nominal_test.txt @@ -1,7 +1,8 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -pandas >1.0.0, <=2.2.2 # cannot pin version due to numpy version incompatibility -dython <=0.7.6 -scipy >1.0.0, <1.14.0 # cannot pin version due to some version conflicts with `oldest` CI configuration +pandas >1.4.0, <=2.2.3 # cannot pin version due to numpy version incompatibility +dython ==0.7.6 ; python_version <"3.9" +dython ~=0.7.8 ; python_version > "3.8" # we do not use `> =` +scipy >1.0.0, <1.15.0 # cannot pin version due to some version conflicts with `oldest` CI configuration statsmodels >0.13.5, <0.15.0 diff --git a/requirements/regression_test.txt b/requirements/regression_test.txt new file mode 100644 index 00000000000..859605fda3b --- /dev/null +++ b/requirements/regression_test.txt @@ -0,0 +1 @@ +permetrics==2.0.0 diff --git a/requirements/segmentation_test.txt b/requirements/segmentation_test.txt index e26b5ef76fb..75d7b97ac6c 100644 --- a/requirements/segmentation_test.txt +++ b/requirements/segmentation_test.txt @@ -1,5 +1,6 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -scipy >1.0.0, <1.14.0 -monai ==1.3.2 +scipy >1.0.0, <1.15.0 +monai ==1.3.2 ; python_version < "3.9" +monai ==1.4.0 ; python_version > "3.8" diff --git a/requirements/text.txt b/requirements/text.txt index 1ef4fed786d..62007c3e127 100644 --- a/requirements/text.txt +++ b/requirements/text.txt @@ -1,10 +1,10 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -nltk >=3.6, <=3.8.1 -tqdm >=4.41.0, <4.67.0 -regex >=2021.9.24, <=2024.5.15 -transformers >4.4.0, <4.43.0 +nltk >3.8.1, <=3.9.1 +tqdm <4.67.0 +regex >=2021.9.24, <=2024.9.11 +transformers >4.4.0, <4.46.0 mecab-python3 >=1.0.6, <1.1.0 ipadic >=1.0.0, <1.1.0 sentencepiece >=0.2.0, <0.3.0 diff --git a/requirements/text_test.txt b/requirements/text_test.txt index 10d94c146e5..04b2a2b921d 100644 --- a/requirements/text_test.txt +++ b/requirements/text_test.txt @@ -4,8 +4,8 @@ jiwer >=2.3.0, <3.1.0 rouge-score >0.1.0, <=0.1.2 bert_score ==0.3.13 -huggingface-hub <0.24 +huggingface-hub <0.27 sacrebleu >=2.3.0, <2.5.0 -mecab-ko >=1.0.0, <1.1.0 -mecab-ko-dic >=1.0.0, <1.1.0 +mecab-ko >=1.0.0, <1.1.0 ; python_version < "3.12" # strict # todo: unpin python_version +mecab-ko-dic >=1.0.0, <1.1.0 ; python_version < "3.12" # todo: unpin python_version diff --git a/requirements/typing.txt b/requirements/typing.txt index 18a6e706240..01c6897fa9c 100644 --- a/requirements/typing.txt +++ b/requirements/typing.txt @@ -1,5 +1,5 @@ -mypy ==1.10.1 -torch ==2.3.1 +mypy ==1.11.2 +torch ==2.5.0 types-PyYAML types-emoji diff --git a/requirements/visual.txt b/requirements/visual.txt index 269a45fc7bb..1cdc4060a8b 100644 --- a/requirements/visual.txt +++ b/requirements/visual.txt @@ -1,5 +1,5 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -matplotlib >=3.3.0, <3.10.0 +matplotlib >=3.6.0, <3.10.0 SciencePlots >= 2.0.0, <2.2.0 diff --git a/setup.py b/setup.py index 6f2e6f06455..2324b660cc0 100755 --- a/setup.py +++ b/setup.py @@ -245,5 +245,6 @@ def _prepare_extras(skip_pattern: str = "^_", skip_files: Tuple[str] = ("base.tx "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ], ) diff --git a/src/conftest.py b/src/conftest.py new file mode 100644 index 00000000000..5f4a26123d3 --- /dev/null +++ b/src/conftest.py @@ -0,0 +1,40 @@ +from pathlib import Path +from typing import Optional + +from lightning_utilities.core.imports import package_available + +if package_available("pytest") and package_available("doctest"): + import doctest + + import pytest + + MANUAL_SEED = doctest.register_optionflag("MANUAL_SEED") + + @pytest.fixture(autouse=True) + def reset_random_seed(seed: int = 42) -> None: # noqa: PT004 + """Reset the random seed before running each doctest.""" + import random + + import numpy as np + import torch + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + class DoctestModule(pytest.Module): + """A custom module class that augments collected doctests with the reset_random_seed fixture.""" + + def collect(self) -> GeneratorExit: + """Augment collected doctests with the reset_random_seed fixture.""" + for item in super().collect(): + if isinstance(item, pytest.DoctestItem): + item.add_marker(pytest.mark.usefixtures("reset_random_seed")) + yield item + + def pytest_collect_file(parent: Path, path: Path) -> Optional[DoctestModule]: + """Collect doctests and add the reset_random_seed fixture.""" + if path.ext == ".py": + return DoctestModule.from_parent(parent, path=Path(path)) + return None diff --git a/src/torchmetrics/__about__.py b/src/torchmetrics/__about__.py index dfc5a7505e2..2acf435134a 100644 --- a/src/torchmetrics/__about__.py +++ b/src/torchmetrics/__about__.py @@ -1,4 +1,4 @@ -__version__ = "1.5.0dev" +__version__ = "1.6.0dev" __author__ = "Lightning-AI et al." __author_email__ = "name@pytorchlightning.ai" __license__ = "Apache-2.0" diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index 38a9141ce7e..8bc1615a36e 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -20,6 +20,13 @@ if not hasattr(PIL, "PILLOW_VERSION"): PIL.PILLOW_VERSION = PIL.__version__ +if package_available("scipy"): + import scipy.signal + + # back compatibility patch due to SMRMpy using scipy.signal.hamming + if not hasattr(scipy.signal, "hamming"): + scipy.signal.hamming = scipy.signal.windows.hamming + from torchmetrics import functional # noqa: E402 from torchmetrics.aggregation import ( # noqa: E402 CatMetric, @@ -108,6 +115,7 @@ MeanSquaredError, MeanSquaredLogError, MinkowskiDistance, + NormalizedRootMeanSquaredError, PearsonCorrCoef, R2Score, RelativeSquaredError, @@ -152,25 +160,23 @@ ) __all__ = [ - "functional", - "Accuracy", "AUROC", + "Accuracy", "AveragePrecision", "BLEUScore", "BootStrapper", + "CHRFScore", "CalibrationError", "CatMetric", - "ClasswiseWrapper", "CharErrorRate", - "CHRFScore", - "ConcordanceCorrCoef", + "ClasswiseWrapper", "CohenKappa", + "ConcordanceCorrCoef", "ConfusionMatrix", "CosineSimilarity", "CramersV", "CriticalSuccessIndex", "Dice", - "TweedieDevianceScore", "ErrorRelativeGlobalDimensionlessSynthesis", "ExactMatch", "ExplainedVariance", @@ -181,8 +187,8 @@ "HammingDistance", "HingeLoss", "JaccardIndex", - "KendallRankCorrCoef", "KLDivergence", + "KendallRankCorrCoef", "LogAUC", "LogCoshError", "MatchErrorRate", @@ -196,14 +202,16 @@ "Metric", "MetricCollection", "MetricTracker", - "MinkowskiDistance", "MinMaxMetric", "MinMetric", + "MinkowskiDistance", "ModifiedPanopticQuality", + "MultiScaleStructuralSimilarityIndexMeasure", "MultioutputWrapper", "MultitaskWrapper", - "MultiScaleStructuralSimilarityIndexMeasure", + "NormalizedRootMeanSquaredError", "PanopticQuality", + "PeakSignalNoiseRatio", "PearsonCorrCoef", "PearsonsContingencyCoefficient", "PermutationInvariantTraining", @@ -211,8 +219,8 @@ "Precision", "PrecisionAtFixedRecall", "PrecisionRecallCurve", - "PeakSignalNoiseRatio", "R2Score", + "ROC", "Recall", "RecallAtFixedPrecision", "RelativeAverageSpectralError", @@ -223,37 +231,38 @@ "RetrievalMRR", "RetrievalNormalizedDCG", "RetrievalPrecision", - "RetrievalRecall", - "RetrievalRPrecision", "RetrievalPrecisionRecallCurve", + "RetrievalRPrecision", + "RetrievalRecall", "RetrievalRecallAtFixedPrecision", - "ROC", "RootMeanSquaredErrorUsingSlidingWindow", "RunningMean", "RunningSum", + "SQuAD", "SacreBLEUScore", - "SignalDistortionRatio", "ScaleInvariantSignalDistortionRatio", "ScaleInvariantSignalNoiseRatio", + "SensitivityAtSpecificity", + "SignalDistortionRatio", "SignalNoiseRatio", "SpearmanCorrCoef", "Specificity", "SpecificityAtSensitivity", - "SensitivityAtSpecificity", "SpectralAngleMapper", "SpectralDistortionIndex", - "SQuAD", - "StructuralSimilarityIndexMeasure", "StatScores", + "StructuralSimilarityIndexMeasure", "SumMetric", "SymmetricMeanAbsolutePercentageError", "TheilsU", "TotalVariation", "TranslationEditRate", "TschuprowsT", + "TweedieDevianceScore", "UniversalImageQualityIndex", "WeightedMeanAbsolutePercentageError", "WordErrorRate", "WordInfoLost", "WordInfoPreserved", + "functional", ] diff --git a/src/torchmetrics/audio/__init__.py b/src/torchmetrics/audio/__init__.py index 6d21902b13e..14b987a7113 100644 --- a/src/torchmetrics/audio/__init__.py +++ b/src/torchmetrics/audio/__init__.py @@ -28,10 +28,17 @@ _ONNXRUNTIME_AVAILABLE, _PESQ_AVAILABLE, _PYSTOI_AVAILABLE, + _SCIPI_AVAILABLE, _TORCHAUDIO_AVAILABLE, - _TORCHAUDIO_GREATER_EQUAL_0_10, ) +if _SCIPI_AVAILABLE: + import scipy.signal + + # back compatibility patch due to SMRMpy using scipy.signal.hamming + if not hasattr(scipy.signal, "hamming"): + scipy.signal.hamming = scipy.signal.windows.hamming + __all__ = [ "PermutationInvariantTraining", "ScaleInvariantSignalDistortionRatio", @@ -52,7 +59,7 @@ __all__ += ["ShortTimeObjectiveIntelligibility"] -if _GAMMATONE_AVAILABLE and _TORCHAUDIO_AVAILABLE and _TORCHAUDIO_GREATER_EQUAL_0_10: +if _GAMMATONE_AVAILABLE and _TORCHAUDIO_AVAILABLE: from torchmetrics.audio.srmr import SpeechReverberationModulationEnergyRatio __all__ += ["SpeechReverberationModulationEnergyRatio"] diff --git a/src/torchmetrics/audio/_deprecated.py b/src/torchmetrics/audio/_deprecated.py index 4721ce1265c..c84604c9a2c 100644 --- a/src/torchmetrics/audio/_deprecated.py +++ b/src/torchmetrics/audio/_deprecated.py @@ -13,7 +13,6 @@ class _PermutationInvariantTraining(PermutationInvariantTraining): >>> import torch >>> from torchmetrics.functional import scale_invariant_signal_noise_ratio - >>> _ = torch.manual_seed(42) >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] >>> target = torch.randn(3, 2, 5) # [batch, spk, time] >>> pit = _PermutationInvariantTraining(scale_invariant_signal_noise_ratio, @@ -79,12 +78,11 @@ class _SignalDistortionRatio(SignalDistortionRatio): """Wrapper for deprecated import. >>> import torch - >>> g = torch.manual_seed(1) >>> preds = torch.randn(8000) >>> target = torch.randn(8000) >>> sdr = _SignalDistortionRatio() >>> sdr(preds, target) - tensor(-12.0589) + tensor(-11.9930) >>> # use with pit >>> from torchmetrics.functional import signal_distortion_ratio >>> preds = torch.randn(4, 2, 8000) # [batch, spk, time] @@ -92,7 +90,7 @@ class _SignalDistortionRatio(SignalDistortionRatio): >>> pit = _PermutationInvariantTraining(signal_distortion_ratio, ... mode="speaker-wise", eval_func="max") >>> pit(preds, target) - tensor(-11.6051) + tensor(-11.7277) """ diff --git a/src/torchmetrics/audio/dnsmos.py b/src/torchmetrics/audio/dnsmos.py index a6d45aa11cb..74d035a7fd4 100644 --- a/src/torchmetrics/audio/dnsmos.py +++ b/src/torchmetrics/audio/dnsmos.py @@ -76,11 +76,10 @@ class DeepNoiseSuppressionMeanOpinionScore(Metric): Example: >>> from torch import randn >>> from torchmetrics.audio import DeepNoiseSuppressionMeanOpinionScore - >>> g = torch.manual_seed(1) >>> preds = randn(8000) >>> dnsmos = DeepNoiseSuppressionMeanOpinionScore(8000, False) >>> dnsmos(preds) - tensor([2.2285, 2.1132, 1.3972, 1.3652], dtype=torch.float64) + tensor([2.2..., 2.0..., 1.1..., 1.2...], dtype=torch.float64) """ diff --git a/src/torchmetrics/audio/pesq.py b/src/torchmetrics/audio/pesq.py index bcaf6d08d9c..dfcf623aa25 100644 --- a/src/torchmetrics/audio/pesq.py +++ b/src/torchmetrics/audio/pesq.py @@ -71,17 +71,16 @@ class PerceptualEvaluationSpeechQuality(Metric): If ``mode`` is not either ``"wb"`` or ``"nb"`` Example: - >>> import torch + >>> from torch import randn >>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality - >>> g = torch.manual_seed(1) - >>> preds = torch.randn(8000) - >>> target = torch.randn(8000) + >>> preds = randn(8000) + >>> target = randn(8000) >>> pesq = PerceptualEvaluationSpeechQuality(8000, 'nb') >>> pesq(preds, target) - tensor(2.2076) + tensor(2.2885) >>> wb_pesq = PerceptualEvaluationSpeechQuality(16000, 'wb') >>> wb_pesq(preds, target) - tensor(1.7359) + tensor(1.6805) """ diff --git a/src/torchmetrics/audio/pit.py b/src/torchmetrics/audio/pit.py index 150b8b6d983..2def91d4d01 100644 --- a/src/torchmetrics/audio/pit.py +++ b/src/torchmetrics/audio/pit.py @@ -61,12 +61,11 @@ class PermutationInvariantTraining(Metric): see :ref:`Metric kwargs` for more info. Example: - >>> import torch + >>> from torch import randn >>> from torchmetrics.audio import PermutationInvariantTraining >>> from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio - >>> _ = torch.manual_seed(42) - >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] - >>> target = torch.randn(3, 2, 5) # [batch, spk, time] + >>> preds = randn(3, 2, 5) # [batch, spk, time] + >>> target = randn(3, 2, 5) # [batch, spk, time] >>> pit = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, ... mode="speaker-wise", eval_func="max") >>> pit(preds, target) diff --git a/src/torchmetrics/audio/sdr.py b/src/torchmetrics/audio/sdr.py index a4af4c13c60..9b8646aaa1f 100644 --- a/src/torchmetrics/audio/sdr.py +++ b/src/torchmetrics/audio/sdr.py @@ -70,23 +70,22 @@ class SignalDistortionRatio(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: - >>> import torch + >>> from torch import randn >>> from torchmetrics.audio import SignalDistortionRatio - >>> g = torch.manual_seed(1) - >>> preds = torch.randn(8000) - >>> target = torch.randn(8000) + >>> preds = randn(8000) + >>> target = randn(8000) >>> sdr = SignalDistortionRatio() >>> sdr(preds, target) - tensor(-12.0589) + tensor(-11.9930) >>> # use with pit >>> from torchmetrics.audio import PermutationInvariantTraining >>> from torchmetrics.functional.audio import signal_distortion_ratio - >>> preds = torch.randn(4, 2, 8000) # [batch, spk, time] - >>> target = torch.randn(4, 2, 8000) + >>> preds = randn(4, 2, 8000) # [batch, spk, time] + >>> target = randn(4, 2, 8000) >>> pit = PermutationInvariantTraining(signal_distortion_ratio, ... mode="speaker-wise", eval_func="max") >>> pit(preds, target) - tensor(-11.6051) + tensor(-11.7277) """ @@ -302,23 +301,22 @@ class SourceAggregatedSignalDistortionRatio(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: - >>> import torch + >>> from torch import randn >>> from torchmetrics.audio import SourceAggregatedSignalDistortionRatio - >>> g = torch.manual_seed(1) - >>> preds = torch.randn(2, 8000) # [..., spk, time] - >>> target = torch.randn(2, 8000) + >>> preds = randn(2, 8000) # [..., spk, time] + >>> target = randn(2, 8000) >>> sasdr = SourceAggregatedSignalDistortionRatio() >>> sasdr(preds, target) - tensor(-41.6579) + tensor(-50.8171) >>> # use with pit >>> from torchmetrics.audio import PermutationInvariantTraining >>> from torchmetrics.functional.audio import source_aggregated_signal_distortion_ratio - >>> preds = torch.randn(4, 2, 8000) # [batch, spk, time] - >>> target = torch.randn(4, 2, 8000) + >>> preds = randn(4, 2, 8000) # [batch, spk, time] + >>> target = randn(4, 2, 8000) >>> pit = PermutationInvariantTraining(source_aggregated_signal_distortion_ratio, ... mode="permutation-wise", eval_func="max") >>> pit(preds, target) - tensor(-41.2790) + tensor(-43.9780) """ diff --git a/src/torchmetrics/audio/snr.py b/src/torchmetrics/audio/snr.py index bbbe059c8e4..d8b9fd4c173 100644 --- a/src/torchmetrics/audio/snr.py +++ b/src/torchmetrics/audio/snr.py @@ -268,15 +268,13 @@ class ComplexScaleInvariantSignalNoiseRatio(Metric): If ``preds`` and ``target`` does not have the same shape. Example: - >>> import torch - >>> from torch import tensor + >>> from torch import randn >>> from torchmetrics.audio import ComplexScaleInvariantSignalNoiseRatio - >>> g = torch.manual_seed(1) - >>> preds = torch.randn((1,257,100,2)) - >>> target = torch.randn((1,257,100,2)) + >>> preds = randn((1,257,100,2)) + >>> target = randn((1,257,100,2)) >>> c_si_snr = ComplexScaleInvariantSignalNoiseRatio() >>> c_si_snr(preds, target) - tensor(-63.4849) + tensor(-38.8832) """ diff --git a/src/torchmetrics/audio/srmr.py b/src/torchmetrics/audio/srmr.py index 620e8743fd2..0ced6d6f24f 100644 --- a/src/torchmetrics/audio/srmr.py +++ b/src/torchmetrics/audio/srmr.py @@ -24,11 +24,10 @@ _GAMMATONE_AVAILABLE, _MATPLOTLIB_AVAILABLE, _TORCHAUDIO_AVAILABLE, - _TORCHAUDIO_GREATER_EQUAL_0_10, ) from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE -if not all([_GAMMATONE_AVAILABLE, _TORCHAUDIO_AVAILABLE, _TORCHAUDIO_GREATER_EQUAL_0_10]): +if not all([_GAMMATONE_AVAILABLE, _TORCHAUDIO_AVAILABLE]): __doctest_skip__ = ["SpeechReverberationModulationEnergyRatio", "SpeechReverberationModulationEnergyRatio.plot"] elif not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["SpeechReverberationModulationEnergyRatio.plot"] @@ -76,13 +75,12 @@ class SpeechReverberationModulationEnergyRatio(Metric): If ``gammatone`` or ``torchaudio`` package is not installed Example: - >>> import torch + >>> from torch import randn >>> from torchmetrics.audio import SpeechReverberationModulationEnergyRatio - >>> g = torch.manual_seed(1) - >>> preds = torch.randn(8000) + >>> preds = randn(8000) >>> srmr = SpeechReverberationModulationEnergyRatio(8000) >>> srmr(preds) - tensor(0.3354) + tensor(0.3191) """ @@ -106,7 +104,7 @@ def __init__( **kwargs: Any, ) -> None: super().__init__(**kwargs) - if not _TORCHAUDIO_AVAILABLE or not _TORCHAUDIO_GREATER_EQUAL_0_10 or not _GAMMATONE_AVAILABLE: + if not _TORCHAUDIO_AVAILABLE or not _GAMMATONE_AVAILABLE: raise ModuleNotFoundError( "speech_reverberation_modulation_energy_ratio requires you to have `gammatone` and" " `torchaudio>=0.10` installed. Either install as ``pip install torchmetrics[audio]`` or " diff --git a/src/torchmetrics/audio/stoi.py b/src/torchmetrics/audio/stoi.py index c1f14ac2c7e..253dab3ea38 100644 --- a/src/torchmetrics/audio/stoi.py +++ b/src/torchmetrics/audio/stoi.py @@ -63,14 +63,13 @@ class ShortTimeObjectiveIntelligibility(Metric): If ``pystoi`` package is not installed Example: - >>> import torch + >>> from torch import randn >>> from torchmetrics.audio import ShortTimeObjectiveIntelligibility - >>> g = torch.manual_seed(1) - >>> preds = torch.randn(8000) - >>> target = torch.randn(8000) + >>> preds = randn(8000) + >>> target = randn(8000) >>> stoi = ShortTimeObjectiveIntelligibility(8000, False) >>> stoi(preds, target) - tensor(-0.0100) + tensor(-0.084...) """ @@ -132,11 +131,10 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ :scale: 75 >>> # Example plotting a single value - >>> import torch + >>> from torch import randn >>> from torchmetrics.audio import ShortTimeObjectiveIntelligibility - >>> g = torch.manual_seed(1) - >>> preds = torch.randn(8000) - >>> target = torch.randn(8000) + >>> preds = randn(8000) + >>> target = randn(8000) >>> metric = ShortTimeObjectiveIntelligibility(8000, False) >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot() @@ -145,12 +143,11 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ :scale: 75 >>> # Example plotting multiple values - >>> import torch + >>> from torch import randn >>> from torchmetrics.audio import ShortTimeObjectiveIntelligibility >>> metric = ShortTimeObjectiveIntelligibility(8000, False) - >>> g = torch.manual_seed(1) - >>> preds = torch.randn(8000) - >>> target = torch.randn(8000) + >>> preds = randn(8000) + >>> target = randn(8000) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) diff --git a/src/torchmetrics/classification/group_fairness.py b/src/torchmetrics/classification/group_fairness.py index 4d6735901bb..8e38b24faeb 100644 --- a/src/torchmetrics/classification/group_fairness.py +++ b/src/torchmetrics/classification/group_fairness.py @@ -302,25 +302,23 @@ def plot( .. plot:: :scale: 75 - >>> import torch - >>> _ = torch.manual_seed(42) + >>> from torch import ones, rand, randint >>> # Example plotting a single value >>> from torchmetrics.classification import BinaryFairness >>> metric = BinaryFairness(2) - >>> metric.update(torch.rand(20), torch.randint(2,(20,)), torch.randint(2,(20,))) + >>> metric.update(rand(20), randint(2, (20,)), ones(20).long()) >>> fig_, ax_ = metric.plot() .. plot:: :scale: 75 - >>> import torch - >>> _ = torch.manual_seed(42) + >>> from torch import ones, rand, randint >>> # Example plotting multiple values >>> from torchmetrics.classification import BinaryFairness >>> metric = BinaryFairness(2) >>> values = [ ] >>> for _ in range(10): - ... values.append(metric(torch.rand(20), torch.randint(2,(20,)), torch.ones(20).long())) + ... values.append(metric(rand(20), randint(2, (20,) ), ones(20).long())) >>> fig_, ax_ = metric.plot(values) """ diff --git a/src/torchmetrics/classification/ranking.py b/src/torchmetrics/classification/ranking.py index dd58022197b..4d2df8a9151 100644 --- a/src/torchmetrics/classification/ranking.py +++ b/src/torchmetrics/classification/ranking.py @@ -66,10 +66,10 @@ class MultilabelCoverageError(Metric): Set to ``False`` for faster computations. Example: + >>> from torch import rand, randint >>> from torchmetrics.classification import MultilabelCoverageError - >>> _ = torch.manual_seed(42) - >>> preds = torch.rand(10, 5) - >>> target = torch.randint(2, (10, 5)) + >>> preds = rand(10, 5) + >>> target = randint(2, (10, 5)) >>> mlce = MultilabelCoverageError(num_labels=5) >>> mlce(preds, target) tensor(3.9000) @@ -186,10 +186,10 @@ class MultilabelRankingAveragePrecision(Metric): Set to ``False`` for faster computations. Example: + >>> from torch import rand, randint >>> from torchmetrics.classification import MultilabelRankingAveragePrecision - >>> _ = torch.manual_seed(42) - >>> preds = torch.rand(10, 5) - >>> target = torch.randint(2, (10, 5)) + >>> preds = rand(10, 5) + >>> target = randint(2, (10, 5)) >>> mlrap = MultilabelRankingAveragePrecision(num_labels=5) >>> mlrap(preds, target) tensor(0.7744) @@ -308,10 +308,10 @@ class MultilabelRankingLoss(Metric): Set to ``False`` for faster computations. Example: + >>> from torch import rand, randint >>> from torchmetrics.classification import MultilabelRankingLoss - >>> _ = torch.manual_seed(42) - >>> preds = torch.rand(10, 5) - >>> target = torch.randint(2, (10, 5)) + >>> preds = rand(10, 5) + >>> target = randint(2, (10, 5)) >>> mlrl = MultilabelRankingLoss(num_labels=5) >>> mlrl(preds, target) tensor(0.4167) diff --git a/src/torchmetrics/clustering/calinski_harabasz_score.py b/src/torchmetrics/clustering/calinski_harabasz_score.py index 5d15a33b0e5..483e4332148 100644 --- a/src/torchmetrics/clustering/calinski_harabasz_score.py +++ b/src/torchmetrics/clustering/calinski_harabasz_score.py @@ -54,14 +54,13 @@ class CalinskiHarabaszScore(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example:: - >>> import torch + >>> from torch import randn, randint >>> from torchmetrics.clustering import CalinskiHarabaszScore - >>> _ = torch.manual_seed(42) - >>> data = torch.randn(10, 3) - >>> labels = torch.randint(3, (10,)) + >>> data = randn(20, 3) + >>> labels = randint(3, (20,)) >>> metric = CalinskiHarabaszScore() >>> metric(data, labels) - tensor(3.0053) + tensor(2.2128) """ @@ -109,7 +108,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> import torch >>> from torchmetrics.clustering import CalinskiHarabaszScore >>> metric = CalinskiHarabaszScore() - >>> metric.update(torch.randn(10, 3), torch.randint(0, 2, (10,))) + >>> metric.update(torch.randn(20, 3), torch.randint(3, (20,))) >>> fig_, ax_ = metric.plot(metric.compute()) .. plot:: @@ -121,7 +120,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> metric = CalinskiHarabaszScore() >>> values = [ ] >>> for _ in range(10): - ... values.append(metric(torch.randn(10, 3), torch.randint(0, 2, (10,)))) + ... values.append(metric(torch.randn(20, 3), torch.randint(3, (20,)))) >>> fig_, ax_ = metric.plot(values) """ diff --git a/src/torchmetrics/clustering/davies_bouldin_score.py b/src/torchmetrics/clustering/davies_bouldin_score.py index c856be3ae59..40827b568cb 100644 --- a/src/torchmetrics/clustering/davies_bouldin_score.py +++ b/src/torchmetrics/clustering/davies_bouldin_score.py @@ -64,11 +64,10 @@ class DaviesBouldinScore(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example:: - >>> import torch + >>> from torch import randn, randint >>> from torchmetrics.clustering import DaviesBouldinScore - >>> _ = torch.manual_seed(42) - >>> data = torch.randn(10, 3) - >>> labels = torch.randint(3, (10,)) + >>> data = randn(10, 3) + >>> labels = randint(3, (10,)) >>> metric = DaviesBouldinScore() >>> metric(data, labels) tensor(1.2540) diff --git a/src/torchmetrics/clustering/dunn_index.py b/src/torchmetrics/clustering/dunn_index.py index 89565261f3e..9373db1045e 100644 --- a/src/torchmetrics/clustering/dunn_index.py +++ b/src/torchmetrics/clustering/dunn_index.py @@ -121,7 +121,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> metric = DunnIndex(p=2) >>> values = [ ] >>> for _ in range(10): - ... values.append(metric(torch.randn(10, 3), torch.randint(0, 2, (10,)))) + ... values.append(metric(torch.randn(50, 3), torch.randint(0, 2, (50,)))) >>> fig_, ax_ = metric.plot(values) """ diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index fbcf6a6ac51..0b7f927deb9 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -14,7 +14,7 @@ # this is just a bypass for this module name collision with built-in one from collections import OrderedDict from copy import deepcopy -from typing import Any, Dict, Hashable, Iterable, Iterator, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Hashable, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple, Union import torch from torch import Tensor @@ -31,6 +31,30 @@ __doctest_skip__ = ["MetricCollection.plot", "MetricCollection.plot_all"] +def _remove_prefix(string: str, prefix: str) -> str: + """Patch for older version with missing method `removeprefix`. + + >>> _remove_prefix("prefix_string", "prefix_") + 'string' + >>> _remove_prefix("not_prefix_string", "prefix_") + 'not_prefix_string' + + """ + return string[len(prefix) :] if string.startswith(prefix) else string + + +def _remove_suffix(string: str, suffix: str) -> str: + """Patch for older version with missing method `removesuffix`. + + >>> _remove_suffix("string_suffix", "_suffix") + 'string' + >>> _remove_suffix("string_suffix_missing", "_suffix") + 'string_suffix_missing' + + """ + return string[: -len(suffix)] if string.endswith(suffix) else string + + class MetricCollection(ModuleDict): """MetricCollection class can be used to chain metrics that have the same call pattern into one single class. @@ -343,7 +367,7 @@ def _compute_and_reduce( elif method_name == "forward": res = m(*args, **m._filter_kwargs(**kwargs)) else: - raise ValueError("method_name should be either 'compute' or 'forward', but got {method_name}") + raise ValueError(f"method_name should be either 'compute' or 'forward', but got {method_name}") result[k] = res _, duplicates = _flatten_dict(result) @@ -499,11 +523,12 @@ def _set_name(self, base: str) -> str: name = base if self.prefix is None else self.prefix + base return name if self.postfix is None else name + self.postfix - def _to_renamed_ordered_dict(self) -> OrderedDict: - od = OrderedDict() + def _to_renamed_dict(self) -> Mapping[str, Metric]: + # self._modules changed from OrderedDict to dict as of PyTorch 2.5.0 + dict_modules = OrderedDict() if isinstance(self._modules, OrderedDict) else {} for k, v in self._modules.items(): - od[self._set_name(k)] = v - return od + dict_modules[self._set_name(k)] = v + return dict_modules def __iter__(self) -> Iterator[Hashable]: """Return an iterator over the keys of the MetricDict.""" @@ -519,7 +544,7 @@ def keys(self, keep_base: bool = False) -> Iterable[Hashable]: """ if keep_base: return self._modules.keys() - return self._to_renamed_ordered_dict().keys() + return self._to_renamed_dict().keys() def items(self, keep_base: bool = False, copy_state: bool = True) -> Iterable[Tuple[str, Metric]]: r"""Return an iterable of the ModuleDict key/value pairs. @@ -533,7 +558,7 @@ def items(self, keep_base: bool = False, copy_state: bool = True) -> Iterable[Tu self._compute_groups_create_state_ref(copy_state) if keep_base: return self._modules.items() - return self._to_renamed_ordered_dict().items() + return self._to_renamed_dict().items() def values(self, copy_state: bool = True) -> Iterable[Metric]: """Return an iterable of the ModuleDict values. @@ -557,9 +582,9 @@ def __getitem__(self, key: str, copy_state: bool = True) -> Metric: """ self._compute_groups_create_state_ref(copy_state) if self.prefix: - key = key.removeprefix(self.prefix) + key = _remove_prefix(key, self.prefix) if self.postfix: - key = key.removesuffix(self.postfix) + key = _remove_suffix(key, self.postfix) return self._modules[key] @staticmethod diff --git a/src/torchmetrics/detection/__init__.py b/src/torchmetrics/detection/__init__.py index 5fd60cf4e4d..7932d4b33b8 100644 --- a/src/torchmetrics/detection/__init__.py +++ b/src/torchmetrics/detection/__init__.py @@ -12,22 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.detection.panoptic_qualities import ModifiedPanopticQuality, PanopticQuality -from torchmetrics.utilities.imports import ( - _TORCHVISION_GREATER_EQUAL_0_8, - _TORCHVISION_GREATER_EQUAL_0_13, -) +from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE __all__ = ["ModifiedPanopticQuality", "PanopticQuality"] -if _TORCHVISION_GREATER_EQUAL_0_8: +if _TORCHVISION_AVAILABLE: + from torchmetrics.detection.ciou import CompleteIntersectionOverUnion + from torchmetrics.detection.diou import DistanceIntersectionOverUnion from torchmetrics.detection.giou import GeneralizedIntersectionOverUnion from torchmetrics.detection.iou import IntersectionOverUnion from torchmetrics.detection.mean_ap import MeanAveragePrecision - __all__ += ["MeanAveragePrecision", "GeneralizedIntersectionOverUnion", "IntersectionOverUnion"] - -if _TORCHVISION_GREATER_EQUAL_0_13: - from torchmetrics.detection.ciou import CompleteIntersectionOverUnion - from torchmetrics.detection.diou import DistanceIntersectionOverUnion - - __all__ += ["CompleteIntersectionOverUnion", "DistanceIntersectionOverUnion"] + __all__ += [ + "MeanAveragePrecision", + "GeneralizedIntersectionOverUnion", + "IntersectionOverUnion", + "CompleteIntersectionOverUnion", + "DistanceIntersectionOverUnion", + ] diff --git a/src/torchmetrics/detection/_deprecated.py b/src/torchmetrics/detection/_deprecated.py index 898f341bd62..c162c751554 100644 --- a/src/torchmetrics/detection/_deprecated.py +++ b/src/torchmetrics/detection/_deprecated.py @@ -1,17 +1,8 @@ from typing import Any, Collection from torchmetrics.detection import ModifiedPanopticQuality, PanopticQuality -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12 from torchmetrics.utilities.prints import _deprecated_root_import_class -if not _TORCH_GREATER_EQUAL_1_12: - __doctest_skip__ = [ - "_PanopticQuality", - "_PanopticQuality.*", - "_ModifiedPanopticQuality", - "_ModifiedPanopticQuality.*", - ] - class _ModifiedPanopticQuality(ModifiedPanopticQuality): """Wrapper for deprecated import. diff --git a/src/torchmetrics/detection/_mean_ap.py b/src/torchmetrics/detection/_mean_ap.py index fd342608360..4de1f8fe762 100644 --- a/src/torchmetrics/detection/_mean_ap.py +++ b/src/torchmetrics/detection/_mean_ap.py @@ -22,13 +22,13 @@ from torchmetrics.detection.helpers import _fix_empty_tensors, _input_validator from torchmetrics.metric import Metric from torchmetrics.utilities.data import _cumsum -from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _PYCOCOTOOLS_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8 +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _PYCOCOTOOLS_AVAILABLE, _TORCHVISION_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["MeanAveragePrecision.plot"] -if not _TORCHVISION_GREATER_EQUAL_0_8 or not _PYCOCOTOOLS_AVAILABLE: +if not _TORCHVISION_AVAILABLE or not _PYCOCOTOOLS_AVAILABLE: __doctest_skip__ = ["MeanAveragePrecision.plot", "MeanAveragePrecision"] log = logging.getLogger(__name__) @@ -327,10 +327,10 @@ def __init__( "`MAP` metric requires that `pycocotools` installed." " Please install with `pip install pycocotools` or `pip install torchmetrics[detection]`" ) - if not _TORCHVISION_GREATER_EQUAL_0_8: + if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( - "`MeanAveragePrecision` metric requires that `torchvision` version 0.8.0 or newer is installed." - " Please install with `pip install torchvision>=0.8` or `pip install torchmetrics[detection]`." + "`MeanAveragePrecision` metric requires that `torchvision` is installed." + " Please install with `pip install torchmetrics[detection]`." ) allowed_box_formats = ("xyxy", "xywh", "cxcywh") diff --git a/src/torchmetrics/detection/ciou.py b/src/torchmetrics/detection/ciou.py index 0e5b304c44e..b6174c3b60c 100644 --- a/src/torchmetrics/detection/ciou.py +++ b/src/torchmetrics/detection/ciou.py @@ -17,10 +17,10 @@ from torchmetrics.detection.iou import IntersectionOverUnion from torchmetrics.functional.detection.ciou import _ciou_compute, _ciou_update -from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_13 +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE -if not _TORCHVISION_GREATER_EQUAL_0_13: +if not _TORCHVISION_AVAILABLE: __doctest_skip__ = ["CompleteIntersectionOverUnion", "CompleteIntersectionOverUnion.plot"] elif not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["CompleteIntersectionOverUnion.plot"] @@ -110,10 +110,10 @@ def __init__( respect_labels: bool = True, **kwargs: Any, ) -> None: - if not _TORCHVISION_GREATER_EQUAL_0_13: + if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( - f"Metric `{self._iou_type.upper()}` requires that `torchvision` version 0.13.0 or newer is installed." - " Please install with `pip install torchvision>=0.13` or `pip install torchmetrics[detection]`." + f"Metric `{self._iou_type.upper()}` requires that `torchvision` is installed." + " Please install with `pip install torchmetrics[detection]`." ) super().__init__(box_format, iou_threshold, class_metrics, respect_labels, **kwargs) diff --git a/src/torchmetrics/detection/diou.py b/src/torchmetrics/detection/diou.py index 87fa0933d2b..7eb3780a112 100644 --- a/src/torchmetrics/detection/diou.py +++ b/src/torchmetrics/detection/diou.py @@ -17,10 +17,10 @@ from torchmetrics.detection.iou import IntersectionOverUnion from torchmetrics.functional.detection.diou import _diou_compute, _diou_update -from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_13 +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE -if not _TORCHVISION_GREATER_EQUAL_0_13: +if not _TORCHVISION_AVAILABLE: __doctest_skip__ = ["DistanceIntersectionOverUnion", "DistanceIntersectionOverUnion.plot"] elif not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["DistanceIntersectionOverUnion.plot"] @@ -110,10 +110,10 @@ def __init__( respect_labels: bool = True, **kwargs: Any, ) -> None: - if not _TORCHVISION_GREATER_EQUAL_0_13: + if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( - f"Metric `{self._iou_type.upper()}` requires that `torchvision` version 0.13.0 or newer is installed." - " Please install with `pip install torchvision>=0.13` or `pip install torchmetrics[detection]`." + f"Metric `{self._iou_type.upper()}` requires that `torchvision` is installed." + " Please install with `pip install torchmetrics[detection]`." ) super().__init__(box_format, iou_threshold, class_metrics, respect_labels, **kwargs) diff --git a/src/torchmetrics/detection/giou.py b/src/torchmetrics/detection/giou.py index 0d4156561c4..d024adad817 100644 --- a/src/torchmetrics/detection/giou.py +++ b/src/torchmetrics/detection/giou.py @@ -17,10 +17,10 @@ from torchmetrics.detection.iou import IntersectionOverUnion from torchmetrics.functional.detection.giou import _giou_compute, _giou_update -from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8 +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE -if not _TORCHVISION_GREATER_EQUAL_0_8: +if not _TORCHVISION_AVAILABLE: __doctest_skip__ = ["GeneralizedIntersectionOverUnion", "GeneralizedIntersectionOverUnion.plot"] elif not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["GeneralizedIntersectionOverUnion.plot"] diff --git a/src/torchmetrics/detection/iou.py b/src/torchmetrics/detection/iou.py index 7b4a60200ca..ca4178d35ee 100644 --- a/src/torchmetrics/detection/iou.py +++ b/src/torchmetrics/detection/iou.py @@ -20,10 +20,10 @@ from torchmetrics.functional.detection.iou import _iou_compute, _iou_update from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat -from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8 +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE -if not _TORCHVISION_GREATER_EQUAL_0_8: +if not _TORCHVISION_AVAILABLE: __doctest_skip__ = ["IntersectionOverUnion", "IntersectionOverUnion.plot"] elif not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["IntersectionOverUnion.plot"] @@ -146,10 +146,10 @@ def __init__( ) -> None: super().__init__(**kwargs) - if not _TORCHVISION_GREATER_EQUAL_0_8: + if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( - f"Metric `{self._iou_type.upper()}` requires that `torchvision` version 0.8.0 or newer is installed." - " Please install with `pip install torchvision>=0.8` or `pip install torchmetrics[detection]`." + f"Metric `{self._iou_type.upper()}` requires that `torchvision` is installed." + " Please install with `pip install torchmetrics[detection]`." ) allowed_box_formats = ("xyxy", "xywh", "cxcywh") @@ -211,7 +211,8 @@ def compute(self) -> dict: """Computes IoU based on inputs passed in to ``update`` previously.""" score = torch.cat([mat[mat != self._invalid_val] for mat in self.iou_matrix], 0).mean() results: Dict[str, Tensor] = {f"{self._iou_type}": score} - + if torch.isnan(score): # if no valid boxes are found + results[f"{self._iou_type}"] = torch.tensor(0.0, device=score.device) if self.class_metrics: gt_labels = dim_zero_cat(self.groundtruth_labels) classes = gt_labels.unique().tolist() if len(gt_labels) > 0 else [] diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 20c99f094f9..f7aa6eb0276 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -31,14 +31,14 @@ _FASTER_COCO_EVAL_AVAILABLE, _MATPLOTLIB_AVAILABLE, _PYCOCOTOOLS_AVAILABLE, - _TORCHVISION_GREATER_EQUAL_0_8, + _TORCHVISION_AVAILABLE, ) from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["MeanAveragePrecision.plot"] -if not _TORCHVISION_GREATER_EQUAL_0_8 or not (_PYCOCOTOOLS_AVAILABLE or _FASTER_COCO_EVAL_AVAILABLE): +if not (_PYCOCOTOOLS_AVAILABLE or _FASTER_COCO_EVAL_AVAILABLE): __doctest_skip__ = [ "MeanAveragePrecision.plot", "MeanAveragePrecision", @@ -124,19 +124,23 @@ class MeanAveragePrecision(Metric): - ``map_dict``: A dictionary containing the following key-values: - - map: (:class:`~torch.Tensor`), global mean average precision - - map_small: (:class:`~torch.Tensor`), mean average precision for small objects - - map_medium:(:class:`~torch.Tensor`), mean average precision for medium objects - - map_large: (:class:`~torch.Tensor`), mean average precision for large objects + - map: (:class:`~torch.Tensor`), global mean average precision which by default is defined as mAP50-95 e.g. the + mean average precision for IoU thresholds 0.50, 0.55, 0.60, ..., 0.95 averaged over all classes and areas. If + the IoU thresholds are changed this value will be calculated with the new thresholds. + - map_small: (:class:`~torch.Tensor`), mean average precision for small objects (area < 32^2 pixels) + - map_medium:(:class:`~torch.Tensor`), mean average precision for medium objects (32^2 pixels < area < 96^2 + pixels) + - map_large: (:class:`~torch.Tensor`), mean average precision for large objects (area > 96^2 pixels) - mar_{mdt[0]}: (:class:`~torch.Tensor`), mean average recall for `max_detection_thresholds[0]` (default 1) detection per image - mar_{mdt[1]}: (:class:`~torch.Tensor`), mean average recall for `max_detection_thresholds[1]` (default 10) detection per image - mar_{mdt[1]}: (:class:`~torch.Tensor`), mean average recall for `max_detection_thresholds[2]` (default 100) detection per image - - mar_small: (:class:`~torch.Tensor`), mean average recall for small objects - - mar_medium: (:class:`~torch.Tensor`), mean average recall for medium objects - - mar_large: (:class:`~torch.Tensor`), mean average recall for large objects + - mar_small: (:class:`~torch.Tensor`), mean average recall for small objects (area < 32^2 pixels) + - mar_medium: (:class:`~torch.Tensor`), mean average recall for medium objects (32^2 pixels < area < 96^2 + pixels) + - mar_large: (:class:`~torch.Tensor`), mean average recall for large objects (area > 96^2 pixels) - map_50: (:class:`~torch.Tensor`) (-1 if 0.5 not in the list of iou thresholds), mean average precision at IoU=0.50 - map_75: (:class:`~torch.Tensor`) (-1 if 0.75 not in the list of iou thresholds), mean average precision at @@ -150,8 +154,11 @@ class MeanAveragePrecision(Metric): For an example on how to use this metric check the `torchmetrics mAP example`_. .. note:: - ``map`` score is calculated with @[ IoU=self.iou_thresholds | area=all | max_dets=max_detection_thresholds ]. - Caution: If the initialization parameters are changed, dictionary keys for mAR can change as well. + ``map`` score is calculated with @[ IoU=self.iou_thresholds | area=all | max_dets=max_detection_thresholds ] + e.g. the mean average precision for IoU thresholds 0.50, 0.55, 0.60, ..., 0.95 averaged over all classes and + all areas and all max detections per image. If the IoU thresholds are changed this value will be calculated with + the new thresholds. Caution: If the initialization parameters are changed, dictionary keys for mAR can change as + well. .. note:: This metric supports, at the moment, two different backends for the evaluation. The default backend is @@ -383,10 +390,10 @@ def __init__( " Please install with `pip install pycocotools` or `pip install faster-coco-eval` or" " `pip install torchmetrics[detection]`." ) - if not _TORCHVISION_GREATER_EQUAL_0_8: + if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( - "`MeanAveragePrecision` metric requires that `torchvision` version 0.8.0 or newer is installed." - " Please install with `pip install torchvision>=0.8` or `pip install torchmetrics[detection]`." + f"Metric `{self._iou_type.upper()}` requires that `torchvision` is installed." + " Please install with `pip install torchmetrics[detection]`." ) allowed_box_formats = ("xyxy", "xywh", "cxcywh") @@ -416,7 +423,7 @@ def __init__( if max_detection_thresholds is not None and len(max_detection_thresholds) != 3: raise ValueError( "When providing a list of max detection thresholds it should have length 3." - " Got value {len(max_detection_thresholds)}" + f" Got value {len(max_detection_thresholds)}" ) max_det_threshold, _ = torch.sort(torch.tensor(max_detection_thresholds or [1, 10, 100], dtype=torch.int)) self.max_detection_thresholds = max_det_threshold.tolist() @@ -524,65 +531,67 @@ def compute(self) -> dict: for anno in coco_preds.dataset["annotations"]: anno["area"] = anno[f"area_{i_type}"] - coco_eval = self.cocoeval(coco_target, coco_preds, iouType=i_type) # type: ignore[operator] - coco_eval.params.iouThrs = np.array(self.iou_thresholds, dtype=np.float64) - coco_eval.params.recThrs = np.array(self.rec_thresholds, dtype=np.float64) - coco_eval.params.maxDets = self.max_detection_thresholds - - coco_eval.evaluate() - coco_eval.accumulate() - coco_eval.summarize() - stats = coco_eval.stats - result_dict.update(self._coco_stats_to_tensor_dict(stats, prefix=prefix)) - - summary = {} - if self.extended_summary: - summary = { - f"{prefix}ious": apply_to_collection( - coco_eval.ious, np.ndarray, lambda x: torch.tensor(x, dtype=torch.float32) - ), - f"{prefix}precision": torch.tensor(coco_eval.eval["precision"]), - f"{prefix}recall": torch.tensor(coco_eval.eval["recall"]), - f"{prefix}scores": torch.tensor(coco_eval.eval["scores"]), - } - result_dict.update(summary) - - # if class mode is enabled, evaluate metrics per class - if self.class_metrics: - if self.average == "micro": - # since micro averaging have all the data in one class, we need to reinitialize the coco_eval - # object in macro mode to get the per class stats + if len(coco_preds.imgs) == 0 or len(coco_target.imgs) == 0: + result_dict.update(self._coco_stats_to_tensor_dict(12 * [-1.0], prefix=prefix)) + else: + coco_eval = self.cocoeval(coco_target, coco_preds, iouType=i_type) # type: ignore[operator] + coco_eval.params.iouThrs = np.array(self.iou_thresholds, dtype=np.float64) + coco_eval.params.recThrs = np.array(self.rec_thresholds, dtype=np.float64) + coco_eval.params.maxDets = self.max_detection_thresholds + + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + stats = coco_eval.stats + result_dict.update(self._coco_stats_to_tensor_dict(stats, prefix=prefix)) + + summary = {} + if self.extended_summary: + summary = { + f"{prefix}ious": apply_to_collection( + coco_eval.ious, np.ndarray, lambda x: torch.tensor(x, dtype=torch.float32) + ), + f"{prefix}precision": torch.tensor(coco_eval.eval["precision"]), + f"{prefix}recall": torch.tensor(coco_eval.eval["recall"]), + f"{prefix}scores": torch.tensor(coco_eval.eval["scores"]), + } + result_dict.update(summary) + + # if class mode is enabled, evaluate metrics per class + if self.class_metrics: + # regardless of average method, reinitialize dataset to get rid of internal state which can + # lead to wrong results when evaluating per class coco_preds, coco_target = self._get_coco_datasets(average="macro") coco_eval = self.cocoeval(coco_target, coco_preds, iouType=i_type) # type: ignore[operator] coco_eval.params.iouThrs = np.array(self.iou_thresholds, dtype=np.float64) coco_eval.params.recThrs = np.array(self.rec_thresholds, dtype=np.float64) coco_eval.params.maxDets = self.max_detection_thresholds - map_per_class_list = [] - mar_per_class_list = [] - for class_id in self._get_classes(): - coco_eval.params.catIds = [class_id] - with contextlib.redirect_stdout(io.StringIO()): - coco_eval.evaluate() - coco_eval.accumulate() - coco_eval.summarize() - class_stats = coco_eval.stats - - map_per_class_list.append(torch.tensor([class_stats[0]])) - mar_per_class_list.append(torch.tensor([class_stats[8]])) - - map_per_class_values = torch.tensor(map_per_class_list, dtype=torch.float32) - mar_per_class_values = torch.tensor(mar_per_class_list, dtype=torch.float32) - else: - map_per_class_values = torch.tensor([-1], dtype=torch.float32) - mar_per_class_values = torch.tensor([-1], dtype=torch.float32) - prefix = "" if len(self.iou_type) == 1 else f"{i_type}_" - result_dict.update( - { - f"{prefix}map_per_class": map_per_class_values, - f"{prefix}mar_{self.max_detection_thresholds[-1]}_per_class": mar_per_class_values, - }, - ) + map_per_class_list = [] + mar_per_class_list = [] + for class_id in self._get_classes(): + coco_eval.params.catIds = [class_id] + with contextlib.redirect_stdout(io.StringIO()): + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + class_stats = coco_eval.stats + + map_per_class_list.append(torch.tensor([class_stats[0]])) + mar_per_class_list.append(torch.tensor([class_stats[8]])) + + map_per_class_values = torch.tensor(map_per_class_list, dtype=torch.float32) + mar_per_class_values = torch.tensor(mar_per_class_list, dtype=torch.float32) + else: + map_per_class_values = torch.tensor([-1], dtype=torch.float32) + mar_per_class_values = torch.tensor([-1], dtype=torch.float32) + prefix = "" if len(self.iou_type) == 1 else f"{i_type}_" + result_dict.update( + { + f"{prefix}map_per_class": map_per_class_values, + f"{prefix}mar_{self.max_detection_thresholds[-1]}_per_class": mar_per_class_values, + }, + ) result_dict.update({"classes": torch.tensor(self._get_classes(), dtype=torch.int32)}) return result_dict diff --git a/src/torchmetrics/detection/panoptic_qualities.py b/src/torchmetrics/detection/panoptic_qualities.py index 50b7c4c5594..b4629be8e69 100644 --- a/src/torchmetrics/detection/panoptic_qualities.py +++ b/src/torchmetrics/detection/panoptic_qualities.py @@ -26,17 +26,13 @@ _validate_inputs, ) from torchmetrics.metric import Metric -from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCH_GREATER_EQUAL_1_12 +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["PanopticQuality.plot", "ModifiedPanopticQuality.plot"] -if not _TORCH_GREATER_EQUAL_1_12: - __doctest_skip__ = ["PanopticQuality", "PanopticQuality.*", "ModifiedPanopticQuality", "ModifiedPanopticQuality.*"] - - class PanopticQuality(Metric): r"""Compute the `Panoptic Quality`_ for panoptic segmentations. @@ -166,9 +162,6 @@ def __init__( **kwargs: Any, ) -> None: super().__init__(**kwargs) - if not _TORCH_GREATER_EQUAL_1_12: - raise RuntimeError("Panoptic Quality metric requires PyTorch 1.12 or later") - things, stuffs = _parse_categories(things, stuffs) self.things = things self.stuffs = stuffs diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index fead74b6a46..f533900db74 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -101,6 +101,7 @@ mean_squared_error, mean_squared_log_error, minkowski_distance, + normalized_root_mean_squared_error, pearson_corrcoef, r2_score, relative_squared_error, @@ -147,14 +148,13 @@ "calibration_error", "char_error_rate", "chrf_score", - "concordance_corrcoef", "cohen_kappa", + "concordance_corrcoef", "confusion_matrix", "cosine_similarity", "cramers_v", "cramers_v_matrix", "critical_success_index", - "tweedie_deviance_score", "dice", "error_relative_global_dimensionless_synthesis", "exact_match", @@ -179,12 +179,14 @@ "mean_squared_log_error", "minkowski_distance", "multiscale_structural_similarity_index_measure", + "normalized_root_mean_squared_error", "pairwise_cosine_similarity", "pairwise_euclidean_distance", "pairwise_linear_similarity", "pairwise_manhattan_distance", "pairwise_minkowski_distance", "panoptic_quality", + "peak_signal_noise_ratio", "pearson_corrcoef", "pearsons_contingency_coefficient", "pearsons_contingency_coefficient_matrix", @@ -192,10 +194,11 @@ "perplexity", "pit_permutate", "precision", + "precision_at_fixed_recall", "precision_recall_curve", - "peak_signal_noise_ratio", "r2_score", "recall", + "recall_at_fixed_precision", "relative_average_spectral_error", "relative_squared_error", "retrieval_average_precision", @@ -203,24 +206,27 @@ "retrieval_hit_rate", "retrieval_normalized_dcg", "retrieval_precision", + "retrieval_precision_recall_curve", "retrieval_r_precision", "retrieval_recall", "retrieval_reciprocal_rank", - "retrieval_precision_recall_curve", "roc", "root_mean_squared_error_using_sliding_window", "rouge_score", "sacre_bleu_score", - "signal_distortion_ratio", "scale_invariant_signal_distortion_ratio", "scale_invariant_signal_noise_ratio", + "sensitivity_at_specificity", + "signal_distortion_ratio", "signal_noise_ratio", "spearman_corrcoef", "specificity", + "specificity_at_sensitivity", + "spectral_angle_mapper", "spectral_distortion_index", "squad", - "structural_similarity_index_measure", "stat_scores", + "structural_similarity_index_measure", "symmetric_mean_absolute_percentage_error", "theils_u", "theils_u_matrix", @@ -228,14 +234,10 @@ "translation_edit_rate", "tschuprows_t", "tschuprows_t_matrix", + "tweedie_deviance_score", "universal_image_quality_index", - "spectral_angle_mapper", "weighted_mean_absolute_percentage_error", "word_error_rate", "word_information_lost", "word_information_preserved", - "precision_at_fixed_recall", - "recall_at_fixed_precision", - "sensitivity_at_specificity", - "specificity_at_sensitivity", ] diff --git a/src/torchmetrics/functional/audio/__init__.py b/src/torchmetrics/functional/audio/__init__.py index c8a8b5a4bcc..ac228cf671a 100644 --- a/src/torchmetrics/functional/audio/__init__.py +++ b/src/torchmetrics/functional/audio/__init__.py @@ -28,10 +28,17 @@ _ONNXRUNTIME_AVAILABLE, _PESQ_AVAILABLE, _PYSTOI_AVAILABLE, + _SCIPI_AVAILABLE, _TORCHAUDIO_AVAILABLE, - _TORCHAUDIO_GREATER_EQUAL_0_10, ) +if _SCIPI_AVAILABLE: + import scipy.signal + + # back compatibility patch due to SMRMpy using scipy.signal.hamming + if not hasattr(scipy.signal, "hamming"): + scipy.signal.hamming = scipy.signal.windows.hamming + __all__ = [ "permutation_invariant_training", "pit_permutate", @@ -53,7 +60,7 @@ __all__ += ["short_time_objective_intelligibility"] -if _GAMMATONE_AVAILABLE and _TORCHAUDIO_AVAILABLE and _TORCHAUDIO_GREATER_EQUAL_0_10: +if _GAMMATONE_AVAILABLE and _TORCHAUDIO_AVAILABLE: from torchmetrics.functional.audio.srmr import speech_reverberation_modulation_energy_ratio __all__ += ["speech_reverberation_modulation_energy_ratio"] diff --git a/src/torchmetrics/functional/audio/_deprecated.py b/src/torchmetrics/functional/audio/_deprecated.py index ebeff731f9d..8b337318f7a 100644 --- a/src/torchmetrics/functional/audio/_deprecated.py +++ b/src/torchmetrics/functional/audio/_deprecated.py @@ -69,21 +69,20 @@ def _signal_distortion_ratio( ) -> Tensor: """Wrapper for deprecated import. - >>> import torch - >>> g = torch.manual_seed(1) - >>> preds = torch.randn(8000) - >>> target = torch.randn(8000) + >>> from torch import randn + >>> preds = randn(8000) + >>> target = randn(8000) >>> _signal_distortion_ratio(preds, target) - tensor(-12.0589) + tensor(-11.9930) >>> # use with permutation_invariant_training - >>> preds = torch.randn(4, 2, 8000) # [batch, spk, time] - >>> target = torch.randn(4, 2, 8000) + >>> preds = randn(4, 2, 8000) # [batch, spk, time] + >>> target = randn(4, 2, 8000) >>> best_metric, best_perm = _permutation_invariant_training(preds, target, _signal_distortion_ratio) >>> best_metric - tensor([-11.6375, -11.4358, -11.7148, -11.6325]) + tensor([-11.7748, -11.7948, -11.7160, -11.6254]) >>> best_perm tensor([[1, 0], - [0, 1], + [1, 0], [1, 0], [0, 1]]) diff --git a/src/torchmetrics/functional/audio/dnsmos.py b/src/torchmetrics/functional/audio/dnsmos.py index ce29b8538c2..9b0dca883db 100644 --- a/src/torchmetrics/functional/audio/dnsmos.py +++ b/src/torchmetrics/functional/audio/dnsmos.py @@ -216,10 +216,9 @@ def deep_noise_suppression_mean_opinion_score( Example: >>> from torch import randn >>> from torchmetrics.functional.audio.dnsmos import deep_noise_suppression_mean_opinion_score - >>> g = torch.manual_seed(1) >>> preds = randn(8000) >>> deep_noise_suppression_mean_opinion_score(preds, 8000, False) - tensor([2.2285, 2.1132, 1.3972, 1.3652], dtype=torch.float64) + tensor([2.2..., 2.0..., 1.1..., 1.2...], dtype=torch.float64) """ if not _LIBROSA_AVAILABLE or not _ONNXRUNTIME_AVAILABLE or not _REQUESTS_AVAILABLE: diff --git a/src/torchmetrics/functional/audio/pesq.py b/src/torchmetrics/functional/audio/pesq.py index c865ac80d26..516014434bf 100644 --- a/src/torchmetrics/functional/audio/pesq.py +++ b/src/torchmetrics/functional/audio/pesq.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + import numpy as np import torch from torch import Tensor @@ -68,13 +70,12 @@ def perceptual_evaluation_speech_quality( Example: >>> from torch import randn >>> from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality - >>> g = torch.manual_seed(1) >>> preds = randn(8000) >>> target = randn(8000) >>> perceptual_evaluation_speech_quality(preds, target, 8000, 'nb') - tensor(2.2076) + tensor(2.2885) >>> perceptual_evaluation_speech_quality(preds, target, 16000, 'wb') - tensor(1.7359) + tensor(1.6805) """ if not _PESQ_AVAILABLE: @@ -84,6 +85,11 @@ def perceptual_evaluation_speech_quality( ) import pesq as pesq_backend + def _issubtype_number(x: Any) -> bool: + return np.issubdtype(type(x), np.number) + + _filter_error_msg = np.vectorize(_issubtype_number) + if fs not in (8000, 16000): raise ValueError(f"Expected argument `fs` to either be 8000 or 16000 but got {fs}") if mode not in ("wb", "nb"): @@ -104,8 +110,8 @@ def perceptual_evaluation_speech_quality( pesq_val_np = np.empty(shape=(preds_np.shape[0])) for b in range(preds_np.shape[0]): pesq_val_np[b] = pesq_backend.pesq(fs, target_np[b, :], preds_np[b, :], mode) - pesq_val = torch.from_numpy(pesq_val_np) - pesq_val = pesq_val.reshape(preds.shape[:-1]) + pesq_val = torch.from_numpy(pesq_val_np[_filter_error_msg(pesq_val_np)].astype(np.float32)) + pesq_val = pesq_val.reshape(len(pesq_val)) if keep_same_device: return pesq_val.to(preds.device) diff --git a/src/torchmetrics/functional/audio/sdr.py b/src/torchmetrics/functional/audio/sdr.py index d549a9c96ed..f6cb0bf2a2e 100644 --- a/src/torchmetrics/functional/audio/sdr.py +++ b/src/torchmetrics/functional/audio/sdr.py @@ -124,23 +124,22 @@ def signal_distortion_ratio( If ``preds`` and ``target`` does not have the same shape Example: - >>> import torch + >>> from torch import randn >>> from torchmetrics.functional.audio import signal_distortion_ratio - >>> g = torch.manual_seed(1) - >>> preds = torch.randn(8000) - >>> target = torch.randn(8000) + >>> preds = randn(8000) + >>> target = randn(8000) >>> signal_distortion_ratio(preds, target) - tensor(-12.0589) + tensor(-11.9930) >>> # use with permutation_invariant_training >>> from torchmetrics.functional.audio import permutation_invariant_training - >>> preds = torch.randn(4, 2, 8000) # [batch, spk, time] - >>> target = torch.randn(4, 2, 8000) + >>> preds = randn(4, 2, 8000) # [batch, spk, time] + >>> target = randn(4, 2, 8000) >>> best_metric, best_perm = permutation_invariant_training(preds, target, signal_distortion_ratio) >>> best_metric - tensor([-11.6375, -11.4358, -11.7148, -11.6325]) + tensor([-11.7748, -11.7948, -11.7160, -11.6254]) >>> best_perm tensor([[1, 0], - [0, 1], + [1, 0], [1, 0], [0, 1]]) @@ -260,23 +259,22 @@ def source_aggregated_signal_distortion_ratio( SA-SDR with shape ``(...)`` Example: - >>> import torch + >>> from torch import randn >>> from torchmetrics.functional.audio import source_aggregated_signal_distortion_ratio - >>> g = torch.manual_seed(1) - >>> preds = torch.randn(2, 8000) # [..., spk, time] - >>> target = torch.randn(2, 8000) + >>> preds = randn(2, 8000) # [..., spk, time] + >>> target = randn(2, 8000) >>> source_aggregated_signal_distortion_ratio(preds, target) - tensor(-41.6579) + tensor(-50.8171) >>> # use with permutation_invariant_training >>> from torchmetrics.functional.audio import permutation_invariant_training - >>> preds = torch.randn(4, 2, 8000) # [batch, spk, time] - >>> target = torch.randn(4, 2, 8000) + >>> preds = randn(4, 2, 8000) # [batch, spk, time] + >>> target = randn(4, 2, 8000) >>> best_metric, best_perm = permutation_invariant_training(preds, target, ... source_aggregated_signal_distortion_ratio, mode="permutation-wise") >>> best_metric - tensor([-37.9511, -41.9124, -42.7369, -42.5155]) + tensor([-42.6290, -44.3500, -34.7503, -54.1828]) >>> best_perm - tensor([[1, 0], + tensor([[0, 1], [1, 0], [0, 1], [1, 0]]) diff --git a/src/torchmetrics/functional/audio/snr.py b/src/torchmetrics/functional/audio/snr.py index fb0a4946722..0adc53dff30 100644 --- a/src/torchmetrics/functional/audio/snr.py +++ b/src/torchmetrics/functional/audio/snr.py @@ -106,13 +106,12 @@ def complex_scale_invariant_signal_noise_ratio(preds: Tensor, target: Tensor, ze If ``preds`` and ``target`` does not have the same shape. Example: - >>> import torch + >>> from torch import randn >>> from torchmetrics.functional.audio import complex_scale_invariant_signal_noise_ratio - >>> g = torch.manual_seed(1) - >>> preds = torch.randn((1,257,100,2)) - >>> target = torch.randn((1,257,100,2)) + >>> preds = randn((1,257,100,2)) + >>> target = randn((1,257,100,2)) >>> complex_scale_invariant_signal_noise_ratio(preds, target) - tensor([-63.4849]) + tensor([-38.8832]) """ if preds.is_complex(): @@ -123,7 +122,7 @@ def complex_scale_invariant_signal_noise_ratio(preds: Tensor, target: Tensor, ze if (preds.ndim < 3 or preds.shape[-1] != 2) or (target.ndim < 3 or target.shape[-1] != 2): raise RuntimeError( "Predictions and targets are expected to have the shape (..., frequency, time, 2)," - " but got {preds.shape} and {target.shape}." + f" but got {preds.shape} and {target.shape}." ) preds = preds.reshape(*preds.shape[:-3], -1) diff --git a/src/torchmetrics/functional/audio/srmr.py b/src/torchmetrics/functional/audio/srmr.py index 03f4f764d8c..d098366df6b 100644 --- a/src/torchmetrics/functional/audio/srmr.py +++ b/src/torchmetrics/functional/audio/srmr.py @@ -27,10 +27,9 @@ from torchmetrics.utilities.imports import ( _GAMMATONE_AVAILABLE, _TORCHAUDIO_AVAILABLE, - _TORCHAUDIO_GREATER_EQUAL_0_10, ) -if not _TORCHAUDIO_AVAILABLE or not _TORCHAUDIO_GREATER_EQUAL_0_10 or not _GAMMATONE_AVAILABLE: +if not _TORCHAUDIO_AVAILABLE or not _GAMMATONE_AVAILABLE: __doctest_skip__ = ["speech_reverberation_modulation_energy_ratio"] @@ -221,15 +220,14 @@ def speech_reverberation_modulation_energy_ratio( If ``gammatone`` or ``torchaudio`` package is not installed Example: - >>> import torch + >>> from torch import randn >>> from torchmetrics.functional.audio import speech_reverberation_modulation_energy_ratio - >>> g = torch.manual_seed(1) - >>> preds = torch.randn(8000) + >>> preds = randn(8000) >>> speech_reverberation_modulation_energy_ratio(preds, 8000) - tensor([0.3354], dtype=torch.float64) + tensor([0.3191], dtype=torch.float64) """ - if not _TORCHAUDIO_AVAILABLE or not _TORCHAUDIO_GREATER_EQUAL_0_10 or not _GAMMATONE_AVAILABLE: + if not _TORCHAUDIO_AVAILABLE or not _GAMMATONE_AVAILABLE: raise ModuleNotFoundError( "speech_reverberation_modulation_energy_ratio requires you to have `gammatone` and" " `torchaudio>=0.10` installed. Either install as ``pip install torchmetrics[audio]`` or " diff --git a/src/torchmetrics/functional/audio/stoi.py b/src/torchmetrics/functional/audio/stoi.py index 79736956c9a..48e9e78510b 100644 --- a/src/torchmetrics/functional/audio/stoi.py +++ b/src/torchmetrics/functional/audio/stoi.py @@ -59,13 +59,12 @@ def short_time_objective_intelligibility( If ``preds`` and ``target`` does not have the same shape Example: - >>> import torch + >>> from torch import randn >>> from torchmetrics.functional.audio.stoi import short_time_objective_intelligibility - >>> g = torch.manual_seed(1) - >>> preds = torch.randn(8000) - >>> target = torch.randn(8000) + >>> preds = randn(8000) + >>> target = randn(8000) >>> short_time_objective_intelligibility(preds, target, 8000).float() - tensor(-0.0100) + tensor(-0.084...) """ if not _PYSTOI_AVAILABLE: diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index b668c152e5b..1b93450ab69 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -94,7 +94,7 @@ def _binary_confusion_matrix_tensor_validation( _check_same_shape(preds, target) # Check that target only contains {0,1} values or value in ignore_index - unique_values = torch.unique(target) + unique_values = torch.unique(target, dim=None) if ignore_index is None: check = torch.any((unique_values != 0) & (unique_values != 1)) else: @@ -107,7 +107,7 @@ def _binary_confusion_matrix_tensor_validation( # If preds is label tensor, also check that it only contains {0,1} values if not preds.is_floating_point(): - unique_values = torch.unique(preds) + unique_values = torch.unique(preds, dim=None) if torch.any((unique_values != 0) & (unique_values != 1)): raise RuntimeError( f"Detected the following values in `preds`: {unique_values} but expected only" @@ -287,7 +287,7 @@ def _multiclass_confusion_matrix_tensor_validation( check_value = num_classes if ignore_index is None else num_classes + 1 for t, name in ((target, "target"),) + ((preds, "preds"),) if not preds.is_floating_point() else (): # noqa: RUF005 - num_unique_values = len(torch.unique(t)) + num_unique_values = len(torch.unique(t, dim=None)) if num_unique_values > check_value: raise RuntimeError( f"Detected more unique values in `{name}` than expected. Expected only {check_value} but found" @@ -454,7 +454,7 @@ def _multilabel_confusion_matrix_tensor_validation( ) # Check that target only contains [0,1] values or value in ignore_index - unique_values = torch.unique(target) + unique_values = torch.unique(target, dim=None) if ignore_index is None: check = torch.any((unique_values != 0) & (unique_values != 1)) else: @@ -467,7 +467,7 @@ def _multilabel_confusion_matrix_tensor_validation( # If preds is label tensor, also check that it only contains [0,1] values if not preds.is_floating_point(): - unique_values = torch.unique(preds) + unique_values = torch.unique(preds, dim=None) if torch.any((unique_values != 0) & (unique_values != 1)): raise RuntimeError( f"Detected the following values in `preds`: {unique_values} but expected only" diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index 1d240df68af..dfddd68255f 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -67,7 +67,7 @@ def _jaccard_index_reduce( raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") confmat = confmat.float() if average == "binary": - return confmat[1, 1] / (confmat[0, 1] + confmat[1, 0] + confmat[1, 1]) + return _safe_divide(confmat[1, 1], (confmat[0, 1] + confmat[1, 0] + confmat[1, 1]), zero_division=zero_division) ignore_index_cond = ignore_index is not None and 0 <= ignore_index < confmat.shape[0] multilabel = confmat.ndim == 3 diff --git a/src/torchmetrics/functional/classification/matthews_corrcoef.py b/src/torchmetrics/functional/classification/matthews_corrcoef.py index 544414ee4a8..45e0238dae5 100644 --- a/src/torchmetrics/functional/classification/matthews_corrcoef.py +++ b/src/torchmetrics/functional/classification/matthews_corrcoef.py @@ -64,12 +64,14 @@ def _matthews_corrcoef_reduce(confmat: Tensor) -> Tensor: denom = cov_ypyp * cov_ytyt if denom == 0 and confmat.numel() == 4: - if tp == 0 or tn == 0: - a = tp + tn - - if fp == 0 or fn == 0: - b = fp + fn - + if fn == 0 and tn == 0: + a, b = tp, fp + elif fp == 0 and tn == 0: + a, b = tp, fn + elif tp == 0 and fn == 0: + a, b = tn, fp + elif tp == 0 and fp == 0: + a, b = tn, fn eps = torch.tensor(torch.finfo(torch.float32).eps, dtype=torch.float32, device=confmat.device) numerator = torch.sqrt(eps) * (a - b) denom = (tp + fp + eps) * (tp + fn + eps) * (tn + fp + eps) * (tn + fn + eps) diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index da12db561a1..c4607fd9489 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -148,7 +148,7 @@ def _binary_precision_recall_curve_tensor_validation( ) # Check that target only contains {0,1} values or value in ignore_index - unique_values = torch.unique(target) + unique_values = torch.unique(target, dim=None) if ignore_index is None: check = torch.any((unique_values != 0) & (unique_values != 1)) else: @@ -417,7 +417,7 @@ def _multiclass_precision_recall_curve_tensor_validation( f" but got {preds.shape} and {target.shape}" ) - num_unique_values = len(torch.unique(target)) + num_unique_values = len(torch.unique(target, dim=None)) check = num_unique_values > num_classes if ignore_index is None else num_unique_values > num_classes + 1 if check: raise RuntimeError( diff --git a/src/torchmetrics/functional/classification/ranking.py b/src/torchmetrics/functional/classification/ranking.py index d6130979cd2..87bade5e88d 100644 --- a/src/torchmetrics/functional/classification/ranking.py +++ b/src/torchmetrics/functional/classification/ranking.py @@ -87,10 +87,10 @@ def multilabel_coverage_error( Set to ``False`` for faster computations. Example: + >>> from torch import rand, randint >>> from torchmetrics.functional.classification import multilabel_coverage_error - >>> _ = torch.manual_seed(42) - >>> preds = torch.rand(10, 5) - >>> target = torch.randint(2, (10, 5)) + >>> preds = rand(10, 5) + >>> target = randint(2, (10, 5)) >>> multilabel_coverage_error(preds, target, num_labels=5) tensor(3.9000) @@ -160,10 +160,10 @@ def multilabel_ranking_average_precision( Set to ``False`` for faster computations. Example: + >>> from torch import rand, randint >>> from torchmetrics.functional.classification import multilabel_ranking_average_precision - >>> _ = torch.manual_seed(42) - >>> preds = torch.rand(10, 5) - >>> target = torch.randint(2, (10, 5)) + >>> preds = rand(10, 5) + >>> target = randint(2, (10, 5)) >>> multilabel_ranking_average_precision(preds, target, num_labels=5) tensor(0.7744) @@ -245,10 +245,10 @@ def multilabel_ranking_loss( Set to ``False`` for faster computations. Example: + >>> from torch import rand, randint >>> from torchmetrics.functional.classification import multilabel_ranking_loss - >>> _ = torch.manual_seed(42) - >>> preds = torch.rand(10, 5) - >>> target = torch.randint(2, (10, 5)) + >>> preds = rand(10, 5) + >>> target = randint(2, (10, 5)) >>> multilabel_ranking_loss(preds, target, num_labels=5) tensor(0.4167) diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index 47c38ff72e1..565c212f9bd 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -67,7 +67,7 @@ def _binary_stat_scores_tensor_validation( _check_same_shape(preds, target) # Check that target only contains [0,1] values or value in ignore_index - unique_values = torch.unique(target) + unique_values = torch.unique(target, dim=None) if ignore_index is None: check = torch.any((unique_values != 0) & (unique_values != 1)) else: @@ -80,7 +80,7 @@ def _binary_stat_scores_tensor_validation( # If preds is label tensor, also check that it only contains [0,1] values if not preds.is_floating_point(): - unique_values = torch.unique(preds) + unique_values = torch.unique(preds, dim=None) if torch.any((unique_values != 0) & (unique_values != 1)): raise RuntimeError( f"Detected the following values in `preds`: {unique_values} but expected only" @@ -314,11 +314,11 @@ def _multiclass_stat_scores_tensor_validation( check_value = num_classes if ignore_index is None else num_classes + 1 for t, name in ((target, "target"),) + ((preds, "preds"),) if not preds.is_floating_point() else (): # noqa: RUF005 - num_unique_values = len(torch.unique(t)) + num_unique_values = len(torch.unique(t, dim=None)) if num_unique_values > check_value: raise RuntimeError( f"Detected more unique values in `{name}` than expected. Expected only {check_value} but found" - f" {num_unique_values} in `target`." + f" {num_unique_values} in `{name}`. Found values: {torch.unique(t, dim=None)}." ) @@ -624,7 +624,7 @@ def _multilabel_stat_scores_tensor_validation( ) # Check that target only contains [0,1] values or value in ignore_index - unique_values = torch.unique(target) + unique_values = torch.unique(target, dim=None) if ignore_index is None: check = torch.any((unique_values != 0) & (unique_values != 1)) else: @@ -637,7 +637,7 @@ def _multilabel_stat_scores_tensor_validation( # If preds is label tensor, also check that it only contains [0,1] values if not preds.is_floating_point(): - unique_values = torch.unique(preds) + unique_values = torch.unique(preds, dim=None) if torch.any((unique_values != 0) & (unique_values != 1)): raise RuntimeError( f"Detected the following values in `preds`: {unique_values} but expected only" diff --git a/src/torchmetrics/functional/clustering/calinski_harabasz_score.py b/src/torchmetrics/functional/clustering/calinski_harabasz_score.py index 2c934dfd44d..7501ff8f15d 100644 --- a/src/torchmetrics/functional/clustering/calinski_harabasz_score.py +++ b/src/torchmetrics/functional/clustering/calinski_harabasz_score.py @@ -31,13 +31,12 @@ def calinski_harabasz_score(data: Tensor, labels: Tensor) -> Tensor: Scalar tensor with the Calinski Harabasz Score Example: - >>> import torch + >>> from torch import randn, randint >>> from torchmetrics.functional.clustering import calinski_harabasz_score - >>> _ = torch.manual_seed(42) - >>> data = torch.randn(10, 3) - >>> labels = torch.randint(0, 2, (10,)) + >>> data = randn(20, 3) + >>> labels = randint(0, 3, (20,)) >>> calinski_harabasz_score(data, labels) - tensor(3.4998) + tensor(2.2128) """ _validate_intrinsic_cluster_data(data, labels) diff --git a/src/torchmetrics/functional/clustering/davies_bouldin_score.py b/src/torchmetrics/functional/clustering/davies_bouldin_score.py index cc3e530d7d9..1d6a7222703 100644 --- a/src/torchmetrics/functional/clustering/davies_bouldin_score.py +++ b/src/torchmetrics/functional/clustering/davies_bouldin_score.py @@ -31,13 +31,12 @@ def davies_bouldin_score(data: Tensor, labels: Tensor) -> Tensor: Scalar tensor with the Davies bouldin score Example: - >>> import torch + >>> from torch import randn, randint >>> from torchmetrics.functional.clustering import davies_bouldin_score - >>> _ = torch.manual_seed(42) - >>> data = torch.randn(10, 3) - >>> labels = torch.randint(0, 2, (10,)) + >>> data = randn(20, 3) + >>> labels = randint(0, 3, (20,)) >>> davies_bouldin_score(data, labels) - tensor(1.3249) + tensor(2.7418) """ _validate_intrinsic_cluster_data(data, labels) diff --git a/src/torchmetrics/functional/detection/__init__.py b/src/torchmetrics/functional/detection/__init__.py index 8f818c7b2df..fab5ccb5f91 100644 --- a/src/torchmetrics/functional/detection/__init__.py +++ b/src/torchmetrics/functional/detection/__init__.py @@ -15,20 +15,19 @@ from torchmetrics.functional.detection.panoptic_qualities import modified_panoptic_quality, panoptic_quality from torchmetrics.utilities.imports import ( _TORCHVISION_AVAILABLE, - _TORCHVISION_GREATER_EQUAL_0_8, - _TORCHVISION_GREATER_EQUAL_0_13, ) __all__ = ["modified_panoptic_quality", "panoptic_quality"] -if _TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8: - from torchmetrics.functional.detection.giou import generalized_intersection_over_union - from torchmetrics.functional.detection.iou import intersection_over_union - - __all__ += ["generalized_intersection_over_union", "intersection_over_union"] - -if _TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_13: +if _TORCHVISION_AVAILABLE: from torchmetrics.functional.detection.ciou import complete_intersection_over_union from torchmetrics.functional.detection.diou import distance_intersection_over_union + from torchmetrics.functional.detection.giou import generalized_intersection_over_union + from torchmetrics.functional.detection.iou import intersection_over_union - __all__ += ["complete_intersection_over_union", "distance_intersection_over_union"] + __all__ += [ + "generalized_intersection_over_union", + "intersection_over_union", + "complete_intersection_over_union", + "distance_intersection_over_union", + ] diff --git a/src/torchmetrics/functional/detection/_deprecated.py b/src/torchmetrics/functional/detection/_deprecated.py index b2500d34f0d..ce0e1ba6acf 100644 --- a/src/torchmetrics/functional/detection/_deprecated.py +++ b/src/torchmetrics/functional/detection/_deprecated.py @@ -3,12 +3,8 @@ from torch import Tensor from torchmetrics.functional.detection.panoptic_qualities import modified_panoptic_quality, panoptic_quality -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12 from torchmetrics.utilities.prints import _deprecated_root_import_func -if not _TORCH_GREATER_EQUAL_1_12: - __doctest_skip__ = ["_panoptic_quality", "_modified_panoptic_quality"] - def _modified_panoptic_quality( preds: Tensor, diff --git a/src/torchmetrics/functional/detection/ciou.py b/src/torchmetrics/functional/detection/ciou.py index 9669029ba73..650651b2e4f 100644 --- a/src/torchmetrics/functional/detection/ciou.py +++ b/src/torchmetrics/functional/detection/ciou.py @@ -15,9 +15,9 @@ import torch -from torchmetrics.utilities.imports import _TORCHVISION_GREATER_EQUAL_0_13 +from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE -if not _TORCHVISION_GREATER_EQUAL_0_13: +if not _TORCHVISION_AVAILABLE: __doctest_skip__ = ["complete_intersection_over_union"] @@ -113,11 +113,10 @@ def complete_intersection_over_union( [-0.3971, -0.1543, 0.5606]]) """ - if not _TORCHVISION_GREATER_EQUAL_0_13: + if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( - f"`{complete_intersection_over_union.__name__}` requires that `torchvision` version 0.13.0 or newer" - " is installed." - " Please install with `pip install torchvision>=0.13` or `pip install torchmetrics[detection]`." + f"`{complete_intersection_over_union.__name__}` requires that `torchvision` is installed." + " Please install with `pip install torchmetrics[detection]`." ) iou = _ciou_update(preds, target, iou_threshold, replacement_val) return _ciou_compute(iou, aggregate) diff --git a/src/torchmetrics/functional/detection/diou.py b/src/torchmetrics/functional/detection/diou.py index 13fb0071fed..7a9a3d907a9 100644 --- a/src/torchmetrics/functional/detection/diou.py +++ b/src/torchmetrics/functional/detection/diou.py @@ -15,9 +15,9 @@ import torch -from torchmetrics.utilities.imports import _TORCHVISION_GREATER_EQUAL_0_13 +from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE -if not _TORCHVISION_GREATER_EQUAL_0_13: +if not _TORCHVISION_AVAILABLE: __doctest_skip__ = ["distance_intersection_over_union"] @@ -113,11 +113,10 @@ def distance_intersection_over_union( [-0.3971, -0.1510, 0.5609]]) """ - if not _TORCHVISION_GREATER_EQUAL_0_13: + if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( - f"`{distance_intersection_over_union.__name__}` requires that `torchvision` version 0.13.0 or newer" - " is installed." - " Please install with `pip install torchvision>=0.13` or `pip install torchmetrics[detection]`." + f"`{distance_intersection_over_union.__name__}` requires that `torchvision` is installed." + " Please install with `pip install torchmetrics[detection]`." ) iou = _diou_update(preds, target, iou_threshold, replacement_val) return _diou_compute(iou, aggregate) diff --git a/src/torchmetrics/functional/detection/giou.py b/src/torchmetrics/functional/detection/giou.py index cc39f813b41..feae12d3011 100644 --- a/src/torchmetrics/functional/detection/giou.py +++ b/src/torchmetrics/functional/detection/giou.py @@ -15,9 +15,9 @@ import torch -from torchmetrics.utilities.imports import _TORCHVISION_GREATER_EQUAL_0_8 +from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE -if not _TORCHVISION_GREATER_EQUAL_0_8: +if not _TORCHVISION_AVAILABLE: __doctest_skip__ = ["generalized_intersection_over_union"] @@ -113,11 +113,10 @@ def generalized_intersection_over_union( [-0.6024, -0.4021, 0.5345]]) """ - if not _TORCHVISION_GREATER_EQUAL_0_8: + if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( - f"`{generalized_intersection_over_union.__name__}` requires that `torchvision` version 0.8.0 or newer" - " is installed." - " Please install with `pip install torchvision>=0.8` or `pip install torchmetrics[detection]`." + f"`{generalized_intersection_over_union.__name__}` requires that `torchvision` is installed." + " Please install with `pip install torchmetrics[detection]`." ) iou = _giou_update(preds, target, iou_threshold, replacement_val) return _giou_compute(iou, aggregate) diff --git a/src/torchmetrics/functional/detection/iou.py b/src/torchmetrics/functional/detection/iou.py index 3d3cef26bb2..249b30dd2d9 100644 --- a/src/torchmetrics/functional/detection/iou.py +++ b/src/torchmetrics/functional/detection/iou.py @@ -15,9 +15,9 @@ import torch -from torchmetrics.utilities.imports import _TORCHVISION_GREATER_EQUAL_0_8 +from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE -if not _TORCHVISION_GREATER_EQUAL_0_8: +if not _TORCHVISION_AVAILABLE: __doctest_skip__ = ["intersection_over_union"] @@ -114,10 +114,10 @@ def intersection_over_union( [0.0000, 0.0000, 0.5654]]) """ - if not _TORCHVISION_GREATER_EQUAL_0_8: + if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( - f"`{intersection_over_union.__name__}` requires that `torchvision` version 0.8.0 or newer is installed." - " Please install with `pip install torchvision>=0.8` or `pip install torchmetrics[detection]`." + f"`{intersection_over_union.__name__}` requires that `torchvision` is installed." + " Please install with `pip install torchmetrics[detection]`." ) iou = _iou_update(preds, target, iou_threshold, replacement_val) return _iou_compute(iou, aggregate) diff --git a/src/torchmetrics/functional/detection/panoptic_qualities.py b/src/torchmetrics/functional/detection/panoptic_qualities.py index 2de9fa09bfa..019d243f0ba 100644 --- a/src/torchmetrics/functional/detection/panoptic_qualities.py +++ b/src/torchmetrics/functional/detection/panoptic_qualities.py @@ -25,10 +25,6 @@ _prepocess_inputs, _validate_inputs, ) -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12 - -if not _TORCH_GREATER_EQUAL_1_12: - __doctest_skip__ = ["panoptic_quality", "modified_panoptic_quality"] def panoptic_quality( @@ -152,9 +148,6 @@ def panoptic_quality( [1.0000, 1.0000, 1.0000]], dtype=torch.float64) """ - if not _TORCH_GREATER_EQUAL_1_12: - raise RuntimeError("Panoptic Quality metric requires PyTorch 1.12 or later") - things, stuffs = _parse_categories(things, stuffs) _validate_inputs(preds, target) void_color = _get_void_color(things, stuffs) diff --git a/src/torchmetrics/functional/image/_deprecated.py b/src/torchmetrics/functional/image/_deprecated.py index 6fc768ce58e..892d07afaa6 100644 --- a/src/torchmetrics/functional/image/_deprecated.py +++ b/src/torchmetrics/functional/image/_deprecated.py @@ -27,10 +27,9 @@ def _spectral_distortion_index( ) -> Tensor: """Wrapper for deprecated import. - >>> import torch - >>> _ = torch.manual_seed(42) - >>> preds = torch.rand([16, 3, 16, 16]) - >>> target = torch.rand([16, 3, 16, 16]) + >>> from torch import rand + >>> preds = rand([16, 3, 16, 16]) + >>> target = rand([16, 3, 16, 16]) >>> _spectral_distortion_index(preds, target) tensor(0.0234) @@ -47,12 +46,10 @@ def _error_relative_global_dimensionless_synthesis( ) -> Tensor: """Wrapper for deprecated import. - >>> import torch - >>> gen = torch.manual_seed(42) - >>> preds = torch.rand([16, 1, 16, 16], generator=gen) + >>> from torch import rand + >>> preds = rand([16, 1, 16, 16]) >>> target = preds * 0.75 - >>> ergds = _error_relative_global_dimensionless_synthesis(preds, target) - >>> torch.round(ergds) + >>> _error_relative_global_dimensionless_synthesis(preds, target).round() tensor(10.) """ @@ -105,12 +102,11 @@ def _peak_signal_noise_ratio( def _relative_average_spectral_error(preds: Tensor, target: Tensor, window_size: int = 8) -> Tensor: """Wrapper for deprecated import. - >>> import torch - >>> gen = torch.manual_seed(22) - >>> preds = torch.rand(4, 3, 16, 16, generator=gen) - >>> target = torch.rand(4, 3, 16, 16, generator=gen) + >>> from torch import rand + >>> preds = rand(4, 3, 16, 16) + >>> target = rand(4, 3, 16, 16) >>> _relative_average_spectral_error(preds, target) - tensor(5114.66...) + tensor(5326.40...) """ _deprecated_root_import_func("relative_average_spectral_error", "image") @@ -122,12 +118,11 @@ def _root_mean_squared_error_using_sliding_window( ) -> Union[Optional[Tensor], Tuple[Optional[Tensor], Tensor]]: """Wrapper for deprecated import. - >>> import torch - >>> gen = torch.manual_seed(22) - >>> preds = torch.rand(4, 3, 16, 16, generator=gen) - >>> target = torch.rand(4, 3, 16, 16, generator=gen) + >>> from torch import rand + >>> preds = rand(4, 3, 16, 16) + >>> target = rand(4, 3, 16, 16) >>> _root_mean_squared_error_using_sliding_window(preds, target) - tensor(0.3999) + tensor(0.4158) """ _deprecated_root_import_func("root_mean_squared_error_using_sliding_window", "image") @@ -143,10 +138,9 @@ def _spectral_angle_mapper( ) -> Tensor: """Wrapper for deprecated import. - >>> import torch - >>> gen = torch.manual_seed(42) - >>> preds = torch.rand([16, 3, 16, 16], generator=gen) - >>> target = torch.rand([16, 3, 16, 16], generator=gen) + >>> from torch import rand + >>> preds = rand([16, 3, 16, 16]) + >>> target = rand([16, 3, 16, 16]) >>> _spectral_angle_mapper(preds, target) tensor(0.5914) @@ -170,12 +164,11 @@ def _multiscale_structural_similarity_index_measure( ) -> Tensor: """Wrapper for deprecated import. - >>> import torch - >>> gen = torch.manual_seed(42) - >>> preds = torch.rand([3, 3, 256, 256], generator=gen) + >>> from torch import rand + >>> preds = rand([3, 3, 256, 256]) >>> target = preds * 0.75 >>> _multiscale_structural_similarity_index_measure(preds, target, data_range=1.0) - tensor(0.9627) + tensor(0.9628) """ _deprecated_root_import_func("multiscale_structural_similarity_index_measure", "image") @@ -235,9 +228,8 @@ def _structural_similarity_index_measure( def _total_variation(img: Tensor, reduction: Literal["mean", "sum", "none", None] = "sum") -> Tensor: """Wrapper for deprecated import. - >>> import torch - >>> _ = torch.manual_seed(42) - >>> img = torch.rand(5, 3, 28, 28) + >>> from torch import rand + >>> img = rand(5, 3, 28, 28) >>> _total_variation(img) tensor(7546.8018) diff --git a/src/torchmetrics/functional/image/d_lambda.py b/src/torchmetrics/functional/image/d_lambda.py index e5ffa25b5cb..5921f51d32d 100644 --- a/src/torchmetrics/functional/image/d_lambda.py +++ b/src/torchmetrics/functional/image/d_lambda.py @@ -65,9 +65,9 @@ def _spectral_distortion_index_compute( - ``'none'``: no reduction will be applied Example: - >>> _ = torch.manual_seed(42) - >>> preds = torch.rand([16, 3, 16, 16]) - >>> target = torch.rand([16, 3, 16, 16]) + >>> from torch import rand + >>> preds = rand([16, 3, 16, 16]) + >>> target = rand([16, 3, 16, 16]) >>> preds, target = _spectral_distortion_index_update(preds, target) >>> _spectral_distortion_index_compute(preds, target) tensor(0.0234) @@ -139,10 +139,10 @@ def spectral_distortion_index( If ``p`` is not a positive integer. Example: + >>> from torch import rand >>> from torchmetrics.functional.image import spectral_distortion_index - >>> _ = torch.manual_seed(42) - >>> preds = torch.rand([16, 3, 16, 16]) - >>> target = torch.rand([16, 3, 16, 16]) + >>> preds = rand([16, 3, 16, 16]) + >>> target = rand([16, 3, 16, 16]) >>> spectral_distortion_index(preds, target) tensor(0.0234) diff --git a/src/torchmetrics/functional/image/d_s.py b/src/torchmetrics/functional/image/d_s.py index e2190cb08a8..824cd3bfa8b 100644 --- a/src/torchmetrics/functional/image/d_s.py +++ b/src/torchmetrics/functional/image/d_s.py @@ -160,10 +160,10 @@ def _spatial_distortion_index_compute( If ``window_size`` is smaller than dimension of ``ms``. Example: - >>> _ = torch.manual_seed(42) - >>> preds = torch.rand([16, 3, 32, 32]) - >>> ms = torch.rand([16, 3, 16, 16]) - >>> pan = torch.rand([16, 3, 32, 32]) + >>> from torch import rand + >>> preds = rand([16, 3, 32, 32]) + >>> ms = rand([16, 3, 16, 16]) + >>> pan = rand([16, 3, 32, 32]) >>> preds, ms, pan, pan_lr = _spatial_distortion_index_update(preds, ms, pan) >>> _spatial_distortion_index_compute(preds, ms, pan, pan_lr) tensor(0.0090) @@ -250,11 +250,11 @@ def spatial_distortion_index( If ``window_size`` is not a positive integer. Example: + >>> from torch import rand >>> from torchmetrics.functional.image import spatial_distortion_index - >>> _ = torch.manual_seed(42) - >>> preds = torch.rand([16, 3, 32, 32]) - >>> ms = torch.rand([16, 3, 16, 16]) - >>> pan = torch.rand([16, 3, 32, 32]) + >>> preds = rand([16, 3, 32, 32]) + >>> ms = rand([16, 3, 16, 16]) + >>> pan = rand([16, 3, 32, 32]) >>> spatial_distortion_index(preds, ms, pan) tensor(0.0090) diff --git a/src/torchmetrics/functional/image/ergas.py b/src/torchmetrics/functional/image/ergas.py index 9d3032de9f1..41500c552e0 100644 --- a/src/torchmetrics/functional/image/ergas.py +++ b/src/torchmetrics/functional/image/ergas.py @@ -62,8 +62,8 @@ def _ergas_compute( - ``'none'`` or ``None``: no reduction will be applied Example: - >>> gen = torch.manual_seed(42) - >>> preds = torch.rand([16, 1, 16, 16], generator=gen) + >>> from torch import rand + >>> preds = rand([16, 1, 16, 16]) >>> target = preds * 0.75 >>> preds, target = _ergas_update(preds, target) >>> torch.round(_ergas_compute(preds, target)) @@ -111,9 +111,9 @@ def error_relative_global_dimensionless_synthesis( If ``preds`` and ``target`` don't have ``BxCxHxW shape``. Example: + >>> from torch import rand >>> from torchmetrics.functional.image import error_relative_global_dimensionless_synthesis - >>> gen = torch.manual_seed(42) - >>> preds = torch.rand([16, 1, 16, 16], generator=gen) + >>> preds = rand([16, 1, 16, 16]) >>> target = preds * 0.75 >>> error_relative_global_dimensionless_synthesis(preds, target) tensor(9.6193) diff --git a/src/torchmetrics/functional/image/lpips.py b/src/torchmetrics/functional/image/lpips.py index 64c237b087e..c557f61ead1 100644 --- a/src/torchmetrics/functional/image/lpips.py +++ b/src/torchmetrics/functional/image/lpips.py @@ -30,7 +30,7 @@ from torch import Tensor, nn from typing_extensions import Literal -from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_13 +from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE _weight_map = { "squeezenet1_1": "SqueezeNet1_1_Weights", @@ -52,13 +52,11 @@ def _get_net(net: str, pretrained: bool) -> nn.modules.container.Sequential: """ from torchvision import models as tv - if _TORCHVISION_GREATER_EQUAL_0_13: + if _TORCHVISION_AVAILABLE: if pretrained: pretrained_features = getattr(tv, net)(weights=getattr(tv, _weight_map[net]).IMAGENET1K_V1).features else: pretrained_features = getattr(tv, net)(weights=None).features - else: - pretrained_features = getattr(tv, net)(pretrained=pretrained).features return pretrained_features @@ -421,13 +419,12 @@ def learned_perceptual_image_patch_similarity( to ``True`` will instead expect input to be in the ``[0,1]`` range. Example: - >>> import torch - >>> _ = torch.manual_seed(123) + >>> from torch import rand >>> from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity - >>> img1 = (torch.rand(10, 3, 100, 100) * 2) - 1 - >>> img2 = (torch.rand(10, 3, 100, 100) * 2) - 1 + >>> img1 = (rand(10, 3, 100, 100) * 2) - 1 + >>> img2 = (rand(10, 3, 100, 100) * 2) - 1 >>> learned_perceptual_image_patch_similarity(img1, img2, net_type='squeeze') - tensor(0.1008) + tensor(0.1005) """ net = _NoTrainLpips(net=net_type).to(device=img1.device, dtype=img1.dtype) diff --git a/src/torchmetrics/functional/image/perceptual_path_length.py b/src/torchmetrics/functional/image/perceptual_path_length.py index 4b675ae9c54..58b0a7bae05 100644 --- a/src/torchmetrics/functional/image/perceptual_path_length.py +++ b/src/torchmetrics/functional/image/perceptual_path_length.py @@ -202,9 +202,8 @@ def perceptual_path_length( A tuple containing the mean, standard deviation and all distances. Example:: - >>> from torchmetrics.functional.image import perceptual_path_length >>> import torch - >>> _ = torch.manual_seed(42) + >>> from torchmetrics.functional.image import perceptual_path_length >>> class DummyGenerator(torch.nn.Module): ... def __init__(self, z_size) -> None: ... super().__init__() diff --git a/src/torchmetrics/functional/image/psnr.py b/src/torchmetrics/functional/image/psnr.py index 8f6a3f20dba..adb80e2bc77 100644 --- a/src/torchmetrics/functional/image/psnr.py +++ b/src/torchmetrics/functional/image/psnr.py @@ -69,6 +69,11 @@ def _psnr_update( Default is None meaning scores will be reduced across all dimensions. """ + if not preds.is_floating_point(): + preds = preds.to(torch.float32) + if not target.is_floating_point(): + target = target.to(torch.float32) + if dim is None: sum_squared_error = torch.sum(torch.pow(preds - target, 2)) num_obs = tensor(target.numel(), device=target.device) diff --git a/src/torchmetrics/functional/image/psnrb.py b/src/torchmetrics/functional/image/psnrb.py index 8e190d9f563..f725ea4bd80 100644 --- a/src/torchmetrics/functional/image/psnrb.py +++ b/src/torchmetrics/functional/image/psnrb.py @@ -122,11 +122,10 @@ def peak_signal_noise_ratio_with_blocked_effect( Tensor with PSNRB score Example: - >>> import torch + >>> from torch import rand >>> from torchmetrics.functional.image import peak_signal_noise_ratio_with_blocked_effect - >>> _ = torch.manual_seed(42) - >>> preds = torch.rand(1, 1, 28, 28) - >>> target = torch.rand(1, 1, 28, 28) + >>> preds = rand(1, 1, 28, 28) + >>> target = rand(1, 1, 28, 28) >>> peak_signal_noise_ratio_with_blocked_effect(preds, target) tensor(7.8402) diff --git a/src/torchmetrics/functional/image/qnr.py b/src/torchmetrics/functional/image/qnr.py index 2e963345634..e34e6bc2473 100644 --- a/src/torchmetrics/functional/image/qnr.py +++ b/src/torchmetrics/functional/image/qnr.py @@ -63,12 +63,11 @@ def quality_with_no_reference( If ``alpha`` or ``beta`` is not a non-negative real number. Example: - >>> import torch + >>> from torch import rand >>> from torchmetrics.functional.image import quality_with_no_reference - >>> _ = torch.manual_seed(42) - >>> preds = torch.rand([16, 3, 32, 32]) - >>> ms = torch.rand([16, 3, 16, 16]) - >>> pan = torch.rand([16, 3, 32, 32]) + >>> preds = rand([16, 3, 32, 32]) + >>> ms = rand([16, 3, 16, 16]) + >>> pan = rand([16, 3, 32, 32]) >>> quality_with_no_reference(preds, ms, pan) tensor(0.9694) diff --git a/src/torchmetrics/functional/image/rase.py b/src/torchmetrics/functional/image/rase.py index 54d20c6eee0..388f2c237a3 100644 --- a/src/torchmetrics/functional/image/rase.py +++ b/src/torchmetrics/functional/image/rase.py @@ -80,12 +80,12 @@ def relative_average_spectral_error(preds: Tensor, target: Tensor, window_size: Relative Average Spectral Error (RASE) Example: + >>> from torch import rand >>> from torchmetrics.functional.image import relative_average_spectral_error - >>> g = torch.manual_seed(22) - >>> preds = torch.rand(4, 3, 16, 16) - >>> target = torch.rand(4, 3, 16, 16) + >>> preds = rand(4, 3, 16, 16) + >>> target = rand(4, 3, 16, 16) >>> relative_average_spectral_error(preds, target) - tensor(5114.66...) + tensor(5326.40...) Raises: ValueError: If ``window_size`` is not a positive integer. diff --git a/src/torchmetrics/functional/image/rmse_sw.py b/src/torchmetrics/functional/image/rmse_sw.py index 9ef9fa773de..a27582bd11a 100644 --- a/src/torchmetrics/functional/image/rmse_sw.py +++ b/src/torchmetrics/functional/image/rmse_sw.py @@ -104,7 +104,8 @@ def _rmse_sw_compute( """ rmse = rmse_val_sum / total_images if rmse_val_sum is not None else None if rmse_map is not None: - rmse_map /= total_images + # prevent overwrite the inputs + rmse_map = rmse_map / total_images return rmse, rmse_map @@ -124,12 +125,12 @@ def root_mean_squared_error_using_sliding_window( (Optionally) RMSE map Example: + >>> from torch import rand >>> from torchmetrics.functional.image import root_mean_squared_error_using_sliding_window - >>> g = torch.manual_seed(22) - >>> preds = torch.rand(4, 3, 16, 16) - >>> target = torch.rand(4, 3, 16, 16) + >>> preds = rand(4, 3, 16, 16) + >>> target = rand(4, 3, 16, 16) >>> root_mean_squared_error_using_sliding_window(preds, target) - tensor(0.3999) + tensor(0.4158) Raises: ValueError: If ``window_size`` is not a positive integer. diff --git a/src/torchmetrics/functional/image/sam.py b/src/torchmetrics/functional/image/sam.py index 9bfd58cfc8c..71927ff6b42 100644 --- a/src/torchmetrics/functional/image/sam.py +++ b/src/torchmetrics/functional/image/sam.py @@ -65,9 +65,9 @@ def _sam_compute( - ``'none'`` or ``None``: no reduction will be applied Example: - >>> gen = torch.manual_seed(42) - >>> preds = torch.rand([16, 3, 16, 16], generator=gen) - >>> target = torch.rand([16, 3, 16, 16], generator=gen) + >>> from torch import rand + >>> preds = rand([16, 3, 16, 16]) + >>> target = rand([16, 3, 16, 16]) >>> preds, target = _sam_update(preds, target) >>> _sam_compute(preds, target) tensor(0.5914) @@ -106,10 +106,10 @@ def spectral_angle_mapper( If ``preds`` and ``target`` don't have ``BxCxHxW shape``. Example: + >>> from torch import rand >>> from torchmetrics.functional.image import spectral_angle_mapper - >>> gen = torch.manual_seed(42) - >>> preds = torch.rand([16, 3, 16, 16], generator=gen) - >>> target = torch.rand([16, 3, 16, 16], generator=gen) + >>> preds = rand([16, 3, 16, 16],) + >>> target = rand([16, 3, 16, 16]) >>> spectral_angle_mapper(preds, target) tensor(0.5914) diff --git a/src/torchmetrics/functional/image/scc.py b/src/torchmetrics/functional/image/scc.py index 167ddbd37b5..a4b7bd85725 100644 --- a/src/torchmetrics/functional/image/scc.py +++ b/src/torchmetrics/functional/image/scc.py @@ -185,17 +185,16 @@ def spatial_correlation_coefficient( Tensor with scc score Example: - >>> import torch + >>> from torch import randn >>> from torchmetrics.functional.image import spatial_correlation_coefficient as scc - >>> _ = torch.manual_seed(42) - >>> x = torch.randn(5, 3, 16, 16) + >>> x = randn(5, 3, 16, 16) >>> scc(x, x) tensor(1.) - >>> x = torch.randn(5, 16, 16) + >>> x = randn(5, 16, 16) >>> scc(x, x) tensor(1.) - >>> x = torch.randn(5, 3, 16, 16) - >>> y = torch.randn(5, 3, 16, 16) + >>> x = randn(5, 3, 16, 16) + >>> y = randn(5, 3, 16, 16) >>> scc(x, y, reduction="none") tensor([0.0223, 0.0256, 0.0616, 0.0159, 0.0170]) diff --git a/src/torchmetrics/functional/image/ssim.py b/src/torchmetrics/functional/image/ssim.py index e33e3943fab..c61ef833fe3 100644 --- a/src/torchmetrics/functional/image/ssim.py +++ b/src/torchmetrics/functional/image/ssim.py @@ -124,11 +124,15 @@ def _ssim_update( dtype = preds.dtype gauss_kernel_size = [int(3.5 * s + 0.5) * 2 + 1 for s in sigma] - pad_h = (gauss_kernel_size[0] - 1) // 2 - pad_w = (gauss_kernel_size[1] - 1) // 2 + if gaussian_kernel: + pad_h = (gauss_kernel_size[0] - 1) // 2 + pad_w = (gauss_kernel_size[1] - 1) // 2 + else: + pad_h = (kernel_size[0] - 1) // 2 + pad_w = (kernel_size[1] - 1) // 2 if is_3d: - pad_d = (gauss_kernel_size[2] - 1) // 2 + pad_d = (kernel_size[2] - 1) // 2 preds = _reflection_pad_3d(preds, pad_d, pad_w, pad_h) target = _reflection_pad_3d(target, pad_d, pad_w, pad_h) if gaussian_kernel: @@ -164,25 +168,21 @@ def _ssim_update( ssim_idx_full_image = ((2 * mu_pred_target + c1) * upper) / ((mu_pred_sq + mu_target_sq + c1) * lower) - if is_3d: - ssim_idx = ssim_idx_full_image[..., pad_h:-pad_h, pad_w:-pad_w, pad_d:-pad_d] - else: - ssim_idx = ssim_idx_full_image[..., pad_h:-pad_h, pad_w:-pad_w] - if return_contrast_sensitivity: contrast_sensitivity = upper / lower if is_3d: contrast_sensitivity = contrast_sensitivity[..., pad_h:-pad_h, pad_w:-pad_w, pad_d:-pad_d] else: contrast_sensitivity = contrast_sensitivity[..., pad_h:-pad_h, pad_w:-pad_w] - return ssim_idx.reshape(ssim_idx.shape[0], -1).mean(-1), contrast_sensitivity.reshape( + + return ssim_idx_full_image.reshape(ssim_idx_full_image.shape[0], -1).mean(-1), contrast_sensitivity.reshape( contrast_sensitivity.shape[0], -1 ).mean(-1) if return_full_image: - return ssim_idx.reshape(ssim_idx.shape[0], -1).mean(-1), ssim_idx_full_image + return ssim_idx_full_image.reshape(ssim_idx_full_image.shape[0], -1).mean(-1), ssim_idx_full_image - return ssim_idx.reshape(ssim_idx.shape[0], -1).mean(-1) + return ssim_idx_full_image.reshape(ssim_idx_full_image.shape[0], -1).mean(-1) def _ssim_compute( @@ -502,12 +502,12 @@ def multiscale_structural_similarity_index_measure( If one of the elements of ``sigma`` is not a ``positive number``. Example: + >>> from torch import rand >>> from torchmetrics.functional.image import multiscale_structural_similarity_index_measure - >>> gen = torch.manual_seed(42) - >>> preds = torch.rand([3, 3, 256, 256], generator=gen) + >>> preds = rand([3, 3, 256, 256]) >>> target = preds * 0.75 >>> multiscale_structural_similarity_index_measure(preds, target, data_range=1.0) - tensor(0.9627) + tensor(0.9628) References: [1] Multi-Scale Structural Similarity For Image Quality Assessment by Zhou Wang, Eero P. Simoncelli and Alan C. diff --git a/src/torchmetrics/functional/image/tv.py b/src/torchmetrics/functional/image/tv.py index da888b97379..be8e3366caa 100644 --- a/src/torchmetrics/functional/image/tv.py +++ b/src/torchmetrics/functional/image/tv.py @@ -64,10 +64,9 @@ def total_variation(img: Tensor, reduction: Optional[Literal["mean", "sum", "non If ``img`` is not 4D tensor Example: - >>> import torch + >>> from torch import rand >>> from torchmetrics.functional.image import total_variation - >>> _ = torch.manual_seed(42) - >>> img = torch.rand(5, 3, 28, 28) + >>> img = rand(5, 3, 28, 28) >>> total_variation(img) tensor(7546.8018) diff --git a/src/torchmetrics/functional/multimodal/clip_iqa.py b/src/torchmetrics/functional/multimodal/clip_iqa.py index 1c60bf12877..49e710e248a 100644 --- a/src/torchmetrics/functional/multimodal/clip_iqa.py +++ b/src/torchmetrics/functional/multimodal/clip_iqa.py @@ -293,32 +293,29 @@ def clip_image_quality_assessment( Example:: Single prompt: + >>> from torch import randint >>> from torchmetrics.functional.multimodal import clip_image_quality_assessment - >>> import torch - >>> _ = torch.manual_seed(42) - >>> imgs = torch.randint(255, (2, 3, 224, 224)).float() + >>> imgs = randint(255, (2, 3, 224, 224)).float() >>> clip_image_quality_assessment(imgs, prompts=("quality",)) tensor([0.8894, 0.8902]) Example:: Multiple prompts: + >>> from torch import randint >>> from torchmetrics.functional.multimodal import clip_image_quality_assessment - >>> import torch - >>> _ = torch.manual_seed(42) - >>> imgs = torch.randint(255, (2, 3, 224, 224)).float() + >>> imgs = randint(255, (2, 3, 224, 224)).float() >>> clip_image_quality_assessment(imgs, prompts=("quality", "brightness")) - {'quality': tensor([0.8894, 0.8902]), 'brightness': tensor([0.5507, 0.5208])} + {'quality': tensor([0.8693, 0.8705]), 'brightness': tensor([0.5722, 0.4762])} Example:: Custom prompts. Must always be a tuple of length 2, with a positive and negative prompt. + >>> from torch import rand >>> from torchmetrics.functional.multimodal import clip_image_quality_assessment - >>> import torch - >>> _ = torch.manual_seed(42) - >>> imgs = torch.randint(255, (2, 3, 224, 224)).float() + >>> imgs = randint(255, (2, 3, 224, 224)).float() >>> clip_image_quality_assessment(imgs, prompts=(("Super good photo.", "Super bad photo."), "brightness")) - {'user_defined_0': tensor([0.9652, 0.9629]), 'brightness': tensor([0.5507, 0.5208])} + {'user_defined_0': tensor([0.9578, 0.9654]), 'brightness': tensor([0.5495, 0.5764])} """ prompts_list, prompts_names = _clip_iqa_format_prompts(prompts) diff --git a/src/torchmetrics/functional/multimodal/clip_score.py b/src/torchmetrics/functional/multimodal/clip_score.py index bc0df92bd63..920eb6972e6 100644 --- a/src/torchmetrics/functional/multimodal/clip_score.py +++ b/src/torchmetrics/functional/multimodal/clip_score.py @@ -153,8 +153,6 @@ def clip_score( If the number of images and captions do not match Example: - >>> import torch - >>> _ = torch.manual_seed(42) >>> from torchmetrics.functional.multimodal import clip_score >>> score = clip_score(torch.randint(255, (3, 224, 224)), "a photo of a cat", "openai/clip-vit-base-patch16") >>> score.detach() diff --git a/src/torchmetrics/functional/nominal/__init__.py b/src/torchmetrics/functional/nominal/__init__.py index f29dd9302f0..772cb395895 100644 --- a/src/torchmetrics/functional/nominal/__init__.py +++ b/src/torchmetrics/functional/nominal/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from torchmetrics.functional.nominal.cramers import cramers_v, cramers_v_matrix from torchmetrics.functional.nominal.fleiss_kappa import fleiss_kappa from torchmetrics.functional.nominal.pearson import ( diff --git a/src/torchmetrics/functional/nominal/cramers.py b/src/torchmetrics/functional/nominal/cramers.py index 75a606f6c19..33b89b92014 100644 --- a/src/torchmetrics/functional/nominal/cramers.py +++ b/src/torchmetrics/functional/nominal/cramers.py @@ -124,10 +124,10 @@ def cramers_v( Cramer's V statistic Example: + >>> from torch import randint, round >>> from torchmetrics.functional.nominal import cramers_v - >>> _ = torch.manual_seed(42) - >>> preds = torch.randint(0, 4, (100,)) - >>> target = torch.round(preds + torch.randn(100)).clamp(0, 4) + >>> preds = randint(0, 4, (100,)) + >>> target = round(preds + torch.randn(100)).clamp(0, 4) >>> cramers_v(preds, target) tensor(0.5284) @@ -161,9 +161,9 @@ def cramers_v_matrix( Cramer's V statistic for a dataset of categorical variables Example: + >>> from torch import randint >>> from torchmetrics.functional.nominal import cramers_v_matrix - >>> _ = torch.manual_seed(42) - >>> matrix = torch.randint(0, 4, (200, 5)) + >>> matrix = randint(0, 4, (200, 5)) >>> cramers_v_matrix(matrix) tensor([[1.0000, 0.0637, 0.0000, 0.0542, 0.1337], [0.0637, 1.0000, 0.0000, 0.0000, 0.0000], diff --git a/src/torchmetrics/functional/nominal/fleiss_kappa.py b/src/torchmetrics/functional/nominal/fleiss_kappa.py index f33e11d07ef..69990f552d8 100644 --- a/src/torchmetrics/functional/nominal/fleiss_kappa.py +++ b/src/torchmetrics/functional/nominal/fleiss_kappa.py @@ -78,21 +78,19 @@ def fleiss_kappa(ratings: Tensor, mode: Literal["counts", "probs"] = "counts") - Example: >>> # Ratings are provided as counts - >>> import torch + >>> from torch import randint >>> from torchmetrics.functional.nominal import fleiss_kappa - >>> _ = torch.manual_seed(42) - >>> ratings = torch.randint(0, 10, size=(100, 5)).long() # 100 samples, 5 categories, 10 raters + >>> ratings = randint(0, 10, size=(100, 5)).long() # 100 samples, 5 categories, 10 raters >>> fleiss_kappa(ratings) tensor(0.0089) Example: >>> # Ratings are provided as probabilities - >>> import torch + >>> from torch import randn >>> from torchmetrics.functional.nominal import fleiss_kappa - >>> _ = torch.manual_seed(42) - >>> ratings = torch.randn(100, 5, 10).softmax(dim=1) # 100 samples, 5 categories, 10 raters + >>> ratings = randn(100, 5, 10).softmax(dim=1) # 100 samples, 5 categories, 10 raters >>> fleiss_kappa(ratings, mode='probs') - tensor(-0.0105) + tensor(-0.0075) """ if mode not in ["counts", "probs"]: diff --git a/src/torchmetrics/functional/nominal/pearson.py b/src/torchmetrics/functional/nominal/pearson.py index bd25c701fe4..55fe1681bf7 100644 --- a/src/torchmetrics/functional/nominal/pearson.py +++ b/src/torchmetrics/functional/nominal/pearson.py @@ -114,10 +114,10 @@ def pearsons_contingency_coefficient( Pearson's Contingency Coefficient Example: + >>> from torch import randint, round >>> from torchmetrics.functional.nominal import pearsons_contingency_coefficient - >>> _ = torch.manual_seed(42) - >>> preds = torch.randint(0, 4, (100,)) - >>> target = torch.round(preds + torch.randn(100)).clamp(0, 4) + >>> preds = randint(0, 4, (100,)) + >>> target = round(preds + torch.randn(100)).clamp(0, 4) >>> pearsons_contingency_coefficient(preds, target) tensor(0.6948) @@ -151,9 +151,9 @@ def pearsons_contingency_coefficient_matrix( Pearson's Contingency Coefficient statistic for a dataset of categorical variables Example: + >>> from torch import randint >>> from torchmetrics.functional.nominal import pearsons_contingency_coefficient_matrix - >>> _ = torch.manual_seed(42) - >>> matrix = torch.randint(0, 4, (200, 5)) + >>> matrix = randint(0, 4, (200, 5)) >>> pearsons_contingency_coefficient_matrix(matrix) tensor([[1.0000, 0.2326, 0.1959, 0.2262, 0.2989], [0.2326, 1.0000, 0.1386, 0.1895, 0.1329], diff --git a/src/torchmetrics/functional/nominal/theils_u.py b/src/torchmetrics/functional/nominal/theils_u.py index 8bdaf38aa8a..f356dbfd03d 100644 --- a/src/torchmetrics/functional/nominal/theils_u.py +++ b/src/torchmetrics/functional/nominal/theils_u.py @@ -138,10 +138,10 @@ def theils_u( Tensor containing Theil's U statistic Example: + >>> from torch import randint >>> from torchmetrics.functional.nominal import theils_u - >>> _ = torch.manual_seed(42) - >>> preds = torch.randint(10, (10,)) - >>> target = torch.randint(10, (10,)) + >>> preds = randint(10, (10,)) + >>> target = randint(10, (10,)) >>> theils_u(preds, target) tensor(0.8530) @@ -172,9 +172,9 @@ def theils_u_matrix( Theil's U statistic for a dataset of categorical variables Example: + >>> from torch import randint >>> from torchmetrics.functional.nominal import theils_u_matrix - >>> _ = torch.manual_seed(42) - >>> matrix = torch.randint(0, 4, (200, 5)) + >>> matrix = randint(0, 4, (200, 5)) >>> theils_u_matrix(matrix) tensor([[1.0000, 0.0202, 0.0142, 0.0196, 0.0353], [0.0202, 1.0000, 0.0070, 0.0136, 0.0065], diff --git a/src/torchmetrics/functional/nominal/tschuprows.py b/src/torchmetrics/functional/nominal/tschuprows.py index 2ea20d57f19..22d256d33d1 100644 --- a/src/torchmetrics/functional/nominal/tschuprows.py +++ b/src/torchmetrics/functional/nominal/tschuprows.py @@ -130,10 +130,10 @@ def tschuprows_t( Tschuprow's T statistic Example: + >>> from torch import randint, round >>> from torchmetrics.functional.nominal import tschuprows_t - >>> _ = torch.manual_seed(42) - >>> preds = torch.randint(0, 4, (100,)) - >>> target = torch.round(preds + torch.randn(100)).clamp(0, 4) + >>> preds = randint(0, 4, (100,)) + >>> target = round(preds + torch.randn(100)).clamp(0, 4) >>> tschuprows_t(preds, target) tensor(0.4930) @@ -169,9 +169,9 @@ def tschuprows_t_matrix( Tschuprow's T statistic for a dataset of categorical variables Example: + >>> from torch import randint >>> from torchmetrics.functional.nominal import tschuprows_t_matrix - >>> _ = torch.manual_seed(42) - >>> matrix = torch.randint(0, 4, (200, 5)) + >>> matrix = randint(0, 4, (200, 5)) >>> tschuprows_t_matrix(matrix) tensor([[1.0000, 0.0637, 0.0000, 0.0542, 0.1337], [0.0637, 1.0000, 0.0000, 0.0000, 0.0000], diff --git a/src/torchmetrics/functional/nominal/utils.py b/src/torchmetrics/functional/nominal/utils.py index 258209326da..8c8cc166778 100644 --- a/src/torchmetrics/functional/nominal/utils.py +++ b/src/torchmetrics/functional/nominal/utils.py @@ -62,18 +62,18 @@ def _drop_empty_rows_and_cols(confmat: Tensor) -> Tensor: """Drop all rows and columns containing only zeros. Example: - >>> import torch + >>> from torch import randint >>> from torchmetrics.functional.nominal.utils import _drop_empty_rows_and_cols - >>> _ = torch.manual_seed(22) - >>> matrix = torch.randint(10, size=(3, 3)) + >>> matrix = randint(10, size=(4, 3)) >>> matrix[1, :] = matrix[:, 1] = 0 >>> matrix - tensor([[9, 0, 6], + tensor([[2, 0, 6], [0, 0, 0], - [2, 0, 8]]) + [0, 0, 0], + [3, 0, 4]]) >>> _drop_empty_rows_and_cols(matrix) - tensor([[9, 6], - [2, 8]]) + tensor([[2, 6], + [3, 4]]) """ confmat = confmat[confmat.sum(1) != 0] diff --git a/src/torchmetrics/functional/regression/__init__.py b/src/torchmetrics/functional/regression/__init__.py index c2dab8c5f59..063fbc059e3 100644 --- a/src/torchmetrics/functional/regression/__init__.py +++ b/src/torchmetrics/functional/regression/__init__.py @@ -23,6 +23,7 @@ from torchmetrics.functional.regression.mape import mean_absolute_percentage_error from torchmetrics.functional.regression.minkowski import minkowski_distance from torchmetrics.functional.regression.mse import mean_squared_error +from torchmetrics.functional.regression.nrmse import normalized_root_mean_squared_error from torchmetrics.functional.regression.pearson import pearson_corrcoef from torchmetrics.functional.regression.r2 import r2_score from torchmetrics.functional.regression.rse import relative_squared_error @@ -39,13 +40,14 @@ "kendall_rank_corrcoef", "kl_divergence", "log_cosh_error", - "mean_squared_log_error", "mean_absolute_error", - "mean_squared_error", - "pearson_corrcoef", "mean_absolute_percentage_error", "mean_absolute_percentage_error", + "mean_squared_error", + "mean_squared_log_error", "minkowski_distance", + "normalized_root_mean_squared_error", + "pearson_corrcoef", "r2_score", "relative_squared_error", "spearman_corrcoef", diff --git a/src/torchmetrics/functional/regression/concordance.py b/src/torchmetrics/functional/regression/concordance.py index e18afe2f619..501cf8da054 100644 --- a/src/torchmetrics/functional/regression/concordance.py +++ b/src/torchmetrics/functional/regression/concordance.py @@ -27,6 +27,8 @@ def _concordance_corrcoef_compute( ) -> Tensor: """Compute the final concordance correlation coefficient based on accumulated statistics.""" pearson = _pearson_corrcoef_compute(var_x, var_y, corr_xy, nb) + var_x = var_x / (nb - 1) + var_y = var_y / (nb - 1) return 2.0 * pearson * var_x.sqrt() * var_y.sqrt() / (var_x + var_y + (mean_x - mean_y) ** 2) diff --git a/src/torchmetrics/functional/regression/kl_divergence.py b/src/torchmetrics/functional/regression/kl_divergence.py index 43fed914a59..6e6563aee71 100644 --- a/src/torchmetrics/functional/regression/kl_divergence.py +++ b/src/torchmetrics/functional/regression/kl_divergence.py @@ -20,7 +20,6 @@ from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.compute import _safe_xlogy -from torchmetrics.utilities.prints import rank_zero_warn def _kld_update(p: Tensor, q: Tensor, log_prob: bool) -> Tuple[Tensor, int]: @@ -92,14 +91,6 @@ def kl_divergence( over data and :math:`Q` is often a prior or approximation of :math:`P`. It should be noted that the KL divergence is a non-symmetrical metric i.e. :math:`D_{KL}(P||Q) \neq D_{KL}(Q||P)`. - .. warning:: - The input order and naming in metric ``kl_divergence`` is set to be deprecated in v1.4 and changed in v1.5. - Input argument ``p`` will be renamed to ``target`` and will be moved to be the second argument of the metric. - Input argument ``q`` will be renamed to ``preds`` and will be moved to the first argument of the metric. - Thus, ``kl_divergence(p, q)`` will equal ``kl_divergence(target=q, preds=p)`` in the future to be consistent - with the rest of ``torchmetrics``. From v1.4 the two new arguments will be added as keyword arguments and - from v1.5 the two old arguments will be removed. - Args: p: data distribution with shape ``[N, d]`` q: prior or approximate distribution with shape ``[N, d]`` @@ -120,14 +111,5 @@ def kl_divergence( tensor(0.0853) """ - rank_zero_warn( - "The input order and naming in metric `kl_divergence` is set to be deprecated in v1.4 and changed in v1.5." - "Input argument `p` will be renamed to `target` and will be moved to be the second argument of the metric." - "Input argument `q` will be renamed to `preds` and will be moved to the first argument of the metric." - "Thus, `kl_divergence(p, q)` will equal `kl_divergence(target=q, preds=p)` in the future to be consistent with" - " the rest of torchmetrics. From v1.4 the two new arguments will be added as keyword arguments and from v1.5" - " the two old arguments will be removed.", - DeprecationWarning, - ) measures, total = _kld_update(p, q, log_prob) return _kld_compute(measures, total, reduction) diff --git a/src/torchmetrics/functional/regression/nrmse.py b/src/torchmetrics/functional/regression/nrmse.py new file mode 100644 index 00000000000..52cae36adb0 --- /dev/null +++ b/src/torchmetrics/functional/regression/nrmse.py @@ -0,0 +1,106 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple, Union + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.regression.mse import _mean_squared_error_update + + +def _normalized_root_mean_squared_error_update( + preds: Tensor, target: Tensor, num_outputs: int, normalization: Literal["mean", "range", "std", "l2"] = "mean" +) -> Tuple[Tensor, int, Tensor]: + """Updates and returns the sum of squared errors and the number of observations for NRMSE computation. + + Args: + preds: Predicted tensor + target: Ground truth tensor + num_outputs: Number of outputs in multioutput setting + normalization: type of normalization to be applied. Choose from "mean", "range", "std", "l2" + + """ + sum_squared_error, num_obs = _mean_squared_error_update(preds, target, num_outputs) + + target = target.view(-1) if num_outputs == 1 else target + if normalization == "mean": + denom = torch.mean(target, dim=0) + elif normalization == "range": + denom = torch.max(target, dim=0).values - torch.min(target, dim=0).values + elif normalization == "std": + denom = torch.std(target, correction=0, dim=0) + elif normalization == "l2": + denom = torch.norm(target, p=2, dim=0) + else: + raise ValueError( + f"Argument `normalization` should be either 'mean', 'range', 'std' or 'l2' but got {normalization}" + ) + return sum_squared_error, num_obs, denom + + +def _normalized_root_mean_squared_error_compute( + sum_squared_error: Tensor, num_obs: Union[int, Tensor], denom: Tensor +) -> Tensor: + """Calculates RMSE and normalizes it.""" + rmse = torch.sqrt(sum_squared_error / num_obs) + return rmse / denom + + +def normalized_root_mean_squared_error( + preds: Tensor, + target: Tensor, + normalization: Literal["mean", "range", "std", "l2"] = "mean", + num_outputs: int = 1, +) -> Tensor: + """Calculates the `Normalized Root Mean Squared Error`_ (NRMSE) also know as scatter index. + + Args: + preds: estimated labels + target: ground truth labels + normalization: type of normalization to be applied. Choose from "mean", "range", "std", "l2" which corresponds + to normalizing the RMSE by the mean of the target, the range of the target, the standard deviation of the + target or the L2 norm of the target. + num_outputs: Number of outputs in multioutput setting + + Return: + Tensor with the NRMSE score + + Example: + >>> import torch + >>> from torchmetrics.functional.regression import normalized_root_mean_squared_error + >>> preds = torch.tensor([0., 1, 2, 3]) + >>> target = torch.tensor([0., 1, 2, 2]) + >>> normalized_root_mean_squared_error(preds, target, normalization="mean") + tensor(0.4000) + >>> normalized_root_mean_squared_error(preds, target, normalization="range") + tensor(0.2500) + >>> normalized_root_mean_squared_error(preds, target, normalization="std") + tensor(0.6030) + >>> normalized_root_mean_squared_error(preds, target, normalization="l2") + tensor(0.1667) + + Example (multioutput): + >>> import torch + >>> from torchmetrics.functional.regression import normalized_root_mean_squared_error + >>> preds = torch.tensor([[0., 1], [2, 3], [4, 5], [6, 7]]) + >>> target = torch.tensor([[0., 1], [3, 3], [4, 5], [8, 9]]) + >>> normalized_root_mean_squared_error(preds, target, normalization="mean", num_outputs=2) + tensor([0.2981, 0.2222]) + + """ + sum_squared_error, num_obs, denom = _normalized_root_mean_squared_error_update( + preds, target, num_outputs=num_outputs, normalization=normalization + ) + return _normalized_root_mean_squared_error_compute(sum_squared_error, num_obs, denom) diff --git a/src/torchmetrics/functional/regression/pearson.py b/src/torchmetrics/functional/regression/pearson.py index c98bc65a85c..47b26344163 100644 --- a/src/torchmetrics/functional/regression/pearson.py +++ b/src/torchmetrics/functional/regression/pearson.py @@ -92,9 +92,10 @@ def _pearson_corrcoef_compute( nb: number of observations """ - var_x /= nb - 1 - var_y /= nb - 1 - corr_xy /= nb - 1 + # prevent overwrite the inputs + var_x = var_x / (nb - 1) + var_y = var_y / (nb - 1) + corr_xy = corr_xy / (nb - 1) # if var_x, var_y is float16 and on cpu, make it bfloat16 as sqrt is not supported for float16 # on cpu, remove this after https://github.com/pytorch/pytorch/issues/54774 is fixed if var_x.dtype == torch.float16 and var_x.device == torch.device("cpu"): diff --git a/src/torchmetrics/functional/regression/wmape.py b/src/torchmetrics/functional/regression/wmape.py index e461951c695..443badba325 100644 --- a/src/torchmetrics/functional/regression/wmape.py +++ b/src/torchmetrics/functional/regression/wmape.py @@ -74,10 +74,9 @@ def weighted_mean_absolute_percentage_error(preds: Tensor, target: Tensor) -> Te Tensor with WMAPE. Example: - >>> import torch - >>> _ = torch.manual_seed(42) - >>> preds = torch.randn(20,) - >>> target = torch.randn(20,) + >>> from torch import randn + >>> preds = randn(20,) + >>> target = randn(20,) >>> weighted_mean_absolute_percentage_error(preds, target) tensor(1.3967) diff --git a/src/torchmetrics/functional/segmentation/__init__.py b/src/torchmetrics/functional/segmentation/__init__.py index 3d23192a36a..068bf77d775 100644 --- a/src/torchmetrics/functional/segmentation/__init__.py +++ b/src/torchmetrics/functional/segmentation/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.functional.segmentation.generalized_dice import generalized_dice_score +from torchmetrics.functional.segmentation.hausdorff_distance import hausdorff_distance from torchmetrics.functional.segmentation.mean_iou import mean_iou -__all__ = ["generalized_dice_score", "mean_iou"] +__all__ = ["generalized_dice_score", "mean_iou", "hausdorff_distance"] diff --git a/src/torchmetrics/functional/segmentation/generalized_dice.py b/src/torchmetrics/functional/segmentation/generalized_dice.py index 04f28584b10..47c5f30964b 100644 --- a/src/torchmetrics/functional/segmentation/generalized_dice.py +++ b/src/torchmetrics/functional/segmentation/generalized_dice.py @@ -125,11 +125,10 @@ def generalized_dice_score( The Generalized Dice Score Example: - >>> import torch - >>> _ = torch.manual_seed(42) + >>> from torch import randint >>> from torchmetrics.functional.segmentation import generalized_dice_score - >>> preds = torch.randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction - >>> target = torch.randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target + >>> preds = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction + >>> target = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target >>> generalized_dice_score(preds, target, num_classes=5) tensor([0.4830, 0.4935, 0.5044, 0.4880]) >>> generalized_dice_score(preds, target, num_classes=5, per_class=True) diff --git a/src/torchmetrics/functional/segmentation/hausdorff_distance.py b/src/torchmetrics/functional/segmentation/hausdorff_distance.py new file mode 100644 index 00000000000..daadc90f6ba --- /dev/null +++ b/src/torchmetrics/functional/segmentation/hausdorff_distance.py @@ -0,0 +1,114 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Literal, Optional, Union + +import torch +from torch import Tensor + +from torchmetrics.functional.segmentation.utils import ( + _ignore_background, + edge_surface_distance, +) +from torchmetrics.utilities.checks import _check_same_shape + + +def _hausdorff_distance_validate_args( + num_classes: int, + include_background: bool, + distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean", + spacing: Optional[Union[Tensor, List[float]]] = None, + directed: bool = False, + input_format: Literal["one-hot", "index"] = "one-hot", +) -> None: + """Validate the arguments of `hausdorff_distance` function.""" + if num_classes <= 0: + raise ValueError(f"Expected argument `num_classes` must be a positive integer, but got {num_classes}.") + if not isinstance(include_background, bool): + raise ValueError(f"Expected argument `include_background` must be a boolean, but got {include_background}.") + if distance_metric not in ["euclidean", "chessboard", "taxicab"]: + raise ValueError( + f"Arg `distance_metric` must be one of 'euclidean', 'chessboard', 'taxicab', but got {distance_metric}." + ) + if spacing is not None and not isinstance(spacing, (list, Tensor)): + raise ValueError(f"Arg `spacing` must be a list or tensor, but got {type(spacing)}.") + if not isinstance(directed, bool): + raise ValueError(f"Expected argument `directed` must be a boolean, but got {directed}.") + if input_format not in ["one-hot", "index"]: + raise ValueError(f"Expected argument `input_format` to be one of 'one-hot', 'index', but got {input_format}.") + + +def hausdorff_distance( + preds: Tensor, + target: Tensor, + num_classes: int, + include_background: bool = False, + distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean", + spacing: Optional[Union[Tensor, List[float]]] = None, + directed: bool = False, + input_format: Literal["one-hot", "index"] = "one-hot", +) -> Tensor: + """Calculate `Hausdorff Distance`_ for semantic segmentation. + + Args: + preds: predicted binarized segmentation map + target: target binarized segmentation map + num_classes: number of classes + include_background: whether to include background class in calculation + distance_metric: distance metric to calculate surface distance. Choose one of `"euclidean"`, + `"chessboard"` or `"taxicab"` + spacing: spacing between pixels along each spatial dimension. If not provided the spacing is assumed to be 1 + directed: whether to calculate directed or undirected Hausdorff distance + input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors + or ``"index"`` for index tensors + + Returns: + Hausdorff Distance for each class and batch element + + Example: + >>> from torch import randint + >>> from torchmetrics.functional.segmentation import hausdorff_distance + >>> preds = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction + >>> target = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target + >>> hausdorff_distance(preds, target, num_classes=5) + tensor([[2.0000, 1.4142, 2.0000, 2.0000], + [1.4142, 2.0000, 2.0000, 2.0000], + [2.0000, 2.0000, 1.4142, 2.0000], + [2.0000, 2.8284, 2.0000, 2.2361]]) + + """ + _hausdorff_distance_validate_args(num_classes, include_background, distance_metric, spacing, directed, input_format) + _check_same_shape(preds, target) + + if input_format == "index": + preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1) + target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1) + + if not include_background: + preds, target = _ignore_background(preds, target) + + distances = torch.zeros(preds.shape[0], preds.shape[1], device=preds.device) + + # TODO: add support for batched inputs + for b in range(preds.shape[0]): + for c in range(preds.shape[1]): + dist = edge_surface_distance( + preds=preds[b, c], + target=target[b, c], + distance_metric=distance_metric, + spacing=spacing, + symmetric=not directed, + ) + distances[b, c] = torch.max(dist) if directed else torch.max(dist[0].max(), dist[1].max()) # type: ignore + return distances diff --git a/src/torchmetrics/functional/segmentation/mean_iou.py b/src/torchmetrics/functional/segmentation/mean_iou.py index 278257d04b1..184e9578cab 100644 --- a/src/torchmetrics/functional/segmentation/mean_iou.py +++ b/src/torchmetrics/functional/segmentation/mean_iou.py @@ -97,11 +97,10 @@ def mean_iou( The mean IoU score Example: - >>> import torch - >>> _ = torch.manual_seed(42) + >>> from torch import randint >>> from torchmetrics.functional.segmentation import mean_iou - >>> preds = torch.randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction - >>> target = torch.randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target + >>> preds = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction + >>> target = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target >>> mean_iou(preds, target, num_classes=5) tensor([0.3193, 0.3305, 0.3382, 0.3246]) >>> mean_iou(preds, target, num_classes=5, per_class=True) diff --git a/src/torchmetrics/functional/segmentation/utils.py b/src/torchmetrics/functional/segmentation/utils.py index 6c2fed92df2..59d42e16171 100644 --- a/src/torchmetrics/functional/segmentation/utils.py +++ b/src/torchmetrics/functional/segmentation/utils.py @@ -32,7 +32,7 @@ def _ignore_background(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: def check_if_binarized(x: Tensor) -> None: - """Check if the input is binarized. + """Check if tensor is binarized. Example: >>> from torchmetrics.functional.segmentation.utils import check_if_binarized @@ -200,9 +200,8 @@ def distance_transform( Args: x: The binary tensor to calculate the distance transform of. - sampling: Only relevant when distance is calculated using the euclidean distance. The sampling refers to the - pixel spacing in the image, i.e. the distance between two adjacent pixels. If not provided, the pixel - spacing is assumed to be 1. + sampling: The sampling refers to the pixel spacing in the image, i.e. the distance between two adjacent pixels. + If not provided, the pixel spacing is assumed to be 1. metric: The distance to use for the distance transform. Can be one of ``"euclidean"``, ``"chessboard"`` or ``"taxicab"``. engine: The engine to use for the distance transform. Can be one of ``["pytorch", "scipy"]``. In general, @@ -249,25 +248,25 @@ def distance_transform( raise ValueError(f"Expected argument `sampling` to have length 2 but got length `{len(sampling)}`.") if engine == "pytorch": + x = x.float() # calculate distance from every foreground pixel to every background pixel i0, j0 = torch.where(x == 0) i1, j1 = torch.where(x == 1) - dis_row = (i1.unsqueeze(1) - i0.unsqueeze(0)).abs_().mul_(sampling[0]) - dis_col = (j1.unsqueeze(1) - j0.unsqueeze(0)).abs_().mul_(sampling[1]) + dis_row = (i1.view(-1, 1) - i0.view(1, -1)).abs() + dis_col = (j1.view(-1, 1) - j0.view(1, -1)).abs() # # calculate distance h, _ = x.shape if metric == "euclidean": - dis_row = dis_row.float() - dis_row.pow_(2).add_(dis_col.pow_(2)).sqrt_() + dis = ((sampling[0] * dis_row) ** 2 + (sampling[1] * dis_col) ** 2).sqrt() if metric == "chessboard": - dis_row = dis_row.max(dis_col) + dis = torch.max(sampling[0] * dis_row, sampling[1] * dis_col).float() if metric == "taxicab": - dis_row.add_(dis_col) + dis = (sampling[0] * dis_row + sampling[1] * dis_col).float() # select only the closest distance - mindis, _ = torch.min(dis_row, dim=1) - z = torch.zeros_like(x, dtype=mindis.dtype).view(-1) + mindis, _ = torch.min(dis, dim=1) + z = torch.zeros_like(x).view(-1) z[i1 * h + j1] = mindis return z.view(x.shape) @@ -279,7 +278,7 @@ def distance_transform( if metric == "euclidean": return ndimage.distance_transform_edt(x.cpu().numpy(), sampling) - return ndimage.distance_transform_cdt(x.cpu().numpy(), metric=metric) + return ndimage.distance_transform_cdt(x.cpu().numpy(), sampling, metric=metric) def mask_edges( @@ -390,6 +389,38 @@ def surface_distance( return dis[preds] +def edge_surface_distance( + preds: Tensor, + target: Tensor, + distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean", + spacing: Optional[Union[Tensor, List[float]]] = None, + symmetric: bool = False, +) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """Extracts the edges from the input masks and calculates the surface distance between them. + + Args: + preds: The predicted binary edge mask. + target: The target binary edge mask. + distance_metric: The distance metric to use. One of `["euclidean", "chessboard", "taxicab"]`. + spacing: The spacing between pixels along each spatial dimension. + symmetric: Whether to calculate the symmetric distance between the edges. + + Returns: + A tensor with length equal to the number of edges in predictions e.g. `preds.sum()`. Each element is the + distance from the corresponding edge in `preds` to the closest edge in `target`. If `symmetric` is `True`, the + function returns a tuple containing the distances from the predicted edges to the target edges and vice versa. + + """ + output = mask_edges(preds, target) + edges_preds, edges_target = output[0].bool(), output[1].bool() + if symmetric: + return ( + surface_distance(edges_preds, edges_target, distance_metric=distance_metric, spacing=spacing), + surface_distance(edges_target, edges_preds, distance_metric=distance_metric, spacing=spacing), + ) + return surface_distance(edges_preds, edges_target, distance_metric=distance_metric, spacing=spacing) + + @functools.lru_cache def get_neighbour_tables( spacing: Union[Tuple[int, int], Tuple[int, int, int]], device: Optional[torch.device] = None diff --git a/src/torchmetrics/functional/shape/__init__.py b/src/torchmetrics/functional/shape/__init__.py new file mode 100644 index 00000000000..7cf4118b053 --- /dev/null +++ b/src/torchmetrics/functional/shape/__init__.py @@ -0,0 +1,16 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics.functional.shape.procrustes import procrustes_disparity + +__all__ = ["procrustes_disparity"] diff --git a/src/torchmetrics/functional/shape/procrustes.py b/src/torchmetrics/functional/shape/procrustes.py new file mode 100644 index 00000000000..08068fd2454 --- /dev/null +++ b/src/torchmetrics/functional/shape/procrustes.py @@ -0,0 +1,66 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple, Union + +import torch +from torch import Tensor, linalg + +from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.utilities.prints import rank_zero_warn + + +def procrustes_disparity( + point_cloud1: Tensor, point_cloud2: Tensor, return_all: bool = False +) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor]]: + """Runs procrustrus analysis on a batch of data points. + + Works similar ``scipy.spatial.procrustes`` but for batches of data points. + + Args: + point_cloud1: The first set of data points + point_cloud2: The second set of data points + return_all: If True, returns the scale and rotation matrices along with the disparity + + """ + _check_same_shape(point_cloud1, point_cloud2) + if point_cloud1.ndim != 3: + raise ValueError( + "Expected both datasets to be 3D tensors of shape (N, M, D), where N is the batch size, M is the number of" + f" data points and D is the dimensionality of the data points, but got {point_cloud1.ndim} dimensions." + ) + + point_cloud1 = point_cloud1 - point_cloud1.mean(dim=1, keepdim=True) + point_cloud2 = point_cloud2 - point_cloud2.mean(dim=1, keepdim=True) + point_cloud1 /= linalg.norm(point_cloud1, dim=[1, 2], keepdim=True) + point_cloud2 /= linalg.norm(point_cloud2, dim=[1, 2], keepdim=True) + + try: + u, w, v = linalg.svd( + torch.matmul(point_cloud2.transpose(1, 2), point_cloud1).transpose(1, 2), full_matrices=False + ) + except Exception as ex: + rank_zero_warn( + f"SVD calculation in procrustes_disparity failed with exception {ex}. Returning 0 disparity and identity" + " scale/rotation.", + UserWarning, + ) + return torch.tensor(0.0), torch.ones(point_cloud1.shape[0]), torch.eye(point_cloud1.shape[2]) + + rotation = torch.matmul(u, v) + scale = w.sum(1, keepdim=True) + point_cloud2 = scale[:, None] * torch.matmul(point_cloud2, rotation.transpose(1, 2)) + disparity = (point_cloud1 - point_cloud2).square().sum(dim=[1, 2]) + if return_all: + return disparity, scale, rotation + return disparity diff --git a/src/torchmetrics/functional/text/_deprecated.py b/src/torchmetrics/functional/text/_deprecated.py index fabfca2c0eb..169c3d5357b 100644 --- a/src/torchmetrics/functional/text/_deprecated.py +++ b/src/torchmetrics/functional/text/_deprecated.py @@ -245,10 +245,9 @@ def _match_error_rate(preds: Union[str, List[str]], target: Union[str, List[str] def _perplexity(preds: Tensor, target: Tensor, ignore_index: Optional[int] = None) -> Tensor: """Wrapper for deprecated import. - >>> import torch - >>> gen = torch.manual_seed(42) - >>> preds = torch.rand(2, 8, 5, generator=gen) - >>> target = torch.randint(5, (2, 8), generator=gen) + >>> from torch import rand, randint + >>> preds = rand(2, 8, 5) + >>> target = randint(5, (2, 8)) >>> target[0, 6:] = -100 >>> _perplexity(preds, target, ignore_index=-100) tensor(5.8540) diff --git a/src/torchmetrics/functional/text/bert.py b/src/torchmetrics/functional/text/bert.py index cfdb8c743b4..71bec857a72 100644 --- a/src/torchmetrics/functional/text/bert.py +++ b/src/torchmetrics/functional/text/bert.py @@ -276,6 +276,7 @@ def bert_score( rescale_with_baseline: bool = False, baseline_path: Optional[str] = None, baseline_url: Optional[str] = None, + truncation: bool = False, ) -> Dict[str, Union[Tensor, List[float], str]]: """`Bert_score Evaluating Text Generation`_ for text similirity matching. @@ -323,6 +324,7 @@ def bert_score( of the files from `BERT_score`_ baseline_path: A path to the user's own local csv/tsv file with the baseline scale. baseline_url: A url path to the user's own csv/tsv file with the baseline scale. + truncation: An indication of whether the input sequences should be truncated to the maximum length. Returns: Python dictionary containing the keys ``precision``, ``recall`` and ``f1`` with corresponding values. @@ -349,11 +351,16 @@ def bert_score( """ if len(preds) != len(target): - raise ValueError("Number of predicted and reference sententes must be the same!") + raise ValueError( + "Expected number of predicted and reference sententes to be the same, but got" + f"{len(preds)} and {len(target)}" + ) if not isinstance(preds, (str, list, dict)): # dict for BERTScore class compute call preds = list(preds) if not isinstance(target, (str, list, dict)): # dict for BERTScore class compute call target = list(target) + if not isinstance(idf, bool): + raise ValueError(f"Expected argument `idf` to be a boolean, but got {idf}.") if verbose and (not _TQDM_AVAILABLE): raise ModuleNotFoundError( @@ -412,13 +419,14 @@ def bert_score( # We ignore mypy typing below as the proper typing is ensured by conditions above, only mypy cannot infer that. if _are_valid_lists: - target_dataset = TextDataset(target, tokenizer, max_length, idf=idf) # type: ignore + target_dataset = TextDataset(target, tokenizer, max_length, idf=idf, truncation=truncation) # type: ignore preds_dataset = TextDataset( preds, # type: ignore tokenizer, max_length, idf=idf, tokens_idf=target_dataset.tokens_idf, + truncation=truncation, ) elif _are_valid_tensors: target_dataset = TokenizedDataset(**target, idf=idf) # type: ignore diff --git a/src/torchmetrics/functional/text/chrf.py b/src/torchmetrics/functional/text/chrf.py index 375355b85cb..ca98778fade 100644 --- a/src/torchmetrics/functional/text/chrf.py +++ b/src/torchmetrics/functional/text/chrf.py @@ -11,26 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# referenced from -# Library Name: torchtext -# Authors: torchtext authors -# Date: 2021-11-25 -# Link: # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Copyright 2017 Maja Popovic - -# The program is distributed under the terms -# of the GNU General Public Licence (GPL) - -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. - -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . +# The code is derived from https://github.com/m-popovic/chrF/blob/6d3c384/chrF%2B%2B.py +# The original author and copyright holder have agreed to relicense the derived code under the Apache License, +# Version 2.0 (the "License") +# Reference to the approval: https://github.com/Lightning-AI/torchmetrics/pull/2701#issuecomment-2316891785 from collections import defaultdict from itertools import chain diff --git a/src/torchmetrics/functional/text/helper_embedding_metric.py b/src/torchmetrics/functional/text/helper_embedding_metric.py index 1ab911b9395..f2b59126c7d 100644 --- a/src/torchmetrics/functional/text/helper_embedding_metric.py +++ b/src/torchmetrics/functional/text/helper_embedding_metric.py @@ -195,10 +195,11 @@ def __init__( tokenizer: Any, max_length: int = 512, preprocess_text_fn: Callable[ - [List[str], Any, int], Union[Dict[str, Tensor], Tuple[Dict[str, Tensor], Optional[Tensor]]] + [List[str], Any, int, bool], Union[Dict[str, Tensor], Tuple[Dict[str, Tensor], Optional[Tensor]]] ] = _preprocess_text, idf: bool = False, tokens_idf: Optional[Dict[int, float]] = None, + truncation: bool = False, ) -> None: """Initialize text dataset class. @@ -209,9 +210,10 @@ def __init__( preprocess_text_fn: A function used for processing the input sentences. idf: An indication of whether calculate token inverse document frequencies to weight the model embeddings. tokens_idf: Inverse document frequencies (these should be calculated on reference sentences). + truncation: An indication of whether tokenized sequences should be padded only to the length of the longest """ - _text = preprocess_text_fn(text, tokenizer, max_length) + _text = preprocess_text_fn(text, tokenizer, max_length, truncation) if isinstance(_text, tuple): self.text, self.sorting_indices = _text else: diff --git a/src/torchmetrics/functional/text/perplexity.py b/src/torchmetrics/functional/text/perplexity.py index 39f832905cf..1561ffa412a 100644 --- a/src/torchmetrics/functional/text/perplexity.py +++ b/src/torchmetrics/functional/text/perplexity.py @@ -130,10 +130,9 @@ def perplexity(preds: Tensor, target: Tensor, ignore_index: Optional[int] = None Perplexity value Examples: - >>> import torch - >>> gen = torch.manual_seed(42) - >>> preds = torch.rand(2, 8, 5, generator=gen) - >>> target = torch.randint(5, (2, 8), generator=gen) + >>> from torch import rand, randint + >>> preds = rand(2, 8, 5) + >>> target = randint(5, (2, 8)) >>> target[0, 6:] = -100 >>> perplexity(preds, target, ignore_index=-100) tensor(5.8540) diff --git a/src/torchmetrics/functional/text/rouge.py b/src/torchmetrics/functional/text/rouge.py index baa0b6a018c..58c9a05fecf 100644 --- a/src/torchmetrics/functional/text/rouge.py +++ b/src/torchmetrics/functional/text/rouge.py @@ -48,10 +48,10 @@ def _ensure_nltk_punkt_is_downloaded() -> None: import nltk try: - nltk.data.find("tokenizers/punkt") + nltk.data.find("tokenizers/punkt_tab") except LookupError: try: - nltk.download("punkt", quiet=True, force=False, halt_on_error=False, raise_on_error=True) + nltk.download("punkt_tab", quiet=True, force=False, halt_on_error=False, raise_on_error=True) except ValueError as err: raise OSError( "`nltk` resource `punkt` is not available on a disk and cannot be downloaded as a machine is not " diff --git a/src/torchmetrics/image/_deprecated.py b/src/torchmetrics/image/_deprecated.py index 597f0a0c636..8b382b89cf7 100644 --- a/src/torchmetrics/image/_deprecated.py +++ b/src/torchmetrics/image/_deprecated.py @@ -17,11 +17,11 @@ class _ErrorRelativeGlobalDimensionlessSynthesis(ErrorRelativeGlobalDimensionlessSynthesis): """Wrapper for deprecated import. - >>> import torch - >>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)) + >>> from torch import rand + >>> preds = rand([16, 1, 16, 16]) >>> target = preds * 0.75 >>> ergas = _ErrorRelativeGlobalDimensionlessSynthesis() - >>> torch.round(ergas(preds, target)) + >>> ergas(preds, target).round() tensor(10.) """ @@ -39,12 +39,12 @@ def __init__( class _MultiScaleStructuralSimilarityIndexMeasure(MultiScaleStructuralSimilarityIndexMeasure): """Wrapper for deprecated import. - >>> import torch - >>> preds = torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) + >>> from torch import rand + >>> preds = rand([3, 3, 256, 256]) >>> target = preds * 0.75 >>> ms_ssim = _MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0) >>> ms_ssim(preds, target) - tensor(0.9627) + tensor(0.9628) """ @@ -103,13 +103,12 @@ def __init__( class _RelativeAverageSpectralError(RelativeAverageSpectralError): """Wrapper for deprecated import. - >>> import torch - >>> g = torch.manual_seed(22) - >>> preds = torch.rand(4, 3, 16, 16) - >>> target = torch.rand(4, 3, 16, 16) + >>> from torch import rand + >>> preds = rand(4, 3, 16, 16) + >>> target = rand(4, 3, 16, 16) >>> rase = _RelativeAverageSpectralError() >>> rase(preds, target) - tensor(5114.66...) + tensor(5326.40...) """ @@ -125,13 +124,12 @@ def __init__( class _RootMeanSquaredErrorUsingSlidingWindow(RootMeanSquaredErrorUsingSlidingWindow): """Wrapper for deprecated import. - >>> import torch - >>> g = torch.manual_seed(22) - >>> preds = torch.rand(4, 3, 16, 16) - >>> target = torch.rand(4, 3, 16, 16) + >>> from torch import rand + >>> preds = rand(4, 3, 16, 16) + >>> target = rand(4, 3, 16, 16) >>> rmse_sw = RootMeanSquaredErrorUsingSlidingWindow() >>> rmse_sw(preds, target) - tensor(0.3999) + tensor(0.4158) """ @@ -147,10 +145,9 @@ def __init__( class _SpectralAngleMapper(SpectralAngleMapper): """Wrapper for deprecated import. - >>> import torch - >>> gen = torch.manual_seed(42) - >>> preds = torch.rand([16, 3, 16, 16], generator=gen) - >>> target = torch.rand([16, 3, 16, 16], generator=gen) + >>> from torch import rand + >>> preds = rand([16, 3, 16, 16]) + >>> target = rand([16, 3, 16, 16]) >>> sam = _SpectralAngleMapper() >>> sam(preds, target) tensor(0.5914) @@ -169,10 +166,9 @@ def __init__( class _SpectralDistortionIndex(SpectralDistortionIndex): """Wrapper for deprecated import. - >>> import torch - >>> _ = torch.manual_seed(42) - >>> preds = torch.rand([16, 3, 16, 16]) - >>> target = torch.rand([16, 3, 16, 16]) + >>> from torch import rand + >>> preds = rand([16, 3, 16, 16]) + >>> target = rand([16, 3, 16, 16]) >>> sdi = _SpectralDistortionIndex() >>> sdi(preds, target) tensor(0.0234) @@ -229,10 +225,9 @@ def __init__( class _TotalVariation(TotalVariation): """Wrapper for deprecated import. - >>> import torch - >>> _ = torch.manual_seed(42) + >>> from torch import rand >>> tv = _TotalVariation() - >>> img = torch.rand(5, 3, 28, 28) + >>> img = rand(5, 3, 28, 28) >>> tv(img) tensor(7546.8018) diff --git a/src/torchmetrics/image/d_lambda.py b/src/torchmetrics/image/d_lambda.py index ff232569748..5b1e58de0dd 100644 --- a/src/torchmetrics/image/d_lambda.py +++ b/src/torchmetrics/image/d_lambda.py @@ -53,11 +53,10 @@ class SpectralDistortionIndex(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: - >>> import torch - >>> _ = torch.manual_seed(42) + >>> from torch import rand >>> from torchmetrics.image import SpectralDistortionIndex - >>> preds = torch.rand([16, 3, 16, 16]) - >>> target = torch.rand([16, 3, 16, 16]) + >>> preds = rand([16, 3, 16, 16]) + >>> target = rand([16, 3, 16, 16]) >>> sdi = SpectralDistortionIndex() >>> sdi(preds, target) tensor(0.0234) @@ -126,11 +125,10 @@ def plot( :scale: 75 >>> # Example plotting a single value - >>> import torch - >>> _ = torch.manual_seed(42) + >>> from torch import rand >>> from torchmetrics.image import SpectralDistortionIndex - >>> preds = torch.rand([16, 3, 16, 16]) - >>> target = torch.rand([16, 3, 16, 16]) + >>> preds = rand([16, 3, 16, 16]) + >>> target = rand([16, 3, 16, 16]) >>> metric = SpectralDistortionIndex() >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot() @@ -139,11 +137,10 @@ def plot( :scale: 75 >>> # Example plotting multiple values - >>> import torch - >>> _ = torch.manual_seed(42) + >>> from torch import rand >>> from torchmetrics.image import SpectralDistortionIndex - >>> preds = torch.rand([16, 3, 16, 16]) - >>> target = torch.rand([16, 3, 16, 16]) + >>> preds = rand([16, 3, 16, 16]) + >>> target = rand([16, 3, 16, 16]) >>> metric = SpectralDistortionIndex() >>> values = [ ] >>> for _ in range(10): diff --git a/src/torchmetrics/image/d_s.py b/src/torchmetrics/image/d_s.py index 434561c558a..a8dacbaf447 100644 --- a/src/torchmetrics/image/d_s.py +++ b/src/torchmetrics/image/d_s.py @@ -74,13 +74,12 @@ class SpatialDistortionIndex(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: - >>> import torch - >>> _ = torch.manual_seed(42) + >>> from torch import rand >>> from torchmetrics.image import SpatialDistortionIndex - >>> preds = torch.rand([16, 3, 32, 32]) + >>> preds = rand([16, 3, 32, 32]) >>> target = { - ... 'ms': torch.rand([16, 3, 16, 16]), - ... 'pan': torch.rand([16, 3, 32, 32]), + ... 'ms': rand([16, 3, 16, 16]), + ... 'pan': rand([16, 3, 32, 32]), ... } >>> sdi = SpatialDistortionIndex() >>> sdi(preds, target) @@ -191,13 +190,12 @@ def plot( :scale: 75 >>> # Example plotting a single value - >>> import torch - >>> _ = torch.manual_seed(42) + >>> from torch import rand >>> from torchmetrics.image import SpatialDistortionIndex - >>> preds = torch.rand([16, 3, 32, 32]) + >>> preds = rand([16, 3, 32, 32]) >>> target = { - ... 'ms': torch.rand([16, 3, 16, 16]), - ... 'pan': torch.rand([16, 3, 32, 32]), + ... 'ms': rand([16, 3, 16, 16]), + ... 'pan': rand([16, 3, 32, 32]), ... } >>> metric = SpatialDistortionIndex() >>> metric.update(preds, target) @@ -207,13 +205,12 @@ def plot( :scale: 75 >>> # Example plotting multiple values - >>> import torch - >>> _ = torch.manual_seed(42) + >>> from torch import rand >>> from torchmetrics.image import SpatialDistortionIndex - >>> preds = torch.rand([16, 3, 32, 32]) + >>> preds = rand([16, 3, 32, 32]) >>> target = { - ... 'ms': torch.rand([16, 3, 16, 16]), - ... 'pan': torch.rand([16, 3, 32, 32]), + ... 'ms': rand([16, 3, 16, 16]), + ... 'pan': rand([16, 3, 32, 32]), ... } >>> metric = SpatialDistortionIndex() >>> values = [ ] diff --git a/src/torchmetrics/image/ergas.py b/src/torchmetrics/image/ergas.py index f1d8da2073e..bf6b1c99a10 100644 --- a/src/torchmetrics/image/ergas.py +++ b/src/torchmetrics/image/ergas.py @@ -62,12 +62,12 @@ class ErrorRelativeGlobalDimensionlessSynthesis(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: - >>> import torch + >>> from torch import rand >>> from torchmetrics.image import ErrorRelativeGlobalDimensionlessSynthesis - >>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)) + >>> preds = rand([16, 1, 16, 16]) >>> target = preds * 0.75 >>> ergas = ErrorRelativeGlobalDimensionlessSynthesis() - >>> torch.round(ergas(preds, target)) + >>> ergas(preds, target).round() tensor(10.) """ @@ -131,9 +131,9 @@ def plot( :scale: 75 >>> # Example plotting a single value - >>> import torch + >>> from torch import rand >>> from torchmetrics.image import ErrorRelativeGlobalDimensionlessSynthesis - >>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)) + >>> preds = rand([16, 1, 16, 16]) >>> target = preds * 0.75 >>> metric = ErrorRelativeGlobalDimensionlessSynthesis() >>> metric.update(preds, target) @@ -143,9 +143,9 @@ def plot( :scale: 75 >>> # Example plotting multiple values - >>> import torch + >>> from torch import rand >>> from torchmetrics.image import ErrorRelativeGlobalDimensionlessSynthesis - >>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)) + >>> preds = rand([16, 1, 16, 16]) >>> target = preds * 0.75 >>> metric = ErrorRelativeGlobalDimensionlessSynthesis() >>> values = [ ] diff --git a/src/torchmetrics/image/fid.py b/src/torchmetrics/image/fid.py index af4b93d7ad5..8c2e5d3cf76 100644 --- a/src/torchmetrics/image/fid.py +++ b/src/torchmetrics/image/fid.py @@ -265,8 +265,7 @@ class FrechetInceptionDistance(Metric): If ``reset_real_features`` is not an ``bool`` Example: - >>> import torch - >>> _ = torch.manual_seed(123) + >>> from torch import rand >>> from torchmetrics.image.fid import FrechetInceptionDistance >>> fid = FrechetInceptionDistance(feature=64) >>> # generate two slightly overlapping image intensity distributions @@ -275,7 +274,7 @@ class FrechetInceptionDistance(Metric): >>> fid.update(imgs_dist1, real=True) >>> fid.update(imgs_dist2, real=False) >>> fid.compute() - tensor(12.7202) + tensor(12.6388) """ diff --git a/src/torchmetrics/image/inception.py b/src/torchmetrics/image/inception.py index 4913a5fe50b..2b125542a17 100644 --- a/src/torchmetrics/image/inception.py +++ b/src/torchmetrics/image/inception.py @@ -84,15 +84,14 @@ class InceptionScore(Metric): If ``feature`` is not an ``str``, ``int`` or ``torch.nn.Module`` Example: - >>> import torch - >>> _ = torch.manual_seed(123) + >>> from torch import rand >>> from torchmetrics.image.inception import InceptionScore >>> inception = InceptionScore() >>> # generate some images >>> imgs = torch.randint(0, 255, (100, 3, 299, 299), dtype=torch.uint8) >>> inception.update(imgs) >>> inception.compute() - (tensor(1.0544), tensor(0.0117)) + (tensor(1.0549), tensor(0.0121)) """ diff --git a/src/torchmetrics/image/kid.py b/src/torchmetrics/image/kid.py index e080116a33f..975edf800eb 100644 --- a/src/torchmetrics/image/kid.py +++ b/src/torchmetrics/image/kid.py @@ -148,17 +148,16 @@ class KernelInceptionDistance(Metric): If ``reset_real_features`` is not an ``bool`` Example: - >>> import torch - >>> _ = torch.manual_seed(123) + >>> from torch import randint >>> from torchmetrics.image.kid import KernelInceptionDistance >>> kid = KernelInceptionDistance(subset_size=50) >>> # generate two slightly overlapping image intensity distributions - >>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) - >>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) + >>> imgs_dist1 = randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) + >>> imgs_dist2 = randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) >>> kid.update(imgs_dist1, real=True) >>> kid.update(imgs_dist2, real=False) >>> kid.compute() - (tensor(0.0337), tensor(0.0023)) + (tensor(0.0312), tensor(0.0025)) """ diff --git a/src/torchmetrics/image/lpip.py b/src/torchmetrics/image/lpip.py index c094003fa4f..f1adb73648a 100644 --- a/src/torchmetrics/image/lpip.py +++ b/src/torchmetrics/image/lpip.py @@ -78,15 +78,14 @@ class LearnedPerceptualImagePatchSimilarity(Metric): If ``reduction`` is not one of ``"mean"`` or ``"sum"`` Example: - >>> import torch - >>> _ = torch.manual_seed(123) + >>> from torch import rand >>> from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity >>> lpips = LearnedPerceptualImagePatchSimilarity(net_type='squeeze') >>> # LPIPS needs the images to be in the [-1, 1] range. - >>> img1 = (torch.rand(10, 3, 100, 100) * 2) - 1 - >>> img2 = (torch.rand(10, 3, 100, 100) * 2) - 1 + >>> img1 = (rand(10, 3, 100, 100) * 2) - 1 + >>> img2 = (rand(10, 3, 100, 100) * 2) - 1 >>> lpips(img1, img2) - tensor(0.1046) + tensor(0.1024) """ diff --git a/src/torchmetrics/image/mifid.py b/src/torchmetrics/image/mifid.py index 8c1084e6458..7d7237b7416 100644 --- a/src/torchmetrics/image/mifid.py +++ b/src/torchmetrics/image/mifid.py @@ -129,13 +129,12 @@ class MemorizationInformedFrechetInceptionDistance(Metric): If ``reset_real_features`` is not an ``bool`` Example:: - >>> import torch - >>> _ = torch.manual_seed(42) + >>> from torch import randint >>> from torchmetrics.image.mifid import MemorizationInformedFrechetInceptionDistance >>> mifid = MemorizationInformedFrechetInceptionDistance(feature=64) >>> # generate two slightly overlapping image intensity distributions - >>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) - >>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) + >>> imgs_dist1 = randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) + >>> imgs_dist2 = randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) >>> mifid.update(imgs_dist1, real=True) >>> mifid.update(imgs_dist2, real=False) >>> mifid.compute() diff --git a/src/torchmetrics/image/perceptual_path_length.py b/src/torchmetrics/image/perceptual_path_length.py index 6b89b36b1c1..4c312404934 100644 --- a/src/torchmetrics/image/perceptual_path_length.py +++ b/src/torchmetrics/image/perceptual_path_length.py @@ -98,9 +98,7 @@ class PerceptualPathLength(Metric): If ``upper_discard`` is not a float between 0 and 1 or None. Example:: - >>> from torchmetrics.image import PerceptualPathLength >>> import torch - >>> _ = torch.manual_seed(42) >>> class DummyGenerator(torch.nn.Module): ... def __init__(self, z_size) -> None: ... super().__init__() @@ -112,10 +110,8 @@ class PerceptualPathLength(Metric): ... return torch.randn(num_samples, self.z_size) >>> generator = DummyGenerator(2) >>> ppl = PerceptualPathLength(num_samples=10) - >>> ppl(generator) # doctest: +SKIP - (tensor(0.2371), - tensor(0.1763), - tensor([0.3502, 0.1362, 0.2535, 0.0902, 0.1784, 0.0769, 0.5871, 0.0691, 0.3921])) + >>> ppl(generator) + (tensor(...), tensor(...), tensor([...])) """ diff --git a/src/torchmetrics/image/psnrb.py b/src/torchmetrics/image/psnrb.py index dc050e81fb8..bac58b84e46 100644 --- a/src/torchmetrics/image/psnrb.py +++ b/src/torchmetrics/image/psnrb.py @@ -51,12 +51,10 @@ class PeakSignalNoiseRatioWithBlockedEffect(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: - >>> import torch - >>> from torchmetrics.image import PeakSignalNoiseRatioWithBlockedEffect + >>> from torch import rand >>> metric = PeakSignalNoiseRatioWithBlockedEffect() - >>> _ = torch.manual_seed(42) - >>> preds = torch.rand(2, 1, 10, 10) - >>> target = torch.rand(2, 1, 10, 10) + >>> preds = rand(2, 1, 10, 10) + >>> target = rand(2, 1, 10, 10) >>> metric(preds, target) tensor(7.2893) diff --git a/src/torchmetrics/image/qnr.py b/src/torchmetrics/image/qnr.py index 96023db8aab..b226a6b19fd 100644 --- a/src/torchmetrics/image/qnr.py +++ b/src/torchmetrics/image/qnr.py @@ -70,13 +70,12 @@ class QualityWithNoReference(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: - >>> import torch - >>> _ = torch.manual_seed(42) + >>> from torch import rand >>> from torchmetrics.image import QualityWithNoReference - >>> preds = torch.rand([16, 3, 32, 32]) + >>> preds = rand([16, 3, 32, 32]) >>> target = { - ... 'ms': torch.rand([16, 3, 16, 16]), - ... 'pan': torch.rand([16, 3, 32, 32]), + ... 'ms': rand([16, 3, 16, 16]), + ... 'pan': rand([16, 3, 32, 32]), ... } >>> qnr = QualityWithNoReference() >>> qnr(preds, target) @@ -195,13 +194,12 @@ def plot( :scale: 75 >>> # Example plotting a single value - >>> import torch - >>> _ = torch.manual_seed(42) + >>> from torch import rand >>> from torchmetrics.image import QualityWithNoReference - >>> preds = torch.rand([16, 3, 32, 32]) + >>> preds = rand([16, 3, 32, 32]) >>> target = { - ... 'ms': torch.rand([16, 3, 16, 16]), - ... 'pan': torch.rand([16, 3, 32, 32]), + ... 'ms': rand([16, 3, 16, 16]), + ... 'pan': rand([16, 3, 32, 32]), ... } >>> metric = QualityWithNoReference() >>> metric.update(preds, target) @@ -211,13 +209,12 @@ def plot( :scale: 75 >>> # Example plotting multiple values - >>> import torch - >>> _ = torch.manual_seed(42) + >>> from torch import rand >>> from torchmetrics.image import QualityWithNoReference - >>> preds = torch.rand([16, 3, 32, 32]) + >>> preds = rand([16, 3, 32, 32]) >>> target = { - ... 'ms': torch.rand([16, 3, 16, 16]), - ... 'pan': torch.rand([16, 3, 32, 32]), + ... 'ms': rand([16, 3, 16, 16]), + ... 'pan': rand([16, 3, 32, 32]), ... } >>> metric = QualityWithNoReference() >>> values = [ ] diff --git a/src/torchmetrics/image/rase.py b/src/torchmetrics/image/rase.py index c422762eb68..b1eb32141a6 100644 --- a/src/torchmetrics/image/rase.py +++ b/src/torchmetrics/image/rase.py @@ -46,14 +46,12 @@ class RelativeAverageSpectralError(Metric): Relative Average Spectral Error (RASE) Example: - >>> import torch - >>> from torchmetrics.image import RelativeAverageSpectralError - >>> g = torch.manual_seed(22) - >>> preds = torch.rand(4, 3, 16, 16) - >>> target = torch.rand(4, 3, 16, 16) + >>> from torch import rand + >>> preds = rand(4, 3, 16, 16) + >>> target = rand(4, 3, 16, 16) >>> rase = RelativeAverageSpectralError() >>> rase(preds, target) - tensor(5114.66...) + tensor(5326.40...) Raises: ValueError: If ``window_size`` is not a positive integer. @@ -124,13 +122,12 @@ def plot( :scale: 75 >>> # Example plotting multiple values - >>> import torch - >>> _ = torch.manual_seed(42) + >>> from torch import rand >>> from torchmetrics.image import RelativeAverageSpectralError >>> metric = RelativeAverageSpectralError() >>> values = [ ] >>> for _ in range(10): - ... values.append(metric(torch.rand(4, 3, 16, 16), torch.rand(4, 3, 16, 16))) + ... values.append(metric(rand(4, 3, 16, 16), rand(4, 3, 16, 16))) >>> fig_, ax_ = metric.plot(values) """ diff --git a/src/torchmetrics/image/rmse_sw.py b/src/torchmetrics/image/rmse_sw.py index e0838ec760d..c1f7c652879 100644 --- a/src/torchmetrics/image/rmse_sw.py +++ b/src/torchmetrics/image/rmse_sw.py @@ -43,13 +43,13 @@ class RootMeanSquaredErrorUsingSlidingWindow(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: + >>> from torch import rand >>> from torchmetrics.image import RootMeanSquaredErrorUsingSlidingWindow - >>> g = torch.manual_seed(22) - >>> preds = torch.rand(4, 3, 16, 16) - >>> target = torch.rand(4, 3, 16, 16) + >>> preds = rand(4, 3, 16, 16) + >>> target = rand(4, 3, 16, 16) >>> rmse_sw = RootMeanSquaredErrorUsingSlidingWindow() >>> rmse_sw(preds, target) - tensor(0.3999) + tensor(0.4158) Raises: ValueError: If ``window_size`` is not a positive integer. diff --git a/src/torchmetrics/image/sam.py b/src/torchmetrics/image/sam.py index 47d4839c5e2..7aad1782852 100644 --- a/src/torchmetrics/image/sam.py +++ b/src/torchmetrics/image/sam.py @@ -56,11 +56,10 @@ class SpectralAngleMapper(Metric): Tensor with SpectralAngleMapper score Example: - >>> import torch + >>> from torch import rand >>> from torchmetrics.image import SpectralAngleMapper - >>> gen = torch.manual_seed(42) - >>> preds = torch.rand([16, 3, 16, 16], generator=gen) - >>> target = torch.rand([16, 3, 16, 16], generator=gen) + >>> preds = rand([16, 3, 16, 16]) + >>> target = rand([16, 3, 16, 16]) >>> sam = SpectralAngleMapper() >>> sam(preds, target) tensor(0.5914) @@ -141,11 +140,10 @@ def plot( :scale: 75 >>> # Example plotting single value - >>> import torch + >>> from torch import rand >>> from torchmetrics.image import SpectralAngleMapper - >>> gen = torch.manual_seed(42) - >>> preds = torch.rand([16, 3, 16, 16], generator=gen) - >>> target = torch.rand([16, 3, 16, 16], generator=gen) + >>> preds = rand([16, 3, 16, 16]) + >>> target = rand([16, 3, 16, 16]) >>> metric = SpectralAngleMapper() >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot() @@ -154,11 +152,10 @@ def plot( :scale: 75 >>> # Example plotting multiple values - >>> import torch + >>> from torch import rand >>> from torchmetrics.image import SpectralAngleMapper - >>> gen = torch.manual_seed(42) - >>> preds = torch.rand([16, 3, 16, 16], generator=gen) - >>> target = torch.rand([16, 3, 16, 16], generator=gen) + >>> preds = rand([16, 3, 16, 16]) + >>> target = rand([16, 3, 16, 16]) >>> metric = SpectralAngleMapper() >>> values = [ ] >>> for _ in range(10): diff --git a/src/torchmetrics/image/scc.py b/src/torchmetrics/image/scc.py index 15ea2b96ecf..fac28658f63 100644 --- a/src/torchmetrics/image/scc.py +++ b/src/torchmetrics/image/scc.py @@ -39,11 +39,10 @@ class SpatialCorrelationCoefficient(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: - >>> import torch - >>> _ = torch.manual_seed(42) + >>> from torch import randn >>> from torchmetrics.image import SpatialCorrelationCoefficient as SCC - >>> preds = torch.randn([32, 3, 64, 64]) - >>> target = torch.randn([32, 3, 64, 64]) + >>> preds = randn([32, 3, 64, 64]) + >>> target = randn([32, 3, 64, 64]) >>> scc = SCC() >>> scc(preds, target) tensor(0.0023) diff --git a/src/torchmetrics/image/ssim.py b/src/torchmetrics/image/ssim.py index 5056589fd14..648f9c26029 100644 --- a/src/torchmetrics/image/ssim.py +++ b/src/torchmetrics/image/ssim.py @@ -268,14 +268,13 @@ class MultiScaleStructuralSimilarityIndexMeasure(Metric): If ``normalize`` is neither `None`, `ReLU` nor `simple`. Example: + >>> from torch import rand >>> from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure - >>> import torch - >>> gen = torch.manual_seed(42) - >>> preds = torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) + >>> preds = torch.rand([3, 3, 256, 256]) >>> target = preds * 0.75 >>> ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0) >>> ms_ssim(preds, target) - tensor(0.9627) + tensor(0.9628) """ @@ -394,9 +393,9 @@ def plot( :scale: 75 >>> # Example plotting a single value + >>> from torch import rand >>> from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure - >>> import torch - >>> preds = torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) + >>> preds = rand([3, 3, 256, 256]) >>> target = preds * 0.75 >>> metric = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0) >>> metric.update(preds, target) @@ -406,9 +405,9 @@ def plot( :scale: 75 >>> # Example plotting multiple values + >>> from torch import rand >>> from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure - >>> import torch - >>> preds = torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) + >>> preds = rand([3, 3, 256, 256]) >>> target = preds * 0.75 >>> metric = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0) >>> values = [ ] diff --git a/src/torchmetrics/image/tv.py b/src/torchmetrics/image/tv.py index 972a07b745d..ca6276b7e86 100644 --- a/src/torchmetrics/image/tv.py +++ b/src/torchmetrics/image/tv.py @@ -53,9 +53,8 @@ class TotalVariation(Metric): If ``reduction`` is not one of ``'sum'``, ``'mean'``, ``'none'`` or ``None`` Example: - >>> import torch + >>> from torch import rand >>> from torchmetrics.image import TotalVariation - >>> _ = torch.manual_seed(42) >>> tv = TotalVariation() >>> img = torch.rand(5, 3, 28, 28) >>> tv(img) diff --git a/src/torchmetrics/image/vif.py b/src/torchmetrics/image/vif.py index 3e844932399..0358cb80e66 100644 --- a/src/torchmetrics/image/vif.py +++ b/src/torchmetrics/image/vif.py @@ -37,11 +37,10 @@ class VisualInformationFidelity(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: - >>> import torch - >>> _ = torch.manual_seed(42) + >>> from torch import randn >>> from torchmetrics.image import VisualInformationFidelity - >>> preds = torch.randn([32, 3, 41, 41]) - >>> target = torch.randn([32, 3, 41, 41]) + >>> preds = randn([32, 3, 41, 41]) + >>> target = randn([32, 3, 41, 41]) >>> vif = VisualInformationFidelity() >>> vif(preds, target) tensor(0.0032) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 8a1f2b7eacf..940e393c6d1 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -284,7 +284,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: """Aggregate and evaluate batch input directly. Serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch - statistics to the overall accumululating metric state. Input arguments are the exact same as corresponding + statistics to the overall accumulating metric state. Input arguments are the exact same as corresponding ``update`` method. The returned output is the exact same as the output of ``compute``. Args: @@ -361,7 +361,7 @@ def _forward_full_state_update(self, *args: Any, **kwargs: Any) -> Any: def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: """Forward computation using single call to `update`. - This can be done when the global metric state is a sinple reduction of batch states. This can be unsafe for + This can be done when the global metric state is a simple reduction of batch states. This can be unsafe for certain metric cases but is also the fastest way to both accumulate globally and compute locally. """ @@ -802,7 +802,7 @@ def _apply(self, fn: Callable, exclude_state: Sequence[str] = "") -> Module: """Overwrite `_apply` function such that we can also move metric states to the correct device. This method is called by the base ``nn.Module`` class whenever `.to`, `.cuda`, `.float`, `.half` etc. methods - are called. Dtype conversion is garded and will only happen through the special `set_dtype` method. + are called. Dtype conversion is guarded and will only happen through the special `set_dtype` method. Args: fn: the function to apply @@ -985,19 +985,19 @@ def __floordiv__(self, other: Union["Metric", builtins.float, Tensor]) -> "Compo """Construct compositional metric using the floor division operator.""" return CompositionalMetric(torch.floor_divide, self, other) - def __ge__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": # type: ignore[misc] + def __ge__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": """Construct compositional metric using the greater than or equal operator.""" return CompositionalMetric(torch.ge, self, other) - def __gt__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": # type: ignore[misc] + def __gt__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": """Construct compositional metric using the greater than operator.""" return CompositionalMetric(torch.gt, self, other) - def __le__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": # type: ignore[misc] + def __le__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": """Construct compositional metric using the less than or equal operator.""" return CompositionalMetric(torch.le, self, other) - def __lt__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": # type: ignore[misc] + def __lt__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": """Construct compositional metric using the less than operator.""" return CompositionalMetric(torch.lt, self, other) @@ -1025,7 +1025,7 @@ def __pow__(self, other: Union["Metric", builtins.float, Tensor]) -> "Compositio """Construct compositional metric using the exponential/power operator.""" return CompositionalMetric(torch.pow, self, other) - def __radd__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": # type: ignore[misc] + def __radd__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": """Construct compositional metric using the addition operator.""" return CompositionalMetric(torch.add, other, self) @@ -1042,11 +1042,11 @@ def __rmatmul__(self, other: Union["Metric", builtins.float, Tensor]) -> "Compos """Construct compositional metric using the matrix multiplication operator.""" return CompositionalMetric(torch.matmul, other, self) - def __rmod__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": # type: ignore[misc] + def __rmod__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": """Construct compositional metric using the remainder operator.""" return CompositionalMetric(torch.fmod, other, self) - def __rmul__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": # type: ignore[misc] + def __rmul__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": """Construct compositional metric using the multiplication operator.""" return CompositionalMetric(torch.mul, other, self) @@ -1058,11 +1058,11 @@ def __rpow__(self, other: Union["Metric", builtins.float, Tensor]) -> "Compositi """Construct compositional metric using the exponential/power operator.""" return CompositionalMetric(torch.pow, other, self) - def __rsub__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": # type: ignore[misc] + def __rsub__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": """Construct compositional metric using the subtraction operator.""" return CompositionalMetric(torch.sub, other, self) - def __rtruediv__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": # type: ignore[misc] + def __rtruediv__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": """Construct compositional metric using the true divide operator.""" return CompositionalMetric(torch.true_divide, other, self) @@ -1166,7 +1166,7 @@ def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Opt """ def update(self, *args: Any, **kwargs: Any) -> None: - """Redirect the call to the input which the conposition was formed from.""" + """Redirect the call to the input which the composition was formed from.""" if isinstance(self.metric_a, Metric): self.metric_a.update(*args, **self.metric_a._filter_kwargs(**kwargs)) @@ -1174,7 +1174,7 @@ def update(self, *args: Any, **kwargs: Any) -> None: self.metric_b.update(*args, **self.metric_b._filter_kwargs(**kwargs)) def compute(self) -> Any: - """Redirect the call to the input which the conposition was formed from.""" + """Redirect the call to the input which the composition was formed from.""" # also some parsing for kwargs? val_a = self.metric_a.compute() if isinstance(self.metric_a, Metric) else self.metric_a val_b = self.metric_b.compute() if isinstance(self.metric_b, Metric) else self.metric_b @@ -1216,7 +1216,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: return self._forward_cache def reset(self) -> None: - """Redirect the call to the input which the conposition was formed from.""" + """Redirect the call to the input which the composition was formed from.""" if isinstance(self.metric_a, Metric): self.metric_a.reset() diff --git a/src/torchmetrics/multimodal/clip_iqa.py b/src/torchmetrics/multimodal/clip_iqa.py index b3f934fee33..7a66de4ae3e 100644 --- a/src/torchmetrics/multimodal/clip_iqa.py +++ b/src/torchmetrics/multimodal/clip_iqa.py @@ -128,10 +128,9 @@ class CLIPImageQualityAssessment(Metric): Example:: Single prompt: + >>> from torch import randint >>> from torchmetrics.multimodal import CLIPImageQualityAssessment - >>> import torch - >>> _ = torch.manual_seed(42) - >>> imgs = torch.randint(255, (2, 3, 224, 224)).float() + >>> imgs = randint(255, (2, 3, 224, 224)).float() >>> metric = CLIPImageQualityAssessment() >>> metric(imgs) tensor([0.8894, 0.8902]) @@ -139,24 +138,22 @@ class CLIPImageQualityAssessment(Metric): Example:: Multiple prompts: + >>> from torch import randint >>> from torchmetrics.multimodal import CLIPImageQualityAssessment - >>> import torch - >>> _ = torch.manual_seed(42) - >>> imgs = torch.randint(255, (2, 3, 224, 224)).float() + >>> imgs = randint(255, (2, 3, 224, 224)).float() >>> metric = CLIPImageQualityAssessment(prompts=("quality", "brightness")) >>> metric(imgs) - {'quality': tensor([0.8894, 0.8902]), 'brightness': tensor([0.5507, 0.5208])} + {'quality': tensor([0.8693, 0.8705]), 'brightness': tensor([0.5722, 0.4762])} Example:: Custom prompts. Must always be a tuple of length 2, with a positive and negative prompt. + >>> from torch import randint >>> from torchmetrics.multimodal import CLIPImageQualityAssessment - >>> import torch - >>> _ = torch.manual_seed(42) - >>> imgs = torch.randint(255, (2, 3, 224, 224)).float() + >>> imgs = randint(255, (2, 3, 224, 224)).float() >>> metric = CLIPImageQualityAssessment(prompts=(("Super good photo.", "Super bad photo."), "brightness")) >>> metric(imgs) - {'user_defined_0': tensor([0.9652, 0.9629]), 'brightness': tensor([0.5507, 0.5208])} + {'user_defined_0': tensor([0.9578, 0.9654]), 'brightness': tensor([0.5495, 0.5764])} """ diff --git a/src/torchmetrics/multimodal/clip_score.py b/src/torchmetrics/multimodal/clip_score.py index eb432e0be16..f385fbc145d 100644 --- a/src/torchmetrics/multimodal/clip_score.py +++ b/src/torchmetrics/multimodal/clip_score.py @@ -82,12 +82,12 @@ class CLIPScore(Metric): If transformers package is not installed or version is lower than 4.10.0 Example: - >>> import torch + >>> from torch import randint >>> from torchmetrics.multimodal.clip_score import CLIPScore >>> metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16") - >>> score = metric(torch.randint(255, (3, 224, 224), generator=torch.manual_seed(42)), "a photo of a cat") - >>> score.detach() - tensor(24.4255) + >>> score = metric(randint(255, (3, 224, 224)), "a photo of a cat") + >>> score.detach().round() + tensor(25.) """ diff --git a/src/torchmetrics/nominal/__init__.py b/src/torchmetrics/nominal/__init__.py index f23a7eb8c6b..e36da870308 100644 --- a/src/torchmetrics/nominal/__init__.py +++ b/src/torchmetrics/nominal/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from torchmetrics.nominal.cramers import CramersV from torchmetrics.nominal.fleiss_kappa import FleissKappa from torchmetrics.nominal.pearson import PearsonsContingencyCoefficient diff --git a/src/torchmetrics/nominal/cramers.py b/src/torchmetrics/nominal/cramers.py index 2d780e7e379..df47cc079d1 100644 --- a/src/torchmetrics/nominal/cramers.py +++ b/src/torchmetrics/nominal/cramers.py @@ -69,10 +69,10 @@ class CramersV(Metric): Example:: + >>> from torch import randint, randn >>> from torchmetrics.nominal import CramersV - >>> _ = torch.manual_seed(42) - >>> preds = torch.randint(0, 4, (100,)) - >>> target = torch.round(preds + torch.randn(100)).clamp(0, 4) + >>> preds = randint(0, 4, (100,)) + >>> target = (preds + randn(100)).round().clamp(0, 4) >>> cramers_v = CramersV(num_classes=5) >>> cramers_v(preds, target) tensor(0.5284) diff --git a/src/torchmetrics/nominal/fleiss_kappa.py b/src/torchmetrics/nominal/fleiss_kappa.py index 9518302cdac..619e8196d4c 100644 --- a/src/torchmetrics/nominal/fleiss_kappa.py +++ b/src/torchmetrics/nominal/fleiss_kappa.py @@ -54,23 +54,21 @@ class FleissKappa(Metric): Example: >>> # Ratings are provided as counts - >>> import torch + >>> from torch import randint >>> from torchmetrics.nominal import FleissKappa - >>> _ = torch.manual_seed(42) - >>> ratings = torch.randint(0, 10, size=(100, 5)).long() # 100 samples, 5 categories, 10 raters + >>> ratings = randint(0, 10, size=(100, 5)).long() # 100 samples, 5 categories, 10 raters >>> metric = FleissKappa(mode='counts') >>> metric(ratings) tensor(0.0089) Example: >>> # Ratings are provided as probabilities - >>> import torch + >>> from torch import randn >>> from torchmetrics.nominal import FleissKappa - >>> _ = torch.manual_seed(42) - >>> ratings = torch.randn(100, 5, 10).softmax(dim=1) # 100 samples, 5 categories, 10 raters + >>> ratings = randn(100, 5, 10).softmax(dim=1) # 100 samples, 5 categories, 10 raters >>> metric = FleissKappa(mode='probs') >>> metric(ratings) - tensor(-0.0105) + tensor(-0.0075) """ diff --git a/src/torchmetrics/nominal/pearson.py b/src/torchmetrics/nominal/pearson.py index 2fc88c8e851..a43cd548792 100644 --- a/src/torchmetrics/nominal/pearson.py +++ b/src/torchmetrics/nominal/pearson.py @@ -73,10 +73,10 @@ class PearsonsContingencyCoefficient(Metric): Example:: + >>> from torch import randint, randn >>> from torchmetrics.nominal import PearsonsContingencyCoefficient - >>> _ = torch.manual_seed(42) - >>> preds = torch.randint(0, 4, (100,)) - >>> target = torch.round(preds + torch.randn(100)).clamp(0, 4) + >>> preds = randint(0, 4, (100,)) + >>> target = (preds + randn(100)).round().clamp(0, 4) >>> pearsons_contingency_coefficient = PearsonsContingencyCoefficient(num_classes=5) >>> pearsons_contingency_coefficient(preds, target) tensor(0.6948) diff --git a/src/torchmetrics/nominal/theils_u.py b/src/torchmetrics/nominal/theils_u.py index f82c7658b1f..7f7f22ecb1b 100644 --- a/src/torchmetrics/nominal/theils_u.py +++ b/src/torchmetrics/nominal/theils_u.py @@ -59,10 +59,10 @@ class TheilsU(Metric): Example:: + >>> from torch import randint >>> from torchmetrics.nominal import TheilsU - >>> _ = torch.manual_seed(42) - >>> preds = torch.randint(10, (10,)) - >>> target = torch.randint(10, (10,)) + >>> preds = randint(10, (10,)) + >>> target = randint(10, (10,)) >>> metric = TheilsU(num_classes=10) >>> metric(preds, target) tensor(0.8530) diff --git a/src/torchmetrics/nominal/tschuprows.py b/src/torchmetrics/nominal/tschuprows.py index a14832b4121..9986fa2ec6f 100644 --- a/src/torchmetrics/nominal/tschuprows.py +++ b/src/torchmetrics/nominal/tschuprows.py @@ -69,10 +69,10 @@ class TschuprowsT(Metric): Example:: + >>> from torch import randint >>> from torchmetrics.nominal import TschuprowsT - >>> _ = torch.manual_seed(42) - >>> preds = torch.randint(0, 4, (100,)) - >>> target = torch.round(preds + torch.randn(100)).clamp(0, 4) + >>> preds = randint(0, 4, (100,)) + >>> target = (preds + torch.randn(100)).round().clamp(0, 4) >>> tschuprows_t = TschuprowsT(num_classes=5) >>> tschuprows_t(preds, target) tensor(0.4930) diff --git a/src/torchmetrics/regression/__init__.py b/src/torchmetrics/regression/__init__.py index 03ba8023a10..6a41c01bcdb 100644 --- a/src/torchmetrics/regression/__init__.py +++ b/src/torchmetrics/regression/__init__.py @@ -23,6 +23,7 @@ from torchmetrics.regression.mape import MeanAbsolutePercentageError from torchmetrics.regression.minkowski import MinkowskiDistance from torchmetrics.regression.mse import MeanSquaredError +from torchmetrics.regression.nrmse import NormalizedRootMeanSquaredError from torchmetrics.regression.pearson import PearsonCorrCoef from torchmetrics.regression.r2 import R2Score from torchmetrics.regression.rse import RelativeSquaredError @@ -36,14 +37,15 @@ "CosineSimilarity", "CriticalSuccessIndex", "ExplainedVariance", - "KendallRankCorrCoef", "KLDivergence", + "KendallRankCorrCoef", "LogCoshError", - "MeanSquaredLogError", "MeanAbsoluteError", "MeanAbsolutePercentageError", - "MinkowskiDistance", "MeanSquaredError", + "MeanSquaredLogError", + "MinkowskiDistance", + "NormalizedRootMeanSquaredError", "PearsonCorrCoef", "R2Score", "RelativeSquaredError", diff --git a/src/torchmetrics/regression/kl_divergence.py b/src/torchmetrics/regression/kl_divergence.py index 8c5c41d3d57..10fc6d90bc9 100644 --- a/src/torchmetrics/regression/kl_divergence.py +++ b/src/torchmetrics/regression/kl_divergence.py @@ -22,7 +22,6 @@ from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE -from torchmetrics.utilities.prints import rank_zero_warn if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["KLDivergence.plot"] @@ -47,14 +46,6 @@ class KLDivergence(Metric): - ``kl_divergence`` (:class:`~torch.Tensor`): A tensor with the KL divergence - .. warning:: - The input order and naming in metric ``KLDivergence`` is set to be deprecated in v1.4 and changed in v1.5. - Input argument ``p`` will be renamed to ``target`` and will be moved to be the second argument of the metric. - Input argument ``q`` will be renamed to ``preds`` and will be moved to the first argument of the metric. - Thus, ``KLDivergence(p, q)`` will equal ``KLDivergence(target=q, preds=p)`` in the future to be consistent - with the rest of ``torchmetrics``. From v1.4 the two new arguments will be added as keyword arguments and - from v1.5 the two old arguments will be removed. - Args: log_prob: bool indicating if input is log-probabilities or probabilities. If given as probabilities, will normalize to make sure the distributes sum to 1. @@ -102,15 +93,6 @@ def __init__( reduction: Literal["mean", "sum", "none", None] = "mean", **kwargs: Any, ) -> None: - rank_zero_warn( - "The input order and naming in metric `KLDivergence` is set to be deprecated in v1.4 and changed in v1.5." - "Input argument `p` will be renamed to `target` and will be moved to be the second argument of the metric." - "Input argument `q` will be renamed to `preds` and will be moved to the first argument of the metric." - "Thus, `KLDivergence(p, q)` will equal `KLDivergence(target=q, preds=p)` in the future to be consistent" - " with the rest of torchmetrics. From v1.4 the two new arguments will be added as keyword arguments and" - " from v1.5 the two old arguments will be removed.", - DeprecationWarning, - ) super().__init__(**kwargs) if not isinstance(log_prob, bool): raise TypeError(f"Expected argument `log_prob` to be bool but got {log_prob}") diff --git a/src/torchmetrics/regression/nrmse.py b/src/torchmetrics/regression/nrmse.py new file mode 100644 index 00000000000..62562803542 --- /dev/null +++ b/src/torchmetrics/regression/nrmse.py @@ -0,0 +1,279 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Sequence, Union + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.regression.nrmse import ( + _mean_squared_error_update, + _normalized_root_mean_squared_error_compute, +) +from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["NormalizedRootMeanSquaredError.plot"] + + +def _final_aggregation( + min_val: Tensor, + max_val: Tensor, + mean_val: Tensor, + var_val: Tensor, + target_squared: Tensor, + total: Tensor, + normalization: Literal["mean", "range", "std", "l2"] = "mean", +) -> Tensor: + """In the case of multiple devices we need to aggregate the statistics from the different devices.""" + if len(min_val) == 1: + if normalization == "mean": + return mean_val[0] + if normalization == "range": + return max_val[0] - min_val[0] + if normalization == "std": + return var_val[0] + if normalization == "l2": + return target_squared[0] + + min_val_1, max_val_1, mean_val_1, var_val_1, target_squared_1, total_1 = ( + min_val[0], + max_val[0], + mean_val[0], + var_val[0], + target_squared[0], + total[0], + ) + for i in range(1, len(min_val)): + min_val_2, max_val_2, mean_val_2, var_val_2, target_squared_2, total_2 = ( + min_val[i], + max_val[i], + mean_val[i], + var_val[i], + target_squared[i], + total[i], + ) + # update total and mean + total = total_1 + total_2 + mean = (total_1 * mean_val_1 + total_2 * mean_val_2) / total + + # update variance + _temp = (total_1 + 1) * mean - total_1 * mean_val_1 + var_val_1 += (_temp - mean_val_1) * (_temp - mean) - (_temp - mean) ** 2 + _temp = (total_2 + 1) * mean - total_2 * mean_val_2 + var_val_2 += (_temp - mean_val_2) * (_temp - mean) - (_temp - mean) ** 2 + var = var_val_1 + var_val_2 + + # update min and max and target squared + min_val = torch.min(min_val_1, min_val_2) + max_val = torch.max(max_val_1, max_val_2) + target_squared = target_squared_1 + target_squared_2 + + if normalization == "mean": + return mean + if normalization == "range": + return max_val - min_val + if normalization == "std": + return (var / total).sqrt() + return target_squared.sqrt() + + +class NormalizedRootMeanSquaredError(Metric): + r"""Calculates the `Normalized Root Mean Squared Error`_ (NRMSE) also know as scatter index. + + The metric is defined as: + + .. math:: + \text{NRMSE} = \frac{\text{RMSE}}{\text{denom}} + + where RMSE is the root mean squared error and `denom` is the normalization factor. The normalization factor can be + either be the mean, range, standard deviation or L2 norm of the target, which can be set using the `normalization` + argument. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): Predictions from model + - ``target`` (:class:`~torch.Tensor`): Ground truth values + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``nrmse`` (:class:`~torch.Tensor`): A tensor with the mean squared error + + Args: + normalization: type of normalization to be applied. Choose from "mean", "range", "std", "l2" which corresponds + to normalizing the RMSE by the mean of the target, the range of the target, the standard deviation of the + target or the L2 norm of the target. + num_outputs: Number of outputs in multioutput setting + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example:: + Single output normalized root mean squared error computation: + + >>> import torch + >>> from torchmetrics import NormalizedRootMeanSquaredError + >>> target = torch.tensor([2.5, 5.0, 4.0, 8.0]) + >>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0]) + >>> nrmse = NormalizedRootMeanSquaredError(normalization="mean") + >>> nrmse(preds, target) + tensor(0.1919) + >>> nrmse = NormalizedRootMeanSquaredError(normalization="range") + >>> nrmse(preds, target) + tensor(0.1701) + + Example:: + Multioutput normalized root mean squared error computation: + + >>> import torch + >>> from torchmetrics import NormalizedRootMeanSquaredError + >>> preds = torch.tensor([[0., 1], [2, 3], [4, 5], [6, 7]]) + >>> target = torch.tensor([[0., 1], [3, 3], [4, 5], [8, 9]]) + >>> nrmse = NormalizedRootMeanSquaredError(num_outputs=2) + >>> nrmse(preds, target) + tensor([0.2981, 0.2222]) + + """ + + is_differentiable: bool = True + higher_is_better: bool = False + full_state_update: bool = True + plot_lower_bound: float = 0.0 + + sum_squared_error: Tensor + total: Tensor + min_val: Tensor + max_val: Tensor + target_squared: Tensor + mean_val: Tensor + var_val: Tensor + + def __init__( + self, + normalization: Literal["mean", "range", "std", "l2"] = "mean", + num_outputs: int = 1, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + if normalization not in ("mean", "range", "std", "l2"): + raise ValueError( + f"Argument `normalization` should be either 'mean', 'range', 'std' or 'l2', but got {normalization}" + ) + self.normalization = normalization + + if not (isinstance(num_outputs, int) and num_outputs > 0): + raise ValueError(f"Expected num_outputs to be a positive integer but got {num_outputs}") + self.num_outputs = num_outputs + + self.add_state("sum_squared_error", default=torch.zeros(num_outputs), dist_reduce_fx="sum") + self.add_state("total", default=torch.zeros(num_outputs), dist_reduce_fx=None) + self.add_state("min_val", default=float("Inf") * torch.ones(self.num_outputs), dist_reduce_fx=None) + self.add_state("max_val", default=-float("Inf") * torch.ones(self.num_outputs), dist_reduce_fx=None) + self.add_state("mean_val", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) + self.add_state("var_val", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) + self.add_state("target_squared", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets. + + See `mean_squared_error_update` for details. + + """ + sum_squared_error, num_obs = _mean_squared_error_update(preds, target, self.num_outputs) + self.sum_squared_error += sum_squared_error + target = target.view(-1) if self.num_outputs == 1 else target + + # Update min and max and target squared + self.min_val = torch.minimum(target.min(dim=0).values, self.min_val) + self.max_val = torch.maximum(target.max(dim=0).values, self.max_val) + self.target_squared += (target**2).sum(dim=0) + + # Update mean and variance + new_mean = (self.total * self.mean_val + target.sum(dim=0)) / (self.total + num_obs) + self.total += num_obs + new_var = ((target - new_mean) * (target - self.mean_val)).sum(dim=0) + self.mean_val = new_mean + self.var_val += new_var + + def compute(self) -> Tensor: + """Computes NRMSE over state. + + See `mean_squared_error_compute` for details. + + """ + if (self.num_outputs == 1 and self.mean_val.numel() > 1) or (self.num_outputs > 1 and self.mean_val.ndim > 1): + denom = _final_aggregation( + min_val=self.min_val, + max_val=self.max_val, + mean_val=self.mean_val, + var_val=self.var_val, + target_squared=self.target_squared, + total=self.total, + normalization=self.normalization, + ) + total = self.total.squeeze().sum(dim=0) + else: + if self.normalization == "mean": + denom = self.mean_val + elif self.normalization == "range": + denom = self.max_val - self.min_val + elif self.normalization == "std": + denom = torch.sqrt(self.var_val / self.total) + else: + denom = torch.sqrt(self.target_squared) + total = self.total + return _normalized_root_mean_squared_error_compute(self.sum_squared_error, total, denom) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting a single value + >>> from torchmetrics.regression import NormalizedRootMeanSquaredError + >>> metric = NormalizedRootMeanSquaredError() + >>> metric.update(randn(10,), randn(10,)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting multiple values + >>> from torchmetrics.regression import NormalizedRootMeanSquaredError + >>> metric = NormalizedRootMeanSquaredError() + >>> values = [] + >>> for _ in range(10): + ... values.append(metric(randn(10,), randn(10,))) + >>> fig, ax = metric.plot(values) + + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/regression/r2.py b/src/torchmetrics/regression/r2.py index 611d62a745c..a54502e087d 100644 --- a/src/torchmetrics/regression/r2.py +++ b/src/torchmetrics/regression/r2.py @@ -13,7 +13,6 @@ # limitations under the License. from typing import Any, Optional, Sequence, Union -import torch from torch import Tensor, tensor from torchmetrics.functional.regression.r2 import _r2_score_compute, _r2_score_update @@ -38,8 +37,8 @@ class R2Score(Metric): where the parameter :math:`k` (the number of independent regressors) should be provided as the `adjusted` argument. The score is only proper defined when :math:`SS_{tot}\neq 0`, which can happen for near constant targets. In this - case a score of 0 is returned. By definition the score is bounded between 0 and 1, where 1 corresponds to the - predictions exactly matching the targets. + case a score of 0 is returned. By definition the score is bounded between :math:`-inf` and 1.0, with 1.0 indicating + perfect prediction, 0 indicating constant prediction and negative values indicating worse than constant prediction. As input to ``forward`` and ``update`` the metric accepts the following input: @@ -65,23 +64,32 @@ class R2Score(Metric): * ``'variance_weighted'`` scores are weighted by their individual variances kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + .. warning:: + Argument ``num_outputs`` in ``R2Score`` has been deprecated because it is no longer necessary and will be + removed in v1.6.0 of TorchMetrics. The number of outputs is now automatically inferred from the shape + of the input tensors. + Raises: ValueError: If ``adjusted`` parameter is not an integer larger or equal to 0. ValueError: If ``multioutput`` is not one of ``"raw_values"``, ``"uniform_average"`` or ``"variance_weighted"``. - Example: + Example (single output): + >>> from torch import tensor >>> from torchmetrics.regression import R2Score - >>> target = torch.tensor([3, -0.5, 2, 7]) - >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> target = tensor([3, -0.5, 2, 7]) + >>> preds = tensor([2.5, 0.0, 2, 8]) >>> r2score = R2Score() >>> r2score(preds, target) tensor(0.9486) - >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) - >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) - >>> r2score = R2Score(num_outputs=2, multioutput='raw_values') + Example (multioutput): + >>> from torch import tensor + >>> from torchmetrics.regression import R2Score + >>> target = tensor([[0.5, 1], [-1, 1], [7, -6]]) + >>> preds = tensor([[0, 2], [-1, 2], [8, -5]]) + >>> r2score = R2Score(multioutput='raw_values') >>> r2score(preds, target) tensor([0.9654, 0.9082]) @@ -90,7 +98,6 @@ class R2Score(Metric): is_differentiable: bool = True higher_is_better: bool = True full_state_update: bool = False - plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 sum_squared_error: Tensor @@ -100,15 +107,11 @@ class R2Score(Metric): def __init__( self, - num_outputs: int = 1, adjusted: int = 0, multioutput: str = "uniform_average", **kwargs: Any, ) -> None: super().__init__(**kwargs) - - self.num_outputs = num_outputs - if adjusted < 0 or not isinstance(adjusted, int): raise ValueError("`adjusted` parameter should be an integer larger or equal to 0.") self.adjusted = adjusted @@ -120,19 +123,19 @@ def __init__( ) self.multioutput = multioutput - self.add_state("sum_squared_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") - self.add_state("sum_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") - self.add_state("residual", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") + self.add_state("sum_squared_error", default=tensor(0.0), dist_reduce_fx="sum") + self.add_state("sum_error", default=tensor(0.0), dist_reduce_fx="sum") + self.add_state("residual", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" sum_squared_error, sum_error, residual, total = _r2_score_update(preds, target) - self.sum_squared_error += sum_squared_error - self.sum_error += sum_error - self.residual += residual - self.total += total + self.sum_squared_error = self.sum_squared_error + sum_squared_error + self.sum_error = self.sum_error + sum_error + self.residual = self.residual + residual + self.total = self.total + total def compute(self) -> Tensor: """Compute r2 score over the metric states.""" diff --git a/src/torchmetrics/regression/spearman.py b/src/torchmetrics/regression/spearman.py index 50e0e22337f..62be592919d 100644 --- a/src/torchmetrics/regression/spearman.py +++ b/src/torchmetrics/regression/spearman.py @@ -88,7 +88,7 @@ def __init__( " For large datasets, this may lead to large memory footprint." ) if not isinstance(num_outputs, int) and num_outputs < 1: - raise ValueError("Expected argument `num_outputs` to be an int larger than 0, but got {num_outputs}") + raise ValueError(f"Expected argument `num_outputs` to be an int larger than 0, but got {num_outputs}") self.num_outputs = num_outputs self.add_state("preds", default=[], dist_reduce_fx="cat") diff --git a/src/torchmetrics/regression/symmetric_mape.py b/src/torchmetrics/regression/symmetric_mape.py index 01806bb6594..82b5702f476 100644 --- a/src/torchmetrics/regression/symmetric_mape.py +++ b/src/torchmetrics/regression/symmetric_mape.py @@ -41,7 +41,7 @@ class SymmetricMeanAbsolutePercentageError(Metric): As output of ``forward`` and ``compute`` the metric returns the following output: - - ``smape`` (:class:`~torch.Tensor`): A tensor with non-negative floating point smape value between 0 and 1 + - ``smape`` (:class:`~torch.Tensor`): A tensor with non-negative floating point smape value between 0 and 2 Args: kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. @@ -60,6 +60,7 @@ class SymmetricMeanAbsolutePercentageError(Metric): higher_is_better: bool = False full_state_update: bool = False plot_lower_bound: float = 0.0 + plot_upper_bound: float = 2.0 sum_abs_per_error: Tensor total: Tensor diff --git a/src/torchmetrics/regression/wmape.py b/src/torchmetrics/regression/wmape.py index 067c995bfb2..42fadb906ef 100644 --- a/src/torchmetrics/regression/wmape.py +++ b/src/torchmetrics/regression/wmape.py @@ -51,10 +51,9 @@ class WeightedMeanAbsolutePercentageError(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: - >>> import torch - >>> _ = torch.manual_seed(42) - >>> preds = torch.randn(20,) - >>> target = torch.randn(20,) + >>> from torch import randn + >>> preds = randn(20,) + >>> target = randn(20,) >>> wmape = WeightedMeanAbsolutePercentageError() >>> wmape(preds, target) tensor(1.3967) diff --git a/src/torchmetrics/segmentation/__init__.py b/src/torchmetrics/segmentation/__init__.py index 5b609c2c738..6e9b1c63313 100644 --- a/src/torchmetrics/segmentation/__init__.py +++ b/src/torchmetrics/segmentation/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.segmentation.generalized_dice import GeneralizedDiceScore +from torchmetrics.segmentation.hausdorff_distance import HausdorffDistance from torchmetrics.segmentation.mean_iou import MeanIoU -__all__ = ["GeneralizedDiceScore", "MeanIoU"] +__all__ = ["GeneralizedDiceScore", "MeanIoU", "HausdorffDistance"] diff --git a/src/torchmetrics/segmentation/generalized_dice.py b/src/torchmetrics/segmentation/generalized_dice.py index 66f09437000..95da9ab26d9 100644 --- a/src/torchmetrics/segmentation/generalized_dice.py +++ b/src/torchmetrics/segmentation/generalized_dice.py @@ -89,20 +89,19 @@ class GeneralizedDiceScore(Metric): If ``input_format`` is not one of ``"one-hot"`` or ``"index"`` Example: - >>> import torch - >>> _ = torch.manual_seed(0) + >>> from torch import randint >>> from torchmetrics.segmentation import GeneralizedDiceScore >>> gds = GeneralizedDiceScore(num_classes=3) - >>> preds = torch.randint(0, 2, (10, 3, 128, 128)) - >>> target = torch.randint(0, 2, (10, 3, 128, 128)) + >>> preds = randint(0, 2, (10, 3, 128, 128)) + >>> target = randint(0, 2, (10, 3, 128, 128)) >>> gds(preds, target) - tensor(0.4983) + tensor(0.4992) >>> gds = GeneralizedDiceScore(num_classes=3, per_class=True) >>> gds(preds, target) - tensor([0.4987, 0.4966, 0.4995]) + tensor([0.5001, 0.4993, 0.4982]) >>> gds = GeneralizedDiceScore(num_classes=3, per_class=True, include_background=False) >>> gds(preds, target) - tensor([0.4966, 0.4995]) + tensor([0.4993, 0.4982]) """ diff --git a/src/torchmetrics/segmentation/hausdorff_distance.py b/src/torchmetrics/segmentation/hausdorff_distance.py new file mode 100644 index 00000000000..f1e8812ed30 --- /dev/null +++ b/src/torchmetrics/segmentation/hausdorff_distance.py @@ -0,0 +1,157 @@ +# Copyright The Lightning team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Literal, Optional, Sequence, Union + +import torch +from torch import Tensor + +from torchmetrics.functional.segmentation.hausdorff_distance import ( + _hausdorff_distance_validate_args, + hausdorff_distance, +) +from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["HausdorffDistance.plot"] + + +class HausdorffDistance(Metric): + r"""Compute the `Hausdorff Distance`_ between two subsets of a metric space for semantic segmentation. + + .. math:: + d_{\Pi}(X,Y) = \max{/sup_{x\in X} {d(x,Y)}, /sup_{y\in Y} {d(X,y)}} + + where :math:`\X, \Y` are two subsets of a metric space with distance metric :math:`d`. The Hausdorff distance is + the maximum distance from a point in one set to the closest point in the other set. The Hausdorff distance is a + measure of the degree of mismatch between two sets. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being + the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)`` + can be provided, where the integer values correspond to the class index. The input type can be controlled + with the ``input_format`` argument. + - ``target`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being + the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)`` + can be provided, where the integer values correspond to the class index. The input type can be controlled + with the ``input_format`` argument. + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``hausdorff_distance`` (:class:`~torch.Tensor`): A scalar float tensor with the Hausdorff distance averaged over + classes and samples + + Args: + num_classes: number of classes + include_background: whether to include background class in calculation + distance_metric: distance metric to calculate surface distance. Choose one of `"euclidean"`, + `"chessboard"` or `"taxicab"` + spacing: spacing between pixels along each spatial dimension. If not provided the spacing is assumed to be 1 + directed: whether to calculate directed or undirected Hausdorff distance + input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors + or ``"index"`` for index tensors + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> from torch import randint + >>> from torchmetrics.segmentation import HausdorffDistance + >>> preds = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction + >>> target = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target + >>> hausdorff_distance = HausdorffDistance(distance_metric="euclidean", num_classes=5) + >>> hausdorff_distance(preds, target) + tensor(1.9567) + + """ + + is_differentiable: bool = True + higher_is_better: bool = False + full_state_update: bool = False + plot_lower_bound: float = 0.0 + + score: Tensor + total: Tensor + + def __init__( + self, + num_classes: int, + include_background: bool = False, + distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean", + spacing: Optional[Union[Tensor, List[float]]] = None, + directed: bool = False, + input_format: Literal["one-hot", "index"] = "one-hot", + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + _hausdorff_distance_validate_args( + num_classes, include_background, distance_metric, spacing, directed, input_format + ) + self.num_classes = num_classes + self.include_background = include_background + self.distance_metric = distance_metric + self.spacing = spacing + self.directed = directed + self.input_format = input_format + self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + score = hausdorff_distance( + preds, + target, + self.num_classes, + include_background=self.include_background, + distance_metric=self.distance_metric, + spacing=self.spacing, + directed=self.directed, + input_format=self.input_format, + ) + self.score += score.sum() + self.total += score.numel() + + def compute(self) -> Tensor: + """Compute final Hausdorff distance over states.""" + return self.score / self.total + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randint + >>> from torchmetrics.segmentation import HausdorffDistance + >>> preds = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction + >>> target = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target + >>> metric = HausdorffDistance(num_classes=5) + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/segmentation/mean_iou.py b/src/torchmetrics/segmentation/mean_iou.py index 77d465ebd21..0fe831f5231 100644 --- a/src/torchmetrics/segmentation/mean_iou.py +++ b/src/torchmetrics/segmentation/mean_iou.py @@ -70,20 +70,19 @@ class MeanIoU(Metric): If ``input_format`` is not one of ``"one-hot"`` or ``"index"`` Example: - >>> import torch - >>> _ = torch.manual_seed(0) + >>> from torch import randint >>> from torchmetrics.segmentation import MeanIoU >>> miou = MeanIoU(num_classes=3) - >>> preds = torch.randint(0, 2, (10, 3, 128, 128)) - >>> target = torch.randint(0, 2, (10, 3, 128, 128)) + >>> preds = randint(0, 2, (10, 3, 128, 128)) + >>> target = randint(0, 2, (10, 3, 128, 128)) >>> miou(preds, target) - tensor(0.3318) + tensor(0.3326) >>> miou = MeanIoU(num_classes=3, per_class=True) >>> miou(preds, target) - tensor([0.3322, 0.3303, 0.3329]) + tensor([0.3334, 0.3327, 0.3318]) >>> miou = MeanIoU(num_classes=3, per_class=True, include_background=False) >>> miou(preds, target) - tensor([0.3303, 0.3329]) + tensor([0.3327, 0.3318]) """ @@ -111,7 +110,8 @@ def __init__( self.input_format = input_format num_classes = num_classes - 1 if not include_background else num_classes - self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="mean") + self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="sum") + self.add_state("num_batches", default=torch.tensor(0), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: """Update the state with the new data.""" @@ -120,10 +120,11 @@ def update(self, preds: Tensor, target: Tensor) -> None: ) score = _mean_iou_compute(intersection, union, per_class=self.per_class) self.score += score.mean(0) if self.per_class else score.mean() + self.num_batches += 1 def compute(self) -> Tensor: - """Update the state with the new data.""" - return self.score # / self.num_batches + """Compute the final Mean Intersection over Union (mIoU).""" + return self.score / self.num_batches def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. diff --git a/src/torchmetrics/shape/__init__.py b/src/torchmetrics/shape/__init__.py new file mode 100644 index 00000000000..263a1e395a2 --- /dev/null +++ b/src/torchmetrics/shape/__init__.py @@ -0,0 +1,16 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics.shape.procrustes import ProcrustesDisparity + +__all__ = ["ProcrustesDisparity"] diff --git a/src/torchmetrics/shape/procrustes.py b/src/torchmetrics/shape/procrustes.py new file mode 100644 index 00000000000..a924fb48a4a --- /dev/null +++ b/src/torchmetrics/shape/procrustes.py @@ -0,0 +1,137 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Sequence, Union + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics import Metric +from torchmetrics.functional.shape.procrustes import procrustes_disparity +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["ProcrustesDisparity.plot"] + + +class ProcrustesDisparity(Metric): + r"""Compute the `Procrustes Disparity`_. + + The Procrustes Disparity is defined as the sum of the squared differences between two datasets after + applying a Procrustes transformation. The Procrustes Disparity is useful to compare two datasets + that are similar but not aligned. + + The metric works similar to ``scipy.spatial.procrustes`` but for batches of data points. The disparity is + aggregated over the batch, thus to get the individual disparities please use the functional version of this + metric: ``torchmetrics.functional.shape.procrustes.procrustes_disparity``. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``point_cloud1`` (torch.Tensor): A tensor of shape ``(N, M, D)`` with ``N`` being the batch size, + ``M`` the number of data points and ``D`` the dimensionality of the data points. + - ``point_cloud2`` (torch.Tensor): A tensor of shape ``(N, M, D)`` with ``N`` being the batch size, + ``M`` the number of data points and ``D`` the dimensionality of the data points. + + + As output to ``forward`` and ``compute`` the metric returns the following output: + + - ``gds`` (:class:`~torch.Tensor`): A scalar tensor with the Procrustes Disparity. + + Args: + reduction: Determines whether to return the mean disparity or the sum of the disparities. + Can be one of ``"mean"`` or ``"sum"``. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Raises: + ValueError: If ``average`` is not one of ``"mean"`` or ``"sum"``. + + Example: + >>> from torch import randn + >>> from torchmetrics.shape import ProcrustesDisparity + >>> metric = ProcrustesDisparity() + >>> point_cloud1 = randn(10, 50, 2) + >>> point_cloud2 = randn(10, 50, 2) + >>> metric(point_cloud1, point_cloud2) + tensor(0.9770) + + """ + + disparity: Tensor + total: Tensor + full_state_update: bool = False + is_differentiable: bool = False + higher_is_better: bool = False + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + + def __init__(self, reduction: Literal["mean", "sum"] = "mean", **kwargs: Any) -> None: + super().__init__(**kwargs) + if reduction not in ("mean", "sum"): + raise ValueError(f"Argument `reduction` must be one of ['mean', 'sum'], got {reduction}") + self.reduction = reduction + self.add_state("disparity", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, point_cloud1: torch.Tensor, point_cloud2: torch.Tensor) -> None: + """Update the Procrustes Disparity with the given datasets.""" + disparity: Tensor = procrustes_disparity(point_cloud1, point_cloud2) # type: ignore[assignment] + self.disparity += disparity.sum() + self.total += disparity.numel() + + def compute(self) -> torch.Tensor: + """Computes the Procrustes Disparity.""" + if self.reduction == "mean": + return self.disparity / self.total + return self.disparity + + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.shape import ProcrustesDisparity + >>> metric = ProcrustesDisparity() + >>> metric.update(torch.randn(10, 50, 2), torch.randn(10, 50, 2)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.shape import ProcrustesDisparity + >>> metric = ProcrustesDisparity() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.randn(10, 50, 2), torch.randn(10, 50, 2))) + >>> fig_, ax_ = metric.plot(values) + + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/text/_deprecated.py b/src/torchmetrics/text/_deprecated.py index d3ba1c4010e..0c7ffef29af 100644 --- a/src/torchmetrics/text/_deprecated.py +++ b/src/torchmetrics/text/_deprecated.py @@ -144,10 +144,9 @@ def __init__( class _Perplexity(Perplexity): """Wrapper for deprecated import. - >>> import torch - >>> gen = torch.manual_seed(42) - >>> preds = torch.rand(2, 8, 5, generator=gen) - >>> target = torch.randint(5, (2, 8), generator=gen) + >>> from torch import rand, randint + >>> preds = rand(2, 8, 5) + >>> target = randint(5, (2, 8)) >>> target[0, 6:] = -100 >>> perp = _Perplexity(ignore_index=-100) >>> perp(preds, target) diff --git a/src/torchmetrics/text/bert.py b/src/torchmetrics/text/bert.py index cd7d4ce0aaf..6e1bab1b9bd 100644 --- a/src/torchmetrics/text/bert.py +++ b/src/torchmetrics/text/bert.py @@ -107,6 +107,7 @@ class BERTScore(Metric): of the files from `BERT_score`_. baseline_path: A path to the user's own local csv/tsv file with the baseline scale. baseline_url: A url path to the user's own csv/tsv file with the baseline scale. + truncation: An indication of whether the input sequences should be truncated to the ``max_length``. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: @@ -150,6 +151,7 @@ def __init__( rescale_with_baseline: bool = False, baseline_path: Optional[str] = None, baseline_url: Optional[str] = None, + truncation: bool = False, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -169,6 +171,7 @@ def __init__( self.rescale_with_baseline = rescale_with_baseline self.baseline_path = baseline_path self.baseline_url = baseline_url + self.truncation = truncation if user_tokenizer: self.tokenizer = user_tokenizer @@ -210,7 +213,7 @@ def update(self, preds: Union[str, Sequence[str]], target: Union[str, Sequence[s preds, self.tokenizer, self.max_length, - truncation=False, + truncation=self.truncation, sort_according_length=False, own_tokenizer=self.user_tokenizer, ) @@ -218,7 +221,7 @@ def update(self, preds: Union[str, Sequence[str]], target: Union[str, Sequence[s target, self.tokenizer, self.max_length, - truncation=False, + truncation=self.truncation, sort_according_length=False, own_tokenizer=self.user_tokenizer, ) diff --git a/src/torchmetrics/text/infolm.py b/src/torchmetrics/text/infolm.py index b5c2de893f7..31fea4adc23 100644 --- a/src/torchmetrics/text/infolm.py +++ b/src/torchmetrics/text/infolm.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Union import torch from torch import Tensor @@ -111,12 +111,25 @@ class InfoLM(Metric): """ is_differentiable = False - higher_is_better = True preds_input_ids: List[Tensor] preds_attention_mask: List[Tensor] target_input_ids: List[Tensor] target_attention_mask: List[Tensor] + _information_measure_higher_is_better: ClassVar = { + # following values are <0 + "kl_divergence": True, + "alpha_divergence": True, + # following values are >0 + "beta_divergence": False, + "ab_divergence": False, + "renyi_divergence": False, + "l1_distance": False, + "l2_distance": False, + "l_infinity_distance": False, + "fisher_rao_distance": False, + } + def __init__( self, model_name_or_path: Union[str, os.PathLike] = "bert-base-uncased", @@ -156,6 +169,15 @@ def __init__( self.add_state("target_input_ids", [], dist_reduce_fx="cat") self.add_state("target_attention_mask", [], dist_reduce_fx="cat") + @property + def higher_is_better(self) -> bool: # type: ignore[override] + """Returns a bool indicating whether a higher value of the information measure is better. + + Done this way as depends on if the information measure is positive or negative. + + """ + return self._information_measure_higher_is_better[self.information_measure] + def update(self, preds: Union[str, Sequence[str]], target: Union[str, Sequence[str]]) -> None: """Update state with predictions and targets.""" preds_input_ids, preds_attention_mask, target_input_ids, target_attention_mask = _infolm_update( diff --git a/src/torchmetrics/text/perplexity.py b/src/torchmetrics/text/perplexity.py index 51804881df8..d13eac2f402 100644 --- a/src/torchmetrics/text/perplexity.py +++ b/src/torchmetrics/text/perplexity.py @@ -48,11 +48,10 @@ class Perplexity(Metric): Additional keyword arguments, see :ref:`Metric kwargs` for more info. Examples: + >>> from torch import rand, randint >>> from torchmetrics.text import Perplexity - >>> import torch - >>> gen = torch.manual_seed(42) - >>> preds = torch.rand(2, 8, 5, generator=gen) - >>> target = torch.randint(5, (2, 8), generator=gen) + >>> preds = rand(2, 8, 5) + >>> target = randint(5, (2, 8)) >>> target[0, 6:] = -100 >>> perp = Perplexity(ignore_index=-100) >>> perp(preds, target) diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index 68cd344877d..ee11a36136f 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -56,8 +56,8 @@ def _safe_divide(num: Tensor, denom: Tensor, zero_division: float = 0.0) -> Tens """ num = num if num.is_floating_point() else num.float() denom = denom if denom.is_floating_point() else denom.float() - zero_division = torch.tensor(zero_division).float().to(num.device) - return torch.where(denom != 0, num / denom, zero_division) + zero_division_tensor = torch.tensor(zero_division, dtype=num.dtype).to(num.device, non_blocking=True) + return torch.where(denom != 0, num / denom, zero_division_tensor) def _adjust_weights_safe_divide( diff --git a/src/torchmetrics/utilities/data.py b/src/torchmetrics/utilities/data.py index 739c9b09710..1a68e655c33 100644 --- a/src/torchmetrics/utilities/data.py +++ b/src/torchmetrics/utilities/data.py @@ -19,7 +19,7 @@ from torch import Tensor from torchmetrics.utilities.exceptions import TorchMetricsUserWarning -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_1_13, _XLA_AVAILABLE +from torchmetrics.utilities.imports import _XLA_AVAILABLE from torchmetrics.utilities.prints import rank_zero_warn METRIC_EPS = 1e-6 @@ -115,8 +115,6 @@ def to_onehot( def _top_k_with_half_precision_support(x: Tensor, k: int = 1, dim: int = 1) -> Tensor: """torch.top_k does not support half precision on CPU.""" if x.dtype == torch.half and not x.is_cuda: - if not _TORCH_GREATER_EQUAL_1_13: - raise RuntimeError("Half precision (torch.float16) is not supported on CPU for PyTorch < 1.13.") idx = torch.argsort(x, dim=dim, stable=True).flip(dim) return idx.narrow(dim, 0, k) return x.topk(k=k, dim=dim).indices @@ -200,7 +198,7 @@ def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor: if minlength is None: minlength = len(torch.unique(x)) - if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE or _TORCH_GREATER_EQUAL_1_12 and x.is_mps: + if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE and x.is_mps: mesh = torch.arange(minlength, device=x.device).repeat(len(x), 1) return torch.eq(x.reshape(-1, 1), mesh).sum(dim=0) @@ -215,7 +213,7 @@ def _cumsum(x: Tensor, dim: Optional[int] = 0, dtype: Optional[torch.dtype] = No "Expect some slowdowns.", TorchMetricsUserWarning, ) - return x.cpu().cumsum(dim=dim, dtype=dtype).cuda() + return x.cpu().cumsum(dim=dim, dtype=dtype).to(x.device) return torch.cumsum(x, dim=dim, dtype=dtype) diff --git a/src/torchmetrics/utilities/imports.py b/src/torchmetrics/utilities/imports.py index 68692683351..28bda373600 100644 --- a/src/torchmetrics/utilities/imports.py +++ b/src/torchmetrics/utilities/imports.py @@ -19,13 +19,10 @@ from lightning_utilities.core.imports import RequirementCache _PYTHON_VERSION = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" -_TORCH_LOWER_2_0 = RequirementCache("torch<2.0.0") -_TORCH_GREATER_EQUAL_1_11 = RequirementCache("torch>=1.11.0") -_TORCH_GREATER_EQUAL_1_12 = RequirementCache("torch>=1.12.0") -_TORCH_GREATER_EQUAL_1_13 = RequirementCache("torch>=1.13.0") -_TORCH_GREATER_EQUAL_2_0 = RequirementCache("torch>=2.0.0") _TORCH_GREATER_EQUAL_2_1 = RequirementCache("torch>=2.1.0") _TORCH_GREATER_EQUAL_2_2 = RequirementCache("torch>=2.2.0") +_TORCH_GREATER_EQUAL_2_5 = RequirementCache("torch>=2.5.0") +_TORCHMETRICS_GREATER_EQUAL_1_6 = RequirementCache("torchmetrics>=1.7.0") _NLTK_AVAILABLE = RequirementCache("nltk") _ROUGE_SCORE_AVAILABLE = RequirementCache("rouge_score") @@ -36,8 +33,6 @@ _LPIPS_AVAILABLE = RequirementCache("lpips") _PYCOCOTOOLS_AVAILABLE = RequirementCache("pycocotools") _TORCHVISION_AVAILABLE = RequirementCache("torchvision") -_TORCHVISION_GREATER_EQUAL_0_8 = RequirementCache("torchvision>=0.8.0") -_TORCHVISION_GREATER_EQUAL_0_13 = RequirementCache("torchvision>=0.13.0") _TQDM_AVAILABLE = RequirementCache("tqdm") _TRANSFORMERS_AVAILABLE = RequirementCache("transformers") _TRANSFORMERS_GREATER_EQUAL_4_4 = RequirementCache("transformers>=4.4.0") @@ -45,7 +40,6 @@ _PESQ_AVAILABLE = RequirementCache("pesq") _GAMMATONE_AVAILABLE = RequirementCache("gammatone") _TORCHAUDIO_AVAILABLE = RequirementCache("torchaudio") -_TORCHAUDIO_GREATER_EQUAL_0_10 = RequirementCache("torchaudio>=0.10.0") _REGEX_AVAILABLE = RequirementCache("regex") _PYSTOI_AVAILABLE = RequirementCache("pystoi") _REQUESTS_AVAILABLE = RequirementCache("requests") @@ -63,6 +57,7 @@ _MECAB_KO_DIC_AVAILABLE = RequirementCache("mecab_ko_dic") _IPADIC_AVAILABLE = RequirementCache("ipadic") _SENTENCEPIECE_AVAILABLE = RequirementCache("sentencepiece") +_SCIPI_AVAILABLE = RequirementCache("scipy") _SKLEARN_GREATER_EQUAL_1_3 = RequirementCache("scikit-learn>=1.3.0") _LATEX_AVAILABLE: bool = shutil.which("latex") is not None diff --git a/src/torchmetrics/utilities/plot.py b/src/torchmetrics/utilities/plot.py index 12b2ebf338c..4c88c078050 100644 --- a/src/torchmetrics/utilities/plot.py +++ b/src/torchmetrics/utilities/plot.py @@ -326,7 +326,7 @@ def plot_curve( If `curve` does not have 3 elements, being in the wrong format """ if len(curve) < 2: - raise ValueError("Expected 2 or 3 elements in curve but got {len(curve)}") + raise ValueError(f"Expected 2 or 3 elements in curve but got {len(curve)}") x, y = curve[:2] _error_on_missing_matplotlib() diff --git a/src/torchmetrics/wrappers/bootstrapping.py b/src/torchmetrics/wrappers/bootstrapping.py index 9b904568b65..d59f7724c2a 100644 --- a/src/torchmetrics/wrappers/bootstrapping.py +++ b/src/torchmetrics/wrappers/bootstrapping.py @@ -75,15 +75,15 @@ class basically keeps multiple copies of the same base metric in memory and when Example:: >>> from pprint import pprint + >>> from torch import randint >>> from torchmetrics.wrappers import BootStrapper >>> from torchmetrics.classification import MulticlassAccuracy - >>> _ = torch.manual_seed(123) >>> base_metric = MulticlassAccuracy(num_classes=5, average='micro') >>> bootstrap = BootStrapper(base_metric, num_bootstraps=20) - >>> bootstrap.update(torch.randint(5, (20,)), torch.randint(5, (20,))) + >>> bootstrap.update(randint(5, (20,)), randint(5, (20,))) >>> output = bootstrap.compute() >>> pprint(output) - {'mean': tensor(0.2205), 'std': tensor(0.0859)} + {'mean': tensor(0.2089), 'std': tensor(0.0772)} """ diff --git a/src/torchmetrics/wrappers/classwise.py b/src/torchmetrics/wrappers/classwise.py index 78cb27ae46c..217c94d6bc0 100644 --- a/src/torchmetrics/wrappers/classwise.py +++ b/src/torchmetrics/wrappers/classwise.py @@ -44,58 +44,54 @@ class ClasswiseWrapper(WrapperMetric): Example:: Basic example where the output of a metric is unwrapped into a dictionary with the class index as keys: - >>> import torch - >>> _ = torch.manual_seed(42) + >>> from torch import randint, randn >>> from torchmetrics.wrappers import ClasswiseWrapper >>> from torchmetrics.classification import MulticlassAccuracy >>> metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None)) - >>> preds = torch.randn(10, 3).softmax(dim=-1) - >>> target = torch.randint(3, (10,)) + >>> preds = randn(10, 3).softmax(dim=-1) + >>> target = randint(3, (10,)) >>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE {'multiclassaccuracy_0': tensor(0.5000), - 'multiclassaccuracy_1': tensor(0.7500), - 'multiclassaccuracy_2': tensor(0.)} + 'multiclassaccuracy_1': tensor(0.7500), + 'multiclassaccuracy_2': tensor(0.)} Example:: Using custom name via prefix and postfix: - >>> import torch - >>> _ = torch.manual_seed(42) + >>> from torch import randint, randn >>> from torchmetrics.wrappers import ClasswiseWrapper >>> from torchmetrics.classification import MulticlassAccuracy >>> metric_pre = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), prefix="acc-") >>> metric_post = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), postfix="-acc") - >>> preds = torch.randn(10, 3).softmax(dim=-1) - >>> target = torch.randint(3, (10,)) + >>> preds = randn(10, 3).softmax(dim=-1) + >>> target = randint(3, (10,)) >>> metric_pre(preds, target) # doctest: +NORMALIZE_WHITESPACE - {'acc-0': tensor(0.5000), - 'acc-1': tensor(0.7500), - 'acc-2': tensor(0.)} + {'acc-0': tensor(0.3333), 'acc-1': tensor(0.6667), 'acc-2': tensor(0.)} >>> metric_post(preds, target) # doctest: +NORMALIZE_WHITESPACE - {'0-acc': tensor(0.5000), - '1-acc': tensor(0.7500), - '2-acc': tensor(0.)} + {'0-acc': tensor(0.3333), '1-acc': tensor(0.6667), '2-acc': tensor(0.)} Example:: Providing labels as a list of strings: + >>> from torch import randint, randn >>> from torchmetrics.wrappers import ClasswiseWrapper >>> from torchmetrics.classification import MulticlassAccuracy >>> metric = ClasswiseWrapper( ... MulticlassAccuracy(num_classes=3, average=None), ... labels=["horse", "fish", "dog"] ... ) - >>> preds = torch.randn(10, 3).softmax(dim=-1) - >>> target = torch.randint(3, (10,)) + >>> preds = randn(10, 3).softmax(dim=-1) + >>> target = randint(3, (10,)) >>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE - {'multiclassaccuracy_horse': tensor(0.3333), - 'multiclassaccuracy_fish': tensor(0.6667), - 'multiclassaccuracy_dog': tensor(0.)} + {'multiclassaccuracy_horse': tensor(0.), + 'multiclassaccuracy_fish': tensor(0.3333), + 'multiclassaccuracy_dog': tensor(0.4000)} Example:: Classwise can also be used in combination with :class:`~torchmetrics.MetricCollection`. In this case, everything will be flattened into a single dictionary: + >>> from torch import randint, randn >>> from torchmetrics import MetricCollection >>> from torchmetrics.wrappers import ClasswiseWrapper >>> from torchmetrics.classification import MulticlassAccuracy, MulticlassRecall @@ -104,15 +100,15 @@ class ClasswiseWrapper(WrapperMetric): ... {'multiclassaccuracy': ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), labels), ... 'multiclassrecall': ClasswiseWrapper(MulticlassRecall(num_classes=3, average=None), labels)} ... ) - >>> preds = torch.randn(10, 3).softmax(dim=-1) - >>> target = torch.randint(3, (10,)) + >>> preds = randn(10, 3).softmax(dim=-1) + >>> target = randint(3, (10,)) >>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE - {'multiclassaccuracy_horse': tensor(0.), + {'multiclassaccuracy_horse': tensor(0.6667), 'multiclassaccuracy_fish': tensor(0.3333), - 'multiclassaccuracy_dog': tensor(0.4000), - 'multiclassrecall_horse': tensor(0.), + 'multiclassaccuracy_dog': tensor(0.5000), + 'multiclassrecall_horse': tensor(0.6667), 'multiclassrecall_fish': tensor(0.3333), - 'multiclassrecall_dog': tensor(0.4000)} + 'multiclassrecall_dog': tensor(0.5000)} """ diff --git a/src/torchmetrics/wrappers/feature_share.py b/src/torchmetrics/wrappers/feature_share.py index b1fe0451534..1bd1b81783b 100644 --- a/src/torchmetrics/wrappers/feature_share.py +++ b/src/torchmetrics/wrappers/feature_share.py @@ -68,7 +68,6 @@ class FeatureShare(MetricCollection): Example:: >>> import torch - >>> _ = torch.manual_seed(42) >>> from torchmetrics.wrappers import FeatureShare >>> from torchmetrics.image import FrechetInceptionDistance, KernelInceptionDistance >>> # initialize the metrics diff --git a/src/torchmetrics/wrappers/multitask.py b/src/torchmetrics/wrappers/multitask.py index 556a7183638..fa2f04db97d 100644 --- a/src/torchmetrics/wrappers/multitask.py +++ b/src/torchmetrics/wrappers/multitask.py @@ -38,12 +38,27 @@ class MultitaskWrapper(WrapperMetric): task_metrics: Dictionary associating each task to a Metric or a MetricCollection. The keys of the dictionary represent the names of the tasks, and the values represent the metrics to use for each task. + prefix: + A string to append in front of the metric keys. If not provided, will default to an empty string. + postfix: + A string to append after the keys of the output dict. If not provided, will default to an empty string. + + .. note:: + The use pre prefix and postfix allows for easily creating task wrappers for training, validation and test. + The arguments are only changing the output keys of the computed metrics and not the input keys. This means + that a ``MultitaskWrapper`` initialized as ``MultitaskWrapper({"task": Metric()}, prefix="train_")`` will + still expect the input to be a dictionary with the key "task", but the output will be a dictionary with the key + "train_task". Raises: TypeError: If argument `task_metrics` is not an dictionary TypeError: If not all values in the `task_metrics` dictionary is instances of `Metric` or `MetricCollection` + ValueError: + If `prefix` is not a string + ValueError: + If `postfix` is not a string Example (with a single metric per class): >>> import torch @@ -91,18 +106,59 @@ class MultitaskWrapper(WrapperMetric): {'Classification': {'BinaryAccuracy': tensor(0.3333), 'BinaryF1Score': tensor(0.)}, 'Regression': {'MeanSquaredError': tensor(0.8333), 'MeanAbsoluteError': tensor(0.6667)}} + Example (with a prefix and postfix): + >>> import torch + >>> from torchmetrics.wrappers import MultitaskWrapper + >>> from torchmetrics.regression import MeanSquaredError + >>> from torchmetrics.classification import BinaryAccuracy + >>> + >>> classification_target = torch.tensor([0, 1, 0]) + >>> regression_target = torch.tensor([2.5, 5.0, 4.0]) + >>> targets = {"Classification": classification_target, "Regression": regression_target} + >>> classification_preds = torch.tensor([0, 0, 1]) + >>> regression_preds = torch.tensor([3.0, 5.0, 2.5]) + >>> preds = {"Classification": classification_preds, "Regression": regression_preds} + >>> + >>> metrics = MultitaskWrapper({ + ... "Classification": BinaryAccuracy(), + ... "Regression": MeanSquaredError() + ... }, prefix="train_") + >>> metrics.update(preds, targets) + >>> metrics.compute() + {'train_Classification': tensor(0.3333), 'train_Regression': tensor(0.8333)} + """ - is_differentiable = False + is_differentiable: bool = False def __init__( self, task_metrics: Dict[str, Union[Metric, MetricCollection]], + prefix: Optional[str] = None, + postfix: Optional[str] = None, ) -> None: - self._check_task_metrics_type(task_metrics) super().__init__() + + if not isinstance(task_metrics, dict): + raise TypeError(f"Expected argument `task_metrics` to be a dict. Found task_metrics = {task_metrics}") + + for metric in task_metrics.values(): + if not (isinstance(metric, (Metric, MetricCollection))): + raise TypeError( + "Expected each task's metric to be a Metric or a MetricCollection. " + f"Found a metric of type {type(metric)}" + ) + self.task_metrics = nn.ModuleDict(task_metrics) + if prefix is not None and not isinstance(prefix, str): + raise ValueError(f"Expected argument `prefix` to either be `None` or a string but got {prefix}") + self._prefix = prefix or "" + + if postfix is not None and not isinstance(postfix, str): + raise ValueError(f"Expected argument `postfix` to either be `None` or a string but got {postfix}") + self._postfix = postfix or "" + def items(self, flatten: bool = True) -> Iterable[Tuple[str, nn.Module]]: """Iterate over task and task metrics. @@ -114,9 +170,9 @@ def items(self, flatten: bool = True) -> Iterable[Tuple[str, nn.Module]]: for task_name, metric in self.task_metrics.items(): if flatten and isinstance(metric, MetricCollection): for sub_metric_name, sub_metric in metric.items(): - yield f"{task_name}_{sub_metric_name}", sub_metric + yield f"{self._prefix}{task_name}_{sub_metric_name}{self._postfix}", sub_metric else: - yield task_name, metric + yield f"{self._prefix}{task_name}{self._postfix}", metric def keys(self, flatten: bool = True) -> Iterable[str]: """Iterate over task names. @@ -129,9 +185,9 @@ def keys(self, flatten: bool = True) -> Iterable[str]: for task_name, metric in self.task_metrics.items(): if flatten and isinstance(metric, MetricCollection): for sub_metric_name in metric: - yield f"{task_name}_{sub_metric_name}" + yield f"{self._prefix}{task_name}_{sub_metric_name}{self._postfix}" else: - yield task_name + yield f"{self._prefix}{task_name}{self._postfix}" def values(self, flatten: bool = True) -> Iterable[nn.Module]: """Iterate over task metrics. @@ -147,18 +203,6 @@ def values(self, flatten: bool = True) -> Iterable[nn.Module]: else: yield metric - @staticmethod - def _check_task_metrics_type(task_metrics: Dict[str, Union[Metric, MetricCollection]]) -> None: - if not isinstance(task_metrics, dict): - raise TypeError(f"Expected argument `task_metrics` to be a dict. Found task_metrics = {task_metrics}") - - for metric in task_metrics.values(): - if not (isinstance(metric, (Metric, MetricCollection))): - raise TypeError( - "Expected each task's metric to be a Metric or a MetricCollection. " - f"Found a metric of type {type(metric)}" - ) - def update(self, task_preds: Dict[str, Any], task_targets: Dict[str, Any]) -> None: """Update each task's metric with its corresponding pred and target. @@ -179,9 +223,13 @@ def update(self, task_preds: Dict[str, Any], task_targets: Dict[str, Any]) -> No target = task_targets[task_name] metric.update(pred, target) + def _convert_output(self, output: Dict[str, Any]) -> Dict[str, Any]: + """Convert the output of the underlying metrics to a dictionary with the task names as keys.""" + return {f"{self._prefix}{task_name}{self._postfix}": task_output for task_name, task_output in output.items()} + def compute(self) -> Dict[str, Any]: """Compute metrics for all tasks.""" - return {task_name: metric.compute() for task_name, metric in self.task_metrics.items()} + return self._convert_output({task_name: metric.compute() for task_name, metric in self.task_metrics.items()}) def forward(self, task_preds: Dict[str, Tensor], task_targets: Dict[str, Tensor]) -> Dict[str, Any]: """Call underlying forward methods for all tasks and return the result as a dictionary.""" @@ -189,10 +237,10 @@ def forward(self, task_preds: Dict[str, Tensor], task_targets: Dict[str, Tensor] # value of full_state_update, and that also accumulates the results. Here, all computations are handled by the # underlying metrics, which all have their own value of full_state_update, and which all accumulate the results # by themselves. - return { + return self._convert_output({ task_name: metric(task_preds[task_name], task_targets[task_name]) for task_name, metric in self.task_metrics.items() - } + }) def reset(self) -> None: """Reset all underlying metrics.""" @@ -215,16 +263,8 @@ def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> """ multitask_copy = deepcopy(self) - if prefix is not None: - prefix = self._check_arg(prefix, "prefix") - multitask_copy.task_metrics = nn.ModuleDict({ - prefix + key: value for key, value in multitask_copy.task_metrics.items() - }) - if postfix is not None: - postfix = self._check_arg(postfix, "postfix") - multitask_copy.task_metrics = nn.ModuleDict({ - key + postfix: value for key, value in multitask_copy.task_metrics.items() - }) + multitask_copy._prefix = self._check_arg(prefix, "prefix") or "" + multitask_copy._postfix = self._check_arg(postfix, "prefix") or "" return multitask_copy def plot( diff --git a/src/torchmetrics/wrappers/tracker.py b/src/torchmetrics/wrappers/tracker.py index e898813f7b4..554902d1092 100644 --- a/src/torchmetrics/wrappers/tracker.py +++ b/src/torchmetrics/wrappers/tracker.py @@ -52,15 +52,14 @@ class MetricTracker(ModuleList): better (``True``) or lower is better (``False``). Example (single metric): + >>> from torch import randint >>> from torchmetrics.wrappers import MetricTracker >>> from torchmetrics.classification import MulticlassAccuracy - >>> _ = torch.manual_seed(42) >>> tracker = MetricTracker(MulticlassAccuracy(num_classes=10, average='micro')) >>> for epoch in range(5): ... tracker.increment() ... for batch_idx in range(5): - ... preds, target = torch.randint(10, (100,)), torch.randint(10, (100,)) - ... tracker.update(preds, target) + ... tracker.update(randint(10, (100,)), randint(10, (100,))) ... print(f"current acc={tracker.compute()}") current acc=0.1120000034570694 current acc=0.08799999952316284 @@ -76,36 +75,39 @@ class MetricTracker(ModuleList): tensor([0.1120, 0.0880, 0.1260, 0.0800, 0.1020]) Example (multiple metrics using MetricCollection): + >>> from torch import randn >>> from torchmetrics.wrappers import MetricTracker >>> from torchmetrics import MetricCollection >>> from torchmetrics.regression import MeanSquaredError, ExplainedVariance - >>> _ = torch.manual_seed(42) >>> tracker = MetricTracker(MetricCollection([MeanSquaredError(), ExplainedVariance()]), maximize=[False, True]) >>> for epoch in range(5): ... tracker.increment() ... for batch_idx in range(5): - ... preds, target = torch.randn(100), torch.randn(100) - ... tracker.update(preds, target) + ... tracker.update(randn(100), randn(100)) ... print(f"current stats={tracker.compute()}") # doctest: +NORMALIZE_WHITESPACE - current stats={'MeanSquaredError': tensor(1.8218), 'ExplainedVariance': tensor(-0.8969)} - current stats={'MeanSquaredError': tensor(2.0268), 'ExplainedVariance': tensor(-1.0206)} - current stats={'MeanSquaredError': tensor(1.9491), 'ExplainedVariance': tensor(-0.8298)} - current stats={'MeanSquaredError': tensor(1.9800), 'ExplainedVariance': tensor(-0.9199)} - current stats={'MeanSquaredError': tensor(2.2481), 'ExplainedVariance': tensor(-1.1622)} + current stats={'MeanSquaredError': tensor(2.3292), 'ExplainedVariance': tensor(-0.9516)} + current stats={'MeanSquaredError': tensor(2.1370), 'ExplainedVariance': tensor(-1.0775)} + current stats={'MeanSquaredError': tensor(2.1695), 'ExplainedVariance': tensor(-0.9945)} + current stats={'MeanSquaredError': tensor(2.1072), 'ExplainedVariance': tensor(-1.1878)} + current stats={'MeanSquaredError': tensor(2.0562), 'ExplainedVariance': tensor(-1.0754)} >>> from pprint import pprint >>> best_res, which_epoch = tracker.best_metric(return_step=True) >>> pprint(best_res) # doctest: +ELLIPSIS - {'ExplainedVariance': -0.829..., - 'MeanSquaredError': 1.821...} + {'ExplainedVariance': -0.951..., + 'MeanSquaredError': 2.056...} >>> which_epoch - {'MeanSquaredError': 0, 'ExplainedVariance': 2} + {'MeanSquaredError': 4, 'ExplainedVariance': 0} >>> pprint(tracker.compute_all()) - {'ExplainedVariance': tensor([-0.8969, -1.0206, -0.8298, -0.9199, -1.1622]), - 'MeanSquaredError': tensor([1.8218, 2.0268, 1.9491, 1.9800, 2.2481])} + {'ExplainedVariance': tensor([-0.9516, -1.0775, -0.9945, -1.1878, -1.0754]), + 'MeanSquaredError': tensor([2.3292, 2.1370, 2.1695, 2.1072, 2.0562])} """ - def __init__(self, metric: Union[Metric, MetricCollection], maximize: Union[bool, List[bool]] = True) -> None: + maximize: Union[bool, List[bool]] + + def __init__( + self, metric: Union[Metric, MetricCollection], maximize: Optional[Union[bool, List[bool]]] = True + ) -> None: super().__init__() if not isinstance(metric, (Metric, MetricCollection)): raise TypeError( @@ -113,13 +115,42 @@ def __init__(self, metric: Union[Metric, MetricCollection], maximize: Union[bool f" `Metric` or `MetricCollection` but got {metric}" ) self._base_metric = metric - if not isinstance(maximize, (bool, list)): - raise ValueError("Argument `maximize` should either be a single bool or list of bool") - if isinstance(maximize, list) and isinstance(metric, MetricCollection) and len(maximize) != len(metric): - raise ValueError("The len of argument `maximize` should match the length of the metric collection") - if isinstance(metric, Metric) and not isinstance(maximize, bool): - raise ValueError("Argument `maximize` should be a single bool when `metric` is a single Metric") - self.maximize = maximize + + if maximize is None: + if isinstance(metric, Metric): + if getattr(metric, "higher_is_better", None) is None: + raise AttributeError( + f"The metric '{metric.__class__.__name__}' does not have a 'higher_is_better' attribute." + " Please provide the `maximize` argument explicitly." + ) + self.maximize = metric.higher_is_better # type: ignore[assignment] # this is false alarm + elif isinstance(metric, MetricCollection): + self.maximize = [] + for name, m in metric.items(): + if getattr(m, "higher_is_better", None) is None: + raise AttributeError( + f"The metric '{name}' in the MetricCollection does not have a 'higher_is_better' attribute." + " Please provide the `maximize` argument explicitly." + ) + self.maximize.append(m.higher_is_better) # type: ignore[arg-type] # this is false alarm + else: + rank_zero_warn( + "The default value for `maximize` will be changed from `True` to `None` in v1.7.0 of TorchMetrics," + "will automatically infer the value based on the `higher_is_better` attribute of the metric" + " (if such attribute exists) or raise an error if it does not. If you are explicitly setting the" + " `maximize` argument to either `True` or `False` already, you can ignore this warning.", + FutureWarning, + ) + + if not isinstance(maximize, (bool, list)): + raise ValueError("Argument `maximize` should either be a single bool or list of bool") + if isinstance(maximize, list) and not all(isinstance(m, bool) for m in maximize): + raise ValueError("Argument `maximize` is list but not type of bool.") + if isinstance(maximize, list) and isinstance(metric, MetricCollection) and len(maximize) != len(metric): + raise ValueError("The len of argument `maximize` should match the length of the metric collection") + if isinstance(metric, Metric) and not isinstance(maximize, bool): + raise ValueError("Argument `maximize` should be a single bool when `metric` is a single Metric") + self.maximize = maximize self._increment_called = False diff --git a/tests/README.md b/tests/README.md index 7f5cbd4e98a..6fce25567ef 100644 --- a/tests/README.md +++ b/tests/README.md @@ -7,16 +7,16 @@ the following command in the root directory of the project: pip install . -r requirements/_devel.txt ``` -Then for windows users, to execute the tests (unit tests and integration tests) run the following command (will only run non-DDP tests): +Then for Windows users, to execute the tests (unit tests and integration tests) run the following command (will only run non-DDP tests): ```bash pytest tests/ ``` -For linux/Mac users you will need to provide the `-m` argument to indicate if `ddp` tests should also be executed: +For Linux/Mac users you will need to provide the `-m` argument to indicate if `ddp` tests should also be executed: ```bash -pytest -m DDP tests/ # to run only DDP tests +USE_PYTEST_POOL="1" pytest -m DDP tests/ # to run only DDP tests pytest -m "not DDP" tests/ # to run all tests except DDP tests ``` diff --git a/tests/integrations/test_lightning.py b/tests/integrations/test_lightning.py index fec59ea22b0..05799b2711d 100644 --- a/tests/integrations/test_lightning.py +++ b/tests/integrations/test_lightning.py @@ -27,10 +27,10 @@ from torchmetrics import MetricCollection from torchmetrics.aggregation import SumMetric -from torchmetrics.classification import BinaryAccuracy, BinaryAveragePrecision +from torchmetrics.classification import BinaryAccuracy, BinaryAveragePrecision, MulticlassAccuracy from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError from torchmetrics.utilities.prints import rank_zero_only -from torchmetrics.wrappers import MultitaskWrapper +from torchmetrics.wrappers import ClasswiseWrapper, MinMaxMetric, MultitaskWrapper from integrations.lightning.boring_model import BoringModel @@ -504,3 +504,118 @@ def configure_optimizers(self): model = model.type(torch.half) assert model.metric.sum_value.dtype == torch.float32 + + +def test_collection_classwise_lightning_integration(tmpdir): + """Check the integration of ClasswiseWrapper, MetricCollection and LightningModule. + + See issue: https://github.com/Lightning-AI/torchmetrics/issues/2683 + + """ + + class TestModel(BoringModel): + def __init__(self) -> None: + super().__init__() + self.train_metrics = MetricCollection( + { + "macro_accuracy": MulticlassAccuracy(num_classes=5, average="macro"), + "classwise_accuracy": ClasswiseWrapper(MulticlassAccuracy(num_classes=5, average=None)), + }, + prefix="train_", + ) + self.val_metrics = self.train_metrics.clone(prefix="val_") + + def training_step(self, batch, batch_idx): + loss = self(batch).sum() + preds = torch.randint(0, 5, (100,), device=batch.device) + target = torch.randint(0, 5, (100,), device=batch.device) + + batch_values = self.train_metrics(preds, target) + self.log_dict(batch_values, on_step=True, on_epoch=False) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + preds = torch.randint(0, 5, (100,), device=batch.device) + target = torch.randint(0, 5, (100,), device=batch.device) + self.val_metrics.update(preds, target) + + def on_validation_epoch_end(self): + self.log_dict(self.val_metrics.compute(), on_step=False, on_epoch=True) + self.val_metrics.reset() + + model = TestModel() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + log_every_n_steps=1, + ) + trainer.fit(model) + + logged = trainer.logged_metrics + + # check that all metrics are logged + assert "train_macro_accuracy" in logged + assert "val_macro_accuracy" in logged + for i in range(5): + assert f"train_multiclassaccuracy_{i}" in logged + assert f"val_multiclassaccuracy_{i}" in logged + + +def test_collection_minmax_lightning_integration(tmpdir): + """Check the integration of MinMaxWrapper, MetricCollection and LightningModule. + + See issue: https://github.com/Lightning-AI/torchmetrics/issues/2763 + + """ + + class TestModel(BoringModel): + def __init__(self) -> None: + super().__init__() + self.train_metrics = MetricCollection( + { + "macro_accuracy": MinMaxMetric(MulticlassAccuracy(num_classes=5, average="macro")), + "weighted_accuracy": MinMaxMetric(MulticlassAccuracy(num_classes=5, average="weighted")), + }, + prefix="train_", + ) + self.val_metrics = self.train_metrics.clone(prefix="val_") + + def training_step(self, batch, batch_idx): + loss = self(batch).sum() + preds = torch.randint(0, 5, (100,), device=batch.device) + target = torch.randint(0, 5, (100,), device=batch.device) + + batch_values = self.train_metrics(preds, target) + self.log_dict(batch_values, on_step=True, on_epoch=False) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + preds = torch.randint(0, 5, (100,), device=batch.device) + target = torch.randint(0, 5, (100,), device=batch.device) + self.val_metrics.update(preds, target) + + def on_validation_epoch_end(self): + self.log_dict(self.val_metrics.compute(), on_step=False, on_epoch=True) + self.val_metrics.reset() + + model = TestModel() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + log_every_n_steps=1, + ) + trainer.fit(model) + + logged = trainer.logged_metrics + + # check that all metrics are logged + for prefix in ["train_", "val_"]: + for metric in ["macro_accuracy", "weighted_accuracy"]: + for key in ["max", "min", "raw"]: + assert f"{prefix}{metric}_{key}" in logged diff --git a/tests/unittests/_helpers/testers.py b/tests/unittests/_helpers/testers.py index deb4c12324e..1b46d6f237f 100644 --- a/tests/unittests/_helpers/testers.py +++ b/tests/unittests/_helpers/testers.py @@ -32,7 +32,12 @@ def _assert_allclose(tm_result: Any, ref_result: Any, atol: float = 1e-8, key: O """Recursively assert that two results are within a certain tolerance.""" # single output compare if isinstance(tm_result, Tensor): - assert np.allclose(tm_result.detach().cpu().numpy(), ref_result, atol=atol, equal_nan=True) + assert np.allclose( + tm_result.detach().cpu().numpy() if isinstance(tm_result, Tensor) else tm_result, + ref_result.detach().cpu().numpy() if isinstance(ref_result, Tensor) else ref_result, + atol=atol, + equal_nan=True, + ), f"tm_result: {tm_result}, ref_result: {ref_result}" # multi output compare elif isinstance(tm_result, Sequence): for pl_res, ref_res in zip(tm_result, ref_result): @@ -40,7 +45,12 @@ def _assert_allclose(tm_result: Any, ref_result: Any, atol: float = 1e-8, key: O elif isinstance(tm_result, Dict): if key is None: raise KeyError("Provide Key for Dict based metric results.") - assert np.allclose(tm_result[key].detach().cpu().numpy(), ref_result, atol=atol, equal_nan=True) + assert np.allclose( + tm_result[key].detach().cpu().numpy() if isinstance(tm_result[key], Tensor) else tm_result[key], + ref_result.detach().cpu().numpy() if isinstance(ref_result, Tensor) else ref_result, + atol=atol, + equal_nan=True, + ), f"tm_result: {tm_result}, ref_result: {ref_result}" else: raise ValueError("Unknown format for comparison") @@ -147,6 +157,7 @@ def _class_test( # verify metrics work after being loaded from pickled state pickled_metric = pickle.dumps(metric) metric = pickle.loads(pickled_metric) + metric_clone = deepcopy(metric) for i in range(rank, num_batches, world_size): batch_kwargs_update = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()} @@ -154,6 +165,16 @@ def _class_test( # compute batch stats and aggregate for global stats batch_result = metric(preds[i], target[i], **batch_kwargs_update) + if rank == 0 and world_size == 1 and i == 0: # check only in non-ddp mode and first batch + # dummy check to make sure that forward/update works as expected + metric_clone.update(preds[i], target[i], **batch_kwargs_update) + update_result = metric_clone.compute() + if isinstance(batch_result, dict): + for key in batch_result: + _assert_allclose(batch_result, update_result[key], key=key) + else: + _assert_allclose(batch_result, update_result) + if metric.dist_sync_on_step and check_dist_sync_on_step and rank == 0: if isinstance(preds, Tensor): ddp_preds = torch.cat([preds[i + r] for r in range(world_size)]).cpu() diff --git a/tests/unittests/audio/test_sdr.py b/tests/unittests/audio/test_sdr.py index 61257588606..8d5a8c7ab8f 100644 --- a/tests/unittests/audio/test_sdr.py +++ b/tests/unittests/audio/test_sdr.py @@ -21,7 +21,6 @@ from torch import Tensor from torchmetrics.audio import SignalDistortionRatio from torchmetrics.functional import signal_distortion_ratio -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_11 from unittests import _Input from unittests._helpers import seed_all @@ -61,9 +60,6 @@ def _reference_sdr_batch( return sdr -@pytest.mark.skipif( # FIXME: figure out why tests leads to cuda errors on latest torch - _TORCH_GREATER_EQUAL_1_11 and torch.cuda.is_available(), reason="tests leads to cuda errors on latest torch" -) @pytest.mark.parametrize( "preds, target", [(inputs_1spk.preds, inputs_1spk.target), (inputs_2spk.preds, inputs_2spk.target)], @@ -137,12 +133,8 @@ def test_on_real_audio(): """Test that metric works on real audio signal.""" _, ref = wavfile.read(_SAMPLE_AUDIO_SPEECH) _, deg = wavfile.read(_SAMPLE_AUDIO_SPEECH_BAB_DB) - assert torch.allclose( - signal_distortion_ratio(torch.from_numpy(deg), torch.from_numpy(ref)).float(), - torch.tensor(0.2211), - rtol=0.0001, - atol=1e-4, - ) + sdr = signal_distortion_ratio(torch.from_numpy(deg), torch.from_numpy(ref)) + assert torch.allclose(sdr.float(), torch.tensor(0.2211), rtol=0.0001, atol=1e-4) def test_too_low_precision(): diff --git a/tests/unittests/audio/test_srmr.py b/tests/unittests/audio/test_srmr.py index 3b1b07862ce..e7370546478 100644 --- a/tests/unittests/audio/test_srmr.py +++ b/tests/unittests/audio/test_srmr.py @@ -20,7 +20,6 @@ from torch import Tensor from torchmetrics.audio.srmr import SpeechReverberationModulationEnergyRatio from torchmetrics.functional.audio.srmr import speech_reverberation_modulation_energy_ratio -from torchmetrics.utilities.imports import _TORCHAUDIO_GREATER_EQUAL_0_10 from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester @@ -63,8 +62,6 @@ def update(self, preds: Tensor, target: Tensor) -> None: super().update(preds=preds) -# FIXME: bring compatibility with torchaudio 0.10+ -@pytest.mark.skipif(not _TORCHAUDIO_GREATER_EQUAL_0_10, reason="torchaudio>=0.10.0 is required") @pytest.mark.parametrize( "preds, fs, fast, norm", [ diff --git a/tests/unittests/audio/test_stoi.py b/tests/unittests/audio/test_stoi.py index 54374098779..d7998aaf8b2 100644 --- a/tests/unittests/audio/test_stoi.py +++ b/tests/unittests/audio/test_stoi.py @@ -124,9 +124,5 @@ def test_on_real_audio(): """Test that metric works on real audio signal.""" rate, ref = wavfile.read(_SAMPLE_AUDIO_SPEECH) rate, deg = wavfile.read(_SAMPLE_AUDIO_SPEECH_BAB_DB) - assert torch.allclose( - short_time_objective_intelligibility(torch.from_numpy(deg), torch.from_numpy(ref), rate).float(), - torch.tensor(0.6739177), - rtol=0.0001, - atol=1e-4, - ) + stoi = short_time_objective_intelligibility(torch.from_numpy(deg), torch.from_numpy(ref), rate) + assert torch.allclose(stoi.float(), torch.tensor(0.6739177), rtol=1e-2, atol=5e-3) diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index 55062ccbe29..77b333ce66a 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -33,7 +33,6 @@ MultilabelAveragePrecision, ) from torchmetrics.utilities.checks import _allclose_recursive -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_0 from unittests._helpers import seed_all from unittests._helpers.testers import DummyMetricDiff, DummyMetricMultiOutputDict, DummyMetricSum @@ -151,7 +150,6 @@ def test_metric_collection_args_kwargs(tmpdir): assert metric_collection["DummyMetricDiff"].x == -20 -@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_0, reason="Test requires torch 2.0 or higher") @pytest.mark.parametrize( ("prefix", "postfix"), [ diff --git a/tests/unittests/classification/test_accuracy.py b/tests/unittests/classification/test_accuracy.py index db497cdb197..30d4a473a84 100644 --- a/tests/unittests/classification/test_accuracy.py +++ b/tests/unittests/classification/test_accuracy.py @@ -335,6 +335,77 @@ def test_multiclass_accuracy_half_gpu(self, inputs, dtype): dtype=dtype, ) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + @pytest.mark.parametrize( + ("average", "use_deterministic_algorithms"), + [ + (None, True), # Defaults to "macro", but explicitly included for testing omission + # average=`macro` stays on GPU when `use_deterministic` is True. Otherwise syncs in `bincount` + ("macro", True), + ("micro", False), + ("micro", True), + ("weighted", True), + ], + ) + def test_multiclass_accuracy_gpu_sync_points( + self, inputs, dtype: torch.dtype, average: str, use_deterministic_algorithms: bool + ): + """Test GPU support of the metric, avoiding CPU sync points.""" + preds, target = inputs + + # Wrap the default functional to attach `sync_debug_mode` as `run_precision_test_gpu` handles moving data + # onto the GPU, so we cannot set the debug mode outside the call + def wrapped_multiclass_accuracy( + preds: torch.Tensor, + target: torch.Tensor, + num_classes: int, + ) -> torch.Tensor: + prev_sync_debug_mode = torch.cuda.get_sync_debug_mode() + torch.cuda.set_sync_debug_mode("error") + try: + validate_args = False # `validate_args` will require CPU sync for exceptions + # average = average #'micro' # default is `macro` which uses a `_bincount` that does a CPU sync + torch.use_deterministic_algorithms(mode=use_deterministic_algorithms) + return multiclass_accuracy(preds, target, num_classes, validate_args=validate_args, average=average) + finally: + torch.cuda.set_sync_debug_mode(prev_sync_debug_mode) + + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassAccuracy, + metric_functional=wrapped_multiclass_accuracy, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + @pytest.mark.parametrize( + ("average", "use_deterministic_algorithms"), + [ + # If you remove from this collection, please add items to `test_multiclass_accuracy_gpu_sync_points` + (None, False), + ("macro", False), + ("weighted", False), + ], + ) + def test_multiclass_accuracy_gpu_sync_points_uptodate( + self, inputs, dtype: torch.dtype, average: str, use_deterministic_algorithms: bool + ): + """Negative test for `test_multiclass_accuracy_gpu_sync_points`, to confirm completeness. + + Tests that `test_multiclass_accuracy_gpu_sync_points` is kept up to date, explicitly validating that known + failures still fail, so that if they're fixed they must be added to + `test_multiclass_accuracy_gpu_sync_points`. + + """ + with pytest.raises(RuntimeError, match="called a synchronizing CUDA operation"): + self.test_multiclass_accuracy_gpu_sync_points( + inputs=inputs, dtype=dtype, average=average, use_deterministic_algorithms=use_deterministic_algorithms + ) + _mc_k_target = torch.tensor([0, 1, 2]) _mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) diff --git a/tests/unittests/classification/test_calibration_error.py b/tests/unittests/classification/test_calibration_error.py index a2000cc984e..86cfded8246 100644 --- a/tests/unittests/classification/test_calibration_error.py +++ b/tests/unittests/classification/test_calibration_error.py @@ -29,7 +29,6 @@ multiclass_calibration_error, ) from torchmetrics.metric import Metric -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_13 from unittests import NUM_CLASSES from unittests._helpers import seed_all @@ -113,8 +112,6 @@ def test_binary_calibration_error_differentiability(self, inputs): def test_binary_calibration_error_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_13: - pytest.xfail(reason="torch.linspace in metric not supported before pytorch v1.13 for cpu + half") if (preds < 0).any() and dtype == torch.half: pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") self.run_precision_test_cpu( @@ -129,8 +126,6 @@ def test_binary_calibration_error_dtype_cpu(self, inputs, dtype): @pytest.mark.parametrize("dtype", [torch.half, torch.double]) def test_binary_calibration_error_dtype_gpu(self, inputs, dtype): """Test dtype support of the metric on GPU.""" - if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_13: - pytest.xfail(reason="torch.searchsorted in metric not supported before pytorch v1.13 for gpu + half") preds, target = inputs self.run_precision_test_gpu( preds=preds, diff --git a/tests/unittests/classification/test_jaccard.py b/tests/unittests/classification/test_jaccard.py index 6901868eac9..e7afdb557a6 100644 --- a/tests/unittests/classification/test_jaccard.py +++ b/tests/unittests/classification/test_jaccard.py @@ -26,6 +26,7 @@ MultilabelJaccardIndex, ) from torchmetrics.functional.classification.jaccard import ( + _jaccard_index_reduce, binary_jaccard_index, multiclass_jaccard_index, multilabel_jaccard_index, @@ -403,6 +404,26 @@ def test_corner_case(): assert torch.allclose(res, out) +def test_jaccard_index_zero_division(): + """Issue: https://github.com/Lightning-AI/torchmetrics/issues/2658.""" + # Test case where all pixels are background (zeros) + confmat = torch.tensor([[4, 0], [0, 0]]) + + # Test with zero_division=0.0 + result = _jaccard_index_reduce(confmat, average="binary", zero_division=0.0) + assert result == 0.0, f"Expected 0.0, but got {result}" + + # Test with zero_division=1.0 + result = _jaccard_index_reduce(confmat, average="binary", zero_division=1.0) + assert result == 1.0, f"Expected 1.0, but got {result}" + + # Test case with some foreground pixels + confmat = torch.tensor([[2, 1], [1, 1]]) + result = _jaccard_index_reduce(confmat, average="binary", zero_division=0.0) + expected = 1 / 3 + assert torch.isclose(result, torch.tensor(expected)), f"Expected {expected}, but got {result}" + + @pytest.mark.parametrize( ("metric", "kwargs"), [ diff --git a/tests/unittests/classification/test_matthews_corrcoef.py b/tests/unittests/classification/test_matthews_corrcoef.py index 03f649bc0ac..2f881604d09 100644 --- a/tests/unittests/classification/test_matthews_corrcoef.py +++ b/tests/unittests/classification/test_matthews_corrcoef.py @@ -331,6 +331,12 @@ def test_zero_case_in_multiclass(): torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]), 0.0, ), + ( + binary_matthews_corrcoef, + torch.tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0]), + torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), + 0.0, + ), (binary_matthews_corrcoef, torch.zeros(10), torch.ones(10), -1.0), (binary_matthews_corrcoef, torch.ones(10), torch.zeros(10), -1.0), ( diff --git a/tests/unittests/classification/test_sensitivity_specificity.py b/tests/unittests/classification/test_sensitivity_specificity.py index df01b67e4ab..e0689859dd4 100644 --- a/tests/unittests/classification/test_sensitivity_specificity.py +++ b/tests/unittests/classification/test_sensitivity_specificity.py @@ -33,7 +33,7 @@ multilabel_sensitivity_at_specificity, ) from torchmetrics.metric import Metric -from torchmetrics.utilities.imports import _SKLEARN_GREATER_EQUAL_1_3, _TORCH_GREATER_EQUAL_1_11 +from torchmetrics.utilities.imports import _SKLEARN_GREATER_EQUAL_1_3 from unittests import NUM_CLASSES from unittests._helpers import seed_all @@ -84,7 +84,6 @@ def _reference_sklearn_sensitivity_at_specificity_binary(preds, target, min_spec @pytest.mark.skipif(not _SKLEARN_GREATER_EQUAL_1_3, reason="metric does not support scikit-learn versions below 1.3") -@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_11, reason="metric does not support torch versions below 1.11") @pytest.mark.parametrize("inputs", (_binary_cases[1], _binary_cases[2], _binary_cases[4], _binary_cases[5])) class TestBinarySensitivityAtSpecificity(MetricTester): """Test class for `BinarySensitivityAtSpecificity` metric.""" @@ -211,7 +210,6 @@ def _reference_sklearn_sensitivity_at_specificity_multiclass(preds, target, min_ @pytest.mark.skipif(not _SKLEARN_GREATER_EQUAL_1_3, reason="metric does not support scikit-learn versions below 1.3") -@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_11, reason="metric does not support torch versions below 1.11") @pytest.mark.parametrize( "inputs", (_multiclass_cases[1], _multiclass_cases[2], _multiclass_cases[4], _multiclass_cases[5]) ) @@ -343,7 +341,6 @@ def _reference_sklearn_sensitivity_at_specificity_multilabel(preds, target, min_ @pytest.mark.skipif(not _SKLEARN_GREATER_EQUAL_1_3, reason="metric does not support scikit-learn versions below 1.3") -@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_11, reason="metric does not support torch versions below 1.11") @pytest.mark.parametrize( "inputs", (_multilabel_cases[1], _multilabel_cases[2], _multilabel_cases[4], _multilabel_cases[5]) ) @@ -463,7 +460,6 @@ def test_multilabel_sensitivity_at_specificity_threshold_arg(self, inputs, min_s assert all(torch.allclose(r1[i], r2[i]) for i in range(len(r1))) -@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_11, reason="metric does not support torch versions below 1.11") @pytest.mark.parametrize( "metric", [ @@ -479,7 +475,6 @@ def test_valid_input_thresholds(recwarn, metric, thresholds): assert len(recwarn) == 0, "Warning was raised when it should not have been." -@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_11, reason="metric does not support torch versions below 1.11") @pytest.mark.parametrize( ("metric", "kwargs"), [ diff --git a/tests/unittests/classification/test_stat_scores.py b/tests/unittests/classification/test_stat_scores.py index 3cf7090a632..5ea4c206bc0 100644 --- a/tests/unittests/classification/test_stat_scores.py +++ b/tests/unittests/classification/test_stat_scores.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from functools import partial import numpy as np @@ -581,8 +582,8 @@ def test_support_for_int(): """See issue: https://github.com/Lightning-AI/torchmetrics/issues/1970.""" seed_all(42) metric = MulticlassStatScores(num_classes=4, average="none", multidim_average="samplewise", ignore_index=0) - prediction = torch.randint(low=0, high=4, size=(1, 224, 224)).to(torch.uint8) - label = torch.randint(low=0, high=4, size=(1, 224, 224)).to(torch.uint8) + prediction = torch.randint(low=0, high=4, size=(1, 50, 50)).to(torch.uint8) + label = torch.randint(low=0, high=4, size=(1, 50, 50)).to(torch.uint8) score = metric(preds=prediction, target=label) assert score.shape == (1, 4, 5) diff --git a/tests/unittests/detection/test_intersection.py b/tests/unittests/detection/test_intersection.py index c42a6763ba9..e76ce966474 100644 --- a/tests/unittests/detection/test_intersection.py +++ b/tests/unittests/detection/test_intersection.py @@ -24,10 +24,9 @@ from torchmetrics.functional.detection.diou import distance_intersection_over_union from torchmetrics.functional.detection.giou import generalized_intersection_over_union from torchmetrics.functional.detection.iou import intersection_over_union -from torchmetrics.utilities.imports import _TORCHVISION_GREATER_EQUAL_0_13 +from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE -# todo: check if some older versions have these functions too? -if _TORCHVISION_GREATER_EQUAL_0_13: +if _TORCHVISION_AVAILABLE: from torchvision.ops import box_iou as tv_iou from torchvision.ops import complete_box_iou as tv_ciou from torchvision.ops import distance_box_iou as tv_diou @@ -63,6 +62,8 @@ def _tv_wrapper_class(preds, target, base_fn, respect_labels, iou_threshold, cla base_name = {tv_ciou: "ciou", tv_diou: "diou", tv_giou: "giou", tv_iou: "iou"}[base_fn] result = {f"{base_name}": score.cpu()} + if torch.isnan(score): + result.update({f"{base_name}": torch.tensor(0.0)}) if class_metrics: for cl in torch.cat(classes).unique().tolist(): class_score, numel = 0, 0 @@ -71,7 +72,6 @@ def _tv_wrapper_class(preds, target, base_fn, respect_labels, iou_threshold, cla class_score += masked_s[masked_s != -1].sum() numel += masked_s[masked_s != -1].numel() result.update({f"{base_name}/cl_{cl}": class_score.cpu() / numel}) - return result @@ -184,7 +184,6 @@ def _add_noise(x, scale=10): (GeneralizedIntersectionOverUnion, generalized_intersection_over_union, tv_giou), ], ) -@pytest.mark.skipif(not _TORCHVISION_GREATER_EQUAL_0_13, reason="test requires torchvision >= 0.13") class TestIntersectionMetrics(MetricTester): """Tester class for the different intersection metrics.""" @@ -328,6 +327,32 @@ def test_functional_error_on_wrong_input_shape(self, class_metric, functional_me with pytest.raises(ValueError, match="Expected target to be of shape.*"): functional_metric(torch.randn(25, 4), torch.randn(25, 25)) + def test_corner_case_only_one_empty_prediction(self, class_metric, functional_metric, reference_metric): + """Test that the metric does not crash when there is only one empty prediction.""" + target = [ + { + "boxes": torch.tensor([ + [8.0000, 70.0000, 76.0000, 110.0000], + [247.0000, 131.0000, 315.0000, 175.0000], + [361.0000, 177.0000, 395.0000, 203.0000], + ]), + "labels": torch.tensor([0, 0, 0]), + } + ] + preds = [ + { + "boxes": torch.empty(size=(0, 4)), + "labels": torch.tensor([], dtype=torch.int64), + "scores": torch.tensor([]), + } + ] + + metric = class_metric() + metric.update(preds, target) + res = metric.compute() + for val in res.values(): + assert val == torch.tensor(0.0) + def test_corner_case(): """See issue: https://github.com/Lightning-AI/torchmetrics/issues/1921.""" diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index d1b1bc628c9..221f1f87aef 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -29,14 +29,11 @@ from torchmetrics.utilities.imports import ( _FASTER_COCO_EVAL_AVAILABLE, _PYCOCOTOOLS_AVAILABLE, - _TORCHVISION_GREATER_EQUAL_0_8, ) from unittests._helpers.testers import MetricTester from unittests.detection import _DETECTION_BBOX, _DETECTION_SEGM, _DETECTION_VAL -_pytest_condition = not (_PYCOCOTOOLS_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8) - def _skip_if_faster_coco_eval_missing(backend): if backend == "faster_coco_eval" and not _FASTER_COCO_EVAL_AVAILABLE: @@ -65,7 +62,7 @@ def _generate_coco_inputs(iou_type): _coco_segm_input = _generate_coco_inputs("segm") -@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 and pycocotools is installed") +@pytest.mark.skipif(_PYCOCOTOOLS_AVAILABLE, reason="test requires that torchvision=>0.8.0 and pycocotools is installed") @pytest.mark.parametrize("iou_type", ["bbox", "segm"]) @pytest.mark.parametrize("backend", ["pycocotools", "faster_coco_eval"]) def test_tm_to_coco(tmpdir, iou_type, backend): @@ -175,7 +172,7 @@ def _compare_against_coco_fn(preds, target, iou_type, iou_thresholds=None, rec_t } -@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 and pycocotools is installed") +@pytest.mark.skipif(_PYCOCOTOOLS_AVAILABLE, reason="test requires that torchvision=>0.8.0 and pycocotools is installed") @pytest.mark.parametrize("iou_type", ["bbox", "segm"]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) @pytest.mark.parametrize("backend", ["pycocotools", "faster_coco_eval"]) @@ -410,7 +407,7 @@ def test_compare_both_same_time(tmpdir, backend): }, ], [ - {"boxes": Tensor([]), "scores": Tensor([]), "labels": Tensor([])}, + {"boxes": Tensor([]), "scores": Tensor([]), "labels": IntTensor([])}, ], ], "target": [ @@ -450,7 +447,7 @@ def _generate_random_segm_input(device, batch_size=2, num_preds_size=10, num_gt_ return preds, targets -@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") +@pytest.mark.skipif(_PYCOCOTOOLS_AVAILABLE, reason="test requires that torchvision=>0.8.0 is installed") @pytest.mark.parametrize( "backend", [ @@ -632,7 +629,12 @@ def test_segm_iou_empty_gt_mask(self, backend): [{"masks": torch.randint(0, 1, (1, 10, 10)).bool(), "scores": Tensor([0.5]), "labels": IntTensor([4])}], [{"masks": Tensor([]), "labels": IntTensor([])}], ) - metric.compute() + res = metric.compute() + for key, value in res.items(): + if key == "classes": + continue + assert value.item() == -1, f"Expected -1 for {key}" + assert res["classes"] == 4 def test_segm_iou_empty_pred_mask(self, backend): """Test empty predictions.""" @@ -641,7 +643,12 @@ def test_segm_iou_empty_pred_mask(self, backend): [{"masks": torch.BoolTensor([]), "scores": Tensor([]), "labels": IntTensor([])}], [{"masks": torch.randint(0, 1, (1, 10, 10)).bool(), "labels": IntTensor([4])}], ) - metric.compute() + res = metric.compute() + for key, value in res.items(): + if key == "classes": + continue + assert value.item() == -1, f"Expected -1 for {key}" + assert res["classes"] == 4 def test_error_on_wrong_input(self, backend): """Test class input validation.""" @@ -862,17 +869,18 @@ def test_average_argument(self, class_metrics, backend): _preds = apply_to_collection(deepcopy(_inputs["preds"]), IntTensor, lambda x: torch.ones_like(x)) _target = apply_to_collection(deepcopy(_inputs["target"]), IntTensor, lambda x: torch.ones_like(x)) + metric_micro = MeanAveragePrecision(average="micro", class_metrics=class_metrics, backend=backend) + metric_micro.update(deepcopy(_inputs["preds"][0]), deepcopy(_inputs["target"][0])) + metric_micro.update(deepcopy(_inputs["preds"][1]), deepcopy(_inputs["target"][1])) + result_micro = metric_micro.compute() + metric_macro = MeanAveragePrecision(average="macro", class_metrics=class_metrics, backend=backend) metric_macro.update(_preds[0], _target[0]) metric_macro.update(_preds[1], _target[1]) result_macro = metric_macro.compute() - metric_micro = MeanAveragePrecision(average="micro", class_metrics=class_metrics, backend=backend) - metric_micro.update(_inputs["preds"][0], _inputs["target"][0]) - metric_micro.update(_inputs["preds"][1], _inputs["target"][1]) - result_micro = metric_micro.compute() - if class_metrics: + print(result_macro["map_per_class"], result_micro["map_per_class"]) assert torch.allclose(result_macro["map_per_class"], result_micro["map_per_class"]) assert torch.allclose(result_macro["mar_100_per_class"], result_micro["mar_100_per_class"]) else: diff --git a/tests/unittests/detection/test_modified_panoptic_quality.py b/tests/unittests/detection/test_modified_panoptic_quality.py index 1d5a067a609..4c864d0e9af 100644 --- a/tests/unittests/detection/test_modified_panoptic_quality.py +++ b/tests/unittests/detection/test_modified_panoptic_quality.py @@ -18,7 +18,6 @@ import torch from torchmetrics.detection import ModifiedPanopticQuality from torchmetrics.functional.detection import modified_panoptic_quality -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12 from unittests import _Input from unittests._helpers import seed_all @@ -77,7 +76,6 @@ def _reference_fn_1_2(preds, target) -> np.ndarray: return np.array([23 / 30]) -@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12") class TestModifiedPanopticQuality(MetricTester): """Test class for `ModifiedPanopticQuality` metric.""" @@ -113,7 +111,6 @@ def test_panoptic_quality_functional(self): ) -@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12") def test_empty_metric(): """Test empty metric.""" with pytest.raises(ValueError, match="At least one of `things` and `stuffs` must be non-empty"): @@ -123,7 +120,6 @@ def test_empty_metric(): assert torch.isnan(metric.compute()) -@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12") def test_error_on_wrong_input(): """Test class input validation.""" with pytest.raises(TypeError, match="Expected argument `stuffs` to contain `int` categories.*"): @@ -166,7 +162,6 @@ def test_error_on_wrong_input(): metric.update(preds, preds) -@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12") def test_extreme_values(): """Test that the metric returns expected values in trivial cases.""" # Exact match between preds and target => metric is 1 @@ -175,7 +170,6 @@ def test_extreme_values(): assert modified_panoptic_quality(_INPUTS_0.target[0], _INPUTS_0.target[0] + 1, **_ARGS_0) == 0.0 -@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12") @pytest.mark.parametrize( ("inputs", "args", "cat_dim"), [ diff --git a/tests/unittests/detection/test_panoptic_quality.py b/tests/unittests/detection/test_panoptic_quality.py index c7333ccac06..245fb4097fc 100644 --- a/tests/unittests/detection/test_panoptic_quality.py +++ b/tests/unittests/detection/test_panoptic_quality.py @@ -18,7 +18,6 @@ import torch from torchmetrics.detection.panoptic_qualities import PanopticQuality from torchmetrics.functional.detection.panoptic_qualities import panoptic_quality -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12 from unittests import _Input from unittests._helpers import seed_all @@ -105,7 +104,6 @@ def _reference_fn_class_order(preds, target) -> np.ndarray: return np.array([1, 0, 2 / 3]) -@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12") class TestPanopticQuality(MetricTester): """Test class for `PanopticQuality` metric.""" @@ -147,7 +145,6 @@ def test_panoptic_quality_functional(self): ) -@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12") def test_empty_metric(): """Test empty metric.""" with pytest.raises(ValueError, match="At least one of `things` and `stuffs` must be non-empty"): @@ -157,7 +154,6 @@ def test_empty_metric(): assert torch.isnan(metric.compute()) -@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12") def test_error_on_wrong_input(): """Test class input validation.""" with pytest.raises(TypeError, match="Expected argument `stuffs` to contain `int` categories.*"): @@ -200,7 +196,6 @@ def test_error_on_wrong_input(): metric.update(preds, preds) -@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12") def test_extreme_values(): """Test that the metric returns expected values in trivial cases.""" # Exact match between preds and target => metric is 1 @@ -209,7 +204,6 @@ def test_extreme_values(): assert panoptic_quality(_INPUTS_0.target[0], _INPUTS_0.target[0] + 1, **_ARGS_0) == 0.0 -@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12") @pytest.mark.parametrize( ("inputs", "args", "cat_dim"), [ diff --git a/tests/unittests/image/test_psnr.py b/tests/unittests/image/test_psnr.py index 66724af1f4c..cdb0e58aac4 100644 --- a/tests/unittests/image/test_psnr.py +++ b/tests/unittests/image/test_psnr.py @@ -167,3 +167,16 @@ def test_missing_data_range(): with pytest.raises(ValueError, match="The `data_range` must be given when `dim` is not None."): peak_signal_noise_ratio(_inputs[0].preds, _inputs[0].target, data_range=None, dim=0) + + +def test_psnr_uint_dtype(): + """Check that automatic casting to float is done for uint dtype. + + See issue: https://github.com/Lightning-AI/torchmetrics/issues/2787 + + """ + preds = torch.randint(0, 255, _input_size, dtype=torch.uint8) + target = torch.randint(0, 255, _input_size, dtype=torch.uint8) + psnr = peak_signal_noise_ratio(preds, target) + prnr2 = peak_signal_noise_ratio(preds.float(), target.float()) + assert torch.allclose(psnr, prnr2) diff --git a/tests/unittests/image/test_ssim.py b/tests/unittests/image/test_ssim.py index 6b464f2a97b..49954f45cd7 100644 --- a/tests/unittests/image/test_ssim.py +++ b/tests/unittests/image/test_ssim.py @@ -278,7 +278,16 @@ def test_ssim_invalid_inputs(pred, target, kernel, sigma, match): structural_similarity_index_measure(pred, target, kernel_size=kernel, sigma=sigma) -def test_ssim_unequal_kernel_size(): +@pytest.mark.parametrize( + ("sigma", "kernel_size", "result"), + [ + ((0.25, 0.5), None, torch.tensor(0.20977394)), + ((0.5, 0.25), None, torch.tensor(0.13884821)), + (None, (3, 5), torch.tensor(0.05032664)), + (None, (5, 3), torch.tensor(0.03472072)), + ], +) +def test_ssim_unequal_kernel_size(sigma, kernel_size, result): """Test the case where kernel_size[0] != kernel_size[1].""" preds = torch.tensor([ [ @@ -306,24 +315,18 @@ def test_ssim_unequal_kernel_size(): ] ] ]) - # kernel order matters - assert torch.isclose( - structural_similarity_index_measure(preds, target, gaussian_kernel=True, sigma=(0.25, 0.5)), - torch.tensor(0.08869550), - ) - assert not torch.isclose( - structural_similarity_index_measure(preds, target, gaussian_kernel=True, sigma=(0.5, 0.25)), - torch.tensor(0.08869550), - ) - - assert torch.isclose( - structural_similarity_index_measure(preds, target, gaussian_kernel=False, kernel_size=(3, 5)), - torch.tensor(0.05131844), - ) - assert not torch.isclose( - structural_similarity_index_measure(preds, target, gaussian_kernel=False, kernel_size=(5, 3)), - torch.tensor(0.05131844), - ) + if sigma is not None: + assert torch.isclose( + structural_similarity_index_measure(preds, target, gaussian_kernel=True, sigma=sigma), + result, + atol=1e-04, + ) + else: + assert torch.isclose( + structural_similarity_index_measure(preds, target, gaussian_kernel=False, kernel_size=kernel_size), + result, + atol=1e-04, + ) @pytest.mark.parametrize( @@ -341,3 +344,19 @@ def test_full_image_output(preds, target): assert len(out) == 2 assert out[0].numel() == 1 assert out[1].shape == preds[0].shape + + +def test_ssim_for_correct_padding(): + """Check that padding is correctly added and removed for SSIM. + + See issue: https://github.com/Lightning-AI/torchmetrics/issues/2718 + + """ + preds = torch.rand([3, 3, 256, 256]) + # let the edge of the image be 0 + target = preds.clone() + target[:, :, 0, :] = 0 + target[:, :, -1, :] = 0 + target[:, :, :, 0] = 0 + target[:, :, :, -1] = 0 + assert structural_similarity_index_measure(preds, target) < 1.0 diff --git a/tests/unittests/regression/test_kendall.py b/tests/unittests/regression/test_kendall.py index 017179069e0..69c32106aba 100644 --- a/tests/unittests/regression/test_kendall.py +++ b/tests/unittests/regression/test_kendall.py @@ -21,7 +21,7 @@ from scipy.stats import kendalltau from torchmetrics.functional.regression.kendall import kendall_rank_corrcoef from torchmetrics.regression.kendall import KendallRankCorrCoef -from torchmetrics.utilities.imports import _SCIPY_GREATER_EQUAL_1_8, _TORCH_LOWER_2_0 +from torchmetrics.utilities.imports import _SCIPY_GREATER_EQUAL_1_8 from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, _Input from unittests._helpers import seed_all @@ -84,11 +84,7 @@ def _reference_scipy_kendall(preds, target, alternative, variant): class TestKendallRankCorrCoef(MetricTester): """Test class for `KendallRankCorrCoef` metric.""" - # TODO - @pytest.mark.skipif( - sys.platform == "darwin" and not _TORCH_LOWER_2_0, - reason="Tests are not working on mac for newer version of PyTorch.", - ) + @pytest.mark.skipif(sys.platform == "darwin", reason="Fails on MacOS") # TODO: investigate @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) def test_kendall_rank_corrcoef(self, preds, target, alternative, variant, ddp): """Test class implementation of metric.""" diff --git a/tests/unittests/regression/test_mean_error.py b/tests/unittests/regression/test_mean_error.py index f37e80e4d16..38c86817184 100644 --- a/tests/unittests/regression/test_mean_error.py +++ b/tests/unittests/regression/test_mean_error.py @@ -18,6 +18,7 @@ import numpy as np import pytest import torch +from permetrics.regression import RegressionMetric from sklearn.metrics import mean_absolute_error as sk_mean_absolute_error from sklearn.metrics import mean_absolute_percentage_error as sk_mean_abs_percentage_error from sklearn.metrics import mean_squared_error as sk_mean_squared_error @@ -29,6 +30,7 @@ mean_absolute_percentage_error, mean_squared_error, mean_squared_log_error, + normalized_root_mean_squared_error, weighted_mean_absolute_percentage_error, ) from torchmetrics.functional.regression.symmetric_mape import symmetric_mean_absolute_percentage_error @@ -39,6 +41,7 @@ MeanSquaredLogError, WeightedMeanAbsolutePercentageError, ) +from torchmetrics.regression.nrmse import NormalizedRootMeanSquaredError from torchmetrics.regression.symmetric_mape import SymmetricMeanAbsolutePercentageError from unittests import BATCH_SIZE, NUM_BATCHES, _Input @@ -114,66 +117,179 @@ def _reference_symmetric_mape( return np.average(output_errors, weights=multioutput) +def _reference_normalized_root_mean_squared_error( + y_true: np.ndarray, y_pred: np.ndarray, normalization: str = "mean", num_outputs: int = 1 +): + """Reference implementation of Normalized Root Mean Squared Error (NRMSE) metric.""" + if num_outputs == 1: + y_true = y_true.flatten() + y_pred = y_pred.flatten() + if normalization != "l2": + evaluator = RegressionMetric(y_true, y_pred) if normalization == "range" else RegressionMetric(y_pred, y_true) + arg_mapping = {"mean": 1, "range": 2, "std": 4} + return evaluator.normalized_root_mean_square_error(model=arg_mapping[normalization]) + # for l2 normalization we do not have a reference implementation + return np.sqrt(np.mean(np.square(y_true - y_pred), axis=0)) / np.linalg.norm(y_true, axis=0) + + def _reference_weighted_mean_abs_percentage_error(target, preds): + """Reference implementation of Weighted Mean Absolute Percentage Error (WMAPE) metric.""" return np.sum(np.abs(target - preds)) / np.sum(np.abs(target)) def _single_target_ref_wrapper(preds, target, sk_fn, metric_args): + """Reference implementation of single-target metrics.""" sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() - res = sk_fn(sk_target, sk_preds) - - return math.sqrt(res) if (metric_args and "squared" in metric_args and not metric_args["squared"]) else res + if metric_args and "normalization" in metric_args: + res = sk_fn(sk_target, sk_preds, normalization=metric_args["normalization"]) + else: + res = sk_fn(sk_target, sk_preds) + if metric_args and "squared" in metric_args and not metric_args["squared"]: + res = math.sqrt(res) + return res def _multi_target_ref_wrapper(preds, target, sk_fn, metric_args): + """Reference implementation of multi-target metrics.""" sk_preds = preds.view(-1, NUM_TARGETS).numpy() sk_target = target.view(-1, NUM_TARGETS).numpy() sk_kwargs = {"multioutput": "raw_values"} if metric_args and "num_outputs" in metric_args else {} - res = sk_fn(sk_target, sk_preds, **sk_kwargs) - return math.sqrt(res) if (metric_args and "squared" in metric_args and not metric_args["squared"]) else res + if metric_args and "normalization" in metric_args: + res = sk_fn(sk_target, sk_preds, **metric_args) + else: + res = sk_fn(sk_target, sk_preds, **sk_kwargs) + if metric_args and "squared" in metric_args and not metric_args["squared"]: + res = math.sqrt(res) + return res @pytest.mark.parametrize( - "preds, target, ref_metric", + ("preds", "target", "ref_metric"), [ (_single_target_inputs.preds, _single_target_inputs.target, _single_target_ref_wrapper), (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_ref_wrapper), ], ) @pytest.mark.parametrize( - "metric_class, metric_functional, sk_fn, metric_args", + ("metric_class", "metric_functional", "sk_fn", "metric_args"), [ - (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": True}), - (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": False}), - (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": True, "num_outputs": NUM_TARGETS}), - (MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error, {}), - (MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error, {"num_outputs": NUM_TARGETS}), - (MeanAbsolutePercentageError, mean_absolute_percentage_error, sk_mean_abs_percentage_error, {}), - ( + pytest.param( + MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": True}, id="mse_singleoutput" + ), + pytest.param( + MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": False}, id="rmse_singleoutput" + ), + pytest.param( + MeanSquaredError, + mean_squared_error, + sk_mean_squared_error, + {"squared": True, "num_outputs": NUM_TARGETS}, + id="mse_multioutput", + ), + pytest.param(MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error, {}, id="mae_singleoutput"), + pytest.param( + MeanAbsoluteError, + mean_absolute_error, + sk_mean_absolute_error, + {"num_outputs": NUM_TARGETS}, + id="mae_multioutput", + ), + pytest.param( + MeanAbsolutePercentageError, + mean_absolute_percentage_error, + sk_mean_abs_percentage_error, + {}, + id="mape_singleoutput", + ), + pytest.param( SymmetricMeanAbsolutePercentageError, symmetric_mean_absolute_percentage_error, _reference_symmetric_mape, {}, + id="symmetric_mean_absolute_percentage_error", ), - (MeanSquaredLogError, mean_squared_log_error, sk_mean_squared_log_error, {}), - ( + pytest.param( + MeanSquaredLogError, mean_squared_log_error, sk_mean_squared_log_error, {}, id="mean_squared_log_error" + ), + pytest.param( WeightedMeanAbsolutePercentageError, weighted_mean_absolute_percentage_error, _reference_weighted_mean_abs_percentage_error, {}, + id="weighted_mean_absolute_percentage_error", + ), + pytest.param( + NormalizedRootMeanSquaredError, + normalized_root_mean_squared_error, + _reference_normalized_root_mean_squared_error, + {"normalization": "mean", "num_outputs": 1}, + id="nrmse_singleoutput_mean", + ), + pytest.param( + NormalizedRootMeanSquaredError, + normalized_root_mean_squared_error, + _reference_normalized_root_mean_squared_error, + {"normalization": "range", "num_outputs": 1}, + id="nrmse_singleoutput_range", + ), + pytest.param( + NormalizedRootMeanSquaredError, + normalized_root_mean_squared_error, + _reference_normalized_root_mean_squared_error, + {"normalization": "std", "num_outputs": 1}, + id="nrmse_singleoutput_std", + ), + pytest.param( + NormalizedRootMeanSquaredError, + normalized_root_mean_squared_error, + _reference_normalized_root_mean_squared_error, + {"normalization": "l2", "num_outputs": 1}, + id="nrmse_multioutput_l2", + ), + pytest.param( + NormalizedRootMeanSquaredError, + normalized_root_mean_squared_error, + _reference_normalized_root_mean_squared_error, + {"normalization": "mean", "num_outputs": NUM_TARGETS}, + id="nrmse_multioutput_mean", + ), + pytest.param( + NormalizedRootMeanSquaredError, + normalized_root_mean_squared_error, + _reference_normalized_root_mean_squared_error, + {"normalization": "range", "num_outputs": NUM_TARGETS}, + id="nrmse_multioutput_range", + ), + pytest.param( + NormalizedRootMeanSquaredError, + normalized_root_mean_squared_error, + _reference_normalized_root_mean_squared_error, + {"normalization": "std", "num_outputs": NUM_TARGETS}, + id="nrmse_multioutput_std", + ), + pytest.param( + NormalizedRootMeanSquaredError, + normalized_root_mean_squared_error, + _reference_normalized_root_mean_squared_error, + {"normalization": "l2", "num_outputs": NUM_TARGETS}, + id="nrmse_multioutput_l2", ), ], ) class TestMeanError(MetricTester): """Test class for `MeanError` metric.""" + atol = 1e-5 + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) def test_mean_error_class( self, preds, target, ref_metric, metric_class, metric_functional, sk_fn, metric_args, ddp ): """Test class implementation of metric.""" + if metric_args and "num_outputs" in metric_args and preds.ndim < 3: + pytest.skip("Test only runs for multi-output setting") self.run_class_metric_test( ddp=ddp, preds=preds, @@ -187,6 +303,8 @@ def test_mean_error_functional( self, preds, target, ref_metric, metric_class, metric_functional, sk_fn, metric_args ): """Test functional implementation of metric.""" + if metric_args and "num_outputs" in metric_args and preds.ndim < 3: + pytest.skip("Test only runs for multi-output setting") self.run_functional_metric_test( preds=preds, target=target, @@ -199,6 +317,8 @@ def test_mean_error_differentiability( self, preds, target, ref_metric, metric_class, metric_functional, sk_fn, metric_args ): """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" + if metric_args and "num_outputs" in metric_args and preds.ndim < 3: + pytest.skip("Test only runs for multi-output setting") self.run_differentiability_test( preds=preds, target=target, @@ -225,6 +345,10 @@ def test_mean_error_half_cpu(self, preds, target, ref_metric, metric_class, metr # WeightedMeanAbsolutePercentageError half + cpu does not work due to missing support in torch.clamp pytest.xfail("WeightedMeanAbsolutePercentageError metric does not support cpu + half precision") + if metric_class == NormalizedRootMeanSquaredError: + # NormalizedRootMeanSquaredError half + cpu does not work due to missing support in torch.sqrt + pytest.xfail("NormalizedRootMeanSquaredError metric does not support cpu + half precision") + self.run_precision_test_cpu(preds, target, metric_class, metric_functional) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") @@ -234,10 +358,30 @@ def test_mean_error_half_gpu(self, preds, target, ref_metric, metric_class, metr @pytest.mark.parametrize( - "metric_class", [MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError, MeanAbsolutePercentageError] + "metric_class", + [ + MeanSquaredError, + MeanAbsoluteError, + MeanSquaredLogError, + MeanAbsolutePercentageError, + NormalizedRootMeanSquaredError, + ], ) def test_error_on_different_shape(metric_class): """Test that error is raised on different shapes of input.""" metric = metric_class() with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"): metric(torch.randn(100), torch.randn(50)) + + +@pytest.mark.parametrize( + ("metric_class", "arguments", "error_msg"), + [ + (MeanSquaredError, {"squared": "something"}, "Expected argument `squared` to be a boolean.*"), + (NormalizedRootMeanSquaredError, {"normalization": "something"}, "Argument `normalization` should be either.*"), + ], +) +def test_error_on_wrong_extra_args(metric_class, arguments, error_msg): + """Test that error is raised on wrong extra arguments.""" + with pytest.raises(ValueError, match=error_msg): + metric_class(**arguments) diff --git a/tests/unittests/regression/test_pearson.py b/tests/unittests/regression/test_pearson.py index 0d23507aeed..07cbf3fd65c 100644 --- a/tests/unittests/regression/test_pearson.py +++ b/tests/unittests/regression/test_pearson.py @@ -164,3 +164,25 @@ def test_single_sample_update(): metric(torch.tensor([7.0]), torch.tensor([8.0])) res2 = metric.compute() assert torch.allclose(res1, res2) + + +def test_overwrite_reference_inputs(): + """Test that the normalizations does not overwrite inputs. + + Variables var_x, var_y, corr_xy are references to the object variables and get incorrectly scaled down such that + when you update again and compute you get very wrong values. + + """ + y = torch.randn(100) + y_pred = y + torch.randn(y.shape) / 5 + # Initialize Pearson correlation coefficient metric + pearson = PearsonCorrCoef() + # Compute the Pearson correlation coefficient + correlation = pearson(y, y_pred) + + pearson = PearsonCorrCoef() + for lower, upper in [(0, 33), (33, 66), (66, 99), (99, 100)]: + pearson.update(torch.tensor(y[lower:upper]), torch.tensor(y_pred[lower:upper])) + pearson.compute() + + assert torch.isclose(pearson.compute(), correlation) diff --git a/tests/unittests/regression/test_r2.py b/tests/unittests/regression/test_r2.py index 32ce2554ed7..8649a3392e9 100644 --- a/tests/unittests/regression/test_r2.py +++ b/tests/unittests/regression/test_r2.py @@ -60,17 +60,17 @@ def _multi_target_ref_wrapper(preds, target, adjusted, multioutput): @pytest.mark.parametrize("adjusted", [0, 5, 10]) @pytest.mark.parametrize("multioutput", ["raw_values", "uniform_average", "variance_weighted"]) @pytest.mark.parametrize( - "preds, target, ref_metric, num_outputs", + "preds, target, ref_metric", [ - (_single_target_inputs.preds, _single_target_inputs.target, _single_target_ref_wrapper, 1), - (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_ref_wrapper, NUM_TARGETS), + (_single_target_inputs.preds, _single_target_inputs.target, _single_target_ref_wrapper), + (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_ref_wrapper), ], ) class TestR2Score(MetricTester): """Test class for `R2Score` metric.""" @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_r2(self, adjusted, multioutput, preds, target, ref_metric, num_outputs, ddp): + def test_r2(self, adjusted, multioutput, preds, target, ref_metric, ddp): """Test class implementation of metric.""" self.run_class_metric_test( ddp, @@ -78,10 +78,10 @@ def test_r2(self, adjusted, multioutput, preds, target, ref_metric, num_outputs, target, R2Score, partial(ref_metric, adjusted=adjusted, multioutput=multioutput), - metric_args={"adjusted": adjusted, "multioutput": multioutput, "num_outputs": num_outputs}, + metric_args={"adjusted": adjusted, "multioutput": multioutput}, ) - def test_r2_functional(self, adjusted, multioutput, preds, target, ref_metric, num_outputs): + def test_r2_functional(self, adjusted, multioutput, preds, target, ref_metric): """Test functional implementation of metric.""" self.run_functional_metric_test( preds, @@ -91,35 +91,23 @@ def test_r2_functional(self, adjusted, multioutput, preds, target, ref_metric, n metric_args={"adjusted": adjusted, "multioutput": multioutput}, ) - def test_r2_differentiability(self, adjusted, multioutput, preds, target, ref_metric, num_outputs): + def test_r2_differentiability(self, adjusted, multioutput, preds, target, ref_metric): """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" self.run_differentiability_test( - preds=preds, - target=target, - metric_module=partial(R2Score, num_outputs=num_outputs), - metric_functional=r2_score, - metric_args={"adjusted": adjusted, "multioutput": multioutput}, + preds, target, R2Score, r2_score, {"adjusted": adjusted, "multioutput": multioutput} ) - def test_r2_half_cpu(self, adjusted, multioutput, preds, target, ref_metric, num_outputs): + def test_r2_half_cpu(self, adjusted, multioutput, preds, target, ref_metric): """Test dtype support of the metric on CPU.""" self.run_precision_test_cpu( - preds, - target, - partial(R2Score, num_outputs=num_outputs), - r2_score, - {"adjusted": adjusted, "multioutput": multioutput}, + preds, target, R2Score, r2_score, {"adjusted": adjusted, "multioutput": multioutput} ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") - def test_r2_half_gpu(self, adjusted, multioutput, preds, target, ref_metric, num_outputs): + def test_r2_half_gpu(self, adjusted, multioutput, preds, target, ref_metric): """Test dtype support of the metric on GPU.""" self.run_precision_test_gpu( - preds, - target, - partial(R2Score, num_outputs=num_outputs), - r2_score, - {"adjusted": adjusted, "multioutput": multioutput}, + preds, target, R2Score, r2_score, {"adjusted": adjusted, "multioutput": multioutput} ) diff --git a/tests/unittests/segmentation/inputs.py b/tests/unittests/segmentation/inputs.py new file mode 100644 index 00000000000..996b8364e9c --- /dev/null +++ b/tests/unittests/segmentation/inputs.py @@ -0,0 +1,28 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +__all__ = ["_Input"] + +from typing import NamedTuple + +from torch import Tensor + +from unittests._helpers import seed_all + +seed_all(42) + + +# extrinsic input for clustering metrics that requires predicted clustering labels and target clustering labels +class _Input(NamedTuple): + preds: Tensor + target: Tensor diff --git a/tests/unittests/segmentation/test_generalized_dice_score.py b/tests/unittests/segmentation/test_generalized_dice_score.py index a2bbab7b921..742f31cc8fd 100644 --- a/tests/unittests/segmentation/test_generalized_dice_score.py +++ b/tests/unittests/segmentation/test_generalized_dice_score.py @@ -16,13 +16,17 @@ import pytest import torch +from lightning_utilities.core.imports import RequirementCache from monai.metrics.generalized_dice import compute_generalized_dice from torchmetrics.functional.segmentation.generalized_dice import generalized_dice_score from torchmetrics.segmentation.generalized_dice import GeneralizedDiceScore from unittests import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, _Input +from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester +seed_all(42) + _inputs1 = _Input( preds=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 16)), target=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 16)), @@ -48,10 +52,11 @@ def _reference_generalized_dice( if input_format == "index": preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1) target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1) - val = compute_generalized_dice(preds, target, include_background=include_background) + monai_extra_arg = {"sum_over_classes": True} if RequirementCache("monai>=1.4.0") else {} + val = compute_generalized_dice(preds, target, include_background=include_background, **monai_extra_arg) if reduce: val = val.mean() - return val + return val.squeeze() @pytest.mark.parametrize( @@ -63,11 +68,11 @@ def _reference_generalized_dice( ], ) @pytest.mark.parametrize("include_background", [True, False]) -class TestMeanIoU(MetricTester): +class TestGeneralizedDiceScore(MetricTester): """Test class for `MeanIoU` metric.""" @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_mean_iou_class(self, preds, target, input_format, include_background, ddp): + def test_generalized_dice_class(self, preds, target, input_format, include_background, ddp): """Test class implementation of metric.""" self.run_class_metric_test( ddp=ddp, @@ -87,7 +92,7 @@ def test_mean_iou_class(self, preds, target, input_format, include_background, d }, ) - def test_mean_iou_functional(self, preds, target, input_format, include_background): + def test_generalized_dice_functional(self, preds, target, input_format, include_background): """Test functional implementation of metric.""" self.run_functional_metric_test( preds=preds, diff --git a/tests/unittests/segmentation/test_hausdorff_distance.py b/tests/unittests/segmentation/test_hausdorff_distance.py new file mode 100644 index 00000000000..afd77c1f4b2 --- /dev/null +++ b/tests/unittests/segmentation/test_hausdorff_distance.py @@ -0,0 +1,116 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from typing import Any + +import pytest +import torch +from monai.metrics.hausdorff_distance import compute_hausdorff_distance as monai_hausdorff_distance +from torchmetrics.functional.segmentation.hausdorff_distance import hausdorff_distance +from torchmetrics.segmentation.hausdorff_distance import HausdorffDistance + +from unittests import NUM_BATCHES, _Input +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester + +seed_all(42) +BATCH_SIZE = 4 # use smaller than normal batch size to reduce test time +NUM_CLASSES = 3 # use smaller than normal class size to reduce test time + +_inputs1 = _Input( + preds=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 16, 16)), + target=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 16, 16)), +) +_inputs2 = _Input( + preds=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)), + target=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)), +) + + +def reference_metric(preds, target, input_format, reduce, **kwargs: Any): + """Reference implementation of metric.""" + if input_format == "index": + preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1) + target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1) + score = monai_hausdorff_distance(preds, target, **kwargs) + return score.mean() if reduce else score + + +@pytest.mark.parametrize("inputs, input_format", [(_inputs1, "one-hot"), (_inputs2, "index")]) +@pytest.mark.parametrize("distance_metric", ["euclidean", "chessboard", "taxicab"]) +@pytest.mark.parametrize("directed", [True, False]) +@pytest.mark.parametrize("spacing", [None, [2, 2]]) +class TestHausdorffDistance(MetricTester): + """Test class for `HausdorffDistance` metric.""" + + atol = 1e-5 + + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + def test_hausdorff_distance_class(self, inputs, input_format, distance_metric, directed, spacing, ddp): + """Test class implementation of metric.""" + if spacing is not None and distance_metric != "euclidean": + pytest.skip("Spacing is only supported for Euclidean distance metric.") + preds, target = inputs + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=HausdorffDistance, + reference_metric=partial( + reference_metric, + input_format=input_format, + distance_metric=distance_metric, + directed=directed, + spacing=spacing, + reduce=True, + ), + metric_args={ + "num_classes": NUM_CLASSES, + "distance_metric": distance_metric, + "directed": directed, + "spacing": spacing, + "input_format": input_format, + }, + ) + + def test_hausdorff_distance_functional(self, inputs, input_format, distance_metric, directed, spacing): + """Test functional implementation of metric.""" + if spacing is not None and distance_metric != "euclidean": + pytest.skip("Spacing is only supported for Euclidean distance metric.") + preds, target = inputs + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=hausdorff_distance, + reference_metric=partial( + reference_metric, + input_format=input_format, + distance_metric=distance_metric, + directed=directed, + spacing=spacing, + reduce=False, + ), + metric_args={ + "num_classes": NUM_CLASSES, + "distance_metric": distance_metric, + "directed": directed, + "spacing": spacing, + "input_format": input_format, + }, + ) + + +def test_hausdorff_distance_raises_error(): + """Check that metric raises appropriate errors.""" + preds, target = _inputs1 diff --git a/tests/unittests/segmentation/test_utils.py b/tests/unittests/segmentation/test_utils.py index d37941a6ff3..39cff09a2dd 100644 --- a/tests/unittests/segmentation/test_utils.py +++ b/tests/unittests/segmentation/test_utils.py @@ -14,6 +14,7 @@ import pytest import torch from monai.metrics.utils import get_code_to_measure_table +from monai.metrics.utils import get_edge_surface_distance as monai_get_edge_surface_distance from monai.metrics.utils import get_mask_edges as monai_get_mask_edges from monai.metrics.utils import get_surface_distance as monai_get_surface_distance from scipy.ndimage import binary_erosion as scibinary_erosion @@ -23,6 +24,7 @@ from torchmetrics.functional.segmentation.utils import ( binary_erosion, distance_transform, + edge_surface_distance, generate_binary_structure, get_neighbour_tables, mask_edges, @@ -231,3 +233,50 @@ def test_mask_edges(cases, spacing, crop, device): for r1, r2 in zip(res, reference_res): assert torch.allclose(r1.cpu().float(), torch.from_numpy(r2).float()) + + +@pytest.mark.parametrize( + "cases", + [ + ( + torch.tensor( + [[1, 1, 1, 1, 1], [1, 0, 0, 0, 1], [1, 0, 0, 0, 1], [1, 0, 0, 0, 1], [1, 1, 1, 1, 1]], dtype=torch.bool + ), + torch.tensor( + [[1, 1, 1, 1, 0], [1, 0, 0, 1, 0], [1, 0, 0, 1, 0], [1, 0, 0, 1, 0], [1, 1, 1, 1, 0]], dtype=torch.bool + ), + ), + (torch.randint(0, 2, (5, 5), dtype=torch.bool), torch.randint(0, 2, (5, 5), dtype=torch.bool)), + (torch.randint(0, 2, (50, 50), dtype=torch.bool), torch.randint(0, 2, (50, 50), dtype=torch.bool)), + ], +) +@pytest.mark.parametrize("distance_metric", ["euclidean", "chessboard", "taxicab"]) +@pytest.mark.parametrize("symmetric", [False, True]) +@pytest.mark.parametrize("spacing", [None, 1, 2]) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_edge_surface_distance(cases, distance_metric, symmetric, spacing, device): + """Test the edge surface distance function.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA device not available.") + if spacing == 2 and distance_metric != "euclidean": + pytest.skip("Only euclidean distance is supported for spacing != 1 in reference") + preds, target = cases + if spacing is not None: + spacing = preds.ndim * [spacing] + + res = edge_surface_distance( + preds.to(device), target.to(device), spacing=spacing, distance_metric=distance_metric, symmetric=symmetric + ) + _, reference_res, _ = monai_get_edge_surface_distance( + preds, + target, + spacing=tuple(spacing) if spacing is not None else spacing, + distance_metric=distance_metric, + symmetric=symmetric, + ) + + if symmetric: + assert torch.allclose(res[0].cpu(), reference_res[0].to(res[0].dtype)) + assert torch.allclose(res[1].cpu(), reference_res[1].to(res[1].dtype)) + else: + assert torch.allclose(res.cpu(), reference_res[0].to(res.dtype)) diff --git a/tests/unittests/shape/__init__.py b/tests/unittests/shape/__init__.py new file mode 100644 index 00000000000..94f1dec4a9f --- /dev/null +++ b/tests/unittests/shape/__init__.py @@ -0,0 +1,13 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/shape/test_procrustes.py b/tests/unittests/shape/test_procrustes.py new file mode 100644 index 00000000000..a3b89e13eb7 --- /dev/null +++ b/tests/unittests/shape/test_procrustes.py @@ -0,0 +1,95 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import numpy as np +import pytest +import torch +from scipy.spatial import procrustes as scipy_procrustes +from torchmetrics.functional.shape.procrustes import procrustes_disparity +from torchmetrics.shape.procrustes import ProcrustesDisparity + +from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, _Input +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester + +seed_all(42) + +NUM_TARGETS = 5 + + +_inputs = _Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 50, EXTRA_DIM), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, 50, EXTRA_DIM), +) + + +def _reference_procrustes(point_cloud1, point_cloud2, reduction=None): + point_cloud1 = point_cloud1.numpy() + point_cloud2 = point_cloud2.numpy() + + if reduction is None: + return np.array([scipy_procrustes(d1, d2)[2] for d1, d2 in zip(point_cloud1, point_cloud2)]) + + disparity = 0 + for d1, d2 in zip(point_cloud1, point_cloud2): + disparity += scipy_procrustes(d1, d2)[2] + if reduction == "mean": + return disparity / len(point_cloud1) + return disparity + + +@pytest.mark.parametrize("point_cloud1, point_cloud2", [(_inputs.preds, _inputs.target)]) +class TestProcrustesDisparity(MetricTester): + """Test class for `ProcrustesDisparity` metric.""" + + @pytest.mark.parametrize("reduction", ["sum", "mean"]) + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + def test_procrustes_disparity(self, reduction, point_cloud1, point_cloud2, ddp): + """Test class implementation of metric.""" + self.run_class_metric_test( + ddp, + point_cloud1, + point_cloud2, + ProcrustesDisparity, + partial(_reference_procrustes, reduction=reduction), + metric_args={"reduction": reduction}, + ) + + def test_procrustes_disparity_functional(self, point_cloud1, point_cloud2): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + point_cloud1, + point_cloud2, + procrustes_disparity, + _reference_procrustes, + ) + + +def test_error_on_different_shape(): + """Test that error is raised on different shapes of input.""" + metric = ProcrustesDisparity() + with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"): + metric(torch.randn(10, 100, 2), torch.randn(10, 50, 2)) + with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"): + procrustes_disparity(torch.randn(10, 100, 2), torch.randn(10, 50, 2)) + + +def test_error_on_non_3d_input(): + """Test that error is raised if input is not 3-dimensional.""" + metric = ProcrustesDisparity() + with pytest.raises(ValueError, match="Expected both datasets to be 3D tensors of shape"): + metric(torch.randn(100), torch.randn(100)) + with pytest.raises(ValueError, match="Expected both datasets to be 3D tensors of shape"): + procrustes_disparity(torch.randn(100), torch.randn(100)) diff --git a/tests/unittests/test_deprecated.py b/tests/unittests/test_deprecated.py deleted file mode 100644 index f126fa06561..00000000000 --- a/tests/unittests/test_deprecated.py +++ /dev/null @@ -1,16 +0,0 @@ -import pytest -import torch -from torchmetrics.functional.regression import kl_divergence -from torchmetrics.regression import KLDivergence - - -def test_deprecated_kl_divergence_input_order(): - """Ensure that the deprecated input order for kl_divergence raises a warning.""" - preds = torch.randn(10, 2) - target = torch.randn(10, 2) - - with pytest.deprecated_call(match="The input order and naming in metric `kl_divergence` is set to be deprecated.*"): - kl_divergence(preds, target) - - with pytest.deprecated_call(match="The input order and naming in metric `KLDivergence` is set to be deprecated.*"): - KLDivergence() diff --git a/tests/unittests/text/test_bertscore.py b/tests/unittests/text/test_bertscore.py index dfd6d60a0e5..1d74f0c858d 100644 --- a/tests/unittests/text/test_bertscore.py +++ b/tests/unittests/text/test_bertscore.py @@ -174,10 +174,7 @@ def test_bertscore_differentiability( @skip_on_connection_issues() @pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") -@pytest.mark.parametrize( - "idf", - [(False,), (True,)], -) +@pytest.mark.parametrize("idf", [True, False]) def test_bertscore_sorting(idf: bool): """Test that BERTScore is invariant to the order of the inputs.""" short = "Short text" @@ -191,3 +188,20 @@ def test_bertscore_sorting(idf: bool): # First index should be the self-comparison - sorting by length should not shuffle this assert score["f1"][0] > score["f1"][1] + + +@skip_on_connection_issues() +@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") +@pytest.mark.parametrize("truncation", [True, False]) +def test_bertscore_truncation(truncation: bool): + """Test that BERTScore truncation works as expected.""" + pred = ["abc " * 2000] + gt = ["def " * 2000] + bert_score = BERTScore(truncation=truncation) + + if truncation: + res = bert_score(pred, gt) + assert res["f1"] > 0.0 + else: + with pytest.raises(RuntimeError, match="The expanded size of the tensor.*must match.*"): + bert_score(pred, gt) diff --git a/tests/unittests/text/test_infolm.py b/tests/unittests/text/test_infolm.py index 1ee45cde02e..b3fd26026ca 100644 --- a/tests/unittests/text/test_infolm.py +++ b/tests/unittests/text/test_infolm.py @@ -182,3 +182,18 @@ def test_infolm_differentiability(self, preds, targets, information_measure, idf metric_functional=infolm, metric_args=metric_args, ) + + @skip_on_connection_issues() + def test_infolm_higher_is_better_property(self, preds, targets, information_measure, idf, alpha, beta): + """Test the `higher_is_better` property of the metric.""" + metric_args = { + "model_name_or_path": MODEL_NAME, + "information_measure": information_measure, + "idf": idf, + "alpha": alpha, + "beta": beta, + "max_length": MAX_LENGTH, + } + + metric = InfoLM(**metric_args) + assert metric.higher_is_better == metric._information_measure_higher_is_better[information_measure] diff --git a/tests/unittests/text/test_sacre_bleu.py b/tests/unittests/text/test_sacre_bleu.py index e8b66012011..bf2d45fbd5a 100644 --- a/tests/unittests/text/test_sacre_bleu.py +++ b/tests/unittests/text/test_sacre_bleu.py @@ -55,8 +55,8 @@ def test_bleu_score_class(self, ddp, preds, targets, tokenize, lowercase): """Test class implementation of metric.""" if _should_skip_tokenizer(tokenize): pytest.skip(reason="`ko-mecab` tokenizer requires `mecab-ko` package to be installed") - if tokenize == "flores200": - pytest.skip("flores200 tests are flaky") # TODO: figure out why + if tokenize == "flores200" or tokenize == "flores101": + pytest.skip("flores101 and flores200 tests are flaky") # TODO: figure out why metric_args = {"tokenize": tokenize, "lowercase": lowercase} original_sacrebleu = partial(_reference_sacre_bleu, tokenize=tokenize, lowercase=lowercase) @@ -75,6 +75,8 @@ def test_bleu_score_functional(self, preds, targets, tokenize, lowercase): """Test functional implementation of metric.""" if _should_skip_tokenizer(tokenize): pytest.skip(reason="`ko-mecab` tokenizer requires `mecab-ko` package to be installed") + if tokenize == "flores200" or tokenize == "flores101": + pytest.skip("flores101 and flores200 tests are flaky") # TODO: figure out why metric_args = {"tokenize": tokenize, "lowercase": lowercase} original_sacrebleu = partial(_reference_sacre_bleu, tokenize=tokenize, lowercase=lowercase) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 29f2d7c81de..efb7077682e 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -130,6 +130,7 @@ MeanSquaredError, MeanSquaredLogError, MinkowskiDistance, + NormalizedRootMeanSquaredError, PearsonCorrCoef, R2Score, RelativeSquaredError, @@ -150,11 +151,11 @@ RetrievalRecallAtFixedPrecision, RetrievalRPrecision, ) +from torchmetrics.shape import ProcrustesDisparity from torchmetrics.text import ( BERTScore, BLEUScore, CharErrorRate, - CHRFScore, EditDistance, ExtendedEditDistance, InfoLM, @@ -168,10 +169,6 @@ WordInfoLost, WordInfoPreserved, ) -from torchmetrics.utilities.imports import ( - _TORCH_GREATER_EQUAL_1_12, - _TORCHAUDIO_GREATER_EQUAL_0_10, -) from torchmetrics.utilities.plot import _get_col_row_split from torchmetrics.wrappers import ( BootStrapper, @@ -316,7 +313,6 @@ _audio_input, None, id="speech_reverberation_modulation_energy_ratio", - marks=pytest.mark.skipif(not _TORCHAUDIO_GREATER_EQUAL_0_10, reason="test requires torchaudio>=0.10"), ), pytest.param( partial(PermutationInvariantTraining, metric_func=scale_invariant_signal_noise_ratio, eval_func="max"), @@ -342,9 +338,6 @@ _panoptic_input, _panoptic_input, id="panoptic quality", - marks=pytest.mark.skipif( - not _TORCH_GREATER_EQUAL_1_12, reason="Panoptic Quality metric requires PyTorch 1.12 or later" - ), ), pytest.param(BinaryAveragePrecision, _rand_input, _binary_randint_input, id="binary average precision"), pytest.param( @@ -477,6 +470,7 @@ pytest.param(MeanAbsoluteError, _rand_input, _rand_input, id="mean absolute error"), pytest.param(MeanAbsolutePercentageError, _rand_input, _rand_input, id="mean absolute percentage error"), pytest.param(partial(MinkowskiDistance, p=3), _rand_input, _rand_input, id="minkowski distance"), + pytest.param(NormalizedRootMeanSquaredError, _rand_input, _rand_input, id="normalized root mean squared error"), pytest.param(PearsonCorrCoef, _rand_input, _rand_input, id="pearson corr coef"), pytest.param(R2Score, _rand_input, _rand_input, id="r2 score"), pytest.param(RelativeSquaredError, _rand_input, _rand_input, id="relative squared error"), @@ -588,7 +582,6 @@ pytest.param(EditDistance, _text_input_1, _text_input_2, id="edit distance"), pytest.param(MatchErrorRate, _text_input_1, _text_input_2, id="match error rate"), pytest.param(BLEUScore, _text_input_3, _text_input_4, id="bleu score"), - pytest.param(CHRFScore, _text_input_3, _text_input_4, id="bleu score"), pytest.param( partial(InfoLM, model_name_or_path="google/bert_uncased_L-2_H-128_A-2", idf=False, verbose=False), _text_input_1, @@ -611,6 +604,12 @@ pytest.param(CalinskiHarabaszScore, lambda: torch.randn(100, 3), _nominal_input, id="calinski harabasz score"), pytest.param(NormalizedMutualInfoScore, _nominal_input, _nominal_input, id="normalized mutual info score"), pytest.param(DunnIndex, lambda: torch.randn(100, 3), _nominal_input, id="dunn index"), + pytest.param( + ProcrustesDisparity, + lambda: torch.randn(1, 100, 3), + lambda: torch.randn(1, 100, 3), + id="procrustes disparity", + ), ], ) @pytest.mark.parametrize("num_vals", [1, 3]) diff --git a/tests/unittests/utilities/test_utilities.py b/tests/unittests/utilities/test_utilities.py index e61b2ec0e33..ca082dec969 100644 --- a/tests/unittests/utilities/test_utilities.py +++ b/tests/unittests/utilities/test_utilities.py @@ -31,7 +31,7 @@ ) from torchmetrics.utilities.distributed import class_reduce, reduce from torchmetrics.utilities.exceptions import TorchMetricsUserWarning -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_13, _TORCH_GREATER_EQUAL_2_2 +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_2 def test_prints(): @@ -172,9 +172,6 @@ def test_recursive_allclose(inputs, expected): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU") @pytest.mark.xfail(sys.platform == "win32", reason="test will only fail on non-windows systems") -@pytest.mark.skipif( - not _TORCH_GREATER_EQUAL_1_13, reason="earlier versions was silently non-deterministic, even in deterministic mode" -) def test_cumsum_still_not_supported(use_deterministic_algorithms): """Make sure that cumsum on gpu and deterministic mode still fails. @@ -188,7 +185,9 @@ def test_cumsum_still_not_supported(use_deterministic_algorithms): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU") def test_custom_cumsum(use_deterministic_algorithms): """Test custom cumsum implementation.""" - x = torch.arange(100).float().cuda() + # check that cumsum works as expected on non-default cuda device + device = torch.device("cuda:1") if torch.cuda.device_count() > 1 else torch.device("cuda:0") + x = torch.arange(100).float().to(device) if sys.platform != "win32": with pytest.warns( TorchMetricsUserWarning, match="You are trying to use a metric in deterministic mode on GPU that.*" @@ -217,8 +216,6 @@ def _reference_topk(x, dim, k): @pytest.mark.parametrize("dim", [0, 1]) def test_custom_topk(dtype, k, dim): """Test custom topk implementation.""" - if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_13: - pytest.skip("half precision topk not supported in Pytorch < 1.13") x = torch.randn(100, 10, dtype=dtype) top_k = select_topk(x, dim=dim, topk=k) assert top_k.shape == (100, 10) diff --git a/tests/unittests/wrappers/test_multitask.py b/tests/unittests/wrappers/test_multitask.py index 63af6f31b35..fb3ae8987cc 100644 --- a/tests/unittests/wrappers/test_multitask.py +++ b/tests/unittests/wrappers/test_multitask.py @@ -19,6 +19,7 @@ from torchmetrics import MetricCollection from torchmetrics.classification import BinaryAccuracy, BinaryF1Score from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_5 from torchmetrics.wrappers import MultitaskWrapper from unittests import BATCH_SIZE, NUM_BATCHES @@ -90,13 +91,15 @@ def test_error_on_wrong_keys(): "Classification": BinaryAccuracy(), }) + order_dict = "" if _TORCH_GREATER_EQUAL_2_5 else "o" + with pytest.raises( ValueError, match=re.escape( - "Expected arguments `task_preds` and `task_targets` to have the same keys as the wrapped `task_metrics`. " - "Found task_preds.keys() = dict_keys(['Classification']), task_targets.keys() = " - "dict_keys(['Classification', 'Regression']) and self.task_metrics.keys() = " - "odict_keys(['Classification', 'Regression'])" + "Expected arguments `task_preds` and `task_targets` to have the same keys as the wrapped `task_metrics`." + " Found task_preds.keys() = dict_keys(['Classification'])," + " task_targets.keys() = dict_keys(['Classification', 'Regression'])" + f" and self.task_metrics.keys() = {order_dict}dict_keys(['Classification', 'Regression'])" ), ): multitask_metrics.update(wrong_key_preds, _multitask_targets) @@ -104,9 +107,10 @@ def test_error_on_wrong_keys(): with pytest.raises( ValueError, match=re.escape( - "Expected arguments `task_preds` and `task_targets` to have the same keys as the wrapped `task_metrics`. " - "Found task_preds.keys() = dict_keys(['Classification', 'Regression']), task_targets.keys() = " - "dict_keys(['Classification']) and self.task_metrics.keys() = odict_keys(['Classification', 'Regression'])" + "Expected arguments `task_preds` and `task_targets` to have the same keys as the wrapped `task_metrics`." + " Found task_preds.keys() = dict_keys(['Classification', 'Regression'])," + " task_targets.keys() = dict_keys(['Classification'])" + f" and self.task_metrics.keys() = {order_dict}dict_keys(['Classification', 'Regression'])" ), ): multitask_metrics.update(_multitask_preds, wrong_key_targets) @@ -114,9 +118,10 @@ def test_error_on_wrong_keys(): with pytest.raises( ValueError, match=re.escape( - "Expected arguments `task_preds` and `task_targets` to have the same keys as the wrapped `task_metrics`. " - "Found task_preds.keys() = dict_keys(['Classification', 'Regression']), task_targets.keys() = " - "dict_keys(['Classification', 'Regression']) and self.task_metrics.keys() = odict_keys(['Classification'])" + "Expected arguments `task_preds` and `task_targets` to have the same keys as the wrapped `task_metrics`." + " Found task_preds.keys() = dict_keys(['Classification', 'Regression'])," + " task_targets.keys() = dict_keys(['Classification', 'Regression'])" + f" and self.task_metrics.keys() = {order_dict}dict_keys(['Classification'])" ), ): wrong_key_multitask_metrics.update(_multitask_preds, _multitask_targets) @@ -248,14 +253,24 @@ def test_key_value_items_method(method, flatten): def test_clone_with_prefix_and_postfix(): """Check that the clone method works with prefix and postfix arguments.""" - multitask_metrics = MultitaskWrapper({"Classification": BinaryAccuracy(), "Regression": MeanSquaredError()}) - cloned_metrics_with_prefix = multitask_metrics.clone(prefix="prefix_") - cloned_metrics_with_postfix = multitask_metrics.clone(postfix="_postfix") + multitask_metrics = MultitaskWrapper( + {"Classification": BinaryAccuracy(), "Regression": MeanSquaredError()}, + prefix="prefix_", + postfix="_postfix", + ) + assert set(multitask_metrics.keys()) == {"prefix_Classification_postfix", "prefix_Regression_postfix"} - # Check if the cloned metrics have the expected keys - assert set(cloned_metrics_with_prefix.task_metrics.keys()) == {"prefix_Classification", "prefix_Regression"} - assert set(cloned_metrics_with_postfix.task_metrics.keys()) == {"Classification_postfix", "Regression_postfix"} + output = multitask_metrics( + {"Classification": _classification_preds, "Regression": _regression_preds}, + {"Classification": _classification_target, "Regression": _regression_target}, + ) + assert set(output.keys()) == {"prefix_Classification_postfix", "prefix_Regression_postfix"} + + cloned_metrics = multitask_metrics.clone(prefix="new_prefix_", postfix="_new_postfix") + assert set(cloned_metrics.keys()) == {"new_prefix_Classification_new_postfix", "new_prefix_Regression_new_postfix"} - # Check if the cloned metrics have the expected values - assert isinstance(cloned_metrics_with_prefix.task_metrics["prefix_Classification"], BinaryAccuracy) - assert isinstance(cloned_metrics_with_prefix.task_metrics["prefix_Regression"], MeanSquaredError) + output = cloned_metrics( + {"Classification": _classification_preds, "Regression": _regression_preds}, + {"Classification": _classification_target, "Regression": _regression_target}, + ) + assert set(output.keys()) == {"new_prefix_Classification_new_postfix", "new_prefix_Regression_new_postfix"} diff --git a/tests/unittests/wrappers/test_tracker.py b/tests/unittests/wrappers/test_tracker.py index 93cdbc452ca..97c2ae37234 100644 --- a/tests/unittests/wrappers/test_tracker.py +++ b/tests/unittests/wrappers/test_tracker.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings + import pytest import torch -from torchmetrics import MetricCollection +from torchmetrics import Metric, MetricCollection from torchmetrics.classification import ( MulticlassAccuracy, MulticlassConfusionMatrix, @@ -22,6 +24,7 @@ MulticlassRecall, ) from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError +from torchmetrics.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_1_6 from torchmetrics.wrappers import MetricTracker, MultioutputWrapper from unittests._helpers import seed_all @@ -216,3 +219,37 @@ def test_metric_tracker_and_collection_multioutput(input_to_tracker, assert_type else: assert best_metric is None assert which_epoch is None + + +def test_tracker_futurewarning(): + """Check that future warning is raised for the maximize argument. + + Also to make sure that we remove it in future versions of TM. + + """ + if _TORCHMETRICS_GREATER_EQUAL_1_6: + # Check that for future versions that we remove the warning + with warnings.catch_warnings(): + warnings.simplefilter("error") + MetricTracker(MeanSquaredError(), maximize=True) + else: + with pytest.warns(FutureWarning, match="The default value for `maximize` will be changed from `True` to.*"): + MetricTracker(MeanSquaredError(), maximize=True) + + +@pytest.mark.parametrize( + "base_metric", + [ + MeanSquaredError(), + MeanAbsoluteError(), + MulticlassAccuracy(num_classes=10), + MetricCollection([MeanSquaredError(), MeanAbsoluteError()]), + ], +) +def test_tracker_higher_is_better_integration(base_metric): + """Check that the maximize argument is correctly set based on the metric higher_is_better attribute.""" + tracker = MetricTracker(base_metric, maximize=None) + if isinstance(base_metric, Metric): + assert tracker.maximize == base_metric.higher_is_better + else: + assert tracker.maximize == [m.higher_is_better for m in base_metric.values()]