diff --git a/src/metpy/plots/declarative.py b/src/metpy/plots/declarative.py index a141a4e7747..be0ede8d1f5 100644 --- a/src/metpy/plots/declarative.py +++ b/src/metpy/plots/declarative.py @@ -56,13 +56,14 @@ def lookup_map_feature(feature_name): return feat.with_scale(scaler) -def plot_kwargs(data): +def plot_kwargs(data, args): """Set the keyword arguments for MapPanel plotting.""" if hasattr(data.metpy, 'cartopy_crs'): # Conditionally add cartopy transform if we are on a map. kwargs = {'transform': data.metpy.cartopy_crs} else: kwargs = {} + kwargs.update(args) return kwargs @@ -102,8 +103,19 @@ def __dir__(self): lambda name: not (name in dir(HasTraits) or name.startswith('_')), dir(type(self)) ) - - mpl_args = Union([Dict(), Int(), Float(), Unicode()]) + + mpl_args = Dict(allow_none=True) + mpl_args.__doc__ = """Supply a dictionary of valid Matplotlib keyword arguments to modify + how the plot variable is drawn. + + Using this attribute you must choose the appropriate keyword arguments (kwargs) based on + what you are plotting (e.g., contours, color-filled contours, image plot, etc.). This is + available for all plot types (ContourPlot, FilledContourPlot, RasterPlot, ImagePlot, + BarbPlot, ArrowPlot, PlotGeometry, and PlotObs). For PlotObs, the kwargs re those to + specify the StationPlot object. NOTE: Setting the mpl_args trait will override + any other trait that corresponds to a specific kwarg for the particular plot type + (e.g., linecolor, linewidth). + """ class Panel(MetPyHasTraits): @@ -938,20 +950,17 @@ def _build(self): """Build the plot by calling any plotting methods as necessary.""" x_like, y_like, imdata = self.plotdata - kwargs = plot_kwargs(imdata) + kwargs = plot_kwargs(imdata, self.mpl_args) # If we're on a map, we use min/max for y and manually figure out origin to try to # avoid upside down images created by images where y[0] > y[-1], as well as # specifying the transform kwargs['extent'] = (x_like[0], x_like[-1], y_like.min(), y_like.max()) kwargs['origin'] = 'upper' if y_like[0] > y_like[-1] else 'lower' + kwargs.setdefault('cmap', self._cmap_obj) + kwargs.setdefault('norm', self._norm_obj) - self.handle = self.parent.ax.imshow( - imdata, - cmap=self._cmap_obj, - norm=self._norm_obj, - **kwargs - ) + self.handle = self.parent.ax.imshow(imdata, **kwargs) @exporter.export @@ -995,11 +1004,12 @@ def _build(self): """Build the plot by calling any plotting methods as necessary.""" x_like, y_like, imdata = self.plotdata - kwargs = plot_kwargs(imdata) + kwargs = plot_kwargs(imdata, self.mpl_args) + kwargs.setdefault('linewidths', self.linewidth) + kwargs.setdefault('colors', self.linecolor) + kwargs.setdefault('linestyles', self.linestyle) - self.handle = self.parent.ax.contour(x_like, y_like, imdata, self.contours, - colors=self.linecolor, linewidths=self.linewidth, - linestyles=self.linestyle, **kwargs) + self.handle = self.parent.ax.contour(x_like, y_like, imdata, self.contours, **kwargs) if self.clabels: self.handle.clabel(inline=1, fmt='%.0f', inline_spacing=8, use_clabeltext=True, fontsize=self.label_fontsize) @@ -1020,11 +1030,11 @@ def _build(self): """Build the plot by calling any plotting methods as necessary.""" x_like, y_like, imdata = self.plotdata - kwargs = plot_kwargs(imdata) + kwargs = plot_kwargs(imdata, self.mpl_args) + kwargs.setdefault('cmap', self._cmap_obj) + kwargs.setdefault('norm', self._norm_obj) - self.handle = self.parent.ax.contourf(x_like, y_like, imdata, self.contours, - cmap=self._cmap_obj, norm=self._norm_obj, - **kwargs) + self.handle = self.parent.ax.contourf(x_like, y_like, imdata, self.contours, **kwargs) @exporter.export @@ -1042,11 +1052,11 @@ def _build(self): """Build the raster plot by calling any plotting methods as necessary.""" x_like, y_like, imdata = self.plotdata - kwargs = plot_kwargs(imdata) + kwargs = plot_kwargs(imdata, self.mpl_args) + kwargs.setdefault('cmap', self._cmap_obj) + kwargs.setdefault('norm', self._norm_obj) - self.handle = self.parent.ax.pcolormesh(x_like, y_like, imdata, - cmap=self._cmap_obj, norm=self._norm_obj, - **kwargs) + self.handle = self.parent.ax.pcolormesh(x_like, y_like, imdata, **kwargs) @exporter.export @@ -1221,7 +1231,11 @@ def _build(self): """Build the plot by calling needed plotting methods as necessary.""" x_like, y_like, u, v = self.plotdata - kwargs = plot_kwargs(u) + kwargs = plot_kwargs(u, self.mpl_args) + kwargs.setdefault('color', self.color) + kwargs.setdefault('pivot', self.pivot) + kwargs.setdefault('length', self.barblength) + kwargs.setdefault('zorder', 2) # Conditionally apply the proper transform if 'transform' in kwargs and self.earth_relative: @@ -1232,7 +1246,7 @@ def _build(self): self.handle = self.parent.ax.barbs( x_like[wind_slice], y_like[wind_slice], u.values[wind_slice], v.values[wind_slice], - color=self.color, pivot=self.pivot, length=self.barblength, zorder=2, **kwargs) + **kwargs) @exporter.export @@ -1283,7 +1297,10 @@ def _build(self): """Build the plot by calling needed plotting methods as necessary.""" x_like, y_like, u, v = self.plotdata - kwargs = plot_kwargs(u) + kwargs = plot_kwargs(u, self.mpl_args) + kwargs.setdefault('color', self.color) + kwargs.setdefault('pivot', self.pivot) + kwargs.setdefault('scale', self.arrowscale) # Conditionally apply the proper transform if 'transform' in kwargs and self.earth_relative: @@ -1294,7 +1311,7 @@ def _build(self): self.handle = self.parent.ax.quiver( x_like[wind_slice], y_like[wind_slice], u.values[wind_slice], v.values[wind_slice], - color=self.color, pivot=self.pivot, scale=self.arrowscale, **kwargs) + **kwargs) # The order here needs to match the order of the tuple if self.arrowkey is not None: @@ -1569,9 +1586,12 @@ def _build(self): scale = 1. if self.parent._proj_obj == ccrs.PlateCarree() else 100000. point_locs = self.parent._proj_obj.transform_points(ccrs.PlateCarree(), lon, lat) subset = reduce_point_density(point_locs, self.reduce_points * scale) + kwargs = self.mpl_args + kwargs.setdefault('clip_on', True) + kwargs.setdefault('transform', ccrs.PlateCarree()) + kwargs.setdefault('fontsize', self.fontsize) - self.handle = StationPlot(self.parent.ax, lon[subset], lat[subset], clip_on=True, - transform=ccrs.PlateCarree(), fontsize=self.fontsize) + self.handle = StationPlot(self.parent.ax, lon[subset], lat[subset], **kwargs) for i, ob_type in enumerate(self.fields): field_kwargs = {} @@ -1669,6 +1689,17 @@ class PlotGeometry(MetPyHasTraits): the sequence of colors as needed. Default value is black. """ + stroke_width = Union([Instance(collections.abc.Iterable), Float()], default_value=[1], + allow_none=True) + stroke_width.__doc__ = """Stroke width(s) for polygons and lines. + + A single integer or floating point value or collection of values representing the size of + the stroke width. If a collection, the first value corresponds to the first Shapely + object in `geometry`, the second value corresponds to the second Shapely object, and so on. + If `stroke_width` is shorter than `geometry`, `stroke_width` cycles back to the beginning, + repeating the sequence of values as needed. Default value is 1. + """ + marker = Unicode(default_value='.', allow_none=False) marker.__doc__ = """Symbol used to denote points. @@ -1847,27 +1878,38 @@ def _build(self): else self.label_edgecolor) self.label_facecolor = (['none'] if self.label_facecolor is None else self.label_facecolor) + kwargs = self.mpl_args # Each Shapely object is plotted separately with its corresponding colors and label - for geo_obj, stroke, fill, label, fontcolor, fontoutline in zip( - self.geometry, cycle(self.stroke), cycle(self.fill), cycle(self.labels), - cycle(self.label_facecolor), cycle(self.label_edgecolor)): + for geo_obj, stroke, strokewidth, fill, label, fontcolor, fontoutline in zip( + self.geometry, cycle(self.stroke), cycle(self.stroke_width), cycle(self.fill), + cycle(self.labels), cycle(self.label_facecolor), cycle(self.label_edgecolor)): # Plot the Shapely object with the appropriate method and colors if isinstance(geo_obj, (MultiPolygon, Polygon)): - self.parent.ax.add_geometries([geo_obj], edgecolor=stroke, - facecolor=fill, crs=ccrs.PlateCarree()) + kwargs.setdefault('edgecolor', stroke) + kwargs.setdefault('linewidths', strokewidth) + kwargs.setdefault('facecolor', fill) + kwargs.setdefault('crs', ccrs.PlateCarree()) + self.parent.ax.add_geometries([geo_obj], **kwargs) elif isinstance(geo_obj, (MultiLineString, LineString)): - self.parent.ax.add_geometries([geo_obj], edgecolor=stroke, - facecolor='none', crs=ccrs.PlateCarree()) + kwargs.setdefault('edgecolor', stroke) + kwargs.setdefault('linewidths', strokewidth) + kwargs.setdefault('facecolor', 'none') + kwargs.setdefault('crs', ccrs.PlateCarree()) + self.parent.ax.add_geometries([geo_obj], **kwargs) elif isinstance(geo_obj, MultiPoint): + kwargs.setdefault('color', fill) + kwargs.setdefault('marker', self.marker) + kwargs.setdefault('transform', ccrs.PlateCarree()) for point in geo_obj.geoms: lon, lat = point.coords[0] - self.parent.ax.plot(lon, lat, color=fill, marker=self.marker, - transform=ccrs.PlateCarree()) + self.parent.ax.plot(lon, lat, **kwargs) elif isinstance(geo_obj, Point): + kwargs.setdefault('color', fill) + kwargs.setdefault('marker', self.marker) + kwargs.setdefault('transform', ccrs.PlateCarree()) lon, lat = geo_obj.coords[0] - self.parent.ax.plot(lon, lat, color=fill, marker=self.marker, - transform=ccrs.PlateCarree()) + self.parent.ax.plot(lon, lat, **kwargs) # Plot labels if provided if label: diff --git a/tests/plots/baseline/test_colorfill_args.png b/tests/plots/baseline/test_colorfill_args.png new file mode 100644 index 00000000000..44c18a47db7 Binary files /dev/null and b/tests/plots/baseline/test_colorfill_args.png differ diff --git a/tests/plots/baseline/test_declarative_raster_options.png b/tests/plots/baseline/test_declarative_raster_options.png new file mode 100644 index 00000000000..a441f92a20d Binary files /dev/null and b/tests/plots/baseline/test_declarative_raster_options.png differ diff --git a/tests/plots/baseline/test_declarative_sfc_obs_args.png b/tests/plots/baseline/test_declarative_sfc_obs_args.png new file mode 100644 index 00000000000..0237f1addbd Binary files /dev/null and b/tests/plots/baseline/test_declarative_sfc_obs_args.png differ diff --git a/tests/plots/test_declarative.py b/tests/plots/test_declarative.py index 62daa8b8820..dff7de959c7 100644 --- a/tests/plots/test_declarative.py +++ b/tests/plots/test_declarative.py @@ -32,7 +32,7 @@ def test_declarative_image(): img = ImagePlot() img.data = data.metpy.parse_cf('IR') - img.colormap = 'Greys_r' + img.mpl_args = {'cmap': 'Greys_r'} panel = MapPanel() panel.title = 'Test' @@ -376,7 +376,7 @@ def test_declarative_layers_plot_options(): contour.level = 700 * units.hPa contour.contours = 5 contour.linewidth = 1 - contour.linecolor = 'grey' + contour.mpl_args = {'colors': 'grey'} panel = MapPanel() panel.area = 'us' @@ -615,33 +615,6 @@ def test_colorfill(): return pc.figure -@pytest.mark.mpl_image_compare(remove_text=True, tolerance=0.009) -@needs_cartopy -def test_colorfill_args(): - """Test that we can use ContourFillPlot.""" - data = xr.open_dataset(get_test_data('narr_example.nc', as_file_obj=False)) - - contour = FilledContourPlot() - contour.data = data - contour.level = 700 * units.hPa - contour.field = 'Temperature' - contour.colormap = 'coolwarm' - contour.colorbar = 'vertical' - contour.mpl_args = {'alpha': 0.6} - - panel = MapPanel() - panel.area = (-110, -60, 25, 55) - panel.layers = [] - panel.plots = [contour] - - pc = PanelContainer() - pc.panel = panel - pc.size = (12, 8) - pc.draw() - - return pc.figure - - @pytest.mark.mpl_image_compare(remove_text=True, tolerance=0.02) def test_colorfill_with_image_range(cfeature): """Test that we can use ContourFillPlot with image_range bounds.""" @@ -812,7 +785,8 @@ def test_declarative_barb_options(): barb.field = ['u_wind', 'v_wind'] barb.skip = (10, 10) barb.color = 'blue' - barb.pivot = 'tip' + barb.pivot = 'middle' + barb.mpl_args = {'pivot': 'tip'} barb.barblength = 6.5 panel = MapPanel() @@ -841,7 +815,8 @@ def test_declarative_arrowplot(): arrows.field = ['u_wind', 'v_wind'] arrows.skip = (10, 10) arrows.color = 'blue' - arrows.pivot = 'mid' + arrows.pivot = 'tip' + arrows.mpl_args = {'pivot': 'mid'} arrows.arrowscale = 1000 panel = MapPanel() @@ -1314,6 +1289,39 @@ def test_declarative_sfc_obs(ccrs): return pc.figure +@pytest.mark.mpl_image_compare(remove_text=True, tolerance=0.025) +def test_declarative_sfc_obs_args(ccrs): + """Test making a surface observation plot with mpl arguments.""" + data = pd.read_csv(get_test_data('SFC_obs.csv', as_file_obj=False), + infer_datetime_format=True, parse_dates=['valid']) + + obs = PlotObs() + obs.data = data + obs.time = datetime(1993, 3, 12, 12) + obs.time_window = timedelta(minutes=15) + obs.level = None + obs.fields = ['tmpf'] + obs.colors = ['black'] + obs.mpl_args = {'fontsize': 12} + + # Panel for plot with Map features + panel = MapPanel() + panel.layout = (1, 1, 1) + panel.projection = ccrs.PlateCarree() + panel.area = 'in' + panel.layers = ['states'] + panel.plots = [obs] + + # Bringing it all together + pc = PanelContainer() + pc.size = (10, 10) + pc.panels = [panel] + + pc.draw() + + return pc.figure + + @pytest.mark.mpl_image_compare(remove_text=True, tolerance=0.016) @needs_cartopy def test_declarative_sfc_text(): @@ -1815,6 +1823,33 @@ def test_declarative_raster(): return pc.figure +@pytest.mark.mpl_image_compare(remove_text=True, tolerance=0.02) +@needs_cartopy +def test_declarative_raster_options(): + """Test making a raster plot.""" + data = xr.open_dataset(get_test_data('narr_example.nc', as_file_obj=False)) + + raster = RasterPlot() + raster.data = data + raster.colormap = 'viridis' + raster.field = 'Temperature' + raster.level = 700 * units.hPa + raster.mpl_args = {'alpha': 1, 'cmap': 'coolwarm'} + + panel = MapPanel() + panel.area = 'us' + panel.projection = 'lcc' + panel.layers = ['coastline'] + panel.plots = [raster] + + pc = PanelContainer() + pc.size = (8.0, 8) + pc.panels = [panel] + pc.draw() + + return pc.figure + + @pytest.mark.mpl_image_compare(remove_text=True, tolerance=0.607) @needs_cartopy def test_declarative_region_modifier_zoom_in(): @@ -1975,6 +2010,7 @@ def test_declarative_plot_geometry_polygons(): geo = PlotGeometry() geo.geometry = [slgt_risk_polygon, enh_risk_polygon] geo.stroke = ['#DDAA00', '#FF6600'] + geo.stroke_width = [1] geo.fill = None geo.labels = ['SLGT', 'ENH'] geo.label_facecolor = ['#FFE066', '#FFA366'] @@ -2019,6 +2055,7 @@ def test_declarative_plot_geometry_lines(ccrs): geo.stroke = 'green' geo.labels = ['Irma', '+/- 0.25 deg latitude'] geo.label_facecolor = None + geo.mpl_args = {'linewidth': 1} # Place plot in a panel and container panel = MapPanel()