diff --git a/blobmodel/plotting.py b/blobmodel/plotting.py index f26a7e8..1c66076 100644 --- a/blobmodel/plotting.py +++ b/blobmodel/plotting.py @@ -5,14 +5,14 @@ import numpy as np import xarray as xr from matplotlib import animation +from typing import Union def show_model( dataset: xr.Dataset, variable: str = "n", interval: int = 100, - save: bool = False, - gif_name: str = "blobs.gif", + gif_name: Union[str, None] = None, fps: int = 10, ) -> None: """ @@ -26,10 +26,8 @@ def show_model( Variable to be animated (default: "n"). interval : int, optional Time interval between frames in milliseconds (default: 100). - save : bool, optional - If True, save the animation as a GIF (default: False). gif_name : str, optional - Set the name for the saved GIF (default: "blobs.gif"). + If not None, save the animation as a GIF and name it acoridingly. fps : int, optional Set the frames per second for the saved GIF (default: 10). @@ -103,7 +101,7 @@ def animate_2d(i: int) -> None: fig, animate_2d, frames=dataset["t"].values.size, interval=interval ) - if save: + if gif_name: ani.save(gif_name, writer="ffmpeg", fps=fps) plt.show() diff --git a/examples/1d_animation.py b/examples/1d_animation.py index 6ae010f..352c181 100644 --- a/examples/1d_animation.py +++ b/examples/1d_animation.py @@ -17,4 +17,4 @@ ) ds = bm.make_realization(speed_up=True, error=1e-2) -show_model(dataset=ds, interval=100, save=True) +show_model(dataset=ds, interval=100, gif_name="1d_animation.gif") diff --git a/examples/2d_animation.py b/examples/2d_animation.py index 4952886..ebb0a1e 100644 --- a/examples/2d_animation.py +++ b/examples/2d_animation.py @@ -18,4 +18,4 @@ # create data ds = bm.make_realization(speed_up=True, error=1e-2) # show animation and save as gif -show_model(dataset=ds, interval=100, save=True, gif_name="example.gif", fps=10) +show_model(dataset=ds, interval=100, gif_name="2d_animation.gif", fps=10) diff --git a/examples/compare_to_analytical_sol.py b/examples/compare_to_analytical_sol.py index 0af9088..fe8946b 100644 --- a/examples/compare_to_analytical_sol.py +++ b/examples/compare_to_analytical_sol.py @@ -21,7 +21,7 @@ blob_factory=bf, ) -ds = tmp.make_realization(file_name="profile_comparison.nc", speed_up=True, error=1e-2) +ds = tmp.make_realization(file_name="profile_comparison.nc", speed_up=False, error=1e-4) def plot_convergence_to_analytical_solution(ds): diff --git a/examples/single_blob.py b/examples/single_blob.py index fac8248..2d5879c 100644 --- a/examples/single_blob.py +++ b/examples/single_blob.py @@ -66,4 +66,4 @@ def is_one_dimensional(self) -> bool: ds = bm.make_realization(speed_up=True, error=1e-2) # show animation and save as gif -show_model(dataset=ds, interval=100, save=True, gif_name="example.gif", fps=10) +show_model(dataset=ds, interval=100, gif_name="example.gif", fps=10) diff --git a/tests/test_show_model.py b/tests/test_show_model.py index 7f3a326..4e78741 100644 --- a/tests/test_show_model.py +++ b/tests/test_show_model.py @@ -23,7 +23,7 @@ @patch("matplotlib.pyplot.show") def test_plot_2d(mock_show): warnings.filterwarnings("ignore") - show_model(dataset=ds_2d, interval=100, save=False, fps=10) + show_model(dataset=ds_2d, interval=100, gif_name=None, fps=10) bm_1d = Model( @@ -46,4 +46,4 @@ def test_plot_2d(mock_show): @patch("matplotlib.pyplot.show") def test_plot_1d(mock_show): warnings.filterwarnings("ignore") - show_model(dataset=ds_1d, interval=100, save=False, fps=10) + show_model(dataset=ds_1d, interval=100, gif_name=None, fps=10)