diff --git a/.github/ISSUE_TEMPLATE/bug-report.md b/.github/ISSUE_TEMPLATE/bug-report.md deleted file mode 100644 index 02bc5d0f7b0..00000000000 --- a/.github/ISSUE_TEMPLATE/bug-report.md +++ /dev/null @@ -1,39 +0,0 @@ ---- -name: Bug report -about: Create a report to help us improve -title: '' -labels: '' -assignees: '' - ---- - - - -**What happened**: - -**What you expected to happen**: - -**Minimal Complete Verifiable Example**: - -```python -# Put your MCVE code here -``` - -**Anything else we need to know?**: - -**Environment**: - -
Output of xr.show_versions() - - - - -
diff --git a/.github/ISSUE_TEMPLATE/bugreport.yml b/.github/ISSUE_TEMPLATE/bugreport.yml new file mode 100644 index 00000000000..255c7de07d9 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bugreport.yml @@ -0,0 +1,61 @@ +name: Bug Report +description: File a bug report to help us improve +title: '[Bug]: ' +labels: [bug, 'needs triage'] +assignees: [] +body: + - type: textarea + id: what-happened + attributes: + label: What happened? + description: | + Thanks for reporting a bug! Please describe what you were trying to get done. + Tell us what happened, what went wrong. + validations: + required: true + + - type: textarea + id: what-did-you-expect-to-happen + attributes: + label: What did you expect to happen? + description: | + Describe what you expected to happen. + validations: + required: false + + - type: textarea + id: sample-code + attributes: + label: Minimal Complete Verifiable Example + description: | + Minimal, self-contained copy-pastable example that generates the issue if possible. Please be concise with code posted. See guidelines below on how to provide a good bug report: + + - [Minimal Complete Verifiable Examples](https://stackoverflow.com/help/mcve) + - [Craft Minimal Bug Reports](http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports) + + Bug reports that follow these guidelines are easier to diagnose, and so are often handled much more quickly. + This will be automatically formatted into code, so no need for markdown backticks. + render: python + + - type: textarea + id: log-output + attributes: + label: Relevant log output + description: Please copy and paste any relevant output. This will be automatically formatted into code, so no need for markdown backticks. + render: python + + - type: textarea + id: extra + attributes: + label: Anything else we need to know? + description: | + Please describe any other information you want to share. + + - type: textarea + id: show-versions + attributes: + label: Environment + description: | + Paste the output of `xr.show_versions()` here + validations: + required: true diff --git a/.github/ISSUE_TEMPLATE/feature-request.md b/.github/ISSUE_TEMPLATE/feature-request.md deleted file mode 100644 index 7021fe490aa..00000000000 --- a/.github/ISSUE_TEMPLATE/feature-request.md +++ /dev/null @@ -1,22 +0,0 @@ ---- -name: Feature request -about: Suggest an idea for this project -title: '' -labels: '' -assignees: '' - ---- - - - -**Is your feature request related to a problem? Please describe.** -A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] - -**Describe the solution you'd like** -A clear and concise description of what you want to happen. - -**Describe alternatives you've considered** -A clear and concise description of any alternative solutions or features you've considered. - -**Additional context** -Add any other context about the feature request here. diff --git a/.github/ISSUE_TEMPLATE/newfeature.yml b/.github/ISSUE_TEMPLATE/newfeature.yml new file mode 100644 index 00000000000..ec94b0f4b89 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/newfeature.yml @@ -0,0 +1,37 @@ +name: Feature Request +description: Suggest an idea for xarray +title: '[FEATURE]: ' +labels: [enhancement] +assignees: [] +body: + - type: textarea + id: description + attributes: + label: Is your feature request related to a problem? + description: | + Please do a quick search of existing issues to make sure that this has not been asked before. + Please provide a clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + validations: + required: true + - type: textarea + id: solution + attributes: + label: Describe the solution you'd like + description: | + A clear and concise description of what you want to happen. + - type: textarea + id: alternatives + attributes: + label: Describe alternatives you've considered + description: | + A clear and concise description of any alternative solutions or features you've considered. + validations: + required: false + - type: textarea + id: additional-context + attributes: + label: Additional context + description: | + Add any other context about the feature request here. + validations: + required: false diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index c7ea19a53cb..37b8d357c87 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -2,6 +2,5 @@ - [ ] Closes #xxxx - [ ] Tests added -- [ ] Passes `pre-commit run --all-files` - [ ] User visible changes (including notable bug fixes) are documented in `whats-new.rst` - [ ] New functions/methods are listed in `api.rst` diff --git a/.github/workflows/ci-pre-commit.yml b/.github/workflows/ci-pre-commit.yml deleted file mode 100644 index 4bc5bddfdbc..00000000000 --- a/.github/workflows/ci-pre-commit.yml +++ /dev/null @@ -1,17 +0,0 @@ -name: linting - -on: - push: - branches: "*" - pull_request: - branches: "*" - -jobs: - linting: - name: "pre-commit hooks" - runs-on: ubuntu-latest - if: github.repository == 'pydata/xarray' - steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 - - uses: pre-commit/action@v2.0.3 diff --git a/asv_bench/benchmarks/import_xarray.py b/asv_bench/benchmarks/import_xarray.py deleted file mode 100644 index 94652e3b82a..00000000000 --- a/asv_bench/benchmarks/import_xarray.py +++ /dev/null @@ -1,9 +0,0 @@ -class ImportXarray: - def setup(self, *args, **kwargs): - def import_xr(): - import xarray # noqa: F401 - - self._import_xr = import_xr - - def time_import_xarray(self): - self._import_xr() diff --git a/doc/api.rst b/doc/api.rst index 9433ecfa56d..ef2694ea661 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -32,6 +32,7 @@ Top-level functions ones_like cov corr + cross dot polyval map_blocks diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 1c4b49097a3..f991a4e2a89 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,6 +21,8 @@ v0.21.0 (unreleased) New Features ~~~~~~~~~~~~ +- New top-level function :py:func:`cross`. (:issue:`3279`, :pull:`5365`). + By `Jimmy Westling `_. Breaking changes @@ -29,6 +31,8 @@ Breaking changes Deprecations ~~~~~~~~~~~~ +- Removed the lock kwarg from the zarr and pydap backends, completing the deprecation cycle started in :issue:`5256`. + By `Tom Nicholas `_. Bug fixes diff --git a/xarray/__init__.py b/xarray/__init__.py index 10f16e58081..81ab9f388a8 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -16,7 +16,16 @@ from .core.alignment import align, broadcast from .core.combine import combine_by_coords, combine_nested from .core.common import ALL_DIMS, full_like, ones_like, zeros_like -from .core.computation import apply_ufunc, corr, cov, dot, polyval, unify_chunks, where +from .core.computation import ( + apply_ufunc, + corr, + cov, + cross, + dot, + polyval, + unify_chunks, + where, +) from .core.concat import concat from .core.dataarray import DataArray from .core.dataset import Dataset @@ -60,6 +69,7 @@ "dot", "cov", "corr", + "cross", "full_like", "get_options", "infer_freq", diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py index bc479f9a71d..ffaf3793928 100644 --- a/xarray/backends/pydap_.py +++ b/xarray/backends/pydap_.py @@ -1,5 +1,3 @@ -import warnings - import numpy as np from ..core import indexing @@ -126,15 +124,7 @@ def open_dataset( use_cftime=None, decode_timedelta=None, session=None, - lock=None, ): - # TODO remove after v0.19 - if lock is not None: - warnings.warn( - "The kwarg 'lock' has been deprecated for this backend, and is now " - "ignored. In the future passing lock will raise an error.", - DeprecationWarning, - ) store = PydapDataStore.open( filename_or_obj, diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 3eb6a3caf72..8bd343869ff 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -810,15 +810,7 @@ def open_dataset( chunk_store=None, storage_options=None, stacklevel=3, - lock=None, ): - # TODO remove after v0.19 - if lock is not None: - warnings.warn( - "The kwarg 'lock' has been deprecated for this backend, and is now " - "ignored. In the future passing lock will raise an error.", - DeprecationWarning, - ) filename_or_obj = _normalize_path(filename_or_obj) store = ZarrStore.open_group( diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 191b777107a..9fe93c88734 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -36,6 +36,7 @@ if TYPE_CHECKING: from .coordinates import Coordinates + from .dataarray import DataArray from .dataset import Dataset from .types import T_Xarray @@ -1373,6 +1374,214 @@ def _cov_corr(da_a, da_b, dim=None, ddof=0, method=None): return corr +def cross( + a: Union[DataArray, Variable], b: Union[DataArray, Variable], *, dim: Hashable +) -> Union[DataArray, Variable]: + """ + Compute the cross product of two (arrays of) vectors. + + The cross product of `a` and `b` in :math:`R^3` is a vector + perpendicular to both `a` and `b`. The vectors in `a` and `b` are + defined by the values along the dimension `dim` and can have sizes + 1, 2 or 3. Where the size of either `a` or `b` is + 1 or 2, the remaining components of the input vector is assumed to + be zero and the cross product calculated accordingly. In cases where + both input vectors have dimension 2, the z-component of the cross + product is returned. + + Parameters + ---------- + a, b : DataArray or Variable + Components of the first and second vector(s). + dim : hashable + The dimension along which the cross product will be computed. + Must be available in both vectors. + + Examples + -------- + Vector cross-product with 3 dimensions: + + >>> a = xr.DataArray([1, 2, 3]) + >>> b = xr.DataArray([4, 5, 6]) + >>> xr.cross(a, b, dim="dim_0") + + array([-3, 6, -3]) + Dimensions without coordinates: dim_0 + + Vector cross-product with 2 dimensions, returns in the perpendicular + direction: + + >>> a = xr.DataArray([1, 2]) + >>> b = xr.DataArray([4, 5]) + >>> xr.cross(a, b, dim="dim_0") + + array(-3) + + Vector cross-product with 3 dimensions but zeros at the last axis + yields the same results as with 2 dimensions: + + >>> a = xr.DataArray([1, 2, 0]) + >>> b = xr.DataArray([4, 5, 0]) + >>> xr.cross(a, b, dim="dim_0") + + array([ 0, 0, -3]) + Dimensions without coordinates: dim_0 + + One vector with dimension 2: + + >>> a = xr.DataArray( + ... [1, 2], + ... dims=["cartesian"], + ... coords=dict(cartesian=(["cartesian"], ["x", "y"])), + ... ) + >>> b = xr.DataArray( + ... [4, 5, 6], + ... dims=["cartesian"], + ... coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])), + ... ) + >>> xr.cross(a, b, dim="cartesian") + + array([12, -6, -3]) + Coordinates: + * cartesian (cartesian) >> a = xr.DataArray( + ... [1, 2], + ... dims=["cartesian"], + ... coords=dict(cartesian=(["cartesian"], ["x", "z"])), + ... ) + >>> b = xr.DataArray( + ... [4, 5, 6], + ... dims=["cartesian"], + ... coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])), + ... ) + >>> xr.cross(a, b, dim="cartesian") + + array([-10, 2, 5]) + Coordinates: + * cartesian (cartesian) >> a = xr.DataArray( + ... [[1, 2, 3], [4, 5, 6]], + ... dims=("time", "cartesian"), + ... coords=dict( + ... time=(["time"], [0, 1]), + ... cartesian=(["cartesian"], ["x", "y", "z"]), + ... ), + ... ) + >>> b = xr.DataArray( + ... [[4, 5, 6], [1, 2, 3]], + ... dims=("time", "cartesian"), + ... coords=dict( + ... time=(["time"], [0, 1]), + ... cartesian=(["cartesian"], ["x", "y", "z"]), + ... ), + ... ) + >>> xr.cross(a, b, dim="cartesian") + + array([[-3, 6, -3], + [ 3, -6, 3]]) + Coordinates: + * time (time) int64 0 1 + * cartesian (cartesian) >> ds_a = xr.Dataset(dict(x=("dim_0", [1]), y=("dim_0", [2]), z=("dim_0", [3]))) + >>> ds_b = xr.Dataset(dict(x=("dim_0", [4]), y=("dim_0", [5]), z=("dim_0", [6]))) + >>> c = xr.cross( + ... ds_a.to_array("cartesian"), ds_b.to_array("cartesian"), dim="cartesian" + ... ) + >>> c.to_dataset(dim="cartesian") + + Dimensions: (dim_0: 1) + Dimensions without coordinates: dim_0 + Data variables: + x (dim_0) int64 -3 + y (dim_0) int64 6 + z (dim_0) int64 -3 + + See Also + -------- + numpy.cross : Corresponding numpy function + """ + + if dim not in a.dims: + raise ValueError(f"Dimension {dim!r} not on a") + elif dim not in b.dims: + raise ValueError(f"Dimension {dim!r} not on b") + + if not 1 <= a.sizes[dim] <= 3: + raise ValueError( + f"The size of {dim!r} on a must be 1, 2, or 3 to be " + f"compatible with a cross product but is {a.sizes[dim]}" + ) + elif not 1 <= b.sizes[dim] <= 3: + raise ValueError( + f"The size of {dim!r} on b must be 1, 2, or 3 to be " + f"compatible with a cross product but is {b.sizes[dim]}" + ) + + all_dims = list(dict.fromkeys(a.dims + b.dims)) + + if a.sizes[dim] != b.sizes[dim]: + # Arrays have different sizes. Append zeros where the smaller + # array is missing a value, zeros will not affect np.cross: + + if ( + not isinstance(a, Variable) # Only used to make mypy happy. + and dim in getattr(a, "coords", {}) + and not isinstance(b, Variable) # Only used to make mypy happy. + and dim in getattr(b, "coords", {}) + ): + # If the arrays have coords we know which indexes to fill + # with zeros: + a, b = align( + a, + b, + fill_value=0, + join="outer", + exclude=set(all_dims) - {dim}, + ) + elif min(a.sizes[dim], b.sizes[dim]) == 2: + # If the array doesn't have coords we can only infer + # that it has composite values if the size is at least 2. + # Once padded, rechunk the padded array because apply_ufunc + # requires core dimensions not to be chunked: + if a.sizes[dim] < b.sizes[dim]: + a = a.pad({dim: (0, 1)}, constant_values=0) + # TODO: Should pad or apply_ufunc handle correct chunking? + a = a.chunk({dim: -1}) if is_duck_dask_array(a.data) else a + else: + b = b.pad({dim: (0, 1)}, constant_values=0) + # TODO: Should pad or apply_ufunc handle correct chunking? + b = b.chunk({dim: -1}) if is_duck_dask_array(b.data) else b + else: + raise ValueError( + f"{dim!r} on {'a' if a.sizes[dim] == 1 else 'b'} is incompatible:" + " dimensions without coordinates must have have a length of 2 or 3" + ) + + c = apply_ufunc( + np.cross, + a, + b, + input_core_dims=[[dim], [dim]], + output_core_dims=[[dim] if a.sizes[dim] == 3 else []], + dask="parallelized", + output_dtypes=[np.result_type(a, b)], + ) + c = c.transpose(*all_dims, missing_dims="ignore") + + return c + + def dot(*arrays, dims=None, **kwargs): """Generalized dot product for xarray objects. Like np.einsum, but provides a simpler interface based on array dimensions. diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index 7a40d6e64f8..972d9500777 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -43,8 +43,19 @@ def __init__(self, mod): self.available = duck_array_module is not None +dsk = DuckArrayModule("dask") +dask_version = dsk.version +dask_array_type = dsk.type + +sp = DuckArrayModule("sparse") +sparse_array_type = sp.type +sparse_version = sp.version + +cupy_array_type = DuckArrayModule("cupy").type + + def is_dask_collection(x): - if DuckArrayModule("dask").available: + if dsk.available: from dask.base import is_dask_collection return is_dask_collection(x) @@ -54,14 +65,3 @@ def is_dask_collection(x): def is_duck_dask_array(x): return is_duck_array(x) and is_dask_collection(x) - - -dsk = DuckArrayModule("dask") -dask_version = dsk.version -dask_array_type = dsk.type - -sp = DuckArrayModule("sparse") -sparse_array_type = sp.type -sparse_version = sp.version - -cupy_array_type = DuckArrayModule("cupy").type diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 7288a368e47..c1aedd570bc 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -12,7 +12,6 @@ _process_cmap_cbar_kwargs, get_axis, label_from_attrs, - plt, ) # copied from seaborn @@ -135,7 +134,8 @@ def _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping=None) # copied from seaborn def _parse_size(data, norm): - mpl = plt.matplotlib + + import matplotlib as mpl if data is None: return None @@ -544,6 +544,8 @@ def quiver(ds, x, y, ax, u, v, **kwargs): Wraps :py:func:`matplotlib:matplotlib.pyplot.quiver`. """ + import matplotlib as mpl + if x is None or y is None or u is None or v is None: raise ValueError("Must specify x, y, u, v for quiver plots.") @@ -558,7 +560,7 @@ def quiver(ds, x, y, ax, u, v, **kwargs): # TODO: Fix this by always returning a norm with vmin, vmax in cmap_params if not cmap_params["norm"]: - cmap_params["norm"] = plt.Normalize( + cmap_params["norm"] = mpl.colors.Normalize( cmap_params.pop("vmin"), cmap_params.pop("vmax") ) @@ -574,6 +576,8 @@ def streamplot(ds, x, y, ax, u, v, **kwargs): Wraps :py:func:`matplotlib:matplotlib.pyplot.streamplot`. """ + import matplotlib as mpl + if x is None or y is None or u is None or v is None: raise ValueError("Must specify x, y, u, v for streamplot plots.") @@ -609,7 +613,7 @@ def streamplot(ds, x, y, ax, u, v, **kwargs): # TODO: Fix this by always returning a norm with vmin, vmax in cmap_params if not cmap_params["norm"]: - cmap_params["norm"] = plt.Normalize( + cmap_params["norm"] = mpl.colors.Normalize( cmap_params.pop("vmin"), cmap_params.pop("vmax") ) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index cc6b1ffe777..f3daeeb7f3f 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -9,8 +9,8 @@ _get_nice_quiver_magnitude, _infer_xy_labels, _process_cmap_cbar_kwargs, + import_matplotlib_pyplot, label_from_attrs, - plt, ) # Overrides axes.labelsize, xtick.major.size, ytick.major.size @@ -116,6 +116,8 @@ def __init__( """ + plt = import_matplotlib_pyplot() + # Handle corner case of nonunique coordinates rep_col = col is not None and not data[col].to_index().is_unique rep_row = row is not None and not data[row].to_index().is_unique @@ -517,8 +519,10 @@ def set_titles(self, template="{coord} = {value}", maxchar=30, size=None, **kwar self: FacetGrid object """ + import matplotlib as mpl + if size is None: - size = plt.rcParams["axes.labelsize"] + size = mpl.rcParams["axes.labelsize"] nicetitle = functools.partial(_nicetitle, maxchar=maxchar, template=template) @@ -615,6 +619,8 @@ def map(self, func, *args, **kwargs): self : FacetGrid object """ + plt = import_matplotlib_pyplot() + for ax, namedict in zip(self.axes.flat, self.name_dicts.flat): if namedict is not None: data = self.data.loc[namedict] diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 0d6bae29ee2..af860b22635 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -29,9 +29,9 @@ _resolve_intervals_2dplot, _update_axes, get_axis, + import_matplotlib_pyplot, label_from_attrs, legend_elements, - plt, ) # copied from seaborn @@ -83,6 +83,8 @@ def _parse_size(data, norm, width): If the data is categorical, normalize it to numbers. """ + plt = import_matplotlib_pyplot() + if data is None: return None @@ -680,6 +682,8 @@ def scatter( **kwargs : optional Additional keyword arguments to matplotlib """ + plt = import_matplotlib_pyplot() + # Handle facetgrids first if row or col: allargs = locals().copy() @@ -1107,6 +1111,8 @@ def newplotfunc( allargs["plotfunc"] = globals()[plotfunc.__name__] return _easy_facetgrid(darray, kind="dataarray", **allargs) + plt = import_matplotlib_pyplot() + if ( plotfunc.__name__ == "surface" and not kwargs.get("_is_facetgrid", False) diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index a49302f7f87..9e7e78f4c44 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -47,12 +47,6 @@ def import_matplotlib_pyplot(): return plt -try: - plt = import_matplotlib_pyplot() -except ImportError: - plt = None - - def _determine_extend(calc_data, vmin, vmax): extend_min = calc_data.min() < vmin extend_max = calc_data.max() > vmax @@ -70,7 +64,7 @@ def _build_discrete_cmap(cmap, levels, extend, filled): """ Build a discrete colormap and normalization of the data. """ - mpl = plt.matplotlib + import matplotlib as mpl if len(levels) == 1: levels = [levels[0], levels[0]] @@ -121,7 +115,8 @@ def _build_discrete_cmap(cmap, levels, extend, filled): def _color_palette(cmap, n_colors): - ListedColormap = plt.matplotlib.colors.ListedColormap + import matplotlib.pyplot as plt + from matplotlib.colors import ListedColormap colors_i = np.linspace(0, 1.0, n_colors) if isinstance(cmap, (list, tuple)): @@ -182,7 +177,7 @@ def _determine_cmap_params( cmap_params : dict Use depends on the type of the plotting function """ - mpl = plt.matplotlib + import matplotlib as mpl if isinstance(levels, Iterable): levels = sorted(levels) @@ -290,13 +285,13 @@ def _determine_cmap_params( levels = np.asarray([(vmin + vmax) / 2]) else: # N in MaxNLocator refers to bins, not ticks - ticker = plt.MaxNLocator(levels - 1) + ticker = mpl.ticker.MaxNLocator(levels - 1) levels = ticker.tick_values(vmin, vmax) vmin, vmax = levels[0], levels[-1] # GH3734 if vmin == vmax: - vmin, vmax = plt.LinearLocator(2).tick_values(vmin, vmax) + vmin, vmax = mpl.ticker.LinearLocator(2).tick_values(vmin, vmax) if extend is None: extend = _determine_extend(calc_data, vmin, vmax) @@ -426,7 +421,10 @@ def _assert_valid_xy(darray, xy, name): def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs): - if plt is None: + try: + import matplotlib as mpl + import matplotlib.pyplot as plt + except ImportError: raise ImportError("matplotlib is required for plot.utils.get_axis") if figsize is not None: @@ -439,7 +437,7 @@ def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs): if ax is not None: raise ValueError("cannot provide both `size` and `ax` arguments") if aspect is None: - width, height = plt.rcParams["figure.figsize"] + width, height = mpl.rcParams["figure.figsize"] aspect = width / height figsize = (size * aspect, size) _, ax = plt.subplots(figsize=figsize) @@ -456,6 +454,9 @@ def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs): def _maybe_gca(**kwargs): + + import matplotlib.pyplot as plt + # can call gcf unconditionally: either it exists or would be created by plt.axes f = plt.gcf() @@ -913,7 +914,9 @@ def _process_cmap_cbar_kwargs( def _get_nice_quiver_magnitude(u, v): - ticker = plt.MaxNLocator(3) + import matplotlib as mpl + + ticker = mpl.ticker.MaxNLocator(3) mean = np.mean(np.hypot(u.to_numpy(), v.to_numpy())) magnitude = ticker.tick_values(0, mean)[-2] return magnitude @@ -988,7 +991,7 @@ def legend_elements( """ import warnings - mpl = plt.matplotlib + import matplotlib as mpl mlines = mpl.lines @@ -1125,6 +1128,7 @@ def _legend_add_subtitle(handles, labels, text, func): def _adjust_legend_subtitles(legend): """Make invisible-handle "subtitles" entries look more like titles.""" + plt = import_matplotlib_pyplot() # Legend title not in rcParams until 3.0 font_size = plt.rcParams.get("legend.title_fontsize", None) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index af4fb77a7fb..c4183f2cdc9 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -5209,22 +5209,22 @@ def test_open_fsspec(): # single dataset url = "memory://out2.zarr" ds2 = open_dataset(url, engine="zarr") - assert ds0 == ds2 + xr.testing.assert_equal(ds0, ds2) # single dataset with caching url = "simplecache::memory://out2.zarr" ds2 = open_dataset(url, engine="zarr") - assert ds0 == ds2 + xr.testing.assert_equal(ds0, ds2) # multi dataset url = "memory://out*.zarr" ds2 = open_mfdataset(url, engine="zarr") - assert xr.concat([ds, ds0], dim="time") == ds2 + xr.testing.assert_equal(xr.concat([ds, ds0], dim="time"), ds2) # multi dataset with caching url = "simplecache::memory://out*.zarr" ds2 = open_mfdataset(url, engine="zarr") - assert xr.concat([ds, ds0], dim="time") == ds2 + xr.testing.assert_equal(xr.concat([ds, ds0], dim="time"), ds2) @requires_h5netcdf diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 77d3110104f..6af93607e6b 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1952,3 +1952,110 @@ def test_polyval(use_dask, use_datetime) -> None: da_pv = xr.polyval(da.x, coeffs) xr.testing.assert_allclose(da, da_pv.T) + + +@pytest.mark.parametrize("use_dask", [False, True]) +@pytest.mark.parametrize( + "a, b, ae, be, dim, axis", + [ + [ + xr.DataArray([1, 2, 3]), + xr.DataArray([4, 5, 6]), + [1, 2, 3], + [4, 5, 6], + "dim_0", + -1, + ], + [ + xr.DataArray([1, 2]), + xr.DataArray([4, 5, 6]), + [1, 2], + [4, 5, 6], + "dim_0", + -1, + ], + [ + xr.Variable(dims=["dim_0"], data=[1, 2, 3]), + xr.Variable(dims=["dim_0"], data=[4, 5, 6]), + [1, 2, 3], + [4, 5, 6], + "dim_0", + -1, + ], + [ + xr.Variable(dims=["dim_0"], data=[1, 2]), + xr.Variable(dims=["dim_0"], data=[4, 5, 6]), + [1, 2], + [4, 5, 6], + "dim_0", + -1, + ], + [ # Test dim in the middle: + xr.DataArray( + np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)), + dims=["time", "cartesian", "var"], + coords=dict( + time=(["time"], np.arange(0, 5)), + cartesian=(["cartesian"], ["x", "y", "z"]), + var=(["var"], [1, 1.5, 2, 2.5]), + ), + ), + xr.DataArray( + np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)) + 1, + dims=["time", "cartesian", "var"], + coords=dict( + time=(["time"], np.arange(0, 5)), + cartesian=(["cartesian"], ["x", "y", "z"]), + var=(["var"], [1, 1.5, 2, 2.5]), + ), + ), + np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)), + np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)) + 1, + "cartesian", + 1, + ], + [ # Test 1 sized arrays with coords: + xr.DataArray( + np.array([1]), + dims=["cartesian"], + coords=dict(cartesian=(["cartesian"], ["z"])), + ), + xr.DataArray( + np.array([4, 5, 6]), + dims=["cartesian"], + coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])), + ), + [0, 0, 1], + [4, 5, 6], + "cartesian", + -1, + ], + [ # Test filling inbetween with coords: + xr.DataArray( + [1, 2], + dims=["cartesian"], + coords=dict(cartesian=(["cartesian"], ["x", "z"])), + ), + xr.DataArray( + [4, 5, 6], + dims=["cartesian"], + coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])), + ), + [1, 0, 2], + [4, 5, 6], + "cartesian", + -1, + ], + ], +) +def test_cross(a, b, ae, be, dim: str, axis: int, use_dask: bool) -> None: + expected = np.cross(ae, be, axis=axis) + + if use_dask: + if not has_dask: + pytest.skip("test for dask.") + a = a.chunk() + b = b.chunk() + + actual = xr.cross(a, b, dim=dim) + xr.testing.assert_duckarray_allclose(expected, actual) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 16148c21b43..c8770601c30 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -6497,14 +6497,14 @@ def test_deepcopy_obj_array(): def test_clip(ds): result = ds.clip(min=0.5) - assert result.min(...) >= 0.5 + assert all((result.min(...) >= 0.5).values()) result = ds.clip(max=0.5) - assert result.max(...) <= 0.5 + assert all((result.max(...) <= 0.5).values()) result = ds.clip(min=0.25, max=0.75) - assert result.min(...) >= 0.25 - assert result.max(...) <= 0.75 + assert all((result.min(...) >= 0.25).values()) + assert all((result.max(...) <= 0.75).values()) result = ds.clip(min=ds.mean("y"), max=ds.mean("y")) assert result.dims == ds.dims