Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Mar 18, 2024
1 parent 7734455 commit 8105e96
Show file tree
Hide file tree
Showing 11 changed files with 49 additions and 51 deletions.
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

"""The setup script."""


from setuptools import find_packages, setup

with open("README.md") as readme_file:
Expand Down
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Unit test package for xbitinfo."""

import importlib
from distutils import version

Expand Down
2 changes: 0 additions & 2 deletions tests/test_bitround.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import numpy as np
import pytest
import xarray as xr
from dask import is_dask_collection
from xarray.testing import assert_allclose, assert_equal

Expand Down
1 change: 1 addition & 0 deletions tests/test_get_bitinformation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests for `xbitinfo` package."""

import os

import numpy as np
Expand Down
3 changes: 2 additions & 1 deletion tests/test_get_keepbits.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Tests for `xbitinfo.get_keepbits`."""

import pytest
import xarray as xr

import xbitinfo as xb


@pytest.fixture
@pytest.fixture()
def rasm_info_per_bit(rasm):
return xb.get_bitinformation(rasm, axis=0)

Expand Down
1 change: 0 additions & 1 deletion tests/test_save_compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import numcodecs
import pytest
import xarray as xr
import zarr

import xbitinfo as xb
Expand Down
1 change: 0 additions & 1 deletion xbitinfo/_py_bitinfo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import dask.array as da
import numpy as np
import numpy.ma as nm


def exponent_bias(dtype):
Expand Down
34 changes: 17 additions & 17 deletions xbitinfo/bitround.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,24 @@ def _keepbits_interface(da, keepbits):
else:
raise ValueError(f"name {v} not for in keepbits: {keepbits.keys()}")
elif isinstance(keepbits, xr.Dataset):
assert keepbits.coords["inflevel"].shape <= (
1,
assert (
keepbits.coords["inflevel"].shape <= (1,)
), "Information content is only allowed for one 'inflevel' here. Please make a selection."
if "dim" in keepbits.coords:
assert keepbits.coords["dim"].shape <= (
1,
assert (
keepbits.coords["dim"].shape <= (1,)
), "Information content is only allowed along one dimension here. Please select one `dim`. To find the maximum keepbits, simply use `keepbits.max(dim='dim')`"
v = da.name
if v in keepbits.keys():
keep = int(keepbits[v])
else:
raise ValueError(f"name {v} not for in keepbits: {keepbits.keys()}")
elif isinstance(keepbits, xr.DataArray):
assert keepbits.coords["inflevel"].shape <= (
1,
assert (
keepbits.coords["inflevel"].shape <= (1,)
), "Information content is only allowed for one 'inflevel' here. Please make a selection."
assert keepbits.coords["dim"].shape <= (
1,
assert (
keepbits.coords["dim"].shape <= (1,)
), "Information content is only allowed along one dimension here. Please select one `dim`. To find the maximum keepbits, simply use `keepbits.max(dim='dim')`"
v = da.name
if v == keepbits.name:
Expand Down Expand Up @@ -176,7 +176,9 @@ def bitround_along_dim(
>>> ds_bitrounded_along_lon = xb.bitround.bitround_along_dim(
... ds, info_per_bit, dim="lon"
... )
>>> (ds - ds_bitrounded_along_lon)["air"].isel(time=0).plot() # doctest: +ELLIPSIS
>>> (ds - ds_bitrounded_along_lon)["air"].isel(
... time=0
... ).plot() # doctest: +ELLIPSIS
<matplotlib.collections.QuadMesh object at ...>
"""
Expand All @@ -186,14 +188,12 @@ def bitround_along_dim(
elif inflevels is not None:
stride = ds[dim].size // len(inflevels)
for i, inf in enumerate(inflevels): # last slice might be a bit larger
ds_slice = ds.isel(
{
dim: slice(
stride * i,
stride * (i + 1) if i != len(inflevels) - 1 else None,
)
}
)
ds_slice = ds.isel({
dim: slice(
stride * i,
stride * (i + 1) if i != len(inflevels) - 1 else None,
)
})
keepbits_slice = get_keepbits(info_per_bit, inf)
if inf != 1:
ds_slice_bitrounded = xr_bitround(ds_slice, keepbits_slice)
Expand Down
34 changes: 14 additions & 20 deletions xbitinfo/graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,10 @@ def add_bitinfo_labels(

# write inflevel
t = ax.text(
da.isel(
{
x_dim_name: int(stride * (i + 0.5)),
y_dim_name: da[y_dim_name].size // 2,
}
)[lon_coord_name].values,
da.isel({
x_dim_name: int(stride * (i + 0.5)),
y_dim_name: da[y_dim_name].size // 2,
})[lon_coord_name].values,
label_latitude - label_latitude_offset,
str(round(inf * 100, 2)) + "%",
horizontalalignment="center",
Expand All @@ -166,12 +164,10 @@ def add_bitinfo_labels(
for i, keep in enumerate(keepbits):
# write keepbits
t_keepbits = ax.text(
da.isel(
{
x_dim_name: int(stride * (i + 0.5)),
y_dim_name: da[y_dim_name].size // 2,
}
)[lon_coord_name].values,
da.isel({
x_dim_name: int(stride * (i + 0.5)),
y_dim_name: da[y_dim_name].size // 2,
})[lon_coord_name].values,
label_latitude + label_latitude_offset,
f"keepbits = {keep}",
horizontalalignment="center",
Expand Down Expand Up @@ -420,27 +416,25 @@ def plot_bitinformation(bitinfo, cmap="turku", crop=None):
axs[d].text(
infbits[0] + 0.1,
0.8,
f"{int(infbits[0]-nonmantissa_bits)} mantissa bits",
f"{int(infbits[0] - nonmantissa_bits)} mantissa bits",
fontsize=8,
color="saddlebrown",
)
for i in range(1, nvars):
axs[d].text(
infbits[i] + 0.1,
(i) + 0.8,
f"{int(infbits[i]-9)}",
f"{int(infbits[i] - 9)}",
fontsize=8,
color="saddlebrown",
)

major_xticks = np.array([n_sign, n_sign + n_exp, n_bits], dtype="int")
axs[d].set_xticks(major_xticks[major_xticks <= bits_to_show])
minor_xticks = np.hstack(
[
np.arange(n_sign, nonmantissa_bits - 1),
np.arange(nonmantissa_bits, n_bits - 1),
]
)
minor_xticks = np.hstack([
np.arange(n_sign, nonmantissa_bits - 1),
np.arange(nonmantissa_bits, n_bits - 1),
])
axs[d].set_xticks(
minor_xticks[minor_xticks <= bits_to_show],
minor=True,
Expand Down
8 changes: 6 additions & 2 deletions xbitinfo/save_compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def get_compress_encoding_nc(
Example
-------
>>> ds = xr.Dataset({"Tair": (("time", "x", "y"), np.random.rand(36, 20, 10))})
>>> ds = xr.Dataset(
... {"Tair": (("time", "x", "y"), np.random.rand(36, 20, 10))}
... )
>>> get_compress_encoding_nc(ds)
{'Tair': {'zlib': True, 'shuffle': True, 'complevel': 9, 'chunksizes': (36, 20, 10)}}
>>> get_compress_encoding_nc(ds, for_cdo=True)
Expand Down Expand Up @@ -187,7 +189,9 @@ class ToCompressed_Zarr:
>>> ds = xr.tutorial.load_dataset("rasm")
>>> path = "compressed_rasm.zarr"
>>> ds.to_compressed_zarr(path, mode="w")
>>> ds.to_compressed_zarr(path, compressor=numcodecs.Blosc("zlib"), mode="w")
>>> ds.to_compressed_zarr(
... path, compressor=numcodecs.Blosc("zlib"), mode="w"
... )
>>> ds.to_compressed_zarr(
... path, compressor={"Tair": numcodecs.Blosc("zstd")}, mode="w"
... )
Expand Down
14 changes: 8 additions & 6 deletions xbitinfo/xbitinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
if not already_ran and julia_installed:
already_ran = install(quiet=True)
jl = Julia(compiled_modules=False, debug=False)
from julia import Main # noqa: E402
from julia import Main

path_to_julia_functions = os.path.join(
os.path.dirname(__file__), "bitinformation_wrapper.jl"
Expand Down Expand Up @@ -105,7 +105,7 @@ def dict_to_dataset(info_per_bit):
return dsb


def get_bitinformation( # noqa: C901
def get_bitinformation(
ds,
dim=None,
axis=None,
Expand Down Expand Up @@ -273,7 +273,7 @@ def _jl_get_bitinformation(ds, var, axis, dim, kwargs={}):
axis_jl = ds[var].get_axis_num(dim) + 1
except ValueError:
logging.info(f"Variable {var} does not have dimension {dim}. Skipping.")
return
return None
assert isinstance(axis_jl, int)
Main.dim = axis_jl
kwargs_str = _get_bitinformation_kwargs_handler(ds[var], kwargs)
Expand Down Expand Up @@ -314,7 +314,7 @@ def _py_get_bitinformation(ds, var, axis, dim, kwargs={}):
axis = ds[var].get_axis_num(dim)
except ValueError:
logging.info(f"Variable {var} does not have dimension {dim}. Skipping.")
return
return None
info_per_bit = {}
logging.info("Calling python implementation now")
info_per_bit["bitinfo"] = pb.bitinformation(X, axis=axis).compute()
Expand Down Expand Up @@ -378,11 +378,13 @@ def load_bitinformation(label):
label_file = label + ".json"
if os.path.exists(label_file):
with open(label_file) as f:
logging.debug(f"Load bitinformation from {label+'.json'}")
logging.debug(f"Load bitinformation from {label + '.json'}")
info_per_bit = json.load(f)
return dict_to_dataset(info_per_bit)
else:
raise FileNotFoundError(f"No bitinformation could be found at {label+'.json'}")
raise FileNotFoundError(
f"No bitinformation could be found at {label + '.json'}"
)


def get_keepbits(info_per_bit, inflevel=0.99):
Expand Down

0 comments on commit 8105e96

Please sign in to comment.