Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andersy005 committed Oct 28, 2023
1 parent 1364345 commit b793f74
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 5 deletions.
2 changes: 1 addition & 1 deletion doc/api-hidden.rst
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,6 @@
IndexVariable.cumprod
IndexVariable.cumsum
IndexVariable.equals
IndexVariable.set_dims
IndexVariable.fillna
IndexVariable.get_axis_num
IndexVariable.get_level_variable
Expand All @@ -318,6 +317,7 @@
IndexVariable.rolling_window
IndexVariable.round
IndexVariable.searchsorted
IndexVariable.set_dims
IndexVariable.shift
IndexVariable.squeeze
IndexVariable.stack
Expand Down
8 changes: 4 additions & 4 deletions xarray/namedarray/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,12 +959,12 @@ def expand_dims(
Parameters
----------
dims : str or sequence of str or dict
Dimensions to include on the new object (must be a superset of the existing dimensions).
If a dict, values are used to provide the sizes of new dimensions; otherwise, new dimensions are inserted with length 1.
Dimensions to include on the new object (must be a superset of the existing dimensions).
If a dict, values are used to provide the sizes of new dimensions; otherwise, new dimensions are inserted with length 1.
shape : sequence of int, optional
Shape to broadcast the data to. Must be specified in the same order as `dims`.
If not provided, new dimensions are inserted with length 1.
Shape to broadcast the data to. Must be specified in the same order as `dims`.
If not provided, new dimensions are inserted with length 1.
"""

if isinstance(dims, str):
Expand Down
43 changes: 43 additions & 0 deletions xarray/tests/test_namedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
_Dim,
_DimsLike,
_DType,
_IntOrUnknown,
_Shape,
_ShapeLike,
duckarray,
Expand Down Expand Up @@ -71,6 +72,16 @@ def data() -> NamedArray[Any, np.dtype[np.float32]]:
return narr_float


@pytest.fixture
def random_data() -> NamedArray[Any, np.dtype[np.float32]]:
dtype_float = np.dtype(np.float32)
narr_float: NamedArray[Any, np.dtype[np.float32]]
narr_float = NamedArray(
("x", "y", "z"), np.arange(60).reshape(3, 4, 5).astype(dtype_float)
)
return narr_float


def test_namedarray_init() -> None:
dtype = np.dtype(np.int8)
expected = np.array([1, 2], dtype=dtype)
Expand Down Expand Up @@ -491,3 +502,35 @@ def test_broadcast_to(
) -> None:
actual = data.broadcast_to(dims)
assert actual.sizes == dims


@pytest.mark.parametrize(
"dims, expected_sizes",
[
# Basic case: reversing the dimensions
((), {"z": 5, "y": 4, "x": 3}),
(["y", "x", "z"], {"y": 4, "x": 3, "z": 5}),
(["y", "x", ...], {"y": 4, "x": 3, "z": 5}),
],
)
def test_permute_dims(
random_data: NamedArray[Any, np.dtype[np.float32]],
dims: _DimsLike,
expected_sizes: dict[_Dim, _IntOrUnknown],
) -> None:
actual = random_data.permute_dims(*dims)
assert actual.sizes == expected_sizes


@pytest.mark.parametrize(
"dims",
[
(["y", "x"]),
],
)
def test_permute_dims_errors(
random_data: NamedArray[Any, np.dtype[np.float32]],
dims: _DimsLike,
) -> None:
with pytest.raises(ValueError):
random_data.permute_dims(*dims)

0 comments on commit b793f74

Please sign in to comment.