Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Single matplotlib import" #6064

Merged
merged 2 commits into from
Dec 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions asv_bench/benchmarks/import_xarray.py

This file was deleted.

12 changes: 8 additions & 4 deletions xarray/plot/dataset_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
_process_cmap_cbar_kwargs,
get_axis,
label_from_attrs,
plt,
)

# copied from seaborn
Expand Down Expand Up @@ -135,7 +134,8 @@ def _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping=None)

# copied from seaborn
def _parse_size(data, norm):
mpl = plt.matplotlib

import matplotlib as mpl

if data is None:
return None
Expand Down Expand Up @@ -544,6 +544,8 @@ def quiver(ds, x, y, ax, u, v, **kwargs):

Wraps :py:func:`matplotlib:matplotlib.pyplot.quiver`.
"""
import matplotlib as mpl

if x is None or y is None or u is None or v is None:
raise ValueError("Must specify x, y, u, v for quiver plots.")

Expand All @@ -558,7 +560,7 @@ def quiver(ds, x, y, ax, u, v, **kwargs):

# TODO: Fix this by always returning a norm with vmin, vmax in cmap_params
if not cmap_params["norm"]:
cmap_params["norm"] = plt.Normalize(
cmap_params["norm"] = mpl.colors.Normalize(
cmap_params.pop("vmin"), cmap_params.pop("vmax")
)

Expand All @@ -574,6 +576,8 @@ def streamplot(ds, x, y, ax, u, v, **kwargs):

Wraps :py:func:`matplotlib:matplotlib.pyplot.streamplot`.
"""
import matplotlib as mpl

if x is None or y is None or u is None or v is None:
raise ValueError("Must specify x, y, u, v for streamplot plots.")

Expand Down Expand Up @@ -609,7 +613,7 @@ def streamplot(ds, x, y, ax, u, v, **kwargs):

# TODO: Fix this by always returning a norm with vmin, vmax in cmap_params
if not cmap_params["norm"]:
cmap_params["norm"] = plt.Normalize(
cmap_params["norm"] = mpl.colors.Normalize(
cmap_params.pop("vmin"), cmap_params.pop("vmax")
)

Expand Down
10 changes: 8 additions & 2 deletions xarray/plot/facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
_get_nice_quiver_magnitude,
_infer_xy_labels,
_process_cmap_cbar_kwargs,
import_matplotlib_pyplot,
label_from_attrs,
plt,
)

# Overrides axes.labelsize, xtick.major.size, ytick.major.size
Expand Down Expand Up @@ -116,6 +116,8 @@ def __init__(

"""

plt = import_matplotlib_pyplot()

# Handle corner case of nonunique coordinates
rep_col = col is not None and not data[col].to_index().is_unique
rep_row = row is not None and not data[row].to_index().is_unique
Expand Down Expand Up @@ -517,8 +519,10 @@ def set_titles(self, template="{coord} = {value}", maxchar=30, size=None, **kwar
self: FacetGrid object

"""
import matplotlib as mpl

if size is None:
size = plt.rcParams["axes.labelsize"]
size = mpl.rcParams["axes.labelsize"]

nicetitle = functools.partial(_nicetitle, maxchar=maxchar, template=template)

Expand Down Expand Up @@ -615,6 +619,8 @@ def map(self, func, *args, **kwargs):
self : FacetGrid object

"""
plt = import_matplotlib_pyplot()

for ax, namedict in zip(self.axes.flat, self.name_dicts.flat):
if namedict is not None:
data = self.data.loc[namedict]
Expand Down
8 changes: 7 additions & 1 deletion xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
_resolve_intervals_2dplot,
_update_axes,
get_axis,
import_matplotlib_pyplot,
label_from_attrs,
legend_elements,
plt,
)

# copied from seaborn
Expand Down Expand Up @@ -83,6 +83,8 @@ def _parse_size(data, norm, width):

If the data is categorical, normalize it to numbers.
"""
plt = import_matplotlib_pyplot()

if data is None:
return None

Expand Down Expand Up @@ -680,6 +682,8 @@ def scatter(
**kwargs : optional
Additional keyword arguments to matplotlib
"""
plt = import_matplotlib_pyplot()

# Handle facetgrids first
if row or col:
allargs = locals().copy()
Expand Down Expand Up @@ -1107,6 +1111,8 @@ def newplotfunc(
allargs["plotfunc"] = globals()[plotfunc.__name__]
return _easy_facetgrid(darray, kind="dataarray", **allargs)

plt = import_matplotlib_pyplot()

if (
plotfunc.__name__ == "surface"
and not kwargs.get("_is_facetgrid", False)
Expand Down
34 changes: 19 additions & 15 deletions xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,6 @@ def import_matplotlib_pyplot():
return plt


try:
plt = import_matplotlib_pyplot()
except ImportError:
plt = None


def _determine_extend(calc_data, vmin, vmax):
extend_min = calc_data.min() < vmin
extend_max = calc_data.max() > vmax
Expand All @@ -70,7 +64,7 @@ def _build_discrete_cmap(cmap, levels, extend, filled):
"""
Build a discrete colormap and normalization of the data.
"""
mpl = plt.matplotlib
import matplotlib as mpl

if len(levels) == 1:
levels = [levels[0], levels[0]]
Expand Down Expand Up @@ -121,7 +115,8 @@ def _build_discrete_cmap(cmap, levels, extend, filled):


def _color_palette(cmap, n_colors):
ListedColormap = plt.matplotlib.colors.ListedColormap
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

colors_i = np.linspace(0, 1.0, n_colors)
if isinstance(cmap, (list, tuple)):
Expand Down Expand Up @@ -182,7 +177,7 @@ def _determine_cmap_params(
cmap_params : dict
Use depends on the type of the plotting function
"""
mpl = plt.matplotlib
import matplotlib as mpl

if isinstance(levels, Iterable):
levels = sorted(levels)
Expand Down Expand Up @@ -290,13 +285,13 @@ def _determine_cmap_params(
levels = np.asarray([(vmin + vmax) / 2])
else:
# N in MaxNLocator refers to bins, not ticks
ticker = plt.MaxNLocator(levels - 1)
ticker = mpl.ticker.MaxNLocator(levels - 1)
levels = ticker.tick_values(vmin, vmax)
vmin, vmax = levels[0], levels[-1]

# GH3734
if vmin == vmax:
vmin, vmax = plt.LinearLocator(2).tick_values(vmin, vmax)
vmin, vmax = mpl.ticker.LinearLocator(2).tick_values(vmin, vmax)

if extend is None:
extend = _determine_extend(calc_data, vmin, vmax)
Expand Down Expand Up @@ -426,7 +421,10 @@ def _assert_valid_xy(darray, xy, name):


def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs):
if plt is None:
try:
import matplotlib as mpl
import matplotlib.pyplot as plt
except ImportError:
raise ImportError("matplotlib is required for plot.utils.get_axis")

if figsize is not None:
Expand All @@ -439,7 +437,7 @@ def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs):
if ax is not None:
raise ValueError("cannot provide both `size` and `ax` arguments")
if aspect is None:
width, height = plt.rcParams["figure.figsize"]
width, height = mpl.rcParams["figure.figsize"]
aspect = width / height
figsize = (size * aspect, size)
_, ax = plt.subplots(figsize=figsize)
Expand All @@ -456,6 +454,9 @@ def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs):


def _maybe_gca(**kwargs):

import matplotlib.pyplot as plt

# can call gcf unconditionally: either it exists or would be created by plt.axes
f = plt.gcf()

Expand Down Expand Up @@ -913,7 +914,9 @@ def _process_cmap_cbar_kwargs(


def _get_nice_quiver_magnitude(u, v):
ticker = plt.MaxNLocator(3)
import matplotlib as mpl

ticker = mpl.ticker.MaxNLocator(3)
mean = np.mean(np.hypot(u.to_numpy(), v.to_numpy()))
magnitude = ticker.tick_values(0, mean)[-2]
return magnitude
Expand Down Expand Up @@ -988,7 +991,7 @@ def legend_elements(
"""
import warnings

mpl = plt.matplotlib
import matplotlib as mpl

mlines = mpl.lines

Expand Down Expand Up @@ -1125,6 +1128,7 @@ def _legend_add_subtitle(handles, labels, text, func):

def _adjust_legend_subtitles(legend):
"""Make invisible-handle "subtitles" entries look more like titles."""
plt = import_matplotlib_pyplot()

# Legend title not in rcParams until 3.0
font_size = plt.rcParams.get("legend.title_fontsize", None)
Expand Down