From e31aa1cc7f8c6e8e3332369519aed7dfd6b1f15b Mon Sep 17 00:00:00 2001 From: Adam Gayoso Date: Mon, 4 Nov 2024 08:01:59 +0000 Subject: [PATCH] Change behavior of None default metrics and allow results/plotting to work with bio/batch separately. (#181) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 4 +++- docs/notebooks/large_scale.ipynb | 3 ++- docs/notebooks/lung_example.ipynb | 4 +++- src/scib_metrics/benchmark/_core.py | 33 ++++++++++++++++++++--------- tests/test_benchmarker.py | 22 ++++++++++++++++--- 5 files changed, 50 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8e377d9..1e54044 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,10 +13,12 @@ and this project adheres to [Semantic Versioning][]. ### Added - Add `progress_bar` argument to {class}`scib_metrics.benchmark.Benchmarker` {pr}`152`. +- Add ability of {class}`scib_metrics.benchmark.Benchmarker` plotting code to handle missing sets of metrics {pr}`181`. ### Changed -- Leiden clustering now has a seed argument for reproducibility {pr}`173`. +- Changed Leiden clustering now has a seed argument for reproducibility {pr}`173`. +- Changed passing `None` to `bio_conservation_metrics` or `batch_correction_metrics` in {class}`scib_metrics.benchmark.Benchmarker` now implies to skip this set of metrics {pr}`181`. ### Fixed diff --git a/docs/notebooks/large_scale.ipynb b/docs/notebooks/large_scale.ipynb index 46e9743..2bae6b1 100644 --- a/docs/notebooks/large_scale.ipynb +++ b/docs/notebooks/large_scale.ipynb @@ -30,7 +30,7 @@ "import scanpy as sc\n", "from scvi.data import cellxgene\n", "\n", - "from scib_metrics.benchmark import Benchmarker, BioConservation\n", + "from scib_metrics.benchmark import Benchmarker, BioConservation, BatchCorrection\n", "\n", "%matplotlib inline" ] @@ -294,6 +294,7 @@ " embedding_obsm_keys=[\"Unintegrated\", \"scANVI\", \"scVI\"],\n", " pre_integrated_embedding_obsm_key=\"X_pca\",\n", " bio_conservation_metrics=biocons,\n", + " batch_correction_metrics=BatchCorrection(),\n", " n_jobs=-1,\n", ")\n", "bm.prepare(neighbor_computer=faiss_brute_force_nn)\n", diff --git a/docs/notebooks/lung_example.ipynb b/docs/notebooks/lung_example.ipynb index 73e4f9d..262f10b 100644 --- a/docs/notebooks/lung_example.ipynb +++ b/docs/notebooks/lung_example.ipynb @@ -25,7 +25,7 @@ "import numpy as np\n", "import scanpy as sc\n", "\n", - "from scib_metrics.benchmark import Benchmarker\n", + "from scib_metrics.benchmark import Benchmarker, BioConservation, BatchCorrection\n", "\n", "%matplotlib inline" ] @@ -271,6 +271,8 @@ " adata,\n", " batch_key=\"batch\",\n", " label_key=\"cell_type\",\n", + " bio_conservation_metrics=BioConservation(),\n", + " batch_correction_metrics=BatchCorrection(),\n", " embedding_obsm_keys=[\"Unintegrated\", \"Scanorama\", \"LIGER\", \"Harmony\", \"scVI\", \"scANVI\"],\n", " n_jobs=6,\n", ")\n", diff --git a/src/scib_metrics/benchmark/_core.py b/src/scib_metrics/benchmark/_core.py index dcfc259..6c6d427 100644 --- a/src/scib_metrics/benchmark/_core.py +++ b/src/scib_metrics/benchmark/_core.py @@ -136,8 +136,8 @@ def __init__( batch_key: str, label_key: str, embedding_obsm_keys: list[str], - bio_conservation_metrics: BioConservation | None = None, - batch_correction_metrics: BatchCorrection | None = None, + bio_conservation_metrics: BioConservation | None, + batch_correction_metrics: BatchCorrection | None, pre_integrated_embedding_obsm_key: str | None = None, n_jobs: int = 1, progress_bar: bool = True, @@ -145,8 +145,8 @@ def __init__( self._adata = adata self._embedding_obsm_keys = embedding_obsm_keys self._pre_integrated_embedding_obsm_key = pre_integrated_embedding_obsm_key - self._bio_conservation_metrics = bio_conservation_metrics if bio_conservation_metrics else BioConservation() - self._batch_correction_metrics = batch_correction_metrics if batch_correction_metrics else BatchCorrection() + self._bio_conservation_metrics = bio_conservation_metrics + self._batch_correction_metrics = batch_correction_metrics self._results = pd.DataFrame(columns=list(self._embedding_obsm_keys) + [_METRIC_TYPE]) self._emb_adatas = {} self._neighbor_values = (15, 50, 90) @@ -157,10 +157,14 @@ def __init__( self._n_jobs = n_jobs self._progress_bar = progress_bar - self._metric_collection_dict = { - "Bio conservation": self._bio_conservation_metrics, - "Batch correction": self._batch_correction_metrics, - } + if self._bio_conservation_metrics is None and self._batch_correction_metrics is None: + raise ValueError("Either batch or bio metrics must be defined.") + + self._metric_collection_dict = {} + if self._bio_conservation_metrics is not None: + self._metric_collection_dict.update({"Bio conservation": self._bio_conservation_metrics}) + if self._batch_correction_metrics is not None: + self._metric_collection_dict.update({"Batch correction": self._batch_correction_metrics}) def prepare(self, neighbor_computer: Callable[[np.ndarray, int], NeighborsResults] | None = None) -> None: """Prepare the data for benchmarking. @@ -279,7 +283,10 @@ def get_results(self, min_max_scale: bool = True, clean_names: bool = True) -> p # Compute scores per_class_score = df.groupby(_METRIC_TYPE).mean().transpose() # This is the default scIB weighting from the manuscript - per_class_score["Total"] = 0.4 * per_class_score["Batch correction"] + 0.6 * per_class_score["Bio conservation"] + if self._batch_correction_metrics is not None and self._bio_conservation_metrics is not None: + per_class_score["Total"] = ( + 0.4 * per_class_score["Batch correction"] + 0.6 * per_class_score["Bio conservation"] + ) df = pd.concat([df.transpose(), per_class_score], axis=1) df.loc[_METRIC_TYPE, per_class_score.columns] = _AGGREGATE_SCORE return df @@ -302,7 +309,13 @@ def plot_results_table(self, min_max_scale: bool = True, show: bool = True, save # Do not want to plot what kind of metric it is plot_df = df.drop(_METRIC_TYPE, axis=0) # Sort by total score - plot_df = plot_df.sort_values(by="Total", ascending=False).astype(np.float64) + if self._batch_correction_metrics is not None and self._bio_conservation_metrics is not None: + sort_col = "Total" + elif self._batch_correction_metrics is not None: + sort_col = "Batch correction" + else: + sort_col = "Bio conservation" + plot_df = plot_df.sort_values(by=sort_col, ascending=False).astype(np.float64) plot_df["Method"] = plot_df.index # Split columns by metric type, using df as it doesn't have the new method col diff --git a/tests/test_benchmarker.py b/tests/test_benchmarker.py index 0a73874..bbac237 100644 --- a/tests/test_benchmarker.py +++ b/tests/test_benchmarker.py @@ -7,7 +7,14 @@ def test_benchmarker(): ad, emb_keys, batch_key, labels_key = dummy_benchmarker_adata() - bm = Benchmarker(ad, batch_key, labels_key, emb_keys) + bm = Benchmarker( + ad, + batch_key, + labels_key, + emb_keys, + batch_correction_metrics=BatchCorrection(), + bio_conservation_metrics=BioConservation(), + ) bm.benchmark() results = bm.get_results() assert isinstance(results, pd.DataFrame) @@ -36,7 +43,9 @@ def test_benchmarker_custom_metric_booleans(): def test_benchmarker_custom_metric_callable(): bioc = BioConservation(clisi_knn={"perplexity": 10}) ad, emb_keys, batch_key, labels_key = dummy_benchmarker_adata() - bm = Benchmarker(ad, batch_key, labels_key, emb_keys, bio_conservation_metrics=bioc) + bm = Benchmarker( + ad, batch_key, labels_key, emb_keys, bio_conservation_metrics=bioc, batch_correction_metrics=BatchCorrection() + ) bm.benchmark() results = bm.get_results(clean_names=False) assert "clisi_knn" in results.columns @@ -44,7 +53,14 @@ def test_benchmarker_custom_metric_callable(): def test_benchmarker_custom_near_neighs(): ad, emb_keys, batch_key, labels_key = dummy_benchmarker_adata() - bm = Benchmarker(ad, batch_key, labels_key, emb_keys) + bm = Benchmarker( + ad, + batch_key, + labels_key, + emb_keys, + bio_conservation_metrics=BioConservation(), + batch_correction_metrics=BatchCorrection(), + ) bm.prepare(neighbor_computer=jax_approx_min_k) bm.benchmark() results = bm.get_results()