Skip to content

Commit

Permalink
Change behavior of None default metrics and allow results/plotting to…
Browse files Browse the repository at this point in the history
… work with bio/batch separately. (#181)


---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
adamgayoso and pre-commit-ci[bot] authored Nov 4, 2024
1 parent a2575da commit e31aa1c
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 16 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ and this project adheres to [Semantic Versioning][].
### Added

- Add `progress_bar` argument to {class}`scib_metrics.benchmark.Benchmarker` {pr}`152`.
- Add ability of {class}`scib_metrics.benchmark.Benchmarker` plotting code to handle missing sets of metrics {pr}`181`.

### Changed

- Leiden clustering now has a seed argument for reproducibility {pr}`173`.
- Changed Leiden clustering now has a seed argument for reproducibility {pr}`173`.
- Changed passing `None` to `bio_conservation_metrics` or `batch_correction_metrics` in {class}`scib_metrics.benchmark.Benchmarker` now implies to skip this set of metrics {pr}`181`.

### Fixed

Expand Down
3 changes: 2 additions & 1 deletion docs/notebooks/large_scale.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"import scanpy as sc\n",
"from scvi.data import cellxgene\n",
"\n",
"from scib_metrics.benchmark import Benchmarker, BioConservation\n",
"from scib_metrics.benchmark import Benchmarker, BioConservation, BatchCorrection\n",
"\n",
"%matplotlib inline"
]
Expand Down Expand Up @@ -294,6 +294,7 @@
" embedding_obsm_keys=[\"Unintegrated\", \"scANVI\", \"scVI\"],\n",
" pre_integrated_embedding_obsm_key=\"X_pca\",\n",
" bio_conservation_metrics=biocons,\n",
" batch_correction_metrics=BatchCorrection(),\n",
" n_jobs=-1,\n",
")\n",
"bm.prepare(neighbor_computer=faiss_brute_force_nn)\n",
Expand Down
4 changes: 3 additions & 1 deletion docs/notebooks/lung_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"import numpy as np\n",
"import scanpy as sc\n",
"\n",
"from scib_metrics.benchmark import Benchmarker\n",
"from scib_metrics.benchmark import Benchmarker, BioConservation, BatchCorrection\n",
"\n",
"%matplotlib inline"
]
Expand Down Expand Up @@ -271,6 +271,8 @@
" adata,\n",
" batch_key=\"batch\",\n",
" label_key=\"cell_type\",\n",
" bio_conservation_metrics=BioConservation(),\n",
" batch_correction_metrics=BatchCorrection(),\n",
" embedding_obsm_keys=[\"Unintegrated\", \"Scanorama\", \"LIGER\", \"Harmony\", \"scVI\", \"scANVI\"],\n",
" n_jobs=6,\n",
")\n",
Expand Down
33 changes: 23 additions & 10 deletions src/scib_metrics/benchmark/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,17 @@ def __init__(
batch_key: str,
label_key: str,
embedding_obsm_keys: list[str],
bio_conservation_metrics: BioConservation | None = None,
batch_correction_metrics: BatchCorrection | None = None,
bio_conservation_metrics: BioConservation | None,
batch_correction_metrics: BatchCorrection | None,
pre_integrated_embedding_obsm_key: str | None = None,
n_jobs: int = 1,
progress_bar: bool = True,
):
self._adata = adata
self._embedding_obsm_keys = embedding_obsm_keys
self._pre_integrated_embedding_obsm_key = pre_integrated_embedding_obsm_key
self._bio_conservation_metrics = bio_conservation_metrics if bio_conservation_metrics else BioConservation()
self._batch_correction_metrics = batch_correction_metrics if batch_correction_metrics else BatchCorrection()
self._bio_conservation_metrics = bio_conservation_metrics
self._batch_correction_metrics = batch_correction_metrics
self._results = pd.DataFrame(columns=list(self._embedding_obsm_keys) + [_METRIC_TYPE])
self._emb_adatas = {}
self._neighbor_values = (15, 50, 90)
Expand All @@ -157,10 +157,14 @@ def __init__(
self._n_jobs = n_jobs
self._progress_bar = progress_bar

self._metric_collection_dict = {
"Bio conservation": self._bio_conservation_metrics,
"Batch correction": self._batch_correction_metrics,
}
if self._bio_conservation_metrics is None and self._batch_correction_metrics is None:
raise ValueError("Either batch or bio metrics must be defined.")

self._metric_collection_dict = {}
if self._bio_conservation_metrics is not None:
self._metric_collection_dict.update({"Bio conservation": self._bio_conservation_metrics})
if self._batch_correction_metrics is not None:
self._metric_collection_dict.update({"Batch correction": self._batch_correction_metrics})

def prepare(self, neighbor_computer: Callable[[np.ndarray, int], NeighborsResults] | None = None) -> None:
"""Prepare the data for benchmarking.
Expand Down Expand Up @@ -279,7 +283,10 @@ def get_results(self, min_max_scale: bool = True, clean_names: bool = True) -> p
# Compute scores
per_class_score = df.groupby(_METRIC_TYPE).mean().transpose()
# This is the default scIB weighting from the manuscript
per_class_score["Total"] = 0.4 * per_class_score["Batch correction"] + 0.6 * per_class_score["Bio conservation"]
if self._batch_correction_metrics is not None and self._bio_conservation_metrics is not None:
per_class_score["Total"] = (
0.4 * per_class_score["Batch correction"] + 0.6 * per_class_score["Bio conservation"]
)
df = pd.concat([df.transpose(), per_class_score], axis=1)
df.loc[_METRIC_TYPE, per_class_score.columns] = _AGGREGATE_SCORE
return df
Expand All @@ -302,7 +309,13 @@ def plot_results_table(self, min_max_scale: bool = True, show: bool = True, save
# Do not want to plot what kind of metric it is
plot_df = df.drop(_METRIC_TYPE, axis=0)
# Sort by total score
plot_df = plot_df.sort_values(by="Total", ascending=False).astype(np.float64)
if self._batch_correction_metrics is not None and self._bio_conservation_metrics is not None:
sort_col = "Total"
elif self._batch_correction_metrics is not None:
sort_col = "Batch correction"
else:
sort_col = "Bio conservation"
plot_df = plot_df.sort_values(by=sort_col, ascending=False).astype(np.float64)
plot_df["Method"] = plot_df.index

# Split columns by metric type, using df as it doesn't have the new method col
Expand Down
22 changes: 19 additions & 3 deletions tests/test_benchmarker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@

def test_benchmarker():
ad, emb_keys, batch_key, labels_key = dummy_benchmarker_adata()
bm = Benchmarker(ad, batch_key, labels_key, emb_keys)
bm = Benchmarker(
ad,
batch_key,
labels_key,
emb_keys,
batch_correction_metrics=BatchCorrection(),
bio_conservation_metrics=BioConservation(),
)
bm.benchmark()
results = bm.get_results()
assert isinstance(results, pd.DataFrame)
Expand Down Expand Up @@ -36,15 +43,24 @@ def test_benchmarker_custom_metric_booleans():
def test_benchmarker_custom_metric_callable():
bioc = BioConservation(clisi_knn={"perplexity": 10})
ad, emb_keys, batch_key, labels_key = dummy_benchmarker_adata()
bm = Benchmarker(ad, batch_key, labels_key, emb_keys, bio_conservation_metrics=bioc)
bm = Benchmarker(
ad, batch_key, labels_key, emb_keys, bio_conservation_metrics=bioc, batch_correction_metrics=BatchCorrection()
)
bm.benchmark()
results = bm.get_results(clean_names=False)
assert "clisi_knn" in results.columns


def test_benchmarker_custom_near_neighs():
ad, emb_keys, batch_key, labels_key = dummy_benchmarker_adata()
bm = Benchmarker(ad, batch_key, labels_key, emb_keys)
bm = Benchmarker(
ad,
batch_key,
labels_key,
emb_keys,
bio_conservation_metrics=BioConservation(),
batch_correction_metrics=BatchCorrection(),
)
bm.prepare(neighbor_computer=jax_approx_min_k)
bm.benchmark()
results = bm.get_results()
Expand Down

0 comments on commit e31aa1c

Please sign in to comment.