diff --git a/pycontrails/core/met.py b/pycontrails/core/met.py index af0ef7542..fc9d283e4 100644 --- a/pycontrails/core/met.py +++ b/pycontrails/core/met.py @@ -2235,46 +2235,32 @@ def _is_zarr(ds: xr.Dataset | xr.DataArray) -> bool: return dask0.array.array.array.__class__.__name__ == "ZarrArrayWrapper" -def shift_longitude(data: XArrayType) -> XArrayType: - """Shift longitude values from [0, 360) to [-180, 180) domain. +def shift_longitude(data: XArrayType, bound: float = -180.0) -> XArrayType: + """Shift longitude values from any input domain to [bound, 360 + bound) domain. Sorts data by ascending longitude values. + Parameters ---------- data : XArrayType :class:`xr.Dataset` or :class:`xr.DataArray` with longitude dimension + bound : float, optional + Lower bound of the domain. + Output domain will be [bound, 360 + bound). + Defaults to -180, which results in longitude domain [-180, 180). + Returns ------- XArrayType - :class:`xr.Dataset` or :class:`xr.DataArray` with longitude values on [-180, 180). + :class:`xr.Dataset` or :class:`xr.DataArray` with longitude values on [a, 360 + a). """ return data.assign_coords( - longitude=((data["longitude"].values + 180.0) % 360.0) - 180.0 + longitude=((data["longitude"].values - bound) % 360.0) + bound ).sortby("longitude", ascending=True) -def shift_longitude_reverse(data: XArrayType) -> XArrayType: - """Shift longitude values from [-180, 180) to [0, 360) domain. - - Sorts data by ascending longitude values. - - Parameters - ---------- - data : XArrayType - :class:`xr.Dataset` or :class:`xr.DataArray` with longitude dimension - - Returns - ------- - XArrayType - :class:`xr.Dataset` or :class:`xr.DataArray` with longitude values on [0, 360). - """ - return data.assign_coords(longitude=(data["longitude"].values + 360.0) % 360.0).sortby( - "longitude", ascending=True - ) - - def _wrap_longitude(data: XArrayType) -> XArrayType: """Wrap longitude grid coordinates. diff --git a/tests/unit/test_met.py b/tests/unit/test_met.py index a8f8f032b..4981b076a 100644 --- a/tests/unit/test_met.py +++ b/tests/unit/test_met.py @@ -14,7 +14,7 @@ from pycontrails import DiskCacheStore, MetDataArray, MetDataset, MetVariable from pycontrails.core import cache as cache_module -from pycontrails.core.met import originates_from_ecmwf, shift_longitude, shift_longitude_reverse +from pycontrails.core.met import originates_from_ecmwf, shift_longitude from pycontrails.datalib.ecmwf import ERA5 from tests import OPEN3D_AVAILABLE @@ -216,7 +216,8 @@ def test_shift_longitude(zero_like_da: xr.DataArray) -> None: da = zero_like_da.copy() # assign random values to all coords - da.loc[:] = np.random.rand(*da.shape) # noqa: NPY002 + rng = np.random.default_rng() + da.loc[:] = rng.random(size=da.shape) # shift coordinates to [0, 360) manually lons = da["longitude"].values @@ -242,13 +243,14 @@ def test_shift_longitude(zero_like_da: xr.DataArray) -> None: ) -def test_shift_longitude_reverse(zero_like_da: xr.DataArray) -> None: +def test_shift_longitude_bound(zero_like_da: xr.DataArray) -> None: da = zero_like_da.copy() # assign random values to all coords - da.loc[:] = np.random.rand(*da.shape) # noqa: NPY002 + rng = np.random.default_rng() + da.loc[:] = rng.random(size=da.shape) - da2 = shift_longitude_reverse(da) + da2 = shift_longitude(da, bound=0) # longitude sorted assert np.all(np.diff(da2["longitude"].values) > 0) @@ -265,6 +267,22 @@ def test_shift_longitude_reverse(zero_like_da: xr.DataArray) -> None: da.sel({"longitude": -170})[0, 0, 0].values == da2.sel({"longitude": 190})[0, 0, 0].values ) + da2 = shift_longitude(da, bound=-10) + + # longitude sorted + assert np.all(np.diff(da2["longitude"].values) > 0) + + # longitude correctly translated + assert np.all(da2["longitude"].values < 350.0) + assert np.all(da2["longitude"].values >= -10.0) + + # test a few values + assert da.sel({"longitude": -10})[0, 0, 0].values == da2.sel({"longitude": -10})[0, 0, 0].values + assert da.sel({"longitude": -20})[0, 0, 0].values == da2.sel({"longitude": 340})[0, 0, 0].values + assert ( + da.sel({"longitude": -179})[0, 0, 0].values == da2.sel({"longitude": 181})[0, 0, 0].values + ) + @pytest.fixture() def zero_like_da() -> xr.DataArray: