Skip to content

Commit

Permalink
Drop Python 3.8 and fix one-hot error (#107)
Browse files Browse the repository at this point in the history
* Update pyproject.toml

* Update README.md

* Update pyproject.toml

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update pyproject.toml

* Update test.yaml

* Update pyproject.toml

* Update pyproject.toml

* Update pyproject.toml

* Update pyproject.toml

* Update pyproject.toml

* Update _utils.py

* Update CHANGELOG.md

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
adamgayoso and pre-commit-ci[bot] authored Sep 18, 2023
1 parent 5681f7e commit 81737c0
Show file tree
Hide file tree
Showing 11 changed files with 31 additions and 26 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python: ["3.8", "3.10"]
python: ["3.9", "3.10"]
os: [ubuntu-latest]

env:
Expand Down
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ and this project adheres to [Semantic Versioning][].
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
[semantic versioning]: https://semver.org/spec/v2.0.0.html

## 0.4.0 (2023-09-19)

- Drop Python 3.8 ([#107][])
- Fix jax one-hot error ([#107][])

[#107]: https://github.com/YosefLab/scib-metrics/pull/107

## 0.3.3 (2023-03-29)

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Please refer to the [documentation][link-docs].

## Installation

You need to have Python 3.8 or newer installed on your system. If you don't have
You need to have Python 3.9 or newer installed on your system. If you don't have
Python installed, we recommend installing [Miniconda](https://docs.conda.io/en/latest/miniconda.html).

There are several alternative options to install scib-metrics:
Expand Down
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ name = "scib-metrics"
version = "0.3.3"
description = "Accelerated and Python-only scIB metrics"
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.9"
license = {file = "LICENSE"}
authors = [
{name = "Adam Gayoso"},
Expand All @@ -28,7 +28,7 @@ dependencies = [
"pandas",
"scipy",
"scikit-learn",
"scanpy",
"scanpy>=1.9",
"rich",
"pynndescent",
"igraph>0.9.0",
Expand All @@ -53,6 +53,7 @@ doc = [
"ipython",
"ipykernel",
"sphinx-copybutton",
"numba>=0.57.1",
]
test = [
"pytest",
Expand All @@ -63,6 +64,7 @@ test = [
# For vscode Python extension testing
"flake8",
"black",
"numba>=0.57.1",
]
parallel = [
"joblib"
Expand Down Expand Up @@ -95,8 +97,7 @@ xfail_strict = true

[tool.ruff]
src = ["src"]
line-length = 119
target-version = "py38"
line-length = 120
select = [
"F", # Errors detected by Pyflakes
"E", # Error detected by Pycodestyle
Expand Down Expand Up @@ -152,7 +153,6 @@ convention = "numpy"

[tool.black]
line-length = 120
target-version = ['py38']
include = '\.pyi?$'
exclude = '''
(
Expand Down
4 changes: 2 additions & 2 deletions src/scib_metrics/_kbet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from functools import partial
from typing import Tuple, Union
from typing import Union

import chex
import jax
Expand Down Expand Up @@ -102,7 +102,7 @@ def kbet_per_label(
alpha: float = 0.05,
diffusion_n_comps: int = 100,
return_df: bool = False,
) -> Union[float, Tuple[float, pd.DataFrame]]:
) -> Union[float, tuple[float, pd.DataFrame]]:
"""Compute kBET score per cell type label as in :cite:p:`luecken2022benchmarking`.
This approximates the method used in the original scib package. Notably, the underlying
Expand Down
7 changes: 3 additions & 4 deletions src/scib_metrics/_nmi_ari.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import warnings
from typing import Dict, Tuple

import numpy as np
import scanpy as sc
Expand Down Expand Up @@ -30,14 +29,14 @@ def _compute_nmi_ari_cluster_labels(
X: np.ndarray,
labels: np.ndarray,
resolution: float = 1.0,
) -> Tuple[float, float]:
) -> tuple[float, float]:
labels_pred = _compute_clustering_leiden(X, resolution)
nmi = normalized_mutual_info_score(labels, labels_pred, average_method="arithmetic")
ari = adjusted_rand_score(labels, labels_pred)
return nmi, ari


def nmi_ari_cluster_labels_kmeans(X: np.ndarray, labels: np.ndarray) -> Dict[str, float]:
def nmi_ari_cluster_labels_kmeans(X: np.ndarray, labels: np.ndarray) -> dict[str, float]:
"""Compute nmi and ari between k-means clusters and labels.
This deviates from the original implementation in scib by using k-means
Expand Down Expand Up @@ -69,7 +68,7 @@ def nmi_ari_cluster_labels_kmeans(X: np.ndarray, labels: np.ndarray) -> Dict[str

def nmi_ari_cluster_labels_leiden(
X: spmatrix, labels: np.ndarray, optimize_resolution: bool = True, resolution: float = 1.0, n_jobs: int = 1
) -> Dict[str, float]:
) -> dict[str, float]:
"""Compute nmi and ari between leiden clusters and labels.
This deviates from the original implementation in scib by using leiden instead of
Expand Down
6 changes: 3 additions & 3 deletions src/scib_metrics/benchmark/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import asdict, dataclass
from enum import Enum
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Optional, Union

import matplotlib
import matplotlib.pyplot as plt
Expand All @@ -20,7 +20,7 @@
import scib_metrics
from scib_metrics.nearest_neighbors import NeighborsOutput, pynndescent

Kwargs = Dict[str, Any]
Kwargs = dict[str, Any]
MetricType = Union[bool, Kwargs]

_LABELS = "labels"
Expand Down Expand Up @@ -131,7 +131,7 @@ def __init__(
adata: AnnData,
batch_key: str,
label_key: str,
embedding_obsm_keys: List[str],
embedding_obsm_keys: list[str],
bio_conservation_metrics: Optional[BioConservation] = None,
batch_correction_metrics: Optional[BatchCorrection] = None,
pre_integrated_embedding_obsm_key: Optional[str] = None,
Expand Down
6 changes: 3 additions & 3 deletions src/scib_metrics/utils/_lisi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Tuple, Union
from typing import Union

import chex
import jax
Expand All @@ -23,7 +23,7 @@ class _NeighborProbabilityState:


@jax.jit
def _Hbeta(knn_dists_row: jnp.ndarray, beta: float) -> Tuple[jnp.ndarray, jnp.ndarray]:
def _Hbeta(knn_dists_row: jnp.ndarray, beta: float) -> tuple[jnp.ndarray, jnp.ndarray]:
P = jnp.exp(-knn_dists_row * beta)
sumP = jnp.nansum(P)
H = jnp.where(sumP == 0, 0, jnp.log(sumP) + beta * jnp.nansum(knn_dists_row * P) / sumP)
Expand All @@ -34,7 +34,7 @@ def _Hbeta(knn_dists_row: jnp.ndarray, beta: float) -> Tuple[jnp.ndarray, jnp.nd
@jax.jit
def _get_neighbor_probability(
knn_dists_row: jnp.ndarray, perplexity: float, tol: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
) -> tuple[jnp.ndarray, jnp.ndarray]:
beta = 1
betamin = -jnp.inf
betamax = jnp.inf
Expand Down
4 changes: 2 additions & 2 deletions src/scib_metrics/utils/_pca.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple
from typing import Optional

import jax.numpy as jnp
from chex import dataclass
Expand Down Expand Up @@ -128,7 +128,7 @@ def pca(
@jit
def _pca(
X: NdArray,
) -> Tuple[NdArray, NdArray, NdArray, NdArray, NdArray]:
) -> tuple[NdArray, NdArray, NdArray, NdArray, NdArray]:
"""Principal component analysis.
Parameters
Expand Down
3 changes: 1 addition & 2 deletions src/scib_metrics/utils/_silhouette.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from functools import partial
from typing import Tuple

import jax
import jax.numpy as jnp
Expand All @@ -13,7 +12,7 @@
@jax.jit
def _silhouette_reduce(
D_chunk: jnp.ndarray, start: int, labels: jnp.ndarray, label_freqs: jnp.ndarray
) -> Tuple[jnp.ndarray, jnp.ndarray]:
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Accumulate silhouette statistics for vertical chunk of X.
Follows scikit-learn implementation.
Expand Down
6 changes: 3 additions & 3 deletions src/scib_metrics/utils/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Optional, Tuple
from typing import Optional

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -33,7 +33,7 @@ def one_hot(y: NdArray, n_classes: Optional[int] = None) -> jnp.ndarray:
one_hot: jnp.ndarray
Array of shape (n_cells, n_classes).
"""
n_classes = n_classes or jnp.max(y) + 1
n_classes = n_classes or int(jax.device_get(jnp.max(y))) + 1
return nn.one_hot(jnp.ravel(y), n_classes)


Expand All @@ -48,7 +48,7 @@ def check_square(X: ArrayLike):
raise ValueError("X must be a square matrix")


def convert_knn_graph_to_idx(X: csr_matrix) -> Tuple[np.ndarray, np.ndarray]:
def convert_knn_graph_to_idx(X: csr_matrix) -> tuple[np.ndarray, np.ndarray]:
"""Convert a kNN graph to indices and distances."""
check_array(X, accept_sparse="csr")
check_square(X)
Expand Down

0 comments on commit 81737c0

Please sign in to comment.