Skip to content

Commit

Permalink
update test_namedarray
Browse files Browse the repository at this point in the history
  • Loading branch information
andersy005 committed Oct 28, 2023
1 parent d6240de commit 1364345
Showing 1 changed file with 70 additions and 0 deletions.
70 changes: 70 additions & 0 deletions xarray/tests/test_namedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@

from xarray.namedarray._typing import (
_AttrsLike,
_Dim,
_DimsLike,
_DType,
_Shape,
_ShapeLike,
duckarray,
)
from xarray.namedarray.utils import Default
Expand Down Expand Up @@ -61,6 +63,14 @@ def random_inputs() -> np.ndarray[Any, np.dtype[np.float32]]:
return np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5))


@pytest.fixture
def data() -> NamedArray[Any, np.dtype[np.float32]]:
dtype_float = np.dtype(np.float32)
narr_float: NamedArray[Any, np.dtype[np.float32]]
narr_float = NamedArray(("x",), np.array([1.5, 3.2], dtype=dtype_float))
return narr_float


def test_namedarray_init() -> None:
dtype = np.dtype(np.int8)
expected = np.array([1, 2], dtype=dtype)
Expand Down Expand Up @@ -421,3 +431,63 @@ def _new(
var_float2: Variable[Any, np.dtype[np.float32]]
var_float2 = var_float._replace(("x",), np_val2)
assert var_float2.dtype == dtype_float


@pytest.mark.parametrize(
"new_dims, new_shape",
[
(["x", "y"], (2, 1)), # basic case, expanding along existing dimensions
(["x", "y", "z"], (2, 1, 1)), # adding a new dimension
(["z", "x", "y"], (1, 2, 1)), # adding a new dimension with different order
(["x"], (2,)), # reducing dimensions
({"x": 2, "y": 1}, (2, 1)), # using dict for dims
(
{"x": 2, "y": 1, "z": 1},
(2, 1, 1),
), # using dict for dims, adding new dimension
],
)
def test_expand_dims(
data: NamedArray[Any, np.dtype[np.float32]],
new_dims: _DimsLike,
new_shape: _ShapeLike,
) -> None:
actual = data.expand_dims(new_dims)
# Ensure the expected dims match, especially when new dimensions are added
expected_dims = (
tuple(new_dims)
if isinstance(new_dims, (list, tuple))
else tuple(new_dims.keys())
)
expected = NamedArray(expected_dims, data._data.reshape(*new_shape))
assert np.array_equal(actual.data, expected.data)
assert actual.dims == expected.dims


def test_expand_dims_object_dtype() -> None:
data: NamedArray[Any, np.dtype[object]]
x = np.empty([], dtype=object)
x[()] = ("a", 1)
data = NamedArray([], x)
actual = data.expand_dims(("x",), (3,))
exp_values = np.empty((3,), dtype=object)
for i in range(3):
exp_values[i] = ("a", 1)
assert np.array_equal(actual.data, exp_values)


@pytest.mark.parametrize(
"dims",
[
{"x": 2, "y": 1}, # basic case, broadcasting along existing dimensions
{"x": 2, "y": 3}, # increasing size of existing dimension
{"x": 2, "y": 1, "z": 1}, # adding a new dimension
{"z": 1, "x": 2, "y": 1}, # adding a new dimension with different order
],
)
def test_broadcast_to(
data: NamedArray[Any, np.dtype[np.float32]],
dims: Mapping[Any, _Dim],
) -> None:
actual = data.broadcast_to(dims)
assert actual.sizes == dims

0 comments on commit 1364345

Please sign in to comment.