Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New alignment option: join='strict' #8698

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
bbe7d05
New alignment option: join='strict'
etienneschalk Feb 3, 2024
cddcaa1
Fix what's new newlines + retrigger CI
etienneschalk Feb 4, 2024
37a7b09
wrong join
etienneschalk Feb 4, 2024
a93e44b
Merge branch 'main' into eschalk/issue-8231-align
etienneschalk Feb 7, 2024
ed4873e
Merge branch 'main' into eschalk/issue-8231-align
etienneschalk Feb 9, 2024
c6b1df5
Added test align 2d three arrays
etienneschalk Feb 9, 2024
ed0414a
Added tests and use assert_identical
etienneschalk Feb 10, 2024
a8bc8dc
Merge branch 'main' into eschalk/issue-8231-align
etienneschalk Feb 10, 2024
451c96f
Merge branch 'main' into eschalk/issue-8231-align
etienneschalk Feb 11, 2024
f46b19b
Added tests for join=exact
etienneschalk Feb 14, 2024
95295d1
Try replacing join=strict by broadcast=False
etienneschalk Feb 15, 2024
80f5e1b
Merge branch 'eschalk/issue-8231-align' of github.com:etienneschalk/x…
etienneschalk Feb 15, 2024
d84c688
More broadcasts
etienneschalk Feb 17, 2024
06c2077
Merge branch 'main' into eschalk/issue-8231-align
etienneschalk Feb 17, 2024
a7148d6
CI failed: mypy + warnings
etienneschalk Feb 17, 2024
3a4e81a
Review - revert useless change
etienneschalk Feb 18, 2024
4247a86
Review - missing passed down param, useless newline
etienneschalk Feb 18, 2024
9db7db1
Apply suggestions from code review
etienneschalk Feb 18, 2024
9966d4e
Merge branch 'eschalk/issue-8231-align' of github.com:etienneschalk/x…
etienneschalk Feb 18, 2024
e4d0a84
review - error message
etienneschalk Feb 18, 2024
fe56de8
Review - Use dims
etienneschalk Feb 18, 2024
e4b0cd8
param test
etienneschalk Feb 18, 2024
6a4db7b
Merge branch 'main' into eschalk/issue-8231-align
etienneschalk Feb 18, 2024
26d7311
Remove faulty line whats-new
etienneschalk Feb 18, 2024
df8fd3a
Review - remove broadcast from apply_ufunc
etienneschalk Feb 18, 2024
651289f
Re-add missing line
etienneschalk Feb 18, 2024
5cc61b0
Merge branch 'main' into eschalk/issue-8231-align
etienneschalk Feb 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,6 @@ doc/team-panel.txt
doc/external-examples-gallery.txt
doc/notebooks-examples-gallery.txt
doc/videos-gallery.txt

# MyPy Report
mypy_report/
2 changes: 1 addition & 1 deletion doc/user-guide/options.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Xarray offers a small number of configuration options through :py:func:`set_opti
- ``display_max_rows``
- ``display_style``

2. Control behaviour during operations: ``arithmetic_join``, ``keep_attrs``, ``use_bottleneck``.
2. Control behaviour during operations: ``arithmetic_broadcast``, ``arithmetic_join``, ``keep_attrs``, ``use_bottleneck``.
3. Control colormaps for plots:``cmap_divergent``, ``cmap_sequential``.
4. Aspects of file reading: ``file_cache_maxsize``, ``warn_on_unclosed_files``.

Expand Down
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ Mathias Hauser, Matt Savoie, Maximilian Roos, Rambaud Pierrick, Tom Nicholas
New Features
~~~~~~~~~~~~

- Added the ability to control broadcasting for alignment, and new gloal option ``arithmetic_broadcast``
(:issue:`6806`, :pull:`8698`).
By `Etienne Schalk <https://github.com/etienneschalk>`_.
- Added a simple ``nbytes`` representation in DataArrays and Dataset ``repr``.
(:issue:`8690`, :pull:`8702`).
By `Etienne Schalk <https://github.com/etienneschalk>`_.
Expand Down
61 changes: 53 additions & 8 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@
from collections import defaultdict
from collections.abc import Hashable, Iterable, Mapping
from contextlib import suppress
from typing import TYPE_CHECKING, Any, Callable, Final, Generic, TypeVar, cast, overload
from typing import (
TYPE_CHECKING,
Any,
Callable,
Final,
Generic,
TypeVar,
cast,
get_args,
overload,
)

import numpy as np
import pandas as pd
Expand All @@ -19,7 +29,7 @@
indexes_all_equal,
safe_cast_to_index,
)
from xarray.core.types import T_Alignable
from xarray.core.types import JoinOptions, T_Alignable
from xarray.core.utils import is_dict_like, is_full_slice
from xarray.core.variable import Variable, as_compatible_data, calculate_dimensions

Expand All @@ -28,7 +38,6 @@
from xarray.core.dataset import Dataset
from xarray.core.types import (
Alignable,
JoinOptions,
T_DataArray,
T_Dataset,
T_DuckArray,
Expand Down Expand Up @@ -113,6 +122,7 @@ class Aligner(Generic[T_Alignable]):
results: tuple[T_Alignable, ...]
objects_matching_indexes: tuple[dict[MatchingIndexKey, Index], ...]
join: str
broadcast: bool
exclude_dims: frozenset[Hashable]
exclude_vars: frozenset[Hashable]
copy: bool
Expand All @@ -133,6 +143,7 @@ def __init__(
self,
objects: Iterable[T_Alignable],
join: str = "inner",
broadcast: bool = True,
indexes: Mapping[Any, Any] | None = None,
exclude_dims: str | Iterable[Hashable] = frozenset(),
exclude_vars: Iterable[Hashable] = frozenset(),
Expand All @@ -145,9 +156,10 @@ def __init__(
self.objects = tuple(objects)
self.objects_matching_indexes = ()

if join not in ["inner", "outer", "override", "exact", "left", "right"]:
if join not in get_args(JoinOptions):
raise ValueError(f"invalid value for join: {join}")
self.join = join
self.broadcast = broadcast

self.copy = copy
self.fill_value = fill_value
Expand Down Expand Up @@ -264,13 +276,19 @@ def find_matching_indexes(self) -> None:
self.all_indexes = all_indexes
self.all_index_vars = all_index_vars

if self.join == "override":
if self.join == "override" or not self.broadcast:
for dim_sizes in all_indexes_dim_sizes.values():
for dim, sizes in dim_sizes.items():
if len(sizes) > 1:
message = (
"join='override'"
if self.join == "override"
else "broadcast=False"
)
raise ValueError(
"cannot align objects with join='override' with matching indexes "
f"along dimension {dim!r} that don't have the same size"
f"cannot align objects with indexes "
f"along dimension {dim!r} that don't have the same size "
f"({sizes!r}) when {message}"
)

def find_matching_unindexed_dims(self) -> None:
Expand Down Expand Up @@ -478,6 +496,20 @@ def assert_unindexed_dim_sizes_equal(self) -> None:
f"because of conflicting dimension sizes: {sizes!r}" + add_err_msg
)

def assert_equal_dimension_names(self) -> None:
# When broadcasting is disabled, only allows objects having the exact same dimensions' names.
if self.broadcast:
return

unique_dims = set(tuple(o.dims) for o in self.objects)
all_objects_have_same_dims = len(unique_dims) == 1
if not all_objects_have_same_dims:
raise ValueError(
f"cannot align objects with broadcast=False "
f"because given objects do not share the same dimension names "
f"({[tuple(o.dims) for o in self.objects]!r})."
)

def override_indexes(self) -> None:
objects = list(self.objects)

Expand Down Expand Up @@ -568,6 +600,7 @@ def align(self) -> None:
self.results = (obj.copy(deep=self.copy),)
return

self.assert_equal_dimension_names()
self.find_matching_indexes()
self.find_matching_unindexed_dims()
self.assert_no_index_conflict()
Expand Down Expand Up @@ -595,6 +628,7 @@ def align(
/,
*,
join: JoinOptions = "inner",
broadcast: bool = True,
copy: bool = True,
indexes=None,
exclude: str | Iterable[Hashable] = frozenset(),
Expand All @@ -609,6 +643,7 @@ def align(
/,
*,
join: JoinOptions = "inner",
broadcast: bool = True,
copy: bool = True,
indexes=None,
exclude: str | Iterable[Hashable] = frozenset(),
Expand All @@ -624,6 +659,7 @@ def align(
/,
*,
join: JoinOptions = "inner",
broadcast: bool = True,
copy: bool = True,
indexes=None,
exclude: str | Iterable[Hashable] = frozenset(),
Expand All @@ -640,6 +676,7 @@ def align(
/,
*,
join: JoinOptions = "inner",
broadcast: bool = True,
copy: bool = True,
indexes=None,
exclude: str | Iterable[Hashable] = frozenset(),
Expand All @@ -657,6 +694,7 @@ def align(
/,
*,
join: JoinOptions = "inner",
broadcast: bool = True,
copy: bool = True,
indexes=None,
exclude: str | Iterable[Hashable] = frozenset(),
Expand All @@ -668,6 +706,7 @@ def align(
def align(
*objects: T_Alignable,
join: JoinOptions = "inner",
broadcast: bool = True,
copy: bool = True,
indexes=None,
exclude: str | Iterable[Hashable] = frozenset(),
Expand All @@ -678,6 +717,7 @@ def align(
def align(
*objects: T_Alignable,
join: JoinOptions = "inner",
broadcast: bool = True,
copy: bool = True,
indexes=None,
exclude: str | Iterable[Hashable] = frozenset(),
Expand Down Expand Up @@ -710,7 +750,9 @@ def align(
- "override": if indexes are of same size, rewrite indexes to be
those of the first object with that dimension. Indexes for the same
dimension must have the same size in all objects.

broadcast : bool, optional
Disallow automatic broadcasting of all objects along dimensions that are present in some but not all objects.
If False, this will raise an error when all objects do *not* have the same dimensions.
copy : bool, default: True
If ``copy=True``, data in the return values is always copied. If
``copy=False`` and reindexing is unnecessary, or can be performed with
Expand Down Expand Up @@ -874,6 +916,7 @@ def align(
aligner = Aligner(
objects,
join=join,
broadcast=broadcast,
copy=copy,
indexes=indexes,
exclude_dims=exclude,
Expand All @@ -886,6 +929,7 @@ def align(
def deep_align(
objects: Iterable[Any],
join: JoinOptions = "inner",
broadcast: bool = True,
copy: bool = True,
indexes=None,
exclude: str | Iterable[Hashable] = frozenset(),
Expand Down Expand Up @@ -946,6 +990,7 @@ def is_alignable(obj):
aligned = align(
*targets,
join=join,
broadcast=broadcast,
copy=copy,
indexes=indexes,
exclude=exclude,
Expand Down
2 changes: 2 additions & 0 deletions xarray/core/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
"`.values`)."
)

broadcast = OPTIONS["arithmetic_broadcast"]
join = dataset_join = OPTIONS["arithmetic_join"]

return apply_ufunc(
Expand All @@ -86,6 +87,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
input_core_dims=((),) * ufunc.nin,
output_core_dims=((),) * ufunc.nout,
join=join,
broadcast=broadcast,
dataset_join=dataset_join,
dataset_fill_value=np.nan,
kwargs=kwargs,
Expand Down
6 changes: 6 additions & 0 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def apply_dataarray_vfunc(
*args,
signature: _UFuncSignature,
join: JoinOptions = "inner",
broadcast: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for being unclear (again). I don't think we need this in align at all.

We should simply be checking OPTIONS["arithmetic_broadcast"] in Variable._binary_op . The whole business with align was a misunderstanding of mine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will close this PR as it diverged too much from the original wanted behavior

exclude_dims=frozenset(),
keep_attrs="override",
) -> tuple[DataArray, ...] | DataArray:
Expand All @@ -295,6 +296,7 @@ def apply_dataarray_vfunc(
deep_align(
args,
join=join,
broadcast=broadcast,
copy=False,
exclude=exclude_dims,
raise_on_invalid=False,
Expand Down Expand Up @@ -494,6 +496,7 @@ def apply_dataset_vfunc(
signature: _UFuncSignature,
join="inner",
dataset_join="exact",
broadcast: bool = True,
fill_value=_NO_FILL_VALUE,
exclude_dims=frozenset(),
keep_attrs="override",
Expand All @@ -518,6 +521,7 @@ def apply_dataset_vfunc(
deep_align(
args,
join=join,
broadcast=broadcast,
copy=False,
exclude=exclude_dims,
raise_on_invalid=False,
Expand Down Expand Up @@ -1906,6 +1910,7 @@ def dot(
subscripts = ",".join(subscripts_list)
subscripts += "->..." + "".join(dim_map[d] for d in output_core_dims[0])

broadcast = OPTIONS["arithmetic_broadcast"]
join = OPTIONS["arithmetic_join"]
# using "inner" emulates `(a * b).sum()` for all joins (except "exact")
if join != "exact":
Expand All @@ -1920,6 +1925,7 @@ def dot(
input_core_dims=input_core_dims,
output_core_dims=output_core_dims,
join=join,
broadcast=broadcast,
dask="allowed",
)
return result.transpose(*all_dims, missing_dims="ignore")
Expand Down
1 change: 1 addition & 0 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def concat(
- "override": if indexes are of same size, rewrite indexes to be
those of the first object with that dimension. Indexes for the same
dimension must have the same size in all objects.

combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \
"override"} or callable, default: "override"
A callable or a string indicating how to combine attrs of the objects being
Expand Down
5 changes: 4 additions & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4671,8 +4671,11 @@ def _binary_op(
if isinstance(other, (Dataset, GroupBy)):
return NotImplemented
if isinstance(other, DataArray):
broadcast = OPTIONS["arithmetic_broadcast"]
align_type = OPTIONS["arithmetic_join"]
self, other = align(self, other, join=align_type, copy=False)
self, other = align(
self, other, join=align_type, broadcast=broadcast, copy=False
)
other_variable_or_arraylike: DaCompatible = getattr(other, "variable", other)
other_coords = getattr(other, "coords", None)

Expand Down
5 changes: 4 additions & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7584,8 +7584,11 @@ def _binary_op(self, other, f, reflexive=False, join=None) -> Dataset:
if isinstance(other, GroupBy):
return NotImplemented
align_type = OPTIONS["arithmetic_join"] if join is None else join
broadcast = OPTIONS["arithmetic_broadcast"]
if isinstance(other, (DataArray, Dataset)):
self, other = align(self, other, join=align_type, copy=False)
self, other = align(
self, other, join=align_type, broadcast=broadcast, copy=False
)
g = f if not reflexive else lambda x, y: f(y, x)
ds = self._calculate_binary_op(g, other, join=align_type)
keep_attrs = _get_keep_attrs(default=False)
Expand Down
16 changes: 14 additions & 2 deletions xarray/core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,7 @@ def merge_coords(
objects: Iterable[CoercibleMapping],
compat: CompatOptions = "minimal",
join: JoinOptions = "outer",
broadcast: bool = True,
priority_arg: int | None = None,
indexes: Mapping[Any, Index] | None = None,
fill_value: object = dtypes.NA,
Expand All @@ -554,7 +555,12 @@ def merge_coords(
_assert_compat_valid(compat)
coerced = coerce_pandas_values(objects)
aligned = deep_align(
coerced, join=join, copy=False, indexes=indexes, fill_value=fill_value
coerced,
join=join,
broadcast=broadcast,
copy=False,
indexes=indexes,
fill_value=fill_value,
)
collected = collect_variables_and_indexes(aligned, indexes=indexes)
prioritized = _get_priority_vars_and_indexes(aligned, priority_arg, compat=compat)
Expand Down Expand Up @@ -647,6 +653,7 @@ def merge_core(
objects: Iterable[CoercibleMapping],
compat: CompatOptions = "broadcast_equals",
join: JoinOptions = "outer",
broadcast: bool = True,
combine_attrs: CombineAttrsOptions = "override",
priority_arg: int | None = None,
explicit_coords: Iterable[Hashable] | None = None,
Expand Down Expand Up @@ -709,7 +716,12 @@ def merge_core(

coerced = coerce_pandas_values(objects)
aligned = deep_align(
coerced, join=join, copy=False, indexes=indexes, fill_value=fill_value
coerced,
join=join,
broadcast=broadcast,
copy=False,
indexes=indexes,
fill_value=fill_value,
)

for pos, obj in skip_align_objs:
Expand Down
Loading
Loading