diff --git a/xarray/tests/test_namedarray.py b/xarray/tests/test_namedarray.py index 6e39a3aa94f..17992751153 100644 --- a/xarray/tests/test_namedarray.py +++ b/xarray/tests/test_namedarray.py @@ -19,9 +19,11 @@ from xarray.namedarray._typing import ( _AttrsLike, + _Dim, _DimsLike, _DType, _Shape, + _ShapeLike, duckarray, ) from xarray.namedarray.utils import Default @@ -61,6 +63,14 @@ def random_inputs() -> np.ndarray[Any, np.dtype[np.float32]]: return np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) +@pytest.fixture +def data() -> NamedArray[Any, np.dtype[np.float32]]: + dtype_float = np.dtype(np.float32) + narr_float: NamedArray[Any, np.dtype[np.float32]] + narr_float = NamedArray(("x",), np.array([1.5, 3.2], dtype=dtype_float)) + return narr_float + + def test_namedarray_init() -> None: dtype = np.dtype(np.int8) expected = np.array([1, 2], dtype=dtype) @@ -421,3 +431,63 @@ def _new( var_float2: Variable[Any, np.dtype[np.float32]] var_float2 = var_float._replace(("x",), np_val2) assert var_float2.dtype == dtype_float + + +@pytest.mark.parametrize( + "new_dims, new_shape", + [ + (["x", "y"], (2, 1)), # basic case, expanding along existing dimensions + (["x", "y", "z"], (2, 1, 1)), # adding a new dimension + (["z", "x", "y"], (1, 2, 1)), # adding a new dimension with different order + (["x"], (2,)), # reducing dimensions + ({"x": 2, "y": 1}, (2, 1)), # using dict for dims + ( + {"x": 2, "y": 1, "z": 1}, + (2, 1, 1), + ), # using dict for dims, adding new dimension + ], +) +def test_expand_dims( + data: NamedArray[Any, np.dtype[np.float32]], + new_dims: _DimsLike, + new_shape: _ShapeLike, +) -> None: + actual = data.expand_dims(new_dims) + # Ensure the expected dims match, especially when new dimensions are added + expected_dims = ( + tuple(new_dims) + if isinstance(new_dims, (list, tuple)) + else tuple(new_dims.keys()) + ) + expected = NamedArray(expected_dims, data._data.reshape(*new_shape)) + assert np.array_equal(actual.data, expected.data) + assert actual.dims == expected.dims + + +def test_expand_dims_object_dtype() -> None: + data: NamedArray[Any, np.dtype[object]] + x = np.empty([], dtype=object) + x[()] = ("a", 1) + data = NamedArray([], x) + actual = data.expand_dims(("x",), (3,)) + exp_values = np.empty((3,), dtype=object) + for i in range(3): + exp_values[i] = ("a", 1) + assert np.array_equal(actual.data, exp_values) + + +@pytest.mark.parametrize( + "dims", + [ + {"x": 2, "y": 1}, # basic case, broadcasting along existing dimensions + {"x": 2, "y": 3}, # increasing size of existing dimension + {"x": 2, "y": 1, "z": 1}, # adding a new dimension + {"z": 1, "x": 2, "y": 1}, # adding a new dimension with different order + ], +) +def test_broadcast_to( + data: NamedArray[Any, np.dtype[np.float32]], + dims: Mapping[Any, _Dim], +) -> None: + actual = data.broadcast_to(dims) + assert actual.sizes == dims