diff --git a/setup.py b/setup.py index fe091f16..d9b6bdb9 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,6 @@ """The setup script.""" - from setuptools import find_packages, setup with open("README.md") as readme_file: diff --git a/tests/__init__.py b/tests/__init__.py index 554896f8..61492edc 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,4 +1,5 @@ """Unit test package for xbitinfo.""" + import importlib from distutils import version diff --git a/tests/test_bitround.py b/tests/test_bitround.py index aaa126b9..86cd195c 100644 --- a/tests/test_bitround.py +++ b/tests/test_bitround.py @@ -1,6 +1,4 @@ -import numpy as np import pytest -import xarray as xr from dask import is_dask_collection from xarray.testing import assert_allclose, assert_equal diff --git a/tests/test_get_bitinformation.py b/tests/test_get_bitinformation.py index 0a38ca0f..68b53804 100644 --- a/tests/test_get_bitinformation.py +++ b/tests/test_get_bitinformation.py @@ -1,4 +1,5 @@ """Tests for `xbitinfo` package.""" + import os import numpy as np diff --git a/tests/test_get_keepbits.py b/tests/test_get_keepbits.py index cec0b99f..cd68938a 100644 --- a/tests/test_get_keepbits.py +++ b/tests/test_get_keepbits.py @@ -1,11 +1,12 @@ """Tests for `xbitinfo.get_keepbits`.""" + import pytest import xarray as xr import xbitinfo as xb -@pytest.fixture +@pytest.fixture() def rasm_info_per_bit(rasm): return xb.get_bitinformation(rasm, axis=0) diff --git a/tests/test_save_compressed.py b/tests/test_save_compressed.py index 09e21878..c31c0f92 100644 --- a/tests/test_save_compressed.py +++ b/tests/test_save_compressed.py @@ -3,7 +3,6 @@ import numcodecs import pytest -import xarray as xr import zarr import xbitinfo as xb diff --git a/xbitinfo/_py_bitinfo.py b/xbitinfo/_py_bitinfo.py index 620e3146..00cc7d14 100644 --- a/xbitinfo/_py_bitinfo.py +++ b/xbitinfo/_py_bitinfo.py @@ -1,6 +1,5 @@ import dask.array as da import numpy as np -import numpy.ma as nm def exponent_bias(dtype): diff --git a/xbitinfo/bitround.py b/xbitinfo/bitround.py index 91858b4b..34e35bd5 100644 --- a/xbitinfo/bitround.py +++ b/xbitinfo/bitround.py @@ -37,12 +37,12 @@ def _keepbits_interface(da, keepbits): else: raise ValueError(f"name {v} not for in keepbits: {keepbits.keys()}") elif isinstance(keepbits, xr.Dataset): - assert keepbits.coords["inflevel"].shape <= ( - 1, + assert ( + keepbits.coords["inflevel"].shape <= (1,) ), "Information content is only allowed for one 'inflevel' here. Please make a selection." if "dim" in keepbits.coords: - assert keepbits.coords["dim"].shape <= ( - 1, + assert ( + keepbits.coords["dim"].shape <= (1,) ), "Information content is only allowed along one dimension here. Please select one `dim`. To find the maximum keepbits, simply use `keepbits.max(dim='dim')`" v = da.name if v in keepbits.keys(): @@ -50,11 +50,11 @@ def _keepbits_interface(da, keepbits): else: raise ValueError(f"name {v} not for in keepbits: {keepbits.keys()}") elif isinstance(keepbits, xr.DataArray): - assert keepbits.coords["inflevel"].shape <= ( - 1, + assert ( + keepbits.coords["inflevel"].shape <= (1,) ), "Information content is only allowed for one 'inflevel' here. Please make a selection." - assert keepbits.coords["dim"].shape <= ( - 1, + assert ( + keepbits.coords["dim"].shape <= (1,) ), "Information content is only allowed along one dimension here. Please select one `dim`. To find the maximum keepbits, simply use `keepbits.max(dim='dim')`" v = da.name if v == keepbits.name: @@ -176,7 +176,9 @@ def bitround_along_dim( >>> ds_bitrounded_along_lon = xb.bitround.bitround_along_dim( ... ds, info_per_bit, dim="lon" ... ) - >>> (ds - ds_bitrounded_along_lon)["air"].isel(time=0).plot() # doctest: +ELLIPSIS + >>> (ds - ds_bitrounded_along_lon)["air"].isel( + ... time=0 + ... ).plot() # doctest: +ELLIPSIS """ @@ -186,14 +188,12 @@ def bitround_along_dim( elif inflevels is not None: stride = ds[dim].size // len(inflevels) for i, inf in enumerate(inflevels): # last slice might be a bit larger - ds_slice = ds.isel( - { - dim: slice( - stride * i, - stride * (i + 1) if i != len(inflevels) - 1 else None, - ) - } - ) + ds_slice = ds.isel({ + dim: slice( + stride * i, + stride * (i + 1) if i != len(inflevels) - 1 else None, + ) + }) keepbits_slice = get_keepbits(info_per_bit, inf) if inf != 1: ds_slice_bitrounded = xr_bitround(ds_slice, keepbits_slice) diff --git a/xbitinfo/graphics.py b/xbitinfo/graphics.py index bef4037a..2aa70837 100644 --- a/xbitinfo/graphics.py +++ b/xbitinfo/graphics.py @@ -149,12 +149,10 @@ def add_bitinfo_labels( # write inflevel t = ax.text( - da.isel( - { - x_dim_name: int(stride * (i + 0.5)), - y_dim_name: da[y_dim_name].size // 2, - } - )[lon_coord_name].values, + da.isel({ + x_dim_name: int(stride * (i + 0.5)), + y_dim_name: da[y_dim_name].size // 2, + })[lon_coord_name].values, label_latitude - label_latitude_offset, str(round(inf * 100, 2)) + "%", horizontalalignment="center", @@ -166,12 +164,10 @@ def add_bitinfo_labels( for i, keep in enumerate(keepbits): # write keepbits t_keepbits = ax.text( - da.isel( - { - x_dim_name: int(stride * (i + 0.5)), - y_dim_name: da[y_dim_name].size // 2, - } - )[lon_coord_name].values, + da.isel({ + x_dim_name: int(stride * (i + 0.5)), + y_dim_name: da[y_dim_name].size // 2, + })[lon_coord_name].values, label_latitude + label_latitude_offset, f"keepbits = {keep}", horizontalalignment="center", @@ -420,7 +416,7 @@ def plot_bitinformation(bitinfo, cmap="turku", crop=None): axs[d].text( infbits[0] + 0.1, 0.8, - f"{int(infbits[0]-nonmantissa_bits)} mantissa bits", + f"{int(infbits[0] - nonmantissa_bits)} mantissa bits", fontsize=8, color="saddlebrown", ) @@ -428,19 +424,17 @@ def plot_bitinformation(bitinfo, cmap="turku", crop=None): axs[d].text( infbits[i] + 0.1, (i) + 0.8, - f"{int(infbits[i]-9)}", + f"{int(infbits[i] - 9)}", fontsize=8, color="saddlebrown", ) major_xticks = np.array([n_sign, n_sign + n_exp, n_bits], dtype="int") axs[d].set_xticks(major_xticks[major_xticks <= bits_to_show]) - minor_xticks = np.hstack( - [ - np.arange(n_sign, nonmantissa_bits - 1), - np.arange(nonmantissa_bits, n_bits - 1), - ] - ) + minor_xticks = np.hstack([ + np.arange(n_sign, nonmantissa_bits - 1), + np.arange(nonmantissa_bits, n_bits - 1), + ]) axs[d].set_xticks( minor_xticks[minor_xticks <= bits_to_show], minor=True, diff --git a/xbitinfo/save_compressed.py b/xbitinfo/save_compressed.py index 05f8ae63..fd4ded1e 100644 --- a/xbitinfo/save_compressed.py +++ b/xbitinfo/save_compressed.py @@ -36,7 +36,9 @@ def get_compress_encoding_nc( Example ------- - >>> ds = xr.Dataset({"Tair": (("time", "x", "y"), np.random.rand(36, 20, 10))}) + >>> ds = xr.Dataset( + ... {"Tair": (("time", "x", "y"), np.random.rand(36, 20, 10))} + ... ) >>> get_compress_encoding_nc(ds) {'Tair': {'zlib': True, 'shuffle': True, 'complevel': 9, 'chunksizes': (36, 20, 10)}} >>> get_compress_encoding_nc(ds, for_cdo=True) @@ -187,7 +189,9 @@ class ToCompressed_Zarr: >>> ds = xr.tutorial.load_dataset("rasm") >>> path = "compressed_rasm.zarr" >>> ds.to_compressed_zarr(path, mode="w") - >>> ds.to_compressed_zarr(path, compressor=numcodecs.Blosc("zlib"), mode="w") + >>> ds.to_compressed_zarr( + ... path, compressor=numcodecs.Blosc("zlib"), mode="w" + ... ) >>> ds.to_compressed_zarr( ... path, compressor={"Tair": numcodecs.Blosc("zstd")}, mode="w" ... ) diff --git a/xbitinfo/xbitinfo.py b/xbitinfo/xbitinfo.py index 4450a98f..5156dec9 100644 --- a/xbitinfo/xbitinfo.py +++ b/xbitinfo/xbitinfo.py @@ -22,7 +22,7 @@ if not already_ran and julia_installed: already_ran = install(quiet=True) jl = Julia(compiled_modules=False, debug=False) - from julia import Main # noqa: E402 + from julia import Main path_to_julia_functions = os.path.join( os.path.dirname(__file__), "bitinformation_wrapper.jl" @@ -105,7 +105,7 @@ def dict_to_dataset(info_per_bit): return dsb -def get_bitinformation( # noqa: C901 +def get_bitinformation( ds, dim=None, axis=None, @@ -273,7 +273,7 @@ def _jl_get_bitinformation(ds, var, axis, dim, kwargs={}): axis_jl = ds[var].get_axis_num(dim) + 1 except ValueError: logging.info(f"Variable {var} does not have dimension {dim}. Skipping.") - return + return None assert isinstance(axis_jl, int) Main.dim = axis_jl kwargs_str = _get_bitinformation_kwargs_handler(ds[var], kwargs) @@ -314,7 +314,7 @@ def _py_get_bitinformation(ds, var, axis, dim, kwargs={}): axis = ds[var].get_axis_num(dim) except ValueError: logging.info(f"Variable {var} does not have dimension {dim}. Skipping.") - return + return None info_per_bit = {} logging.info("Calling python implementation now") info_per_bit["bitinfo"] = pb.bitinformation(X, axis=axis).compute() @@ -378,11 +378,13 @@ def load_bitinformation(label): label_file = label + ".json" if os.path.exists(label_file): with open(label_file) as f: - logging.debug(f"Load bitinformation from {label+'.json'}") + logging.debug(f"Load bitinformation from {label + '.json'}") info_per_bit = json.load(f) return dict_to_dataset(info_per_bit) else: - raise FileNotFoundError(f"No bitinformation could be found at {label+'.json'}") + raise FileNotFoundError( + f"No bitinformation could be found at {label + '.json'}" + ) def get_keepbits(info_per_bit, inflevel=0.99):