diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 13267d1..4357e2f 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -20,7 +20,7 @@ jobs: strategy: fail-fast: false matrix: - python: ["3.8", "3.10"] + python: ["3.9", "3.10"] os: [ubuntu-latest] env: diff --git a/CHANGELOG.md b/CHANGELOG.md index 521f1bc..2dc59e8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,13 @@ and this project adheres to [Semantic Versioning][]. [keep a changelog]: https://keepachangelog.com/en/1.0.0/ [semantic versioning]: https://semver.org/spec/v2.0.0.html +## 0.4.0 (2023-09-19) + +- Drop Python 3.8 ([#107][]) +- Fix jax one-hot error ([#107][]) + +[#107]: https://github.com/YosefLab/scib-metrics/pull/107 + ## 0.3.3 (2023-03-29) ### Fixed diff --git a/README.md b/README.md index a6b0178..8bd0467 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ Please refer to the [documentation][link-docs]. ## Installation -You need to have Python 3.8 or newer installed on your system. If you don't have +You need to have Python 3.9 or newer installed on your system. If you don't have Python installed, we recommend installing [Miniconda](https://docs.conda.io/en/latest/miniconda.html). There are several alternative options to install scib-metrics: diff --git a/pyproject.toml b/pyproject.toml index defb0fc..84bb887 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ name = "scib-metrics" version = "0.3.3" description = "Accelerated and Python-only scIB metrics" readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" license = {file = "LICENSE"} authors = [ {name = "Adam Gayoso"}, @@ -28,7 +28,7 @@ dependencies = [ "pandas", "scipy", "scikit-learn", - "scanpy", + "scanpy>=1.9", "rich", "pynndescent", "igraph>0.9.0", @@ -53,6 +53,7 @@ doc = [ "ipython", "ipykernel", "sphinx-copybutton", + "numba>=0.57.1", ] test = [ "pytest", @@ -63,6 +64,7 @@ test = [ # For vscode Python extension testing "flake8", "black", + "numba>=0.57.1", ] parallel = [ "joblib" @@ -95,8 +97,7 @@ xfail_strict = true [tool.ruff] src = ["src"] -line-length = 119 -target-version = "py38" +line-length = 120 select = [ "F", # Errors detected by Pyflakes "E", # Error detected by Pycodestyle @@ -152,7 +153,6 @@ convention = "numpy" [tool.black] line-length = 120 -target-version = ['py38'] include = '\.pyi?$' exclude = ''' ( diff --git a/src/scib_metrics/_kbet.py b/src/scib_metrics/_kbet.py index abda553..ca415a7 100644 --- a/src/scib_metrics/_kbet.py +++ b/src/scib_metrics/_kbet.py @@ -1,6 +1,6 @@ import logging from functools import partial -from typing import Tuple, Union +from typing import Union import chex import jax @@ -102,7 +102,7 @@ def kbet_per_label( alpha: float = 0.05, diffusion_n_comps: int = 100, return_df: bool = False, -) -> Union[float, Tuple[float, pd.DataFrame]]: +) -> Union[float, tuple[float, pd.DataFrame]]: """Compute kBET score per cell type label as in :cite:p:`luecken2022benchmarking`. This approximates the method used in the original scib package. Notably, the underlying diff --git a/src/scib_metrics/_nmi_ari.py b/src/scib_metrics/_nmi_ari.py index 97c09c8..371b293 100644 --- a/src/scib_metrics/_nmi_ari.py +++ b/src/scib_metrics/_nmi_ari.py @@ -1,6 +1,5 @@ import logging import warnings -from typing import Dict, Tuple import numpy as np import scanpy as sc @@ -30,14 +29,14 @@ def _compute_nmi_ari_cluster_labels( X: np.ndarray, labels: np.ndarray, resolution: float = 1.0, -) -> Tuple[float, float]: +) -> tuple[float, float]: labels_pred = _compute_clustering_leiden(X, resolution) nmi = normalized_mutual_info_score(labels, labels_pred, average_method="arithmetic") ari = adjusted_rand_score(labels, labels_pred) return nmi, ari -def nmi_ari_cluster_labels_kmeans(X: np.ndarray, labels: np.ndarray) -> Dict[str, float]: +def nmi_ari_cluster_labels_kmeans(X: np.ndarray, labels: np.ndarray) -> dict[str, float]: """Compute nmi and ari between k-means clusters and labels. This deviates from the original implementation in scib by using k-means @@ -69,7 +68,7 @@ def nmi_ari_cluster_labels_kmeans(X: np.ndarray, labels: np.ndarray) -> Dict[str def nmi_ari_cluster_labels_leiden( X: spmatrix, labels: np.ndarray, optimize_resolution: bool = True, resolution: float = 1.0, n_jobs: int = 1 -) -> Dict[str, float]: +) -> dict[str, float]: """Compute nmi and ari between leiden clusters and labels. This deviates from the original implementation in scib by using leiden instead of diff --git a/src/scib_metrics/benchmark/_core.py b/src/scib_metrics/benchmark/_core.py index a1e2a69..f0c1fcb 100644 --- a/src/scib_metrics/benchmark/_core.py +++ b/src/scib_metrics/benchmark/_core.py @@ -3,7 +3,7 @@ from dataclasses import asdict, dataclass from enum import Enum from functools import partial -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Optional, Union import matplotlib import matplotlib.pyplot as plt @@ -20,7 +20,7 @@ import scib_metrics from scib_metrics.nearest_neighbors import NeighborsOutput, pynndescent -Kwargs = Dict[str, Any] +Kwargs = dict[str, Any] MetricType = Union[bool, Kwargs] _LABELS = "labels" @@ -131,7 +131,7 @@ def __init__( adata: AnnData, batch_key: str, label_key: str, - embedding_obsm_keys: List[str], + embedding_obsm_keys: list[str], bio_conservation_metrics: Optional[BioConservation] = None, batch_correction_metrics: Optional[BatchCorrection] = None, pre_integrated_embedding_obsm_key: Optional[str] = None, diff --git a/src/scib_metrics/utils/_lisi.py b/src/scib_metrics/utils/_lisi.py index a386023..5ac76c9 100644 --- a/src/scib_metrics/utils/_lisi.py +++ b/src/scib_metrics/utils/_lisi.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Tuple, Union +from typing import Union import chex import jax @@ -23,7 +23,7 @@ class _NeighborProbabilityState: @jax.jit -def _Hbeta(knn_dists_row: jnp.ndarray, beta: float) -> Tuple[jnp.ndarray, jnp.ndarray]: +def _Hbeta(knn_dists_row: jnp.ndarray, beta: float) -> tuple[jnp.ndarray, jnp.ndarray]: P = jnp.exp(-knn_dists_row * beta) sumP = jnp.nansum(P) H = jnp.where(sumP == 0, 0, jnp.log(sumP) + beta * jnp.nansum(knn_dists_row * P) / sumP) @@ -34,7 +34,7 @@ def _Hbeta(knn_dists_row: jnp.ndarray, beta: float) -> Tuple[jnp.ndarray, jnp.nd @jax.jit def _get_neighbor_probability( knn_dists_row: jnp.ndarray, perplexity: float, tol: float -) -> Tuple[jnp.ndarray, jnp.ndarray]: +) -> tuple[jnp.ndarray, jnp.ndarray]: beta = 1 betamin = -jnp.inf betamax = jnp.inf diff --git a/src/scib_metrics/utils/_pca.py b/src/scib_metrics/utils/_pca.py index fd9c0be..0e47ceb 100644 --- a/src/scib_metrics/utils/_pca.py +++ b/src/scib_metrics/utils/_pca.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Optional import jax.numpy as jnp from chex import dataclass @@ -128,7 +128,7 @@ def pca( @jit def _pca( X: NdArray, -) -> Tuple[NdArray, NdArray, NdArray, NdArray, NdArray]: +) -> tuple[NdArray, NdArray, NdArray, NdArray, NdArray]: """Principal component analysis. Parameters diff --git a/src/scib_metrics/utils/_silhouette.py b/src/scib_metrics/utils/_silhouette.py index e3e1869..3407385 100644 --- a/src/scib_metrics/utils/_silhouette.py +++ b/src/scib_metrics/utils/_silhouette.py @@ -1,5 +1,4 @@ from functools import partial -from typing import Tuple import jax import jax.numpy as jnp @@ -13,7 +12,7 @@ @jax.jit def _silhouette_reduce( D_chunk: jnp.ndarray, start: int, labels: jnp.ndarray, label_freqs: jnp.ndarray -) -> Tuple[jnp.ndarray, jnp.ndarray]: +) -> tuple[jnp.ndarray, jnp.ndarray]: """Accumulate silhouette statistics for vertical chunk of X. Follows scikit-learn implementation. diff --git a/src/scib_metrics/utils/_utils.py b/src/scib_metrics/utils/_utils.py index 43d4d0d..51ff901 100644 --- a/src/scib_metrics/utils/_utils.py +++ b/src/scib_metrics/utils/_utils.py @@ -1,5 +1,5 @@ import warnings -from typing import Optional, Tuple +from typing import Optional import jax import jax.numpy as jnp @@ -33,7 +33,7 @@ def one_hot(y: NdArray, n_classes: Optional[int] = None) -> jnp.ndarray: one_hot: jnp.ndarray Array of shape (n_cells, n_classes). """ - n_classes = n_classes or jnp.max(y) + 1 + n_classes = n_classes or int(jax.device_get(jnp.max(y))) + 1 return nn.one_hot(jnp.ravel(y), n_classes) @@ -48,7 +48,7 @@ def check_square(X: ArrayLike): raise ValueError("X must be a square matrix") -def convert_knn_graph_to_idx(X: csr_matrix) -> Tuple[np.ndarray, np.ndarray]: +def convert_knn_graph_to_idx(X: csr_matrix) -> tuple[np.ndarray, np.ndarray]: """Convert a kNN graph to indices and distances.""" check_array(X, accept_sparse="csr") check_square(X)