diff --git a/asv_bench/benchmarks/import_xarray.py b/asv_bench/benchmarks/import_xarray.py deleted file mode 100644 index 94652e3b82a..00000000000 --- a/asv_bench/benchmarks/import_xarray.py +++ /dev/null @@ -1,9 +0,0 @@ -class ImportXarray: - def setup(self, *args, **kwargs): - def import_xr(): - import xarray # noqa: F401 - - self._import_xr = import_xr - - def time_import_xarray(self): - self._import_xr() diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 7288a368e47..c1aedd570bc 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -12,7 +12,6 @@ _process_cmap_cbar_kwargs, get_axis, label_from_attrs, - plt, ) # copied from seaborn @@ -135,7 +134,8 @@ def _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping=None) # copied from seaborn def _parse_size(data, norm): - mpl = plt.matplotlib + + import matplotlib as mpl if data is None: return None @@ -544,6 +544,8 @@ def quiver(ds, x, y, ax, u, v, **kwargs): Wraps :py:func:`matplotlib:matplotlib.pyplot.quiver`. """ + import matplotlib as mpl + if x is None or y is None or u is None or v is None: raise ValueError("Must specify x, y, u, v for quiver plots.") @@ -558,7 +560,7 @@ def quiver(ds, x, y, ax, u, v, **kwargs): # TODO: Fix this by always returning a norm with vmin, vmax in cmap_params if not cmap_params["norm"]: - cmap_params["norm"] = plt.Normalize( + cmap_params["norm"] = mpl.colors.Normalize( cmap_params.pop("vmin"), cmap_params.pop("vmax") ) @@ -574,6 +576,8 @@ def streamplot(ds, x, y, ax, u, v, **kwargs): Wraps :py:func:`matplotlib:matplotlib.pyplot.streamplot`. """ + import matplotlib as mpl + if x is None or y is None or u is None or v is None: raise ValueError("Must specify x, y, u, v for streamplot plots.") @@ -609,7 +613,7 @@ def streamplot(ds, x, y, ax, u, v, **kwargs): # TODO: Fix this by always returning a norm with vmin, vmax in cmap_params if not cmap_params["norm"]: - cmap_params["norm"] = plt.Normalize( + cmap_params["norm"] = mpl.colors.Normalize( cmap_params.pop("vmin"), cmap_params.pop("vmax") ) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index cc6b1ffe777..f3daeeb7f3f 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -9,8 +9,8 @@ _get_nice_quiver_magnitude, _infer_xy_labels, _process_cmap_cbar_kwargs, + import_matplotlib_pyplot, label_from_attrs, - plt, ) # Overrides axes.labelsize, xtick.major.size, ytick.major.size @@ -116,6 +116,8 @@ def __init__( """ + plt = import_matplotlib_pyplot() + # Handle corner case of nonunique coordinates rep_col = col is not None and not data[col].to_index().is_unique rep_row = row is not None and not data[row].to_index().is_unique @@ -517,8 +519,10 @@ def set_titles(self, template="{coord} = {value}", maxchar=30, size=None, **kwar self: FacetGrid object """ + import matplotlib as mpl + if size is None: - size = plt.rcParams["axes.labelsize"] + size = mpl.rcParams["axes.labelsize"] nicetitle = functools.partial(_nicetitle, maxchar=maxchar, template=template) @@ -615,6 +619,8 @@ def map(self, func, *args, **kwargs): self : FacetGrid object """ + plt = import_matplotlib_pyplot() + for ax, namedict in zip(self.axes.flat, self.name_dicts.flat): if namedict is not None: data = self.data.loc[namedict] diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 0d6bae29ee2..af860b22635 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -29,9 +29,9 @@ _resolve_intervals_2dplot, _update_axes, get_axis, + import_matplotlib_pyplot, label_from_attrs, legend_elements, - plt, ) # copied from seaborn @@ -83,6 +83,8 @@ def _parse_size(data, norm, width): If the data is categorical, normalize it to numbers. """ + plt = import_matplotlib_pyplot() + if data is None: return None @@ -680,6 +682,8 @@ def scatter( **kwargs : optional Additional keyword arguments to matplotlib """ + plt = import_matplotlib_pyplot() + # Handle facetgrids first if row or col: allargs = locals().copy() @@ -1107,6 +1111,8 @@ def newplotfunc( allargs["plotfunc"] = globals()[plotfunc.__name__] return _easy_facetgrid(darray, kind="dataarray", **allargs) + plt = import_matplotlib_pyplot() + if ( plotfunc.__name__ == "surface" and not kwargs.get("_is_facetgrid", False) diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index a49302f7f87..9e7e78f4c44 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -47,12 +47,6 @@ def import_matplotlib_pyplot(): return plt -try: - plt = import_matplotlib_pyplot() -except ImportError: - plt = None - - def _determine_extend(calc_data, vmin, vmax): extend_min = calc_data.min() < vmin extend_max = calc_data.max() > vmax @@ -70,7 +64,7 @@ def _build_discrete_cmap(cmap, levels, extend, filled): """ Build a discrete colormap and normalization of the data. """ - mpl = plt.matplotlib + import matplotlib as mpl if len(levels) == 1: levels = [levels[0], levels[0]] @@ -121,7 +115,8 @@ def _build_discrete_cmap(cmap, levels, extend, filled): def _color_palette(cmap, n_colors): - ListedColormap = plt.matplotlib.colors.ListedColormap + import matplotlib.pyplot as plt + from matplotlib.colors import ListedColormap colors_i = np.linspace(0, 1.0, n_colors) if isinstance(cmap, (list, tuple)): @@ -182,7 +177,7 @@ def _determine_cmap_params( cmap_params : dict Use depends on the type of the plotting function """ - mpl = plt.matplotlib + import matplotlib as mpl if isinstance(levels, Iterable): levels = sorted(levels) @@ -290,13 +285,13 @@ def _determine_cmap_params( levels = np.asarray([(vmin + vmax) / 2]) else: # N in MaxNLocator refers to bins, not ticks - ticker = plt.MaxNLocator(levels - 1) + ticker = mpl.ticker.MaxNLocator(levels - 1) levels = ticker.tick_values(vmin, vmax) vmin, vmax = levels[0], levels[-1] # GH3734 if vmin == vmax: - vmin, vmax = plt.LinearLocator(2).tick_values(vmin, vmax) + vmin, vmax = mpl.ticker.LinearLocator(2).tick_values(vmin, vmax) if extend is None: extend = _determine_extend(calc_data, vmin, vmax) @@ -426,7 +421,10 @@ def _assert_valid_xy(darray, xy, name): def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs): - if plt is None: + try: + import matplotlib as mpl + import matplotlib.pyplot as plt + except ImportError: raise ImportError("matplotlib is required for plot.utils.get_axis") if figsize is not None: @@ -439,7 +437,7 @@ def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs): if ax is not None: raise ValueError("cannot provide both `size` and `ax` arguments") if aspect is None: - width, height = plt.rcParams["figure.figsize"] + width, height = mpl.rcParams["figure.figsize"] aspect = width / height figsize = (size * aspect, size) _, ax = plt.subplots(figsize=figsize) @@ -456,6 +454,9 @@ def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs): def _maybe_gca(**kwargs): + + import matplotlib.pyplot as plt + # can call gcf unconditionally: either it exists or would be created by plt.axes f = plt.gcf() @@ -913,7 +914,9 @@ def _process_cmap_cbar_kwargs( def _get_nice_quiver_magnitude(u, v): - ticker = plt.MaxNLocator(3) + import matplotlib as mpl + + ticker = mpl.ticker.MaxNLocator(3) mean = np.mean(np.hypot(u.to_numpy(), v.to_numpy())) magnitude = ticker.tick_values(0, mean)[-2] return magnitude @@ -988,7 +991,7 @@ def legend_elements( """ import warnings - mpl = plt.matplotlib + import matplotlib as mpl mlines = mpl.lines @@ -1125,6 +1128,7 @@ def _legend_add_subtitle(handles, labels, text, func): def _adjust_legend_subtitles(legend): """Make invisible-handle "subtitles" entries look more like titles.""" + plt = import_matplotlib_pyplot() # Legend title not in rcParams until 3.0 font_size = plt.rcParams.get("legend.title_fontsize", None)