diff --git a/pygmt/helpers/testing.py b/pygmt/helpers/testing.py index 2b57b8e37e1..a398f3c30d5 100644 --- a/pygmt/helpers/testing.py +++ b/pygmt/helpers/testing.py @@ -172,6 +172,10 @@ def download_test_data(): # Other cache files "@EGM96_to_36.txt", "@MaunaLoa_CO2.txt", + "@RidgeTest.shp", + "@RidgeTest.shx", + "@RidgeTest.dbf", + "@RidgeTest.prj", "@Table_5_11.txt", "@Table_5_11_mean.xyz", "@fractures_06.txt", diff --git a/pygmt/src/plot.py b/pygmt/src/plot.py index 5f627383517..8345c065916 100644 --- a/pygmt/src/plot.py +++ b/pygmt/src/plot.py @@ -229,15 +229,13 @@ def plot(self, data=None, x=None, y=None, size=None, direction=None, **kwargs): and data.geom_type.isin(["Point", "MultiPoint"]).all() ): # checking if the geometry of a geoDataFrame is Point or MultiPoint kwargs["S"] = "s0.2c" - elif ( - "S" not in kwargs and kind == "file" - ): # checking that the data is a file path to set default style + elif "S" not in kwargs and kind == "file" and data.endswith(".gmt"): + # checking that the data is a file path to set default style try: with open(which(data), mode="r", encoding="utf8") as file: line = file.readline() - if ( - "@GMULTIPOINT" in line or "@GPOINT" in line - ): # if the file is gmt style and geometry is set to Point + if "@GMULTIPOINT" in line or "@GPOINT" in line: + # if the file is gmt style and geometry is set to Point kwargs["S"] = "s0.2c" except FileNotFoundError: pass diff --git a/pygmt/src/plot3d.py b/pygmt/src/plot3d.py index 27164eda470..42fc2faf143 100644 --- a/pygmt/src/plot3d.py +++ b/pygmt/src/plot3d.py @@ -199,15 +199,13 @@ def plot3d( and data.geom_type.isin(["Point", "MultiPoint"]).all() ): # checking if the geometry of a geoDataFrame is Point or MultiPoint kwargs["S"] = "u0.2c" - elif ( - "S" not in kwargs and kind == "file" - ): # checking that the data is a file path to set default style + elif "S" not in kwargs and kind == "file" and data.endswith(".gmt"): + # checking that the data is a file path to set default style try: with open(which(data), mode="r", encoding="utf8") as file: line = file.readline() - if ( - "@GMULTIPOINT" in line or "@GPOINT" in line - ): # if the file is gmt style and geometry is set to Point + if "@GMULTIPOINT" in line or "@GPOINT" in line: + # if the file is gmt style and geometry is set to Point kwargs["S"] = "u0.2c" except FileNotFoundError: pass diff --git a/pygmt/tests/baseline/test_plot_shapefile.png.dvc b/pygmt/tests/baseline/test_plot_shapefile.png.dvc new file mode 100644 index 00000000000..c2a13316351 --- /dev/null +++ b/pygmt/tests/baseline/test_plot_shapefile.png.dvc @@ -0,0 +1,4 @@ +outs: +- md5: 75277741d098cf7a0bad7869b574afc9 + size: 24178 + path: test_plot_shapefile.png diff --git a/pygmt/tests/test_plot.py b/pygmt/tests/test_plot.py index fb46a872c0b..629935a047f 100644 --- a/pygmt/tests/test_plot.py +++ b/pygmt/tests/test_plot.py @@ -543,6 +543,17 @@ def test_plot_ogrgmt_file_multipoint_non_default_style(): @pytest.mark.mpl_image_compare +def test_plot_shapefile(): + """ + Make sure that plot works for shapefile. + + See https://github.com/GenericMappingTools/pygmt/issues/1616. + """ + fig = Figure() + fig.plot(data="@RidgeTest.shp", pen="1p") + return fig + + def test_plot_dataframe_incols(): """ Make sure that the incols parameter works for pandas.DataFrame.