From 6130a4b707bf3c92d119785c56f233291a4af698 Mon Sep 17 00:00:00 2001 From: patrickleonardy Date: Fri, 3 Nov 2023 14:32:54 +0100 Subject: [PATCH] make plot_fn return the graphs and not show them --- cobra/evaluation/evaluator.py | 34 ++++++++++++++++++------------ cobra/evaluation/pigs_tables.py | 6 +++--- cobra/evaluation/plotting_utils.py | 18 ++++++++++------ 3 files changed, 34 insertions(+), 24 deletions(-) diff --git a/cobra/evaluation/evaluator.py b/cobra/evaluation/evaluator.py index e196bc5..fef2e3d 100644 --- a/cobra/evaluation/evaluator.py +++ b/cobra/evaluation/evaluator.py @@ -158,7 +158,7 @@ def _compute_scalar_metrics(y_true: np.ndarray, lift_at=lift_at), 2) }) - def plot_roc_curve(self, path: str=None, dim: tuple=(12, 8)): + def plot_roc_curve(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figure: """Plot ROC curve of the model. Parameters @@ -197,8 +197,8 @@ def plot_roc_curve(self, path: str=None, dim: tuple=(12, 8)): if path: plt.savefig(path, format="png", dpi=300, bbox_inches="tight") - - plt.show() + plt.close() + return fig def plot_confusion_matrix(self, path: str=None, dim: tuple=(12, 8), labels: list=["0", "1"]): @@ -232,9 +232,10 @@ def plot_confusion_matrix(self, path: str=None, dim: tuple=(12, 8), if path: plt.savefig(path, format="png", dpi=300, bbox_inches="tight") - plt.show() + plt.close() + return fig - def plot_cumulative_response_curve(self, path: str=None, dim: tuple=(12, 8)): + def plot_cumulative_response_curve(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figure: """Plot cumulative response curve. Parameters @@ -283,9 +284,10 @@ def plot_cumulative_response_curve(self, path: str=None, dim: tuple=(12, 8)): if path is not None: plt.savefig(path, format="png", dpi=300, bbox_inches="tight") - plt.show() + plt.close() + return fig - def plot_lift_curve(self, path: str=None, dim: tuple=(12, 8)): + def plot_lift_curve(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figure: """Plot lift per decile. Parameters @@ -332,9 +334,10 @@ def plot_lift_curve(self, path: str=None, dim: tuple=(12, 8)): if path is not None: plt.savefig(path, format="png", dpi=300, bbox_inches="tight") - plt.show() + plt.close() + return fig - def plot_cumulative_gains(self, path: str=None, dim: tuple=(12, 8)): + def plot_cumulative_gains(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figure: """Plot cumulative gains per decile. Parameters @@ -376,7 +379,8 @@ def plot_cumulative_gains(self, path: str=None, dim: tuple=(12, 8)): if path is not None: plt.savefig(path, format="png", dpi=300, bbox_inches="tight") - plt.show() + plt.close() + return fig @staticmethod def _find_optimal_cutoff(y_true: np.ndarray, @@ -658,7 +662,7 @@ def _compute_qq_residuals(y_true: np.ndarray, "residuals": df["z_res"].values, }) - def plot_predictions(self, path: str=None, dim: tuple=(12, 8)): + def plot_predictions(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figure: """Plot predictions from the model against actual values. Parameters @@ -692,9 +696,10 @@ def plot_predictions(self, path: str=None, dim: tuple=(12, 8)): if path: plt.savefig(path, format="png", dpi=300, bbox_inches="tight") - plt.show() + plt.close() + return fig - def plot_qq(self, path: str=None, dim: tuple=(12, 8)): + def plot_qq(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figure: """Display a Q-Q plot from the standardized prediction residuals. Parameters @@ -733,4 +738,5 @@ def plot_qq(self, path: str=None, dim: tuple=(12, 8)): if path: plt.savefig(path, format="png", dpi=300, bbox_inches="tight") - plt.show() \ No newline at end of file + plt.close() + return fig \ No newline at end of file diff --git a/cobra/evaluation/pigs_tables.py b/cobra/evaluation/pigs_tables.py index 583f4e6..c8d5826 100644 --- a/cobra/evaluation/pigs_tables.py +++ b/cobra/evaluation/pigs_tables.py @@ -107,7 +107,7 @@ def plot_incidence(pig_tables: pd.DataFrame, variable: str, model_type: str, column_order: list=None, - dim: tuple=(12, 8)): + dim: tuple=(12, 8)) -> plt.Figure: """Plots a Predictor Insights Graph (PIG), a graph in which the mean target value is plotted for a number of bins constructed from a predictor variable. When the target is a binary classification target, @@ -257,5 +257,5 @@ def plot_incidence(pig_tables: pd.DataFrame, plt.tight_layout() plt.margins(0.01) - # Show - plt.show() + plt.close() + return fig diff --git a/cobra/evaluation/plotting_utils.py b/cobra/evaluation/plotting_utils.py index 3a51f62..8703e34 100644 --- a/cobra/evaluation/plotting_utils.py +++ b/cobra/evaluation/plotting_utils.py @@ -8,7 +8,7 @@ def plot_univariate_predictor_quality(df_metric: pd.DataFrame, dim: tuple=(12, 8), - path: str=None): + path: str=None) -> plt.Figure: """Plot univariate quality of the predictors. Parameters @@ -58,7 +58,8 @@ def plot_univariate_predictor_quality(df_metric: pd.DataFrame, plt.gca().legend().set_title("") - plt.show() + plt.close() + return fig def plot_correlation_matrix(df_corr: pd.DataFrame, dim: tuple=(12, 8), @@ -81,7 +82,8 @@ def plot_correlation_matrix(df_corr: pd.DataFrame, if path is not None: plt.savefig(path, format="png", dpi=300, bbox_inches="tight") - plt.show() + plt.close() + return fig def plot_performance_curves(model_performance: pd.DataFrame, dim: tuple=(12, 8), @@ -89,7 +91,7 @@ def plot_performance_curves(model_performance: pd.DataFrame, colors: dict={"train": "#0099bf", "selection": "#ff9500", "validation": "#8064a2"}, - metric_name: str=None): + metric_name: str=None) -> plt.Figure: """Plot performance curves generated by the forward feature selection for the train-selection-validation sets. @@ -159,12 +161,13 @@ def plot_performance_curves(model_performance: pd.DataFrame, if path is not None: plt.savefig(path, format="png", dpi=300, bbox_inches="tight") - plt.show() + plt.close() + return fig def plot_variable_importance(df_variable_importance: pd.DataFrame, title: str=None, dim: tuple=(12, 8), - path: str=None): + path: str=None) -> plt.Figure: """Plot variable importance of a given model. Parameters @@ -199,4 +202,5 @@ def plot_variable_importance(df_variable_importance: pd.DataFrame, if path is not None: plt.savefig(path, format="png", dpi=300, bbox_inches="tight") - plt.show() + plt.close() + return fig