diff --git a/src/sisl/viz/figure/figure.py b/src/sisl/viz/figure/figure.py index 9fa7ec1cf..239fb3291 100644 --- a/src/sisl/viz/figure/figure.py +++ b/src/sisl/viz/figure/figure.py @@ -68,6 +68,17 @@ def _build(self, plot_actions, *args, **kwargs): return fig + @classmethod + def fig_has_attr(cls, key: str) -> bool: + """Whether the figure that this class generates has a given attribute. + + Parameters + ----------- + key + the attribute to check for. + """ + return False + @staticmethod def _sanitize_plot_actions(plot_actions): def _flatten(plot_actions, out, level=0, root_i=0): diff --git a/src/sisl/viz/figure/matplotlib.py b/src/sisl/viz/figure/matplotlib.py index 0cf4fe73f..6ea99ded6 100644 --- a/src/sisl/viz/figure/matplotlib.py +++ b/src/sisl/viz/figure/matplotlib.py @@ -195,9 +195,16 @@ def _iter_multiaxis(self, plot_actions): yield sanitized_section_actions + @classmethod + def fig_has_attr(cls, key: str) -> bool: + return hasattr(plt.Axes, key) or hasattr(plt.Figure, key) + def __getattr__(self, key): if key != "axes": - return getattr(self.axes, key) + if hasattr(self.axes, key): + return getattr(self.axes, key) + elif key != "figure" and hasattr(self.figure, key): + return getattr(self.figure, key) raise AttributeError(key) def clear(self, layout=False): diff --git a/src/sisl/viz/figure/plotly.py b/src/sisl/viz/figure/plotly.py index 306010e4f..24a4bf8a6 100644 --- a/src/sisl/viz/figure/plotly.py +++ b/src/sisl/viz/figure/plotly.py @@ -402,6 +402,11 @@ def _iter_animation(self, plot_actions): self.update_layout(sliders=[slider], updatemenus=updatemenus) + @classmethod + def fig_has_attr(cls, key: str) -> bool: + print(key, hasattr(go.Figure, key)) + return hasattr(go.Figure, key) + def __getattr__(self, key): if key != "figure": return getattr(self.figure, key) diff --git a/src/sisl/viz/plot.py b/src/sisl/viz/plot.py index 5c16c0b75..f8eb357e8 100644 --- a/src/sisl/viz/plot.py +++ b/src/sisl/viz/plot.py @@ -6,25 +6,31 @@ from sisl.messages import deprecate from sisl.nodes import Workflow +from .figure import BACKENDS + class Plot(Workflow): """Base class for all plots""" def __getattr__(self, key): if key != "nodes": - # If an ipython key is requested, get the plot and look - # for the key in the plot. This is simply to enhance - # interactivity in a python notebook environment. - # However, this results in a (maybe undesired) behavior: - # The plot is updated when ipython requests it, without any - # explicit request to update it. This is how it has worked - # from the beggining, so it's probably best to keep it like - # this for now. - if "ipython" in key: - output = self.nodes.output.get() + # From the backend input, we find out which class is the figure going to be + # (even if no figure has been created yet or the latest figure was from a different backend) + # Then we check if the attribute will be available there. If it will, we update the plot and + # get the attribute on the updated plot. + # This is so that things like `plot.show()` work as expected. + # It has the downside that `.get()` is called even when for example a method of the figure is + # retreived to get its docs (e.g. in the helper messages of jupyter notebooks) + selected_backend = self.inputs.get("backend") + figure_cls = BACKENDS.get(selected_backend) + if figure_cls is not None and ( + hasattr(figure_cls, key) or figure_cls.fig_has_attr(key) + ): + return getattr(self.nodes.output.get(), key) else: - output = self.nodes.output._output - return getattr(output, key) + raise AttributeError( + f"'{key}' not found in {self.__class__.__name__} with backend '{selected_backend}'" + ) else: return super().__getattr__(key)