diff --git a/.gitignore b/.gitignore index 21c18c17ff7..e8a97ed5328 100644 --- a/.gitignore +++ b/.gitignore @@ -79,3 +79,6 @@ doc/team-panel.txt doc/external-examples-gallery.txt doc/notebooks-examples-gallery.txt doc/videos-gallery.txt + +# MyPy Report +mypy_report/ diff --git a/doc/user-guide/options.rst b/doc/user-guide/options.rst index 12844eccbe4..a260bfec6c5 100644 --- a/doc/user-guide/options.rst +++ b/doc/user-guide/options.rst @@ -16,7 +16,7 @@ Xarray offers a small number of configuration options through :py:func:`set_opti - ``display_max_rows`` - ``display_style`` -2. Control behaviour during operations: ``arithmetic_join``, ``keep_attrs``, ``use_bottleneck``. +2. Control behaviour during operations: ``arithmetic_broadcast``, ``arithmetic_join``, ``keep_attrs``, ``use_bottleneck``. 3. Control colormaps for plots:``cmap_divergent``, ``cmap_sequential``. 4. Aspects of file reading: ``file_cache_maxsize``, ``warn_on_unclosed_files``. diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 80e53a5ee22..cf494cbd13b 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -64,6 +64,9 @@ Mathias Hauser, Matt Savoie, Maximilian Roos, Rambaud Pierrick, Tom Nicholas New Features ~~~~~~~~~~~~ +- Added the ability to control broadcasting for alignment, and new gloal option ``arithmetic_broadcast`` + (:issue:`6806`, :pull:`8698`). + By `Etienne Schalk `_. - Added a simple ``nbytes`` representation in DataArrays and Dataset ``repr``. (:issue:`8690`, :pull:`8702`). By `Etienne Schalk `_. diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 13e3400d170..d3bf15eb340 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -5,7 +5,17 @@ from collections import defaultdict from collections.abc import Hashable, Iterable, Mapping from contextlib import suppress -from typing import TYPE_CHECKING, Any, Callable, Final, Generic, TypeVar, cast, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Final, + Generic, + TypeVar, + cast, + get_args, + overload, +) import numpy as np import pandas as pd @@ -19,7 +29,7 @@ indexes_all_equal, safe_cast_to_index, ) -from xarray.core.types import T_Alignable +from xarray.core.types import JoinOptions, T_Alignable from xarray.core.utils import is_dict_like, is_full_slice from xarray.core.variable import Variable, as_compatible_data, calculate_dimensions @@ -28,7 +38,6 @@ from xarray.core.dataset import Dataset from xarray.core.types import ( Alignable, - JoinOptions, T_DataArray, T_Dataset, T_DuckArray, @@ -113,6 +122,7 @@ class Aligner(Generic[T_Alignable]): results: tuple[T_Alignable, ...] objects_matching_indexes: tuple[dict[MatchingIndexKey, Index], ...] join: str + broadcast: bool exclude_dims: frozenset[Hashable] exclude_vars: frozenset[Hashable] copy: bool @@ -133,6 +143,7 @@ def __init__( self, objects: Iterable[T_Alignable], join: str = "inner", + broadcast: bool = True, indexes: Mapping[Any, Any] | None = None, exclude_dims: str | Iterable[Hashable] = frozenset(), exclude_vars: Iterable[Hashable] = frozenset(), @@ -145,9 +156,10 @@ def __init__( self.objects = tuple(objects) self.objects_matching_indexes = () - if join not in ["inner", "outer", "override", "exact", "left", "right"]: + if join not in get_args(JoinOptions): raise ValueError(f"invalid value for join: {join}") self.join = join + self.broadcast = broadcast self.copy = copy self.fill_value = fill_value @@ -264,13 +276,19 @@ def find_matching_indexes(self) -> None: self.all_indexes = all_indexes self.all_index_vars = all_index_vars - if self.join == "override": + if self.join == "override" or not self.broadcast: for dim_sizes in all_indexes_dim_sizes.values(): for dim, sizes in dim_sizes.items(): if len(sizes) > 1: + message = ( + "join='override'" + if self.join == "override" + else "broadcast=False" + ) raise ValueError( - "cannot align objects with join='override' with matching indexes " - f"along dimension {dim!r} that don't have the same size" + f"cannot align objects with indexes " + f"along dimension {dim!r} that don't have the same size " + f"({sizes!r}) when {message}" ) def find_matching_unindexed_dims(self) -> None: @@ -478,6 +496,20 @@ def assert_unindexed_dim_sizes_equal(self) -> None: f"because of conflicting dimension sizes: {sizes!r}" + add_err_msg ) + def assert_equal_dimension_names(self) -> None: + # When broadcasting is disabled, only allows objects having the exact same dimensions' names. + if self.broadcast: + return + + unique_dims = set(tuple(o.dims) for o in self.objects) + all_objects_have_same_dims = len(unique_dims) == 1 + if not all_objects_have_same_dims: + raise ValueError( + f"cannot align objects with broadcast=False " + f"because given objects do not share the same dimension names " + f"({[tuple(o.dims) for o in self.objects]!r})." + ) + def override_indexes(self) -> None: objects = list(self.objects) @@ -568,6 +600,7 @@ def align(self) -> None: self.results = (obj.copy(deep=self.copy),) return + self.assert_equal_dimension_names() self.find_matching_indexes() self.find_matching_unindexed_dims() self.assert_no_index_conflict() @@ -595,6 +628,7 @@ def align( /, *, join: JoinOptions = "inner", + broadcast: bool = True, copy: bool = True, indexes=None, exclude: str | Iterable[Hashable] = frozenset(), @@ -609,6 +643,7 @@ def align( /, *, join: JoinOptions = "inner", + broadcast: bool = True, copy: bool = True, indexes=None, exclude: str | Iterable[Hashable] = frozenset(), @@ -624,6 +659,7 @@ def align( /, *, join: JoinOptions = "inner", + broadcast: bool = True, copy: bool = True, indexes=None, exclude: str | Iterable[Hashable] = frozenset(), @@ -640,6 +676,7 @@ def align( /, *, join: JoinOptions = "inner", + broadcast: bool = True, copy: bool = True, indexes=None, exclude: str | Iterable[Hashable] = frozenset(), @@ -657,6 +694,7 @@ def align( /, *, join: JoinOptions = "inner", + broadcast: bool = True, copy: bool = True, indexes=None, exclude: str | Iterable[Hashable] = frozenset(), @@ -668,6 +706,7 @@ def align( def align( *objects: T_Alignable, join: JoinOptions = "inner", + broadcast: bool = True, copy: bool = True, indexes=None, exclude: str | Iterable[Hashable] = frozenset(), @@ -678,6 +717,7 @@ def align( def align( *objects: T_Alignable, join: JoinOptions = "inner", + broadcast: bool = True, copy: bool = True, indexes=None, exclude: str | Iterable[Hashable] = frozenset(), @@ -710,7 +750,9 @@ def align( - "override": if indexes are of same size, rewrite indexes to be those of the first object with that dimension. Indexes for the same dimension must have the same size in all objects. - + broadcast : bool, optional + Disallow automatic broadcasting of all objects along dimensions that are present in some but not all objects. + If False, this will raise an error when all objects do *not* have the same dimensions. copy : bool, default: True If ``copy=True``, data in the return values is always copied. If ``copy=False`` and reindexing is unnecessary, or can be performed with @@ -874,6 +916,7 @@ def align( aligner = Aligner( objects, join=join, + broadcast=broadcast, copy=copy, indexes=indexes, exclude_dims=exclude, @@ -886,6 +929,7 @@ def align( def deep_align( objects: Iterable[Any], join: JoinOptions = "inner", + broadcast: bool = True, copy: bool = True, indexes=None, exclude: str | Iterable[Hashable] = frozenset(), @@ -946,6 +990,7 @@ def is_alignable(obj): aligned = align( *targets, join=join, + broadcast=broadcast, copy=copy, indexes=indexes, exclude=exclude, diff --git a/xarray/core/arithmetic.py b/xarray/core/arithmetic.py index 452c7115b75..1a30a0b4d2f 100644 --- a/xarray/core/arithmetic.py +++ b/xarray/core/arithmetic.py @@ -78,6 +78,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): "`.values`)." ) + broadcast = OPTIONS["arithmetic_broadcast"] join = dataset_join = OPTIONS["arithmetic_join"] return apply_ufunc( @@ -86,6 +87,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): input_core_dims=((),) * ufunc.nin, output_core_dims=((),) * ufunc.nout, join=join, + broadcast=broadcast, dataset_join=dataset_join, dataset_fill_value=np.nan, kwargs=kwargs, diff --git a/xarray/core/computation.py b/xarray/core/computation.py index f29f6c4dd35..d3406264b75 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -282,6 +282,7 @@ def apply_dataarray_vfunc( *args, signature: _UFuncSignature, join: JoinOptions = "inner", + broadcast: bool = True, exclude_dims=frozenset(), keep_attrs="override", ) -> tuple[DataArray, ...] | DataArray: @@ -295,6 +296,7 @@ def apply_dataarray_vfunc( deep_align( args, join=join, + broadcast=broadcast, copy=False, exclude=exclude_dims, raise_on_invalid=False, @@ -494,6 +496,7 @@ def apply_dataset_vfunc( signature: _UFuncSignature, join="inner", dataset_join="exact", + broadcast: bool = True, fill_value=_NO_FILL_VALUE, exclude_dims=frozenset(), keep_attrs="override", @@ -518,6 +521,7 @@ def apply_dataset_vfunc( deep_align( args, join=join, + broadcast=broadcast, copy=False, exclude=exclude_dims, raise_on_invalid=False, @@ -1906,6 +1910,7 @@ def dot( subscripts = ",".join(subscripts_list) subscripts += "->..." + "".join(dim_map[d] for d in output_core_dims[0]) + broadcast = OPTIONS["arithmetic_broadcast"] join = OPTIONS["arithmetic_join"] # using "inner" emulates `(a * b).sum()` for all joins (except "exact") if join != "exact": @@ -1920,6 +1925,7 @@ def dot( input_core_dims=input_core_dims, output_core_dims=output_core_dims, join=join, + broadcast=broadcast, dask="allowed", ) return result.transpose(*all_dims, missing_dims="ignore") diff --git a/xarray/core/concat.py b/xarray/core/concat.py index d95cbccd36a..677e798d955 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -146,6 +146,7 @@ def concat( - "override": if indexes are of same size, rewrite indexes to be those of the first object with that dimension. Indexes for the same dimension must have the same size in all objects. + combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ "override"} or callable, default: "override" A callable or a string indicating how to combine attrs of the objects being diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index c00fe1a9e67..ebd5ff737da 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -4671,8 +4671,11 @@ def _binary_op( if isinstance(other, (Dataset, GroupBy)): return NotImplemented if isinstance(other, DataArray): + broadcast = OPTIONS["arithmetic_broadcast"] align_type = OPTIONS["arithmetic_join"] - self, other = align(self, other, join=align_type, copy=False) + self, other = align( + self, other, join=align_type, broadcast=broadcast, copy=False + ) other_variable_or_arraylike: DaCompatible = getattr(other, "variable", other) other_coords = getattr(other, "coords", None) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 884e302b8be..47e3244c0db 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -7584,8 +7584,11 @@ def _binary_op(self, other, f, reflexive=False, join=None) -> Dataset: if isinstance(other, GroupBy): return NotImplemented align_type = OPTIONS["arithmetic_join"] if join is None else join + broadcast = OPTIONS["arithmetic_broadcast"] if isinstance(other, (DataArray, Dataset)): - self, other = align(self, other, join=align_type, copy=False) + self, other = align( + self, other, join=align_type, broadcast=broadcast, copy=False + ) g = f if not reflexive else lambda x, y: f(y, x) ds = self._calculate_binary_op(g, other, join=align_type) keep_attrs = _get_keep_attrs(default=False) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index a689620e524..9989deb23c8 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -541,6 +541,7 @@ def merge_coords( objects: Iterable[CoercibleMapping], compat: CompatOptions = "minimal", join: JoinOptions = "outer", + broadcast: bool = True, priority_arg: int | None = None, indexes: Mapping[Any, Index] | None = None, fill_value: object = dtypes.NA, @@ -554,7 +555,12 @@ def merge_coords( _assert_compat_valid(compat) coerced = coerce_pandas_values(objects) aligned = deep_align( - coerced, join=join, copy=False, indexes=indexes, fill_value=fill_value + coerced, + join=join, + broadcast=broadcast, + copy=False, + indexes=indexes, + fill_value=fill_value, ) collected = collect_variables_and_indexes(aligned, indexes=indexes) prioritized = _get_priority_vars_and_indexes(aligned, priority_arg, compat=compat) @@ -647,6 +653,7 @@ def merge_core( objects: Iterable[CoercibleMapping], compat: CompatOptions = "broadcast_equals", join: JoinOptions = "outer", + broadcast: bool = True, combine_attrs: CombineAttrsOptions = "override", priority_arg: int | None = None, explicit_coords: Iterable[Hashable] | None = None, @@ -709,7 +716,12 @@ def merge_core( coerced = coerce_pandas_values(objects) aligned = deep_align( - coerced, join=join, copy=False, indexes=indexes, fill_value=fill_value + coerced, + join=join, + broadcast=broadcast, + copy=False, + indexes=indexes, + fill_value=fill_value, ) for pos, obj in skip_align_objs: diff --git a/xarray/core/options.py b/xarray/core/options.py index 18e3484e9c4..0c049ebcb28 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Literal, TypedDict +from typing import TYPE_CHECKING, Any, Literal, TypedDict from xarray.core.utils import FrozenDict @@ -9,6 +9,7 @@ from matplotlib.colors import Colormap Options = Literal[ + "arithmetic_broadcast", "arithmetic_join", "cmap_divergent", "cmap_sequential", @@ -34,6 +35,7 @@ ] class T_Options(TypedDict): + arithmetic_broadcast: bool arithmetic_join: Literal["inner", "outer", "left", "right", "exact"] cmap_divergent: str | Colormap cmap_sequential: str | Colormap @@ -59,6 +61,7 @@ class T_Options(TypedDict): OPTIONS: T_Options = { + "arithmetic_broadcast": True, "arithmetic_join": "inner", "cmap_divergent": "RdBu_r", "cmap_sequential": "viridis", @@ -91,26 +94,35 @@ def _positive_integer(value: int) -> bool: return isinstance(value, int) and value > 0 +def _is_boolean(value: Any) -> bool: + return isinstance(value, bool) + + +def _is_boolean_or_default(value: Any) -> bool: + return value in (True, False, "default") + + _VALIDATORS = { + "arithmetic_broadcast": _is_boolean, "arithmetic_join": _JOIN_OPTIONS.__contains__, "display_max_rows": _positive_integer, "display_values_threshold": _positive_integer, "display_style": _DISPLAY_OPTIONS.__contains__, "display_width": _positive_integer, - "display_expand_attrs": lambda choice: choice in [True, False, "default"], - "display_expand_coords": lambda choice: choice in [True, False, "default"], - "display_expand_data_vars": lambda choice: choice in [True, False, "default"], - "display_expand_data": lambda choice: choice in [True, False, "default"], - "display_expand_indexes": lambda choice: choice in [True, False, "default"], - "display_default_indexes": lambda choice: choice in [True, False, "default"], - "enable_cftimeindex": lambda value: isinstance(value, bool), + "display_expand_attrs": _is_boolean_or_default, + "display_expand_coords": _is_boolean_or_default, + "display_expand_data_vars": _is_boolean_or_default, + "display_expand_data": _is_boolean_or_default, + "display_expand_indexes": _is_boolean_or_default, + "display_default_indexes": _is_boolean_or_default, + "enable_cftimeindex": _is_boolean, "file_cache_maxsize": _positive_integer, - "keep_attrs": lambda choice: choice in [True, False, "default"], - "use_bottleneck": lambda value: isinstance(value, bool), - "use_numbagg": lambda value: isinstance(value, bool), - "use_opt_einsum": lambda value: isinstance(value, bool), - "use_flox": lambda value: isinstance(value, bool), - "warn_for_unclosed_files": lambda value: isinstance(value, bool), + "keep_attrs": _is_boolean_or_default, + "use_bottleneck": _is_boolean, + "use_numbagg": _is_boolean, + "use_opt_einsum": _is_boolean, + "use_flox": _is_boolean, + "warn_for_unclosed_files": _is_boolean, } @@ -157,6 +169,8 @@ class set_options: Parameters ---------- + arithmetic_broadcast : bool, default: True + Whether to allow or disallow broadcasting arithmetic_join : {"inner", "outer", "left", "right", "exact"}, default: "inner" DataArray/Dataset alignment in binary operations: diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index df0899509cb..2e6e638f5b1 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -89,6 +89,13 @@ def _importorskip( has_pynio, requires_pynio = _importorskip("Nio") has_cftime, requires_cftime = _importorskip("cftime") has_dask, requires_dask = _importorskip("dask") +with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="The current Dask DataFrame implementation is deprecated.", + category=DeprecationWarning, + ) + has_dask_expr, requires_dask_expr = _importorskip("dask_expr") has_bottleneck, requires_bottleneck = _importorskip("bottleneck") has_rasterio, requires_rasterio = _importorskip("rasterio") has_zarr, requires_zarr = _importorskip("zarr") diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 0cf4cc03a09..5e63d9d0e7f 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re from copy import deepcopy from typing import TYPE_CHECKING, Any, Callable @@ -1261,3 +1262,90 @@ def test_concat_index_not_same_dim() -> None: match=r"Cannot concatenate along dimension 'x' indexes with dimensions.*", ): concat([ds1, ds2], dim="x") + + +def test_concat_join_coordinate_variables_non_asked_dims(): + ds1 = Dataset( + coords={ + "x_center": ("x_center", [1, 2, 3]), + "x_outer": ("x_outer", [0.5, 1.5, 2.5, 3.5]), + }, + ) + + ds2 = Dataset( + coords={ + "x_center": ("x_center", [4, 5, 6]), + "x_outer": ("x_outer", [4.5, 5.5, 6.5]), + }, + ) + + expected_wrongly_concatenated_xds = Dataset( + coords={ + "x_center": ("x_center", [1, 2, 3, 4, 5, 6]), + "x_outer": ("x_outer", [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5]), + }, + ) + + # Using join='outer' + # default's broadcast=True will allow the concatenation to surprisingly happen + # even if `x_outer` sizes do not match + actual_xds = concat( + [ds1, ds2], + join="outer", + dim="x_center", + data_vars="different", + coords="different", + ) + assert_identical(actual_xds, expected_wrongly_concatenated_xds) + + # Using join='exact' + with pytest.raises( + ValueError, + match=re.escape( + "cannot align objects with join='exact' where " + "index/labels/sizes are not equal along these coordinates (dimensions): " + "'x_outer' ('x_outer',)" + ), + ): + concat( + [ds1, ds2], + join="exact", + dim="x_center", + data_vars="different", + coords="different", + ) + + +@pytest.mark.parametrize("join", ("outer", "exact")) +def test_concat_join_non_coordinate_variables(join: JoinOptions): + ds1 = Dataset( + data_vars={ + "a": ("x_center", [1, 2, 3]), + "b": ("x_outer", [0.5, 1.5, 2.5, 3.5]), + }, + ) + + ds2 = Dataset( + data_vars={ + "a": ("x_center", [4, 5, 6]), + "b": ("x_outer", [4.5, 5.5, 6.5]), + }, + ) + + # Whether join='outer' or join='exact' modes are used, + # the concatenation fails because of the behavior disallowing alignment + # of non-indexed dimensions (not attached to a coordinate variable). + with pytest.raises( + ValueError, + match=( + r"cannot reindex or align along dimension 'x_outer' " + r"because of conflicting dimension sizes: {3, 4}" + ), + ): + concat( + [ds1, ds2], + join=join, + dim="x_center", + data_vars="different", + coords="different", + ) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 2829fd7d49c..0edb97430ae 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -51,6 +51,7 @@ requires_bottleneck, requires_cupy, requires_dask, + requires_dask_expr, requires_iris, requires_numexpr, requires_pint, @@ -3203,6 +3204,106 @@ def test_align_str_dtype(self) -> None: assert_identical(expected_b, actual_b) assert expected_b.x.dtype == actual_b.x.dtype + @pytest.mark.parametrize("broadcast", [True, False]) + def test_broadcast_on_vs_off_same_dim_same_size(self, broadcast: bool) -> None: + xda = xr.DataArray([1], dims="x") + + aligned_1, aligned_2 = xr.align(xda, xda, join="exact", broadcast=broadcast) + assert_identical(aligned_1, xda) + assert_identical(aligned_2, xda) + + @pytest.mark.parametrize("broadcast", [True, False]) + def test_broadcast_on_vs_off_same_dim_differing_sizes(self: bool) -> None: + xda_1 = xr.DataArray([1], dims="x") + xda_2 = xr.DataArray([1, 2], dims="x") + + with pytest.raises( + ValueError, + match=re.escape( + "cannot reindex or align along dimension 'x' because of " + "conflicting dimension sizes: {1, 2}" + ), + ): + xr.align(xda_1, xda_2, join="exact", broadcast=broadcast) + + def test_broadcast_on_vs_off_differing_dims_same_sizes(self) -> None: + xda_1 = xr.DataArray([1], dims="x1") + xda_2 = xr.DataArray([1], dims="x2") + + aligned_1, aligned_2 = xr.align(xda_1, xda_2, join="exact", broadcast=True) + assert_identical(aligned_1, xda_1) + assert_identical(aligned_2, xda_2) + + with pytest.raises( + ValueError, + match=re.escape( + "cannot align objects with broadcast=False " + "because given objects do not share the same dimension names " + "([('x1',), ('x2',)])" + ), + ): + xr.align(xda_1, xda_2, join="exact", broadcast=False) + + def test_broadcast_on_vs_off_global_option(self) -> None: + xda_1 = xr.DataArray([1], dims="x1") + xda_2 = xr.DataArray([1], dims="x2") + + with xr.set_options(arithmetic_broadcast=True): + expected_xda = xr.DataArray([[1.0]], dims=("x1", "x2")) + actual_xda = xda_1 / xda_2 + assert_identical(expected_xda, actual_xda) + + with xr.set_options(arithmetic_broadcast=False): + with pytest.raises( + ValueError, + match=re.escape( + "cannot align objects with broadcast=False " + "because given objects do not share the same dimension names " + "([('x1',), ('x2',)])" + ), + ): + xda_1 / xda_2 + + def test_broadcast_on_vs_off_differing_dims_differing_sizes(self) -> None: + xda_1 = xr.DataArray([1], dims="x1") + xda_2 = xr.DataArray([1, 2], dims="x2") + + aligned_1, aligned_2 = xr.align(xda_1, xda_2, join="exact", broadcast=True) + assert_identical(aligned_1, xda_1) + assert_identical(aligned_2, xda_2) + + with pytest.raises( + ValueError, + match=re.escape( + "cannot align objects with broadcast=False " + "because given objects do not share the same dimension names " + "([('x1',), ('x2',)])" + ), + ): + xr.align(xda_1, xda_2, join="exact", broadcast=False) + + def test_broadcast_on_vs_off_2d(self) -> None: + xda_1 = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=("y1", "x1")) + xda_2 = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=("y2", "x2")) + xda_3 = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=("y3", "x3")) + + aligned_1, aligned_2, aligned_3 = xr.align( + xda_1, xda_2, xda_3, join="exact", broadcast=True + ) + assert_identical(aligned_1, xda_1) + assert_identical(aligned_2, xda_2) + assert_identical(aligned_3, xda_3) + + with pytest.raises( + ValueError, + match=re.escape( + "cannot align objects with broadcast=False " + "because given objects do not share the same dimension names " + "([('y1', 'x1'), ('y2', 'x2'), ('y3', 'x3')])" + ), + ): + xr.align(xda_1, xda_2, xda_3, join="exact", broadcast=False) + def test_broadcast_arrays(self) -> None: x = DataArray([1, 2], coords=[("a", [-1, -2])], name="x") y = DataArray([1, 2], coords=[("b", [3, 4])], name="y") @@ -3381,6 +3482,7 @@ def test_to_dataframe_0length(self) -> None: assert len(actual) == 0 assert_array_equal(actual.index.names, list("ABC")) + @requires_dask_expr @requires_dask def test_to_dask_dataframe(self) -> None: arr_np = np.arange(3 * 4).reshape(3, 4)