diff --git a/.github/ISSUE_TEMPLATE/bug-report.md b/.github/ISSUE_TEMPLATE/bug-report.md
deleted file mode 100644
index 02bc5d0f7b0..00000000000
--- a/.github/ISSUE_TEMPLATE/bug-report.md
+++ /dev/null
@@ -1,39 +0,0 @@
----
-name: Bug report
-about: Create a report to help us improve
-title: ''
-labels: ''
-assignees: ''
-
----
-
-
-
-**What happened**:
-
-**What you expected to happen**:
-
-**Minimal Complete Verifiable Example**:
-
-```python
-# Put your MCVE code here
-```
-
-**Anything else we need to know?**:
-
-**Environment**:
-
-Output of xr.show_versions()
-
-
-
-
-
diff --git a/.github/ISSUE_TEMPLATE/bugreport.yml b/.github/ISSUE_TEMPLATE/bugreport.yml
new file mode 100644
index 00000000000..255c7de07d9
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/bugreport.yml
@@ -0,0 +1,61 @@
+name: Bug Report
+description: File a bug report to help us improve
+title: '[Bug]: '
+labels: [bug, 'needs triage']
+assignees: []
+body:
+ - type: textarea
+ id: what-happened
+ attributes:
+ label: What happened?
+ description: |
+ Thanks for reporting a bug! Please describe what you were trying to get done.
+ Tell us what happened, what went wrong.
+ validations:
+ required: true
+
+ - type: textarea
+ id: what-did-you-expect-to-happen
+ attributes:
+ label: What did you expect to happen?
+ description: |
+ Describe what you expected to happen.
+ validations:
+ required: false
+
+ - type: textarea
+ id: sample-code
+ attributes:
+ label: Minimal Complete Verifiable Example
+ description: |
+ Minimal, self-contained copy-pastable example that generates the issue if possible. Please be concise with code posted. See guidelines below on how to provide a good bug report:
+
+ - [Minimal Complete Verifiable Examples](https://stackoverflow.com/help/mcve)
+ - [Craft Minimal Bug Reports](http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports)
+
+ Bug reports that follow these guidelines are easier to diagnose, and so are often handled much more quickly.
+ This will be automatically formatted into code, so no need for markdown backticks.
+ render: python
+
+ - type: textarea
+ id: log-output
+ attributes:
+ label: Relevant log output
+ description: Please copy and paste any relevant output. This will be automatically formatted into code, so no need for markdown backticks.
+ render: python
+
+ - type: textarea
+ id: extra
+ attributes:
+ label: Anything else we need to know?
+ description: |
+ Please describe any other information you want to share.
+
+ - type: textarea
+ id: show-versions
+ attributes:
+ label: Environment
+ description: |
+ Paste the output of `xr.show_versions()` here
+ validations:
+ required: true
diff --git a/.github/ISSUE_TEMPLATE/feature-request.md b/.github/ISSUE_TEMPLATE/feature-request.md
deleted file mode 100644
index 7021fe490aa..00000000000
--- a/.github/ISSUE_TEMPLATE/feature-request.md
+++ /dev/null
@@ -1,22 +0,0 @@
----
-name: Feature request
-about: Suggest an idea for this project
-title: ''
-labels: ''
-assignees: ''
-
----
-
-
-
-**Is your feature request related to a problem? Please describe.**
-A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
-
-**Describe the solution you'd like**
-A clear and concise description of what you want to happen.
-
-**Describe alternatives you've considered**
-A clear and concise description of any alternative solutions or features you've considered.
-
-**Additional context**
-Add any other context about the feature request here.
diff --git a/.github/ISSUE_TEMPLATE/newfeature.yml b/.github/ISSUE_TEMPLATE/newfeature.yml
new file mode 100644
index 00000000000..ec94b0f4b89
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/newfeature.yml
@@ -0,0 +1,37 @@
+name: Feature Request
+description: Suggest an idea for xarray
+title: '[FEATURE]: '
+labels: [enhancement]
+assignees: []
+body:
+ - type: textarea
+ id: description
+ attributes:
+ label: Is your feature request related to a problem?
+ description: |
+ Please do a quick search of existing issues to make sure that this has not been asked before.
+ Please provide a clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
+ validations:
+ required: true
+ - type: textarea
+ id: solution
+ attributes:
+ label: Describe the solution you'd like
+ description: |
+ A clear and concise description of what you want to happen.
+ - type: textarea
+ id: alternatives
+ attributes:
+ label: Describe alternatives you've considered
+ description: |
+ A clear and concise description of any alternative solutions or features you've considered.
+ validations:
+ required: false
+ - type: textarea
+ id: additional-context
+ attributes:
+ label: Additional context
+ description: |
+ Add any other context about the feature request here.
+ validations:
+ required: false
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index c7ea19a53cb..37b8d357c87 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -2,6 +2,5 @@
- [ ] Closes #xxxx
- [ ] Tests added
-- [ ] Passes `pre-commit run --all-files`
- [ ] User visible changes (including notable bug fixes) are documented in `whats-new.rst`
- [ ] New functions/methods are listed in `api.rst`
diff --git a/.github/workflows/ci-pre-commit.yml b/.github/workflows/ci-pre-commit.yml
deleted file mode 100644
index 4bc5bddfdbc..00000000000
--- a/.github/workflows/ci-pre-commit.yml
+++ /dev/null
@@ -1,17 +0,0 @@
-name: linting
-
-on:
- push:
- branches: "*"
- pull_request:
- branches: "*"
-
-jobs:
- linting:
- name: "pre-commit hooks"
- runs-on: ubuntu-latest
- if: github.repository == 'pydata/xarray'
- steps:
- - uses: actions/checkout@v2
- - uses: actions/setup-python@v2
- - uses: pre-commit/action@v2.0.3
diff --git a/asv_bench/benchmarks/import_xarray.py b/asv_bench/benchmarks/import_xarray.py
deleted file mode 100644
index 94652e3b82a..00000000000
--- a/asv_bench/benchmarks/import_xarray.py
+++ /dev/null
@@ -1,9 +0,0 @@
-class ImportXarray:
- def setup(self, *args, **kwargs):
- def import_xr():
- import xarray # noqa: F401
-
- self._import_xr = import_xr
-
- def time_import_xarray(self):
- self._import_xr()
diff --git a/doc/api.rst b/doc/api.rst
index 9433ecfa56d..ef2694ea661 100644
--- a/doc/api.rst
+++ b/doc/api.rst
@@ -32,6 +32,7 @@ Top-level functions
ones_like
cov
corr
+ cross
dot
polyval
map_blocks
diff --git a/doc/whats-new.rst b/doc/whats-new.rst
index 1c4b49097a3..f991a4e2a89 100644
--- a/doc/whats-new.rst
+++ b/doc/whats-new.rst
@@ -21,6 +21,8 @@ v0.21.0 (unreleased)
New Features
~~~~~~~~~~~~
+- New top-level function :py:func:`cross`. (:issue:`3279`, :pull:`5365`).
+ By `Jimmy Westling `_.
Breaking changes
@@ -29,6 +31,8 @@ Breaking changes
Deprecations
~~~~~~~~~~~~
+- Removed the lock kwarg from the zarr and pydap backends, completing the deprecation cycle started in :issue:`5256`.
+ By `Tom Nicholas `_.
Bug fixes
diff --git a/xarray/__init__.py b/xarray/__init__.py
index 10f16e58081..81ab9f388a8 100644
--- a/xarray/__init__.py
+++ b/xarray/__init__.py
@@ -16,7 +16,16 @@
from .core.alignment import align, broadcast
from .core.combine import combine_by_coords, combine_nested
from .core.common import ALL_DIMS, full_like, ones_like, zeros_like
-from .core.computation import apply_ufunc, corr, cov, dot, polyval, unify_chunks, where
+from .core.computation import (
+ apply_ufunc,
+ corr,
+ cov,
+ cross,
+ dot,
+ polyval,
+ unify_chunks,
+ where,
+)
from .core.concat import concat
from .core.dataarray import DataArray
from .core.dataset import Dataset
@@ -60,6 +69,7 @@
"dot",
"cov",
"corr",
+ "cross",
"full_like",
"get_options",
"infer_freq",
diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py
index bc479f9a71d..ffaf3793928 100644
--- a/xarray/backends/pydap_.py
+++ b/xarray/backends/pydap_.py
@@ -1,5 +1,3 @@
-import warnings
-
import numpy as np
from ..core import indexing
@@ -126,15 +124,7 @@ def open_dataset(
use_cftime=None,
decode_timedelta=None,
session=None,
- lock=None,
):
- # TODO remove after v0.19
- if lock is not None:
- warnings.warn(
- "The kwarg 'lock' has been deprecated for this backend, and is now "
- "ignored. In the future passing lock will raise an error.",
- DeprecationWarning,
- )
store = PydapDataStore.open(
filename_or_obj,
diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py
index 3eb6a3caf72..8bd343869ff 100644
--- a/xarray/backends/zarr.py
+++ b/xarray/backends/zarr.py
@@ -810,15 +810,7 @@ def open_dataset(
chunk_store=None,
storage_options=None,
stacklevel=3,
- lock=None,
):
- # TODO remove after v0.19
- if lock is not None:
- warnings.warn(
- "The kwarg 'lock' has been deprecated for this backend, and is now "
- "ignored. In the future passing lock will raise an error.",
- DeprecationWarning,
- )
filename_or_obj = _normalize_path(filename_or_obj)
store = ZarrStore.open_group(
diff --git a/xarray/core/computation.py b/xarray/core/computation.py
index 191b777107a..9fe93c88734 100644
--- a/xarray/core/computation.py
+++ b/xarray/core/computation.py
@@ -36,6 +36,7 @@
if TYPE_CHECKING:
from .coordinates import Coordinates
+ from .dataarray import DataArray
from .dataset import Dataset
from .types import T_Xarray
@@ -1373,6 +1374,214 @@ def _cov_corr(da_a, da_b, dim=None, ddof=0, method=None):
return corr
+def cross(
+ a: Union[DataArray, Variable], b: Union[DataArray, Variable], *, dim: Hashable
+) -> Union[DataArray, Variable]:
+ """
+ Compute the cross product of two (arrays of) vectors.
+
+ The cross product of `a` and `b` in :math:`R^3` is a vector
+ perpendicular to both `a` and `b`. The vectors in `a` and `b` are
+ defined by the values along the dimension `dim` and can have sizes
+ 1, 2 or 3. Where the size of either `a` or `b` is
+ 1 or 2, the remaining components of the input vector is assumed to
+ be zero and the cross product calculated accordingly. In cases where
+ both input vectors have dimension 2, the z-component of the cross
+ product is returned.
+
+ Parameters
+ ----------
+ a, b : DataArray or Variable
+ Components of the first and second vector(s).
+ dim : hashable
+ The dimension along which the cross product will be computed.
+ Must be available in both vectors.
+
+ Examples
+ --------
+ Vector cross-product with 3 dimensions:
+
+ >>> a = xr.DataArray([1, 2, 3])
+ >>> b = xr.DataArray([4, 5, 6])
+ >>> xr.cross(a, b, dim="dim_0")
+
+ array([-3, 6, -3])
+ Dimensions without coordinates: dim_0
+
+ Vector cross-product with 2 dimensions, returns in the perpendicular
+ direction:
+
+ >>> a = xr.DataArray([1, 2])
+ >>> b = xr.DataArray([4, 5])
+ >>> xr.cross(a, b, dim="dim_0")
+
+ array(-3)
+
+ Vector cross-product with 3 dimensions but zeros at the last axis
+ yields the same results as with 2 dimensions:
+
+ >>> a = xr.DataArray([1, 2, 0])
+ >>> b = xr.DataArray([4, 5, 0])
+ >>> xr.cross(a, b, dim="dim_0")
+
+ array([ 0, 0, -3])
+ Dimensions without coordinates: dim_0
+
+ One vector with dimension 2:
+
+ >>> a = xr.DataArray(
+ ... [1, 2],
+ ... dims=["cartesian"],
+ ... coords=dict(cartesian=(["cartesian"], ["x", "y"])),
+ ... )
+ >>> b = xr.DataArray(
+ ... [4, 5, 6],
+ ... dims=["cartesian"],
+ ... coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])),
+ ... )
+ >>> xr.cross(a, b, dim="cartesian")
+
+ array([12, -6, -3])
+ Coordinates:
+ * cartesian (cartesian) >> a = xr.DataArray(
+ ... [1, 2],
+ ... dims=["cartesian"],
+ ... coords=dict(cartesian=(["cartesian"], ["x", "z"])),
+ ... )
+ >>> b = xr.DataArray(
+ ... [4, 5, 6],
+ ... dims=["cartesian"],
+ ... coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])),
+ ... )
+ >>> xr.cross(a, b, dim="cartesian")
+
+ array([-10, 2, 5])
+ Coordinates:
+ * cartesian (cartesian) >> a = xr.DataArray(
+ ... [[1, 2, 3], [4, 5, 6]],
+ ... dims=("time", "cartesian"),
+ ... coords=dict(
+ ... time=(["time"], [0, 1]),
+ ... cartesian=(["cartesian"], ["x", "y", "z"]),
+ ... ),
+ ... )
+ >>> b = xr.DataArray(
+ ... [[4, 5, 6], [1, 2, 3]],
+ ... dims=("time", "cartesian"),
+ ... coords=dict(
+ ... time=(["time"], [0, 1]),
+ ... cartesian=(["cartesian"], ["x", "y", "z"]),
+ ... ),
+ ... )
+ >>> xr.cross(a, b, dim="cartesian")
+
+ array([[-3, 6, -3],
+ [ 3, -6, 3]])
+ Coordinates:
+ * time (time) int64 0 1
+ * cartesian (cartesian) >> ds_a = xr.Dataset(dict(x=("dim_0", [1]), y=("dim_0", [2]), z=("dim_0", [3])))
+ >>> ds_b = xr.Dataset(dict(x=("dim_0", [4]), y=("dim_0", [5]), z=("dim_0", [6])))
+ >>> c = xr.cross(
+ ... ds_a.to_array("cartesian"), ds_b.to_array("cartesian"), dim="cartesian"
+ ... )
+ >>> c.to_dataset(dim="cartesian")
+
+ Dimensions: (dim_0: 1)
+ Dimensions without coordinates: dim_0
+ Data variables:
+ x (dim_0) int64 -3
+ y (dim_0) int64 6
+ z (dim_0) int64 -3
+
+ See Also
+ --------
+ numpy.cross : Corresponding numpy function
+ """
+
+ if dim not in a.dims:
+ raise ValueError(f"Dimension {dim!r} not on a")
+ elif dim not in b.dims:
+ raise ValueError(f"Dimension {dim!r} not on b")
+
+ if not 1 <= a.sizes[dim] <= 3:
+ raise ValueError(
+ f"The size of {dim!r} on a must be 1, 2, or 3 to be "
+ f"compatible with a cross product but is {a.sizes[dim]}"
+ )
+ elif not 1 <= b.sizes[dim] <= 3:
+ raise ValueError(
+ f"The size of {dim!r} on b must be 1, 2, or 3 to be "
+ f"compatible with a cross product but is {b.sizes[dim]}"
+ )
+
+ all_dims = list(dict.fromkeys(a.dims + b.dims))
+
+ if a.sizes[dim] != b.sizes[dim]:
+ # Arrays have different sizes. Append zeros where the smaller
+ # array is missing a value, zeros will not affect np.cross:
+
+ if (
+ not isinstance(a, Variable) # Only used to make mypy happy.
+ and dim in getattr(a, "coords", {})
+ and not isinstance(b, Variable) # Only used to make mypy happy.
+ and dim in getattr(b, "coords", {})
+ ):
+ # If the arrays have coords we know which indexes to fill
+ # with zeros:
+ a, b = align(
+ a,
+ b,
+ fill_value=0,
+ join="outer",
+ exclude=set(all_dims) - {dim},
+ )
+ elif min(a.sizes[dim], b.sizes[dim]) == 2:
+ # If the array doesn't have coords we can only infer
+ # that it has composite values if the size is at least 2.
+ # Once padded, rechunk the padded array because apply_ufunc
+ # requires core dimensions not to be chunked:
+ if a.sizes[dim] < b.sizes[dim]:
+ a = a.pad({dim: (0, 1)}, constant_values=0)
+ # TODO: Should pad or apply_ufunc handle correct chunking?
+ a = a.chunk({dim: -1}) if is_duck_dask_array(a.data) else a
+ else:
+ b = b.pad({dim: (0, 1)}, constant_values=0)
+ # TODO: Should pad or apply_ufunc handle correct chunking?
+ b = b.chunk({dim: -1}) if is_duck_dask_array(b.data) else b
+ else:
+ raise ValueError(
+ f"{dim!r} on {'a' if a.sizes[dim] == 1 else 'b'} is incompatible:"
+ " dimensions without coordinates must have have a length of 2 or 3"
+ )
+
+ c = apply_ufunc(
+ np.cross,
+ a,
+ b,
+ input_core_dims=[[dim], [dim]],
+ output_core_dims=[[dim] if a.sizes[dim] == 3 else []],
+ dask="parallelized",
+ output_dtypes=[np.result_type(a, b)],
+ )
+ c = c.transpose(*all_dims, missing_dims="ignore")
+
+ return c
+
+
def dot(*arrays, dims=None, **kwargs):
"""Generalized dot product for xarray objects. Like np.einsum, but
provides a simpler interface based on array dimensions.
diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py
index 7a40d6e64f8..972d9500777 100644
--- a/xarray/core/pycompat.py
+++ b/xarray/core/pycompat.py
@@ -43,8 +43,19 @@ def __init__(self, mod):
self.available = duck_array_module is not None
+dsk = DuckArrayModule("dask")
+dask_version = dsk.version
+dask_array_type = dsk.type
+
+sp = DuckArrayModule("sparse")
+sparse_array_type = sp.type
+sparse_version = sp.version
+
+cupy_array_type = DuckArrayModule("cupy").type
+
+
def is_dask_collection(x):
- if DuckArrayModule("dask").available:
+ if dsk.available:
from dask.base import is_dask_collection
return is_dask_collection(x)
@@ -54,14 +65,3 @@ def is_dask_collection(x):
def is_duck_dask_array(x):
return is_duck_array(x) and is_dask_collection(x)
-
-
-dsk = DuckArrayModule("dask")
-dask_version = dsk.version
-dask_array_type = dsk.type
-
-sp = DuckArrayModule("sparse")
-sparse_array_type = sp.type
-sparse_version = sp.version
-
-cupy_array_type = DuckArrayModule("cupy").type
diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py
index 7288a368e47..c1aedd570bc 100644
--- a/xarray/plot/dataset_plot.py
+++ b/xarray/plot/dataset_plot.py
@@ -12,7 +12,6 @@
_process_cmap_cbar_kwargs,
get_axis,
label_from_attrs,
- plt,
)
# copied from seaborn
@@ -135,7 +134,8 @@ def _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping=None)
# copied from seaborn
def _parse_size(data, norm):
- mpl = plt.matplotlib
+
+ import matplotlib as mpl
if data is None:
return None
@@ -544,6 +544,8 @@ def quiver(ds, x, y, ax, u, v, **kwargs):
Wraps :py:func:`matplotlib:matplotlib.pyplot.quiver`.
"""
+ import matplotlib as mpl
+
if x is None or y is None or u is None or v is None:
raise ValueError("Must specify x, y, u, v for quiver plots.")
@@ -558,7 +560,7 @@ def quiver(ds, x, y, ax, u, v, **kwargs):
# TODO: Fix this by always returning a norm with vmin, vmax in cmap_params
if not cmap_params["norm"]:
- cmap_params["norm"] = plt.Normalize(
+ cmap_params["norm"] = mpl.colors.Normalize(
cmap_params.pop("vmin"), cmap_params.pop("vmax")
)
@@ -574,6 +576,8 @@ def streamplot(ds, x, y, ax, u, v, **kwargs):
Wraps :py:func:`matplotlib:matplotlib.pyplot.streamplot`.
"""
+ import matplotlib as mpl
+
if x is None or y is None or u is None or v is None:
raise ValueError("Must specify x, y, u, v for streamplot plots.")
@@ -609,7 +613,7 @@ def streamplot(ds, x, y, ax, u, v, **kwargs):
# TODO: Fix this by always returning a norm with vmin, vmax in cmap_params
if not cmap_params["norm"]:
- cmap_params["norm"] = plt.Normalize(
+ cmap_params["norm"] = mpl.colors.Normalize(
cmap_params.pop("vmin"), cmap_params.pop("vmax")
)
diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py
index cc6b1ffe777..f3daeeb7f3f 100644
--- a/xarray/plot/facetgrid.py
+++ b/xarray/plot/facetgrid.py
@@ -9,8 +9,8 @@
_get_nice_quiver_magnitude,
_infer_xy_labels,
_process_cmap_cbar_kwargs,
+ import_matplotlib_pyplot,
label_from_attrs,
- plt,
)
# Overrides axes.labelsize, xtick.major.size, ytick.major.size
@@ -116,6 +116,8 @@ def __init__(
"""
+ plt = import_matplotlib_pyplot()
+
# Handle corner case of nonunique coordinates
rep_col = col is not None and not data[col].to_index().is_unique
rep_row = row is not None and not data[row].to_index().is_unique
@@ -517,8 +519,10 @@ def set_titles(self, template="{coord} = {value}", maxchar=30, size=None, **kwar
self: FacetGrid object
"""
+ import matplotlib as mpl
+
if size is None:
- size = plt.rcParams["axes.labelsize"]
+ size = mpl.rcParams["axes.labelsize"]
nicetitle = functools.partial(_nicetitle, maxchar=maxchar, template=template)
@@ -615,6 +619,8 @@ def map(self, func, *args, **kwargs):
self : FacetGrid object
"""
+ plt = import_matplotlib_pyplot()
+
for ax, namedict in zip(self.axes.flat, self.name_dicts.flat):
if namedict is not None:
data = self.data.loc[namedict]
diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py
index 0d6bae29ee2..af860b22635 100644
--- a/xarray/plot/plot.py
+++ b/xarray/plot/plot.py
@@ -29,9 +29,9 @@
_resolve_intervals_2dplot,
_update_axes,
get_axis,
+ import_matplotlib_pyplot,
label_from_attrs,
legend_elements,
- plt,
)
# copied from seaborn
@@ -83,6 +83,8 @@ def _parse_size(data, norm, width):
If the data is categorical, normalize it to numbers.
"""
+ plt = import_matplotlib_pyplot()
+
if data is None:
return None
@@ -680,6 +682,8 @@ def scatter(
**kwargs : optional
Additional keyword arguments to matplotlib
"""
+ plt = import_matplotlib_pyplot()
+
# Handle facetgrids first
if row or col:
allargs = locals().copy()
@@ -1107,6 +1111,8 @@ def newplotfunc(
allargs["plotfunc"] = globals()[plotfunc.__name__]
return _easy_facetgrid(darray, kind="dataarray", **allargs)
+ plt = import_matplotlib_pyplot()
+
if (
plotfunc.__name__ == "surface"
and not kwargs.get("_is_facetgrid", False)
diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py
index a49302f7f87..9e7e78f4c44 100644
--- a/xarray/plot/utils.py
+++ b/xarray/plot/utils.py
@@ -47,12 +47,6 @@ def import_matplotlib_pyplot():
return plt
-try:
- plt = import_matplotlib_pyplot()
-except ImportError:
- plt = None
-
-
def _determine_extend(calc_data, vmin, vmax):
extend_min = calc_data.min() < vmin
extend_max = calc_data.max() > vmax
@@ -70,7 +64,7 @@ def _build_discrete_cmap(cmap, levels, extend, filled):
"""
Build a discrete colormap and normalization of the data.
"""
- mpl = plt.matplotlib
+ import matplotlib as mpl
if len(levels) == 1:
levels = [levels[0], levels[0]]
@@ -121,7 +115,8 @@ def _build_discrete_cmap(cmap, levels, extend, filled):
def _color_palette(cmap, n_colors):
- ListedColormap = plt.matplotlib.colors.ListedColormap
+ import matplotlib.pyplot as plt
+ from matplotlib.colors import ListedColormap
colors_i = np.linspace(0, 1.0, n_colors)
if isinstance(cmap, (list, tuple)):
@@ -182,7 +177,7 @@ def _determine_cmap_params(
cmap_params : dict
Use depends on the type of the plotting function
"""
- mpl = plt.matplotlib
+ import matplotlib as mpl
if isinstance(levels, Iterable):
levels = sorted(levels)
@@ -290,13 +285,13 @@ def _determine_cmap_params(
levels = np.asarray([(vmin + vmax) / 2])
else:
# N in MaxNLocator refers to bins, not ticks
- ticker = plt.MaxNLocator(levels - 1)
+ ticker = mpl.ticker.MaxNLocator(levels - 1)
levels = ticker.tick_values(vmin, vmax)
vmin, vmax = levels[0], levels[-1]
# GH3734
if vmin == vmax:
- vmin, vmax = plt.LinearLocator(2).tick_values(vmin, vmax)
+ vmin, vmax = mpl.ticker.LinearLocator(2).tick_values(vmin, vmax)
if extend is None:
extend = _determine_extend(calc_data, vmin, vmax)
@@ -426,7 +421,10 @@ def _assert_valid_xy(darray, xy, name):
def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs):
- if plt is None:
+ try:
+ import matplotlib as mpl
+ import matplotlib.pyplot as plt
+ except ImportError:
raise ImportError("matplotlib is required for plot.utils.get_axis")
if figsize is not None:
@@ -439,7 +437,7 @@ def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs):
if ax is not None:
raise ValueError("cannot provide both `size` and `ax` arguments")
if aspect is None:
- width, height = plt.rcParams["figure.figsize"]
+ width, height = mpl.rcParams["figure.figsize"]
aspect = width / height
figsize = (size * aspect, size)
_, ax = plt.subplots(figsize=figsize)
@@ -456,6 +454,9 @@ def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs):
def _maybe_gca(**kwargs):
+
+ import matplotlib.pyplot as plt
+
# can call gcf unconditionally: either it exists or would be created by plt.axes
f = plt.gcf()
@@ -913,7 +914,9 @@ def _process_cmap_cbar_kwargs(
def _get_nice_quiver_magnitude(u, v):
- ticker = plt.MaxNLocator(3)
+ import matplotlib as mpl
+
+ ticker = mpl.ticker.MaxNLocator(3)
mean = np.mean(np.hypot(u.to_numpy(), v.to_numpy()))
magnitude = ticker.tick_values(0, mean)[-2]
return magnitude
@@ -988,7 +991,7 @@ def legend_elements(
"""
import warnings
- mpl = plt.matplotlib
+ import matplotlib as mpl
mlines = mpl.lines
@@ -1125,6 +1128,7 @@ def _legend_add_subtitle(handles, labels, text, func):
def _adjust_legend_subtitles(legend):
"""Make invisible-handle "subtitles" entries look more like titles."""
+ plt = import_matplotlib_pyplot()
# Legend title not in rcParams until 3.0
font_size = plt.rcParams.get("legend.title_fontsize", None)
diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py
index af4fb77a7fb..c4183f2cdc9 100644
--- a/xarray/tests/test_backends.py
+++ b/xarray/tests/test_backends.py
@@ -5209,22 +5209,22 @@ def test_open_fsspec():
# single dataset
url = "memory://out2.zarr"
ds2 = open_dataset(url, engine="zarr")
- assert ds0 == ds2
+ xr.testing.assert_equal(ds0, ds2)
# single dataset with caching
url = "simplecache::memory://out2.zarr"
ds2 = open_dataset(url, engine="zarr")
- assert ds0 == ds2
+ xr.testing.assert_equal(ds0, ds2)
# multi dataset
url = "memory://out*.zarr"
ds2 = open_mfdataset(url, engine="zarr")
- assert xr.concat([ds, ds0], dim="time") == ds2
+ xr.testing.assert_equal(xr.concat([ds, ds0], dim="time"), ds2)
# multi dataset with caching
url = "simplecache::memory://out*.zarr"
ds2 = open_mfdataset(url, engine="zarr")
- assert xr.concat([ds, ds0], dim="time") == ds2
+ xr.testing.assert_equal(xr.concat([ds, ds0], dim="time"), ds2)
@requires_h5netcdf
diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py
index 77d3110104f..6af93607e6b 100644
--- a/xarray/tests/test_computation.py
+++ b/xarray/tests/test_computation.py
@@ -1952,3 +1952,110 @@ def test_polyval(use_dask, use_datetime) -> None:
da_pv = xr.polyval(da.x, coeffs)
xr.testing.assert_allclose(da, da_pv.T)
+
+
+@pytest.mark.parametrize("use_dask", [False, True])
+@pytest.mark.parametrize(
+ "a, b, ae, be, dim, axis",
+ [
+ [
+ xr.DataArray([1, 2, 3]),
+ xr.DataArray([4, 5, 6]),
+ [1, 2, 3],
+ [4, 5, 6],
+ "dim_0",
+ -1,
+ ],
+ [
+ xr.DataArray([1, 2]),
+ xr.DataArray([4, 5, 6]),
+ [1, 2],
+ [4, 5, 6],
+ "dim_0",
+ -1,
+ ],
+ [
+ xr.Variable(dims=["dim_0"], data=[1, 2, 3]),
+ xr.Variable(dims=["dim_0"], data=[4, 5, 6]),
+ [1, 2, 3],
+ [4, 5, 6],
+ "dim_0",
+ -1,
+ ],
+ [
+ xr.Variable(dims=["dim_0"], data=[1, 2]),
+ xr.Variable(dims=["dim_0"], data=[4, 5, 6]),
+ [1, 2],
+ [4, 5, 6],
+ "dim_0",
+ -1,
+ ],
+ [ # Test dim in the middle:
+ xr.DataArray(
+ np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)),
+ dims=["time", "cartesian", "var"],
+ coords=dict(
+ time=(["time"], np.arange(0, 5)),
+ cartesian=(["cartesian"], ["x", "y", "z"]),
+ var=(["var"], [1, 1.5, 2, 2.5]),
+ ),
+ ),
+ xr.DataArray(
+ np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)) + 1,
+ dims=["time", "cartesian", "var"],
+ coords=dict(
+ time=(["time"], np.arange(0, 5)),
+ cartesian=(["cartesian"], ["x", "y", "z"]),
+ var=(["var"], [1, 1.5, 2, 2.5]),
+ ),
+ ),
+ np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)),
+ np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)) + 1,
+ "cartesian",
+ 1,
+ ],
+ [ # Test 1 sized arrays with coords:
+ xr.DataArray(
+ np.array([1]),
+ dims=["cartesian"],
+ coords=dict(cartesian=(["cartesian"], ["z"])),
+ ),
+ xr.DataArray(
+ np.array([4, 5, 6]),
+ dims=["cartesian"],
+ coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])),
+ ),
+ [0, 0, 1],
+ [4, 5, 6],
+ "cartesian",
+ -1,
+ ],
+ [ # Test filling inbetween with coords:
+ xr.DataArray(
+ [1, 2],
+ dims=["cartesian"],
+ coords=dict(cartesian=(["cartesian"], ["x", "z"])),
+ ),
+ xr.DataArray(
+ [4, 5, 6],
+ dims=["cartesian"],
+ coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])),
+ ),
+ [1, 0, 2],
+ [4, 5, 6],
+ "cartesian",
+ -1,
+ ],
+ ],
+)
+def test_cross(a, b, ae, be, dim: str, axis: int, use_dask: bool) -> None:
+ expected = np.cross(ae, be, axis=axis)
+
+ if use_dask:
+ if not has_dask:
+ pytest.skip("test for dask.")
+ a = a.chunk()
+ b = b.chunk()
+
+ actual = xr.cross(a, b, dim=dim)
+ xr.testing.assert_duckarray_allclose(expected, actual)
diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py
index 16148c21b43..c8770601c30 100644
--- a/xarray/tests/test_dataset.py
+++ b/xarray/tests/test_dataset.py
@@ -6497,14 +6497,14 @@ def test_deepcopy_obj_array():
def test_clip(ds):
result = ds.clip(min=0.5)
- assert result.min(...) >= 0.5
+ assert all((result.min(...) >= 0.5).values())
result = ds.clip(max=0.5)
- assert result.max(...) <= 0.5
+ assert all((result.max(...) <= 0.5).values())
result = ds.clip(min=0.25, max=0.75)
- assert result.min(...) >= 0.25
- assert result.max(...) <= 0.75
+ assert all((result.min(...) >= 0.25).values())
+ assert all((result.max(...) <= 0.75).values())
result = ds.clip(min=ds.mean("y"), max=ds.mean("y"))
assert result.dims == ds.dims