diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index c50178ca..6162f8f7 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -26,7 +26,7 @@ jobs: fetch-depth: 0 - name: Setup Miniconda - uses: conda-incubator/setup-miniconda@v2.2.0 + uses: conda-incubator/setup-miniconda@v3.0.2 with: # installer-url: https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh installer-url: https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh @@ -74,7 +74,7 @@ jobs: cp benchmarks.log .asv/results/ working-directory: ${{ env.ASV_DIR }} - - uses: actions/upload-artifact@v3.1.3 + - uses: actions/upload-artifact@v4.3.1 if: always() with: name: asv-benchmark-results-${{ runner.os }} diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 14e52719..2dd80b37 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -33,7 +33,7 @@ jobs: with: fetch-depth: 0 - name: Set up conda - uses: conda-incubator/setup-miniconda@v2.2.0 + uses: conda-incubator/setup-miniconda@v3.0.2 with: auto-update-conda: false channels: conda-forge @@ -60,7 +60,7 @@ jobs: shell: 'bash -l {0}' steps: - uses: actions/checkout@v4.1.1 - - uses: conda-incubator/setup-miniconda@v2.2.0 + - uses: conda-incubator/setup-miniconda@v3.0.2 with: channels: conda-forge miniforge-variant: Mambaforge @@ -91,7 +91,7 @@ jobs: steps: - uses: actions/checkout@v4.1.1 - name: Set up conda - uses: conda-incubator/setup-miniconda@v2.2.0 + uses: conda-incubator/setup-miniconda@v3.0.2 with: auto-update-conda: false channels: conda-forge @@ -134,11 +134,11 @@ jobs: with: fetch-depth: 0 - name: Setup python - uses: actions/setup-python@v4.7.1 + uses: actions/setup-python@v5.0.0 with: python-version: '3.11' - name: Set up Julia - uses: julia-actions/setup-julia@v1.9.2 + uses: julia-actions/setup-julia@v1.9.6 with: version: 1.7.1 - name: Install dependencies diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 21c614f9..34ae4030 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -16,7 +16,7 @@ jobs: - uses: actions/checkout@v4.1.1 - name: Set up Python - uses: actions/setup-python@v4.7.1 + uses: actions/setup-python@v5.0.0 with: python-version: "3.10" @@ -45,7 +45,7 @@ jobs: - name: Publish a Python distribution to PyPI if: success() && github.event_name == 'release' - uses: pypa/gh-action-pypi-publish@v1.8.10 + uses: pypa/gh-action-pypi-publish@v1.8.11 with: user: __token__ password: ${{ secrets.PYPI_PASSWORD }} diff --git a/.github/workflows/updater.yaml b/.github/workflows/updater.yaml index 66674e97..223b3f23 100644 --- a/.github/workflows/updater.yaml +++ b/.github/workflows/updater.yaml @@ -21,3 +21,4 @@ jobs: with: # [Required] Access token with `workflow` scope. token: ${{ secrets.WORKFLOW_SECRET }} + pull_request_branch: gh-actions-update diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8f9e8dce..b71c00c8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ ci: # https://pre-commit.com/ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -14,38 +14,38 @@ repos: - id: mixed-line-ending # This wants to go before isort & flake8 - repo: https://github.com/PyCQA/autoflake - rev: "v2.0.2" + rev: "v2.2.1" hooks: - id: autoflake # isort should run before black as black sometimes tweaks the isort output args: ["--in-place", "--ignore-init-module-imports"] - repo: https://github.com/PyCQA/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort args: ["--profile", "black"] - repo: https://github.com/asottile/pyupgrade - rev: v3.3.1 + rev: v3.15.0 hooks: - id: pyupgrade args: - "--py38-plus" # https://github.com/python/black#version-control-integration - repo: https://github.com/psf/black - rev: 23.3.0 + rev: 23.12.1 hooks: - id: black - id: black-jupyter - repo: https://github.com/keewis/blackdoc - rev: v0.3.8 + rev: v0.3.9 hooks: - id: blackdoc exclude: docs/index.rst - repo: https://github.com/PyCQA/flake8 - rev: 6.0.0 + rev: 6.1.0 hooks: - id: flake8 - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.1.1 + rev: v1.8.0 hooks: - id: mypy exclude: "properties|asv_bench" @@ -58,6 +58,6 @@ repos: types-pkg_resources, types-PyYAML, types-pytz, - typing-extensions==3.10.0.0, + typing-extensions, numpy, ] diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 55433717..471528b7 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -5,6 +5,7 @@ CHANGELOG X.X.X (unreleased) ------------------ +* Add support for additional datatypes in :py:func:`xbitinfo.xbitinfo.plot_bitinformation` (:pr:`218`, :issue:`168`) `Hauke Schulz`_. * Drop python 3.8 support and add python 3.11 (:pr:`175`) `Hauke Schulz`_. * Implement basic retrieval of bitinformation in python as alternative to julia implementation (:pr:`156`, :issue:`155`, :pr:`126`, :issue:`125`) `Hauke Schulz`_ with helpful comments from `Milan Klöwer`_. * Make julia binding to BitInformation.jl optional (:pr:`153`, :issue:`151`) `Aaron Spring`_. diff --git a/docs/chunking.ipynb b/docs/chunking.ipynb new file mode 100644 index 00000000..cf58b158 --- /dev/null +++ b/docs/chunking.ipynb @@ -0,0 +1,412 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a1f40619", + "metadata": {}, + "source": [ + "# Chunking" + ] + }, + { + "cell_type": "markdown", + "id": "b8e2d4f5-c444-404a-8444-3648cb0a94bf", + "metadata": {}, + "source": [ + "Geospatial data can vary in its information density from one part of the world to another. A dataset containing streets will be very dense in cities but contains little information in remote places like the Alps or even the ocean. The same is also true for datasets about the ocean or the atmosphere.\n", + "\n", + "By default the number of bits that need to be kept (`keepbits`) to preserve the requested amount of information is determined based on the entire dataset. This approach doesn't always result in the best compression rates as it preserves too many keepbits in regions with anomalously low information density. The following steps show how the `keepbits` can be retrieved and applied on subsets. In this case, subsets are defined as dataset chunks.\n", + "\n", + "This work is a result of the ECMWF Code4Earth 2023. Please have a look at the [presentation of this project](https://youtu.be/IOi4XvECpsQ?si=hwZkppNRa-J2XVZ9) for additional details." + ] + }, + { + "cell_type": "markdown", + "id": "e515b4bd-a302-45a9-8464-56b67a73a46c", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96e9149e-fc6d-4048-8e45-a29966e5c6b8", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from itertools import product\n", + "import numpy as np\n", + "\n", + "import xarray as xr\n", + "import xbitinfo as xb" + ] + }, + { + "cell_type": "markdown", + "id": "b64e0873-0a27-4757-947a-4a559a102288", + "metadata": {}, + "source": [ + "## Data loading" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "320224c9-06e2-428a-8614-8ed0d15eee82", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# load data\n", + "ds = xr.tutorial.load_dataset(\"air_temperature\")\n", + "\n", + "# Defining chunks that will be used for the reading/bitrounding/writing\n", + "chunks = {\n", + " \"lat\": 5,\n", + " \"lon\": 10,\n", + "}\n", + "\n", + "# Apply chunking\n", + "ds = ds.chunk(chunks)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f3120040-79f1-4a7f-a61f-afec9fb3ca5b", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "ds" + ] + }, + { + "cell_type": "markdown", + "id": "6b1b95de-f8e5-45c3-be3b-0555a67efb77", + "metadata": {}, + "source": [ + "## Zarr chunking and compressing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "91343d2a-63ec-4d61-a369-cc99139297e4", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def bitrounding(chunk, var=\"lat\"):\n", + " \"\"\"\n", + " Just a function that handles all the xbitinfo calls\n", + " \"\"\"\n", + " bitinfo = xb.get_bitinformation(chunk, dim=var, implementation=\"python\")\n", + " keepbits = xb.get_keepbits(bitinfo, 0.99)\n", + " bitround = xb.xr_bitround(chunk, keepbits)\n", + " return bitround, keepbits\n", + "\n", + "\n", + "def slices_from_chunks(chunks):\n", + " \"\"\"Translate chunks tuple to a set of slices in product order\n", + "\n", + " >>> slices_from_chunks(((2, 2), (3, 3, 3))) # doctest: +NORMALIZE_WHITESPACE\n", + " [(slice(0, 2, None), slice(0, 3, None)),\n", + " (slice(0, 2, None), slice(3, 6, None)),\n", + " (slice(0, 2, None), slice(6, 9, None)),\n", + " (slice(2, 4, None), slice(0, 3, None)),\n", + " (slice(2, 4, None), slice(3, 6, None)),\n", + " (slice(2, 4, None), slice(6, 9, None))]\n", + " \"\"\"\n", + " cumdims = []\n", + " for bds in chunks:\n", + " out = np.empty(len(bds) + 1, dtype=int)\n", + " out[0] = 0\n", + " np.cumsum(bds, out=out[1:])\n", + " cumdims.append(out)\n", + " slices = [\n", + " [slice(s, s + dim) for s, dim in zip(starts, shapes)]\n", + " for starts, shapes in zip(cumdims, chunks)\n", + " ]\n", + " return list(product(*slices))" + ] + }, + { + "cell_type": "markdown", + "id": "7221b47f-b8f4-4ebf-bc2b-cb61d12989be", + "metadata": { + "tags": [] + }, + "source": [ + "### Save dataset as compressed zarr after compressing individual chunks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c2ed18f-4dc8-4f5c-88ed-ae5ad41d1647", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%%capture\n", + "fn = \"air_bitrounded_by_chunks.zarr\" # Output filename\n", + "ds.to_compressed_zarr(fn, compute=False, mode=\"w\") # Creates empty file structure\n", + "\n", + "dims = ds.air.dims\n", + "len_dims = len(dims)\n", + "\n", + "slices = slices_from_chunks(ds.air.chunks)\n", + "\n", + "# Loop over each chunk\n", + "keepbits = []\n", + "for b, block in enumerate(ds.air.data.to_delayed().ravel()):\n", + " # Conversion of dask.delayed array to Dataset (as xbitinfo wants type xr.Dataset)\n", + " ds_block = xr.Dataset({\"air\": (dims, block.compute())})\n", + "\n", + " # Apply bitrounding\n", + " rounded_ds, keepbit = bitrounding(ds_block)\n", + " keepbits.append(keepbit)\n", + "\n", + " # Write individual chunk to disk\n", + " rounded_ds.to_zarr(fn, region={dims[d]: s for (d, s) in enumerate(slices[b])})" + ] + }, + { + "cell_type": "markdown", + "id": "5d628121-d5ec-4544-a47f-f47c86524b09", + "metadata": {}, + "source": [ + "### Plot" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d8835a3c-8af0-4423-baf4-84aa9a386f67", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import matplotlib as mpl\n", + "import matplotlib.patheffects as pe" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c3b5f657-cddd-4476-82a3-c3c2c1a6e7b6", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# # Create a figure and axis and plot the air temperature\n", + "fig, ax = plt.subplots(figsize=(12, 6))\n", + "rounded_ds = xr.open_zarr(fn).isel(time=0)\n", + "rounded_ds[\"air\"].plot(ax=ax, cmap=\"RdBu_r\")\n", + "\n", + "slices = slices_from_chunks(rounded_ds.air.chunks)\n", + "\n", + "for i in range(len(slices)):\n", + " # Get chunk limits\n", + " dss = rounded_ds.isel(lat=slices[i][0], lon=slices[i][1])\n", + " lats = dss.lat\n", + " longs = dss.lon\n", + "\n", + " x = float(min(longs[0], longs[-1]))\n", + " y = float(min(lats[0], lats[-1]))\n", + " w = float(abs(longs[0] - longs[-1]))\n", + " h = float(abs(lats[0] - lats[-1]))\n", + "\n", + " # Draw rectangle\n", + " rect = mpl.patches.Rectangle(\n", + " (x, y),\n", + " width=w,\n", + " height=h,\n", + " facecolor=\"none\",\n", + " edgecolor=\"#E5E4E2\",\n", + " path_effects=[pe.withStroke(linewidth=3, foreground=\"gray\")],\n", + " )\n", + " ax.add_patch(rect)\n", + "\n", + " # Annotate number of keepbits\n", + " rx, ry = rect.get_xy()\n", + " cx = rx + rect.get_width() / 2.0\n", + " cy = ry + rect.get_height() / 2.0\n", + " ax.annotate(\n", + " f\"{int(keepbits[i].air):2}\",\n", + " (cx, cy),\n", + " color=\"k\",\n", + " weight=\"normal\",\n", + " fontsize=14,\n", + " ha=\"right\",\n", + " va=\"center\",\n", + " path_effects=[pe.withStroke(linewidth=2, foreground=\"w\")],\n", + " )\n", + "\n", + "fig.text(0.39, 0.94, f\"Keepbits \", weight=\"bold\", fontsize=16)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "b9e8fe5a-2e4e-4dfd-8026-0991e9988668", + "metadata": {}, + "source": [ + "## Reference compression\n", + "For comparision with other compression approaches the dataset is also saved as:\n", + "- uncompressed netCDF\n", + "- lossless compressed zarr\n", + "- lossy compressed zarr while preserving 99% of bitinformation" + ] + }, + { + "cell_type": "markdown", + "id": "a77919ff", + "metadata": {}, + "source": [ + "### Saving to uncompressed `NetCDF` file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e011a900-5da2-40be-a292-d81a0cafcd6d", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Saving the dataset as NetCDF file\n", + "ds.to_netcdf(\"0.air_original.nc\")" + ] + }, + { + "cell_type": "markdown", + "id": "1cc93427", + "metadata": {}, + "source": [ + "### Save dataset as compressed zarr (without bitrounding)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19032fcd-93bc-48b8-ba1f-beba9673491b", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "fn = \"air_compressed.zarr\" # Output filename\n", + "ds.to_compressed_zarr(fn, mode=\"w\") # Creates empty file structure" + ] + }, + { + "cell_type": "markdown", + "id": "648f759c", + "metadata": {}, + "source": [ + "### Save dataset as compressed zarr after applying bitrounding" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "93eb4cd6", + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "%%capture\n", + "fn = \"air_bitrounded.zarr\" # Output filename\n", + "rounded_ds, keepbits = bitrounding(ds)\n", + "rounded_ds.to_compressed_zarr(fn, mode=\"w\")" + ] + }, + { + "cell_type": "markdown", + "id": "d3b60c66-252d-48a6-af93-a00c9ca8f0ba", + "metadata": { + "tags": [] + }, + "source": [ + "## Summary" + ] + }, + { + "cell_type": "markdown", + "id": "b28089ea-22f9-45c6-abc9-b65bd946ac66", + "metadata": {}, + "source": [ + "Below are the file sizes resulting from the various compression techniques outlined above." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "998581b5-6ad9-4f6f-9c61-d0bf1486ec7f", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "!du -hs *.nc *.zarr" + ] + }, + { + "cell_type": "markdown", + "id": "15c6975d-6909-4e2c-9395-0a64d39ed44f", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "fed34f3f-2bee-45d3-9bdf-1237b77cf1b8", + "metadata": {}, + "source": [ + "In this experiment, the sizes are minimized when applying bitrounding and compression to the dataset chunks. \n", + "\n", + "However, it's important to note that this outcome may not be universally applicable, check this for your dataset." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + }, + "toc-autonumbering": true + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/conf.py b/docs/conf.py index 7907a2ce..de102f46 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -152,8 +152,8 @@ copybutton_remove_prompts = True extlinks = { - "issue": ("https://github.com/observingClouds/xbitinfo/issues/%s", "GH#"), - "pr": ("https://github.com/observingClouds/xbitinfo/pull/%s", "GH#"), + "issue": ("https://github.com/observingClouds/xbitinfo/issues/%s", "GH#%s"), + "pr": ("https://github.com/observingClouds/xbitinfo/pull/%s", "GH#%s"), } diff --git a/docs/index.rst b/docs/index.rst index 2dcab627..96a32342 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -107,6 +107,18 @@ Credits ArtificialInformation_Filter.ipynb + +**User Guide** + +* :doc:`chunking` + +.. toctree:: + :maxdepth: 1 + :hidden: + :caption: User Guide + + chunking.ipynb + **Help & Reference** * :doc:`api` diff --git a/environment.yml b/environment.yml index fe242e9e..a75c755e 100644 --- a/environment.yml +++ b/environment.yml @@ -28,7 +28,6 @@ dependencies: - sphinx-book-theme - myst-nb - numcodecs>=0.10.0 - - pytest-lazy-fixture - pip - pip: - -e . diff --git a/setup.py b/setup.py index dfd973c0..fe091f16 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ with open("requirements.txt") as f: requirements = f.read().strip().split("\n") -test_requirements = ["pytest", "pytest-lazy-fixture", "pooch", "netcdf4", "dask"] +test_requirements = ["pytest", "pooch", "netcdf4", "dask"] extras_require = { "viz": ["matplotlib", "cmcrameri"], diff --git a/tests/test_bitinformation_pipeline.py b/tests/test_bitinformation_pipeline.py index 777d7dd9..05388b46 100644 --- a/tests/test_bitinformation_pipeline.py +++ b/tests/test_bitinformation_pipeline.py @@ -9,17 +9,18 @@ @pytest.mark.parametrize( "ds,dim,axis", [ - (pytest.lazy_fixture("ugrid_demo"), None, -1), - (pytest.lazy_fixture("icon_grid_demo"), "ncells", None), - (pytest.lazy_fixture("air_temperature"), "lon", None), - (pytest.lazy_fixture("rasm"), "x", None), - (pytest.lazy_fixture("ROMS_example"), "eta_rho", None), - (pytest.lazy_fixture("era52mt"), "time", None), - (pytest.lazy_fixture("eraint_uvz"), "longitude", None), + ("ugrid_demo", None, -1), + ("icon_grid_demo", "ncells", None), + ("air_temperature", "lon", None), + ("rasm", "x", None), + ("ROMS_example", "eta_rho", None), + ("era52mt", "time", None), + ("eraint_uvz", "longitude", None), ], ) -def test_full(ds, dim, axis): +def test_full(ds, dim, axis, request): """Test xbitinfo end to end.""" + ds = request.getfixturevalue(ds) # xbitinfo bitinfo = xb.get_bitinformation(ds, dim=dim, axis=axis) keepbits = xb.get_keepbits(bitinfo) diff --git a/tests/test_get_bitinformation.py b/tests/test_get_bitinformation.py index 0b370d07..0a38ca0f 100644 --- a/tests/test_get_bitinformation.py +++ b/tests/test_get_bitinformation.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python - """Tests for `xbitinfo` package.""" import os @@ -33,7 +31,7 @@ def assert_different(a, b): numpy.testing.assert_array_equal """ __tracebackhide__ = True - assert type(a) == type(b) + assert isinstance(a, type(b)) if isinstance(a, (Variable, DataArray)): assert not a.equals(b), formatting.diff_array_repr(a, b, "equals") elif isinstance(a, Dataset): @@ -167,7 +165,7 @@ def test_get_bitinformation_dtype(rasm, dtype, implementation): ds = rasm.astype(dtype) v = list(ds.data_vars)[0] dtype_bits = dtype.replace("float", "") - assert len(xb.get_bitinformation(ds, dim="x")[v].coords["bit" + dtype_bits]) == int( + assert len(xb.get_bitinformation(ds, dim="x")[v].coords["bit" + dtype]) == int( dtype_bits ) @@ -206,7 +204,7 @@ def test_get_bitinformation_different_dtypes(rasm, implementation): ds["Tair32"] = ds.Tair.astype("float32") ds["Tair16"] = ds.Tair.astype("float16") bi = xb.get_bitinformation(ds, implementation=implementation) - for bitdim in ["bit16", "bit32", "bit64"]: + for bitdim in ["bitfloat16", "bitfloat32", "bitfloat64"]: assert bitdim in bi.dims assert bitdim in bi.coords @@ -228,17 +226,18 @@ def test_get_bitinformation_keep_attrs(rasm): @pytest.mark.parametrize( "ds,dim,axis", [ - (pytest.lazy_fixture("ugrid_demo"), None, -1), - (pytest.lazy_fixture("icon_grid_demo"), "ncells", None), - (pytest.lazy_fixture("air_temperature"), "lon", None), - (pytest.lazy_fixture("rasm"), "x", None), - (pytest.lazy_fixture("ROMS_example"), "eta_rho", None), - (pytest.lazy_fixture("era52mt"), "time", None), - (pytest.lazy_fixture("eraint_uvz"), "longitude", None), + ("ugrid_demo", None, -1), + ("icon_grid_demo", "ncells", None), + ("air_temperature", "lon", None), + ("rasm", "x", None), + ("ROMS_example", "eta_rho", None), + ("era52mt", "time", None), + ("eraint_uvz", "longitude", None), ], ) -def test_implementations_agree(ds, dim, axis): +def test_implementations_agree(ds, dim, axis, request): """Test whether the python and julia implementation retrieve the same results""" + ds = request.getfixturevalue(ds) bi_python = xb.get_bitinformation( ds, dim=dim, diff --git a/tests/test_visualisation.py b/tests/test_visualisation.py index 6a397b6c..d488c9fe 100644 --- a/tests/test_visualisation.py +++ b/tests/test_visualisation.py @@ -3,7 +3,7 @@ import xarray as xr import xbitinfo as xb -from xbitinfo.graphics import add_bitinfo_labels +from xbitinfo.graphics import add_bitinfo_labels, plot_bitinformation def test_add_bitinfo_labels(): @@ -50,3 +50,11 @@ def test_add_bitinfo_labels(): assert ax.texts[i + 5].get_text() == keepbits_text # Cleanup the plot plt.close() + + +@pytest.mark.parametrize("dtype", ["float64", "float32", "float16"]) +def test_plot_bitinformation(dtype): + rasm = xr.tutorial.load_dataset("air_temperature") + ds = rasm.astype(dtype) + info_per_bit = xb.get_bitinformation(ds, dim="lon") + plot_bitinformation(info_per_bit) diff --git a/xbitinfo/graphics.py b/xbitinfo/graphics.py index 7a3b93d6..46591956 100644 --- a/xbitinfo/graphics.py +++ b/xbitinfo/graphics.py @@ -2,7 +2,7 @@ import numpy as np import xarray as xr -from .xbitinfo import NMBITS, _cdf_from_info_per_bit, get_keepbits +from .xbitinfo import _cdf_from_info_per_bit, bit_partitioning, get_keepbits def add_bitinfo_labels( @@ -117,16 +117,12 @@ def add_bitinfo_labels( CDF = _cdf_from_info_per_bit(info_per_bit, dimension) CDF_DataArray = CDF[da.name] + data_type = np.dtype(dimension.replace("bit", "")) + _, _, n_exp, _ = bit_partitioning(data_type) if inflevels is None: inflevels = [] for i, keep in enumerate(keepbits): - if dimension == "bit16": - mantissa_index = keep + 5 - if dimension == "bit32": - mantissa_index = keep + 8 - if dimension == "bit64": - mantissa_index = keep + 11 - + mantissa_index = keep + n_exp inflevels.append(CDF_DataArray[mantissa_index].values) if keepbits is None: @@ -185,7 +181,29 @@ def add_bitinfo_labels( t_keepbits.set_bbox(dict(facecolor="white", alpha=0.9, edgecolor="white")) -def plot_bitinformation(bitinfo, information_filter=None, cmap="turku"): +def split_dataset_by_dims(info_per_bit): + """Split dataset by its dimensions. + + Parameters + ---------- + info_per_bit : dict + Information content of each bit for each variable in ``da``. This is the output from :py:func:`xbitinfo.xbitinfo.get_bitinformation`. + + Returns + ------- + var_by_dim : dict + Dictionary containing the dimensions of the datasets as keys and the dataset using the dimension as value. + """ + var_by_dim = {d: [] for d in info_per_bit.dims} + for var in info_per_bit.data_vars: + assert ( + len(info_per_bit[var].dims) == 1 + ), f"Variable {var} has more than one dimension." + var_by_dim[info_per_bit[var].dims[0]].append(var) + return var_by_dim + + +def plot_bitinformation(bitinfo, cmap="turku", crop=None): """Plot bitwise information content as in Klöwer et al. 2021 Figure 2. Klöwer, M., Razinger, M., Dominguez, J. J., Düben, P. D., & Palmer, T. N. (2021). @@ -198,6 +216,8 @@ def plot_bitinformation(bitinfo, information_filter=None, cmap="turku"): Containing the bitwise information content for each variable cmap : str or plt.cm Colormap. Defaults to ``"turku"``. + crop : int + Maximum bits to show in figure. Kwargs threshold(` `float ``) : defaults to ``0.7`` @@ -213,121 +233,79 @@ def plot_bitinformation(bitinfo, information_filter=None, cmap="turku"): >>> ds = xr.tutorial.load_dataset("air_temperature") >>> info_per_bit = xb.get_bitinformation(ds, dim="lon") >>> xb.plot_bitinformation(info_per_bit) -
+
""" import matplotlib.pyplot as plt - assert bitinfo.coords["dim"].shape <= ( - 1, - ), "Only bitinfo along one dimension is supported at the moment. Please select dimension before plotting." - + bitinfo = bitinfo.squeeze() assert ( - "bit32" in bitinfo.dims - ), "currently only works properly for float32 data, looking forward to your PR closing https://github.com/observingClouds/xbitinfo/issues/168" + "dim" not in bitinfo.dims + ), "Found dependence of bitinformation on dimension. Please reduce data first by e.g. `bitinfo.max(dim='dim')`" + vars_by_dim = split_dataset_by_dims(bitinfo) + bitinfo_all = bitinfo + subfigure_data = [None] * len(vars_by_dim) + for d, (dim, vars) in enumerate(vars_by_dim.items()): + bitinfo = bitinfo_all[vars] + data_type = np.dtype(dim.replace("bit", "")) + n_bits, n_sign, n_exp, n_mant = bit_partitioning(data_type) + nonmantissa_bits = n_bits - n_mant + if crop is None: + bits_to_show = n_bits + else: + bits_to_show = int(np.min([crop, n_bits])) + nvars = len(bitinfo) + varnames = list(bitinfo.keys()) - nvars = len(bitinfo) - varnames = bitinfo.keys() - - if information_filter == "Gradient": - infbits_dict = get_keepbits( - bitinfo, 0.99, information_filter, **{"threshold": 0.7, "tolerance": 0.001} - ) - infbits100_dict = get_keepbits( - bitinfo, - 0.999999999, - information_filter, - **{"threshold": 0.7, "tolerance": 0.001}, - ) - else: infbits_dict = get_keepbits(bitinfo, 0.99) infbits100_dict = get_keepbits(bitinfo, 0.999999999) - ICnan = np.zeros((nvars, 64)) - infbits = np.zeros(nvars) - infbits100 = np.zeros(nvars) - ICnan[:, :] = np.nan - for v, var in enumerate(varnames): - ic = bitinfo[var].squeeze(drop=True) - ICnan[v, : len(ic)] = ic - # infbits are all bits, infbits_dict were mantissa bits - infbits[v] = infbits_dict[var] + NMBITS[len(ic)] - infbits100[v] = infbits100_dict[var] + NMBITS[len(ic)] - ICnan = np.where(ICnan == 0, np.nan, ICnan) - ICcsum = np.nancumsum(ICnan, axis=1) - - infbitsy = np.hstack([0, np.repeat(np.arange(1, nvars), 2), nvars]) - infbitsx = np.repeat(infbits, 2) - infbitsx100 = np.repeat(infbits100, 2) - - fig_height = np.max([4, 4 + (nvars - 10) * 0.2]) # auto adjust to nvars - fig, ax1 = plt.subplots(1, 1, figsize=(12, fig_height), sharey=True) - ax1.invert_yaxis() - ax1.set_box_aspect(1 / 32 * nvars) - plt.tight_layout(rect=[0.06, 0.18, 0.8, 0.98]) - pos = ax1.get_position() - cax = fig.add_axes([pos.x0, 0.12, pos.x1 - pos.x0, 0.02]) - - ax1right = ax1.twinx() - ax1right.invert_yaxis() - ax1right.set_box_aspect(1 / 32 * nvars) - - if cmap == "turku": - import cmcrameri.cm as cmc - - cmap = cmc.turku_r - pcm = ax1.pcolormesh(ICnan, vmin=0, vmax=1, cmap=cmap) - cbar = plt.colorbar(pcm, cax=cax, orientation="horizontal") - cbar.set_label("information content [bit]") - - # 99% of real information enclosed - ax1.plot( - np.hstack([infbits, infbits[-1]]), - np.arange(nvars + 1), - "C1", - ds="steps-pre", - zorder=10, - label="99% of\ninformation", + ICnan = np.zeros((nvars, 64)) + infbits = np.zeros(nvars) + infbits100 = np.zeros(nvars) + ICnan[:, :] = np.nan + for v, var in enumerate(varnames): + ic = bitinfo[var].squeeze(drop=True) + ICnan[v, : len(ic)] = ic + # infbits are all bits, infbits_dict were mantissa bits + infbits[v] = infbits_dict[var] + nonmantissa_bits + infbits100[v] = infbits100_dict[var] + nonmantissa_bits + ICnan = np.where(ICnan == 0, np.nan, ICnan) + ICcsum = np.nancumsum(ICnan, axis=1) + + infbitsy = np.hstack([0, np.repeat(np.arange(1, nvars), 2), nvars]) + infbitsx = np.repeat(infbits, 2) + infbitsx100 = np.repeat(infbits100, 2) + + fig_height = np.max([4, 4 + (nvars - 10) * 0.2]) # auto adjust to nvars + + subfigure_data[d] = {} + subfigure_data[d]["fig_height"] = fig_height + subfigure_data[d]["nvars"] = nvars + subfigure_data[d]["varnames"] = varnames + subfigure_data[d]["ICnan"] = ICnan + subfigure_data[d]["ICcsum"] = ICcsum + subfigure_data[d]["infbits"] = infbits + subfigure_data[d]["infbitsx"] = infbitsx + subfigure_data[d]["infbitsy"] = infbitsy + subfigure_data[d]["infbitsx100"] = infbitsx100 + subfigure_data[d]["nbits"] = (n_sign, n_exp, n_bits, n_mant, nonmantissa_bits) + subfigure_data[d]["bits_to_show"] = bits_to_show + + fig_heights = [subfig["fig_height"] for subfig in subfigure_data] + fig = plt.figure(figsize=(12, sum(fig_heights) + 2 * 2)) + fig_heights_incl_cax = fig_heights + [2 / (sum(fig_heights) + 2)] * 2 + grid = fig.add_gridspec( + len(subfigure_data) + 2, 1, height_ratios=fig_heights_incl_cax ) - # grey shading - ax1.fill_betweenx( - infbitsy, infbitsx, np.ones(len(infbitsx)) * 32, alpha=0.4, color="grey" - ) - ax1.fill_betweenx( - infbitsy, infbitsx100, np.ones(len(infbitsx)) * 32, alpha=0.1, color="c" - ) - ax1.fill_betweenx( - infbitsy, - infbitsx100, - np.ones(len(infbitsx)) * 32, - alpha=0.3, - facecolor="none", - edgecolor="c", - ) + axs = [] + for i in range(len(subfigure_data) + 2): + ax = fig.add_subplot(grid[i, 0]) + axs.append(ax) - # for legend only - ax1.fill_betweenx( - [-1, -1], - [-1, -1], - [-1, -1], - color="burlywood", - label="last 1% of\ninformation", - alpha=0.5, - ) - ax1.fill_betweenx( - [-1, -1], - [-1, -1], - [-1, -1], - facecolor="teal", - edgecolor="c", - label="false information", - alpha=0.3, - ) - ax1.fill_betweenx([-1, -1], [-1, -1], [-1, -1], color="w", label="unused bits") - - ax1.axvline(1, color="k", lw=1, zorder=3) - ax1.axvline(9, color="k", lw=1, zorder=3) + if isinstance(axs, plt.Axes): + axs = [axs] fig.suptitle( "Real bitwise information content", @@ -337,48 +315,187 @@ def plot_bitinformation(bitinfo, information_filter=None, cmap="turku"): horizontalalignment="left", ) - ax1.set_xlim(0, 32) - ax1.set_ylim(nvars, 0) - ax1right.set_ylim(nvars, 0) - - ax1.set_yticks(np.arange(nvars) + 0.5) - ax1right.set_yticks(np.arange(nvars) + 0.5) - ax1.set_yticklabels(varnames) - ax1right.set_yticklabels([f"{i:4.1f}" for i in ICcsum[:, -1]]) - ax1right.set_ylabel("total information per value [bit]") - - ax1.text( - infbits[0] + 0.1, - 0.8, - f"{int(infbits[0]-9)} mantissa bits", - fontsize=8, - color="saddlebrown", - ) - for i in range(1, nvars): - ax1.text( - infbits[i] + 0.1, - (i) + 0.8, - f"{int(infbits[i]-9)}", - fontsize=8, - color="saddlebrown", + if cmap == "turku": + import cmcrameri.cm as cmc + + cmap = cmc.turku_r + + max_bits_to_show = np.max([d["bits_to_show"] for d in subfigure_data]) + + for d, subfig in enumerate(subfigure_data): + infbits = subfig["infbits"] + nvars = subfig["nvars"] + n_sign, n_exp, n_bits, n_mant, nonmantissa_bits = subfig["nbits"] + ICcsum = subfig["ICcsum"] + ICnan = subfig["ICnan"] + infbitsy = subfig["infbitsy"] + infbitsx = subfig["infbitsx"] + infbitsx100 = subfig["infbitsx100"] + varnames = subfig["varnames"] + bits_to_show = subfig["bits_to_show"] + + mbits_to_show = bits_to_show - nonmantissa_bits + + axs[d].invert_yaxis() + axs[d].set_box_aspect(1 / max_bits_to_show * nvars) + + ax1right = axs[d].twinx() + ax1right.invert_yaxis() + ax1right.set_box_aspect(1 / max_bits_to_show * nvars) + + pcm = axs[d].pcolormesh(ICnan, vmin=0, vmax=1, cmap=cmap) + + if d == len(subfigure_data) - 1: + cax = axs[len(subfigure_data)] + lax = axs[len(subfigure_data) + 1] + lax.axis("off") + cbar = plt.colorbar(pcm, cax=cax, orientation="horizontal") + cbar.set_label("information content [bit]") + + # 99% of real information enclosed + l0 = axs[d].plot( + np.hstack([infbits, infbits[-1]]), + np.arange(nvars + 1), + "C1", + ds="steps-pre", + zorder=10, + label="99% of\ninformation", ) - ax1.set_xticks([1, 9]) - ax1.set_xticks(np.hstack([np.arange(1, 8), np.arange(9, 32)]), minor=True) - ax1.set_xticklabels([]) - ax1.text(0, nvars + 1.2, "sign", rotation=90) - ax1.text(2, nvars + 1.2, "exponent bits", color="darkslategrey") - ax1.text(10, nvars + 1.2, "mantissa bits") + # grey shading + axs[d].fill_betweenx( + infbitsy, + infbitsx, + np.ones(len(infbitsx)) * bits_to_show, + alpha=0.4, + color="grey", + ) + axs[d].fill_betweenx( + infbitsy, + infbitsx100, + np.ones(len(infbitsx)) * bits_to_show, + alpha=0.1, + color="c", + ) + axs[d].fill_betweenx( + infbitsy, + infbitsx100, + np.ones(len(infbitsx)) * bits_to_show, + alpha=0.3, + facecolor="none", + edgecolor="c", + ) - for i in range(1, 9): - ax1.text( - i + 0.5, nvars + 0.5, i, ha="center", fontsize=7, color="darkslategrey" + # for legend only + l1 = axs[d].fill_betweenx( + [-1, -1], + [-1, -1], + [-1, -1], + color="burlywood", + label="last 1% of\ninformation", + alpha=0.5, ) + l2 = axs[d].fill_betweenx( + [-1, -1], + [-1, -1], + [-1, -1], + facecolor="teal", + edgecolor="c", + label="false information", + alpha=0.3, + ) + l3 = axs[d].fill_betweenx( + [-1, -1], [-1, -1], [-1, -1], color="w", label="unused bits", edgecolor="k" + ) + + if n_sign > 0: + axs[d].axvline(n_sign, color="k", lw=1, zorder=3) + axs[d].axvline(nonmantissa_bits, color="k", lw=1, zorder=3) - for i in range(1, 24): - ax1.text(8 + i + 0.5, nvars + 0.5, i, ha="center", fontsize=7) + axs[d].set_ylim(nvars, 0) + ax1right.set_ylim(nvars, 0) + + axs[d].set_yticks(np.arange(nvars) + 0.5) + ax1right.set_yticks(np.arange(nvars) + 0.5) + axs[d].set_yticklabels(varnames) + ax1right.set_yticklabels([f"{i:4.1f}" for i in ICcsum[:, -1]]) + if d == len(subfigure_data) // 2: + ax1right.set_ylabel("total information\nper value [bit]") + + axs[d].text( + infbits[0] + 0.1, + 0.8, + f"{int(infbits[0]-nonmantissa_bits)} mantissa bits", + fontsize=8, + color="saddlebrown", + ) + for i in range(1, nvars): + axs[d].text( + infbits[i] + 0.1, + (i) + 0.8, + f"{int(infbits[i]-9)}", + fontsize=8, + color="saddlebrown", + ) + + major_xticks = np.array([n_sign, n_sign + n_exp, n_bits], dtype="int") + axs[d].set_xticks(major_xticks[major_xticks <= bits_to_show]) + minor_xticks = np.hstack( + [ + np.arange(n_sign, nonmantissa_bits - 1), + np.arange(nonmantissa_bits, n_bits - 1), + ] + ) + axs[d].set_xticks( + minor_xticks[minor_xticks <= bits_to_show], + minor=True, + ) + axs[d].set_xticklabels([]) + if n_sign > 0: + axs[d].text(0, nvars + 1.2, "sign", rotation=90) + if n_exp > 0: + axs[d].text( + n_sign + n_exp / 2, + nvars + 1.2, + "exponent bits", + color="darkslategrey", + horizontalalignment="center", + verticalalignment="center", + ) + axs[d].text( + n_sign + n_exp + mbits_to_show / 2, + nvars + 1.2, + "mantissa bits", + horizontalalignment="center", + verticalalignment="center", + ) - ax1.legend(bbox_to_anchor=(1.08, 0.5), loc="center left", framealpha=0.6) + # Set xticklabels + ## Set exponent labels + for e, i in enumerate(range(n_sign, np.min([n_sign + n_exp, bits_to_show]))): + axs[d].text( + i + 0.5, + nvars + 0.5, + e + 1, + ha="center", + fontsize=7, + color="darkslategrey", + ) + ## Set mantissa labels + for m, i in enumerate( + range(n_sign + n_exp, np.min([n_sign + n_exp + n_mant, bits_to_show])) + ): + axs[d].text(i + 0.5, nvars + 0.5, m + 1, ha="center", fontsize=7) + + if d == len(subfigure_data) - 1: + lax.legend( + bbox_to_anchor=(0.5, 0), + loc="center", + framealpha=0.6, + ncol=4, + handles=[l0[0], l1, l2, l3], + ) + axs[d].set_xlim(0, bits_to_show) fig.show() diff --git a/xbitinfo/xbitinfo.py b/xbitinfo/xbitinfo.py index dbc003f3..0f428f3e 100644 --- a/xbitinfo/xbitinfo.py +++ b/xbitinfo/xbitinfo.py @@ -33,31 +33,38 @@ jl.eval("include(Main.path)") -NMBITS = {64: 12, 32: 9, 16: 6} # number of non mantissa bits for given dtype - - -def get_bit_coords(dtype_size): - """Get coordinates for bits assuming float dtypes.""" - if dtype_size == 16: - coords = ( - ["±"] - + [f"e{int(i)}" for i in range(1, 6)] - + [f"m{int(i-5)}" for i in range(6, 16)] - ) - elif dtype_size == 32: - coords = ( - ["±"] - + [f"e{int(i)}" for i in range(1, 9)] - + [f"m{int(i-8)}" for i in range(9, 32)] - ) - elif dtype_size == 64: - coords = ( - ["±"] - + [f"e{int(i)}" for i in range(1, 12)] - + [f"m{int(i-11)}" for i in range(12, 64)] - ) +def bit_partitioning(dtype): + if dtype.kind == "f": + n_bits = np.finfo(dtype).bits + n_sign = 1 + n_exponent = np.finfo(dtype).nexp + n_mantissa = np.finfo(dtype).nmant + elif dtype.kind == "i": + n_bits = np.iinfo(dtype).bits + n_sign = 1 + n_exponent = 0 + n_mantissa = n_bits - n_sign + elif dtype.kind == "u": + n_bits = np.iinfo(dtype).bits + n_sign = 0 + n_exponent = 0 + n_mantissa = n_bits - n_sign else: - raise ValueError(f"dtype of size {dtype_size} neither known nor implemented.") + raise ValueError(f"dtype {dtype} neither known nor implemented.") + assert ( + n_sign + n_exponent + n_mantissa == n_bits + ), "The components of the datatype could not be safely inferred." + return n_bits, n_sign, n_exponent, n_mantissa + + +def get_bit_coords(dtype): + """Get coordinates for bits based on dtype.""" + n_bits, n_sign, n_exponent, n_mantissa = bit_partitioning(dtype) + coords = ( + n_sign * ["±"] + + [f"e{int(i)}" for i in range(1, n_exponent + 1)] + + [f"m{int(i)}" for i in range(1, n_mantissa + 1)] + ) return coords @@ -65,13 +72,13 @@ def dict_to_dataset(info_per_bit): """Convert keepbits dictionary to :py:class:`xarray.Dataset`.""" dsb = xr.Dataset() for v in info_per_bit.keys(): - dtype_size = len(info_per_bit[v]["bitinfo"]) + dtype = np.dtype(info_per_bit[v]["dtype"]) dim = info_per_bit[v]["dim"] - dim_name = f"bit{dtype_size}" + dim_name = f"bit{dtype}" dsb[v] = xr.DataArray( info_per_bit[v]["bitinfo"], dims=[dim_name], - coords={dim_name: get_bit_coords(dtype_size), "dim": dim}, + coords={dim_name: get_bit_coords(dtype), "dim": dim}, name=v, attrs={ "long_name": f"{v} bitwise information", @@ -145,13 +152,13 @@ def get_bitinformation( # noqa: C901 ------- >>> ds = xr.tutorial.load_dataset("air_temperature") >>> xb.get_bitinformation(ds, dim="lon") # doctest: +ELLIPSIS - - Dimensions: (bit32: 32) + Size: 652B + Dimensions: (bitfloat32: 32) Coordinates: - * bit32 (bit32) >> xb.get_bitinformation(ds) - - Dimensions: (bit32: 32, dim: 3) + Size: 1kB + Dimensions: (bitfloat32: 32, dim: 3) Coordinates: - * bit32 (bit32) >> ds = xr.tutorial.load_dataset("air_temperature") >>> info_per_bit = xb.get_bitinformation(ds, dim="lon") >>> xb.get_keepbits(info_per_bit) - + Size: 28B Dimensions: (inflevel: 1) Coordinates: - dim >> xb.get_keepbits(info_per_bit, inflevel=0.99999999) - + Size: 28B Dimensions: (inflevel: 1) Coordinates: - dim >> xb.get_keepbits(info_per_bit, inflevel=1.0) - + Size: 28B Dimensions: (inflevel: 1) Coordinates: - dim >> info_per_bit = xb.get_bitinformation(ds) >>> xb.get_keepbits(info_per_bit) - + Size: 80B Dimensions: (dim: 3, inflevel: 1) Coordinates: - * dim (dim) 1.0).any(): raise ValueError("Please provide `inflevel` from interval [0.,1.]") - for bitdim in ["bit16", "bit32", "bit64"]: + for bitdim in [ + "bitfloat16", + "bitfloat32", + "bitfloat64", + "bitint16", + "bitint32", + "bitint64", + "bituint16", + "bituint32", + "bituint64", + ]: # get only variables of bitdim bit_vars = [v for v in info_per_bit.data_vars if bitdim in info_per_bit[v].dims] if bit_vars != []: - if information_filter == "Gradient": - cdf = get_cdf_without_artificial_information( - info_per_bit[bit_vars], - bitdim, - kwargs["threshold"], - kwargs["tolerance"], - bit_vars, - ) - else: - cdf = _cdf_from_info_per_bit(info_per_bit[bit_vars], bitdim) + cdf = _cdf_from_info_per_bit(info_per_bit[bit_vars], bitdim) + data_type = np.dtype(bitdim.replace("bit", "")) + n_bits, _, _, n_mant = bit_partitioning(data_type) + bitdim_non_mantissa_bits = n_bits - n_mant - bitdim_non_mantissa_bits = NMBITS[int(bitdim[3:])] keepmantissabits_bitdim = ( (cdf > inflevel).argmax(bitdim) + 1 - bitdim_non_mantissa_bits ) # keep all mantissa bits for 100% information if 1.0 in inflevel: - bitdim_all_mantissa_bits = int(bitdim[3:]) - bitdim_non_mantissa_bits + bitdim_all_mantissa_bits = n_bits - bitdim_non_mantissa_bits keepall = xr.ones_like(keepmantissabits_bitdim.sel(inflevel=1.0)) * ( bitdim_all_mantissa_bits ) @@ -814,7 +826,7 @@ class JsonCustomEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, (np.ndarray, np.number)): return obj.tolist() - elif isinstance(obj, (complex, np.complex)): + elif isinstance(obj, complex): return [obj.real, obj.imag] elif isinstance(obj, set): return list(obj)