Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for cross product #5365

Merged
merged 120 commits into from
Dec 29, 2021
Merged
Show file tree
Hide file tree
Changes from 101 commits
Commits
Show all changes
120 commits
Select commit Hold shift + click to select a range
1490c16
Add support for cross
Illviljan May 23, 2021
03db734
Update test_computation.py
Illviljan May 23, 2021
c824e36
Update computation.py
Illviljan May 23, 2021
7ce39c7
Update computation.py
Illviljan May 23, 2021
916e661
Update test_computation.py
Illviljan May 23, 2021
654ad60
Update test_computation.py
Illviljan May 23, 2021
a6ac578
Update test_computation.py
Illviljan May 23, 2021
e0c1fac
add more tests
Illviljan May 23, 2021
7aebae7
Update xarray/core/computation.py
Illviljan May 23, 2021
b85e236
Merge branch 'master' into Illviljan-cross
Illviljan May 23, 2021
2b54a42
Merge branch 'Illviljan-cross' of https://github.com/Illviljan/xarray…
Illviljan May 23, 2021
be7b2c2
spatial_dim to dim
Illviljan May 23, 2021
4448006
Update computation.py
Illviljan May 23, 2021
af8b09c
use pad instead of concat
Illviljan May 23, 2021
a135e05
copy paste np.cross intro
Illviljan May 23, 2021
6f17b9b
Get last dim for each array, which is more inline with np.cross
Illviljan May 23, 2021
1fadb5f
examples in docs
Illviljan May 23, 2021
57239a4
Update computation.py
Illviljan May 23, 2021
265ef82
more doc examples
Illviljan May 23, 2021
dd60562
single dim required, tranpose after apply_ufunc
Illviljan May 24, 2021
a20cb86
add dims to tests
Illviljan May 24, 2021
7ce9315
Update computation.py
Illviljan May 24, 2021
d5a0ea8
reduce code
Illviljan May 25, 2021
ef94fa4
support xr.Variable
Illviljan May 25, 2021
1a85147
Update computation.py
Illviljan May 25, 2021
2ce3dbe
Update computation.py
Illviljan May 25, 2021
53c84c2
reduce code
Illviljan May 25, 2021
dded720
docstring explanations
Illviljan May 25, 2021
7058166
Use same terms
Illviljan May 25, 2021
cb57a55
docstring formatting
Illviljan May 25, 2021
e69ca81
reduce code
Illviljan May 25, 2021
4b2fc72
add tests for dask
Illviljan May 25, 2021
afe572d
simplify check, align used variables
Illviljan May 26, 2021
e137350
trim down tests
Illviljan May 26, 2021
1a26324
Update computation.py
Illviljan May 26, 2021
531a98b
simplify code
Illviljan May 27, 2021
2146406
Add type hints
Illviljan May 28, 2021
0940472
less type hints
Illviljan May 28, 2021
a7cc565
Update computation.py
Illviljan May 28, 2021
1d1f205
undo type hints
Illviljan May 28, 2021
9af7091
Update computation.py
Illviljan May 28, 2021
14decb3
Add support for datasets
Illviljan May 30, 2021
6f73c32
determine dtype with np.result_type
Illviljan Jun 2, 2021
72330ce
test datasets, daskify the inputs not the results
Illviljan Jun 6, 2021
bce2f3e
rechunk padded values, handle 1 sized datasets
Illviljan Jun 6, 2021
1636d25
expand only unique dims, squeeze out dims in tests
Illviljan Jun 6, 2021
b5b97a0
rechunk along the dim
Illviljan Jun 6, 2021
f77780f
Merge branch 'master' into Illviljan-cross
Illviljan Jun 7, 2021
02364ca
Attempt typing again
Illviljan Jun 17, 2021
e842c75
Merge branch 'master' into Illviljan-cross
Illviljan Jun 17, 2021
ed44400
Update __init__.py
Illviljan Jun 17, 2021
4fe9737
Update computation.py
Illviljan Jun 17, 2021
ec05780
Update computation.py
Illviljan Jun 17, 2021
36c5956
test fixing type in to_stacked_array
Illviljan Jun 17, 2021
cbf289c
test fixing to_stacked_array
Illviljan Jun 17, 2021
4cfd5be
small is large
Illviljan Jun 18, 2021
658a59f
Update computation.py
Illviljan Jun 18, 2021
ab5ae20
Update xarray/core/computation.py
Illviljan Jun 18, 2021
d65ca41
obfuscate variable_dim some
Illviljan Jun 19, 2021
20eef03
Update computation.py
Illviljan Jun 19, 2021
274af32
undo to_stacked_array changes
Illviljan Jun 19, 2021
f352303
test sample_dims typing
Illviljan Jun 19, 2021
0a773cb
to_stacked_array fixes
Illviljan Jun 19, 2021
d8da29f
add reindex_like check
Illviljan Jun 19, 2021
54a76c1
Update computation.py
Illviljan Jun 20, 2021
0a2dc2e
Update computation.py
Illviljan Jun 20, 2021
b3592f3
Update computation.py
Illviljan Jun 20, 2021
06772da
test forcing int type in chunk()
Illviljan Jun 20, 2021
cfd11f7
Update computation.py
Illviljan Jun 20, 2021
8451a9e
Merge branch 'master' into Illviljan-cross
Illviljan Jun 21, 2021
90553ed
test collection in to_stacked_array
Illviljan Jun 21, 2021
6eed96e
Update computation.py
Illviljan Jun 21, 2021
d3648e5
Update computation.py
Illviljan Jun 22, 2021
c639aa3
Update computation.py
Illviljan Jun 22, 2021
4c636f5
Update computation.py
Illviljan Jun 22, 2021
3bea936
Update computation.py
Illviljan Jun 22, 2021
4fc7fcb
Merge branch 'master' into Illviljan-cross
Illviljan Jun 23, 2021
19e8f93
Merge branch 'master' into Illviljan-cross
Illviljan Jun 24, 2021
f71a6f1
Merge branch 'main' into Illviljan-cross
Illviljan Jun 24, 2021
d4070ab
Merge branch 'main' into Illviljan-cross
Illviljan Jun 24, 2021
12da913
whats new and api.rst
Illviljan Jun 24, 2021
ea062e6
Update whats-new.rst
Illviljan Jun 24, 2021
ebd89e6
Merge branch 'main' into Illviljan-cross
Illviljan Jul 2, 2021
3c7122b
Merge branch 'main' into Illviljan-cross
Illviljan Jul 5, 2021
9af1198
Merge branch 'main' into Illviljan-cross
Illviljan Jul 18, 2021
27262e6
Merge branch 'main' into Illviljan-cross
Illviljan Jul 22, 2021
cc91e7c
Merge branch 'main' into Illviljan-cross
Illviljan Jul 25, 2021
629df59
Output as dataset if any input is a dataset
Illviljan Jul 26, 2021
972c7dc
Simplify the if terms instead of using pass.
Illviljan Jul 26, 2021
3c4ace0
Merge branch 'main' into Illviljan-cross
Illviljan Aug 30, 2021
49967d4
Update computation.py
Illviljan Aug 30, 2021
6ab7d19
Remove support for datasets
Illviljan Aug 30, 2021
20a6cb6
Update computation.py
Illviljan Aug 30, 2021
ba3fa9c
Add some typing to test.
Illviljan Aug 30, 2021
8b192f2
doctest fix
Illviljan Aug 30, 2021
a27965c
lint
Illviljan Aug 30, 2021
5ec65d2
Merge branch 'main' into Illviljan-cross
Illviljan Sep 8, 2021
b058084
Update xarray/core/computation.py
Illviljan Oct 3, 2021
f007ed5
Update xarray/core/computation.py
Illviljan Oct 5, 2021
e88ae9d
Update xarray/core/computation.py
Illviljan Oct 5, 2021
9aaee2b
Update computation.py
Illviljan Oct 5, 2021
5d6ecba
Update computation.py
Illviljan Oct 5, 2021
71fc9c1
Update computation.py
Illviljan Oct 5, 2021
a98b2e3
Update computation.py
Illviljan Oct 5, 2021
c95817b
Update computation.py
Illviljan Oct 6, 2021
408eb39
Can't narrow types with old type
Illviljan Oct 7, 2021
316b935
dim now keyword only
Illviljan Oct 7, 2021
3b5b030
use all_dims in transpose
Illviljan Oct 7, 2021
f9c5404
Merge branch 'main' into Illviljan-cross
Illviljan Oct 7, 2021
34b300d
if in transpose indeed needed
Illviljan Oct 7, 2021
cf13bf9
Update xarray/core/computation.py
Illviljan Oct 10, 2021
f2167a6
Update xarray/core/computation.py
Illviljan Oct 10, 2021
570a806
Update xarray/core/computation.py
Illviljan Oct 10, 2021
6f57ed6
Update computation.py
Illviljan Oct 10, 2021
52a986b
Update computation.py
Illviljan Oct 10, 2021
fa78e74
add todo comments
Illviljan Oct 10, 2021
f2d98b6
Merge branch 'main' into Illviljan-cross
Illviljan Oct 31, 2021
7449cd7
Merge branch 'main' into Illviljan-cross
Illviljan Dec 27, 2021
70d2a4b
Update whats-new.rst
Illviljan Dec 27, 2021
e6020e3
Merge branch 'main' into Illviljan-cross
Illviljan Dec 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Top-level functions
ones_like
cov
corr
cross
dot
polyval
map_blocks
Expand Down
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ v0.19.1 (unreleased)

New Features
~~~~~~~~~~~~
- New top-level function :py:func:`cross`. (:issue:`3279`, :pull:`5365`).
By `Jimmy Westling <https://github.com/illviljan>`_.
- Added a :py:func:`get_options` method to xarray's root namespace (:issue:`5698`, :pull:`5716`)
By `Pushkar Kopparla <https://github.com/pkopparla>`_.
- Xarray now does a better job rendering variable names that are long LaTeX sequences when plotting (:issue:`5681`, :pull:`5682`).
Expand Down
12 changes: 11 additions & 1 deletion xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,16 @@
from .core.alignment import align, broadcast
from .core.combine import combine_by_coords, combine_nested
from .core.common import ALL_DIMS, full_like, ones_like, zeros_like
from .core.computation import apply_ufunc, corr, cov, dot, polyval, unify_chunks, where
from .core.computation import (
apply_ufunc,
corr,
cov,
cross,
dot,
polyval,
unify_chunks,
where,
)
from .core.concat import concat
from .core.dataarray import DataArray
from .core.dataset import Dataset
Expand Down Expand Up @@ -56,6 +65,7 @@
"dot",
"cov",
"corr",
"cross",
"full_like",
"get_options",
"infer_freq",
Expand Down
200 changes: 199 additions & 1 deletion xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
if TYPE_CHECKING:
from .coordinates import Coordinates
from .dataset import Dataset
from .types import T_Xarray
from .types import DaCompatible, T_Xarray

_NO_FILL_VALUE = utils.ReprObject("<no-fill-value>")
_DEFAULT_NAME = utils.ReprObject("<default-name>")
Expand Down Expand Up @@ -1387,6 +1387,204 @@ def _get_valid_values(da, other):
return corr


def cross(a: DaCompatible, b: DaCompatible, dim: Hashable) -> DaCompatible:
"""
Return the cross product of two (arrays of) vectors.
Illviljan marked this conversation as resolved.
Show resolved Hide resolved

The cross product of `a` and `b` in :math:`R^3` is a vector
perpendicular to both `a` and `b`. If `a` and `b` are arrays of
vectors, and these axes can have dimensions 2 or 3. Where the
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
dimension of either `a` or `b` is 2, the third component of the
input vector is assumed to be zero and the cross product calculated
accordingly. In cases where both input vectors have dimension 2,
the z-component of the cross product is returned.

Parameters
----------
a, b : DataArray or Variable
Components of the first and second vector(s).
dim : hashable
The dimension along which the cross product will be computed.
Must be available in both vectors.

Examples
--------
Vector cross-product with 3 dimensions.

>>> a = xr.DataArray([1, 2, 3])
>>> b = xr.DataArray([4, 5, 6])
>>> xr.cross(a, b, "dim_0")
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
<xarray.DataArray (dim_0: 3)>
array([-3, 6, -3])
Dimensions without coordinates: dim_0

Vector cross-product with 2 dimensions, returns in the perpendicular
direction:

>>> a = xr.DataArray([1, 2])
>>> b = xr.DataArray([4, 5])
>>> xr.cross(a, b, "dim_0")
<xarray.DataArray ()>
array(-3)

Vector cross-product with 3 dimensions but zeros at the last axis
yields the same results as with 2 dimensions:

>>> a = xr.DataArray([1, 2, 0])
>>> b = xr.DataArray([4, 5, 0])
>>> xr.cross(a, b, "dim_0")
<xarray.DataArray (dim_0: 3)>
array([ 0, 0, -3])
Dimensions without coordinates: dim_0

One vector with dimension 2.

>>> a = xr.DataArray(
... [1, 2],
... dims=["cartesian"],
... coords=dict(cartesian=(["cartesian"], ["x", "y"])),
... )
>>> b = xr.DataArray(
... [4, 5, 6],
... dims=["cartesian"],
... coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])),
... )
>>> xr.cross(a, b, "cartesian")
<xarray.DataArray (cartesian: 3)>
array([12, -6, -3])
Coordinates:
* cartesian (cartesian) object 'x' 'y' 'z'

One vector with dimension 2 but coords in other positions.

>>> a = xr.DataArray(
... [1, 2],
... dims=["cartesian"],
... coords=dict(cartesian=(["cartesian"], ["x", "z"])),
... )
>>> b = xr.DataArray(
... [4, 5, 6],
... dims=["cartesian"],
... coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])),
... )
>>> xr.cross(a, b, "cartesian")
<xarray.DataArray (cartesian: 3)>
array([-10, 2, 5])
Coordinates:
* cartesian (cartesian) object 'x' 'y' 'z'

Multiple vector cross-products. Note that the direction of the
cross product vector is defined by the right-hand rule.

>>> a = xr.DataArray(
... [[1, 2, 3], [4, 5, 6]],
... dims=("time", "cartesian"),
... coords=dict(
... time=(["time"], [0, 1]),
... cartesian=(["cartesian"], ["x", "y", "z"]),
... ),
... )
>>> b = xr.DataArray(
... [[4, 5, 6], [1, 2, 3]],
... dims=("time", "cartesian"),
... coords=dict(
... time=(["time"], [0, 1]),
... cartesian=(["cartesian"], ["x", "y", "z"]),
... ),
... )
>>> xr.cross(a, b, "cartesian")
<xarray.DataArray (time: 2, cartesian: 3)>
array([[-3, 6, -3],
[ 3, -6, 3]])
Coordinates:
* time (time) int64 0 1
* cartesian (cartesian) <U1 'x' 'y' 'z'

Cross can used by on Datasets by converting to DataArrays and then
back to Datasets:
Illviljan marked this conversation as resolved.
Show resolved Hide resolved

>>> ds_a = xr.Dataset(dict(x=("dim_0", [1]), y=("dim_0", [2]), z=("dim_0", [3])))
>>> ds_b = xr.Dataset(dict(x=("dim_0", [4]), y=("dim_0", [5]), z=("dim_0", [6])))
>>> c = xr.cross(
... ds_a.to_array("cartesian"), ds_b.to_array("cartesian"), dim="cartesian"
... )
>>> c.to_dataset(dim="cartesian")
<xarray.Dataset>
Dimensions: (dim_0: 1)
Dimensions without coordinates: dim_0
Data variables:
x (dim_0) int64 -3
y (dim_0) int64 6
z (dim_0) int64 -3

See Also
--------
numpy.cross : Corresponding numpy function
"""

if dim not in a.dims:
raise ValueError(f"Dimension {dim!r} not on a")
elif dim not in b.dims:
raise ValueError(f"Dimension {dim!r} not on b")

if not 1 <= a.sizes[dim] <= 3:
raise ValueError(
f"The size of {dim!r} on a must be 1, 2, or 3 to be "
f"compatible with a cross product but is {a.sizes[dim]}"
)
elif not 1 <= b.sizes[dim] <= 3:
raise ValueError(
f"The size of {dim!r} on b must be 1, 2, or 3 to be "
f"compatible with a cross product but is {b.sizes[dim]}"
)

all_dims = list(dict.fromkeys(a.dims + b.dims))

if a.sizes[dim] != b.sizes[dim]:
# Arrays have different sizes. Append zeros where the smaller
# array is missing a value, zeros will not affect np.cross:

if dim in getattr(a, "coords", {}) and dim in getattr(b, "coords", {}):
# If the arrays have coords we know which indexes to fill
# with zeros:
a, b = align(
a,
b,
fill_value=0,
join="outer",
exclude=set(all_dims) - {dim},
)
elif min(a.sizes[dim], b.sizes[dim]) == 2:
# If the array doesn't have coords we can only infer
# that it has composite values if the size is at least 2.
# Once padded, rechunk the padded array because apply_ufunc
# requires core dimensions not to be chunked:
if a.sizes[dim] < b.sizes[dim]:
a = a.pad({dim: (0, 1)}, constant_values=0)
a = a.chunk({dim: -1}) if is_duck_dask_array(a.data) else a
else:
b = b.pad({dim: (0, 1)}, constant_values=0)
b = b.chunk({dim: -1}) if is_duck_dask_array(b.data) else b
else:
raise ValueError(
f"{dim!r} on {'a' if a.sizes[dim] == 1 else 'b'} is incompatible:"
" dimensions without coordinates must have have a length of 2 or 3"
)

c = apply_ufunc(
np.cross,
a,
b,
input_core_dims=[[dim], [dim]],
output_core_dims=[[dim] if a.sizes[dim] == 3 else []],
dask="parallelized",
output_dtypes=[np.result_type(*a, b)],
)
c = c.transpose(*[d for d in all_dims if d in c.dims])
Illviljan marked this conversation as resolved.
Show resolved Hide resolved

return c


def dot(*arrays, dims=None, **kwargs):
"""Generalized dot product for xarray objects. Like np.einsum, but
provides a simpler interface based on array dimensions.
Expand Down
107 changes: 107 additions & 0 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1930,3 +1930,110 @@ def test_polyval(use_dask, use_datetime) -> None:
da_pv = xr.polyval(da.x, coeffs)

xr.testing.assert_allclose(da, da_pv.T)


@pytest.mark.parametrize("use_dask", [False, True])
@pytest.mark.parametrize(
"a, b, ae, be, dim, axis",
[
[
xr.DataArray([1, 2, 3]),
xr.DataArray([4, 5, 6]),
[1, 2, 3],
[4, 5, 6],
"dim_0",
-1,
],
[
xr.DataArray([1, 2]),
xr.DataArray([4, 5, 6]),
[1, 2],
[4, 5, 6],
"dim_0",
-1,
],
[
xr.Variable(dims=["dim_0"], data=[1, 2, 3]),
xr.Variable(dims=["dim_0"], data=[4, 5, 6]),
[1, 2, 3],
[4, 5, 6],
"dim_0",
-1,
],
[
xr.Variable(dims=["dim_0"], data=[1, 2]),
xr.Variable(dims=["dim_0"], data=[4, 5, 6]),
[1, 2],
[4, 5, 6],
"dim_0",
-1,
],
[ # Test dim in the middle:
xr.DataArray(
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)),
dims=["time", "cartesian", "var"],
coords=dict(
time=(["time"], np.arange(0, 5)),
cartesian=(["cartesian"], ["x", "y", "z"]),
var=(["var"], [1, 1.5, 2, 2.5]),
),
),
xr.DataArray(
np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)) + 1,
dims=["time", "cartesian", "var"],
coords=dict(
time=(["time"], np.arange(0, 5)),
cartesian=(["cartesian"], ["x", "y", "z"]),
var=(["var"], [1, 1.5, 2, 2.5]),
),
),
np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)),
np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)) + 1,
"cartesian",
1,
],
[ # Test 1 sized arrays with coords:
xr.DataArray(
np.array([1]),
dims=["cartesian"],
coords=dict(cartesian=(["cartesian"], ["z"])),
),
xr.DataArray(
np.array([4, 5, 6]),
dims=["cartesian"],
coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])),
),
[0, 0, 1],
[4, 5, 6],
"cartesian",
-1,
],
[ # Test filling inbetween with coords:
xr.DataArray(
[1, 2],
dims=["cartesian"],
coords=dict(cartesian=(["cartesian"], ["x", "z"])),
),
xr.DataArray(
[4, 5, 6],
dims=["cartesian"],
coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])),
),
[1, 0, 2],
[4, 5, 6],
"cartesian",
-1,
],
],
)
def test_cross(a, b, ae, be, dim: str, axis: int, use_dask: bool) -> None:
expected = np.cross(ae, be, axis=axis)

if use_dask:
if not has_dask:
pytest.skip("test for dask.")
a = a.chunk()
b = b.chunk()

actual = xr.cross(a, b, dim=dim)
xr.testing.assert_duckarray_allclose(expected, actual)