Skip to content

Commit

Permalink
Simplify score_genes (#3097)
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep authored Jun 4, 2024
1 parent 4f40d68 commit 21aecd9
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 58 deletions.
6 changes: 1 addition & 5 deletions scanpy/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,7 @@ def _check_use_raw(adata: AnnData, use_raw: None | bool) -> bool:
"""
if use_raw is not None:
return use_raw
else:
if adata.raw is not None:
return True
else:
return False
return adata.raw is not None


# --------------------------------------------------------------------------------
Expand Down
19 changes: 9 additions & 10 deletions scanpy/get/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,22 +433,21 @@ def _get_obs_rep(
is_obsm = obsm is not None
is_obsp = obsp is not None
choices_made = sum((is_layer, is_raw, is_obsm, is_obsp))
assert choices_made <= 1
assert choices_made in {0, 1}
if choices_made == 0:
return adata.X
elif is_layer:
if is_layer:
return adata.layers[layer]
elif use_raw:
if use_raw:
return adata.raw.X
elif is_obsm:
if is_obsm:
return adata.obsm[obsm]
elif is_obsp:
if is_obsp:
return adata.obsp[obsp]
else:
assert False, (
"That was unexpected. Please report this bug at:\n\n\t"
" https://github.com/scverse/scanpy/issues"
)
raise AssertionError(
"That was unexpected. Please report this bug at:\n\n\t"
"https://github.com/scverse/scanpy/issues"
)


def _set_obs_rep(
Expand Down
9 changes: 6 additions & 3 deletions scanpy/tests/test_score_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
if TYPE_CHECKING:
from typing import Literal

HERE = Path(__file__).parent / Path("_data/")
from numpy.typing import NDArray


def _create_random_gene_names(n_genes, name_length):
HERE = Path(__file__).parent / "_data"


def _create_random_gene_names(n_genes, name_length) -> NDArray[np.str_]:
"""
creates a bunch of random gene names (just CAPS letters)
"""
Expand Down Expand Up @@ -68,7 +71,7 @@ def test_score_with_reference():
sc.pp.scale(adata)

sc.tl.score_genes(adata, gene_list=adata.var_names[:100], score_name="Test")
with Path(HERE, "score_genes_reference_paul2015.pkl").open("rb") as file:
with (HERE / "score_genes_reference_paul2015.pkl").open("rb") as file:
reference = pickle.load(file)
# np.testing.assert_allclose(reference, adata.obs["Test"].to_numpy())
np.testing.assert_array_equal(reference, adata.obs["Test"].to_numpy())
Expand Down
67 changes: 27 additions & 40 deletions scanpy/tools/_score_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@

from .. import logging as logg
from .._compat import old_positionals
from ..get import _get_obs_rep

if TYPE_CHECKING:
from collections.abc import Sequence
from typing import Literal

from anndata import AnnData
from numpy.typing import NDArray
from numpy.typing import DTypeLike, NDArray
from scipy.sparse import csc_matrix, csr_matrix

from .._utils import AnyRandom
Expand Down Expand Up @@ -143,20 +144,16 @@ def score_genes(
# Basically we need to compare genes against random genes in a matched
# interval of expression.

_adata = adata.raw if use_raw else adata
_adata_subset = (
_adata[:, gene_pool] if len(gene_pool) < len(_adata.var_names) else _adata
)
# average expression of genes
if issparse(_adata_subset.X):
obs_avg = pd.Series(
np.array(_sparse_nanmean(_adata_subset.X, axis=0)).flatten(),
index=gene_pool,
)
else:
obs_avg = pd.Series(np.nanmean(_adata_subset.X, axis=0), index=gene_pool)
def get_subset(genes: pd.Index[str]):
x = _get_obs_rep(adata, use_raw=use_raw)
if len(genes) == len(var_names):
return x
idx = var_names.get_indexer(genes)
return x[:, idx]

# Sometimes (and I don't know how) missing data may be there, with nansfor
# average expression of genes
obs_avg = pd.Series(_nan_means(get_subset(gene_pool), axis=0), index=gene_pool)
# Sometimes (and I don’t know how) missing data may be there, with NaNs for missing entries
obs_avg = obs_avg[np.isfinite(obs_avg)]

n_items = int(np.round(len(obs_avg) / (n_bins - 1)))
Expand All @@ -170,19 +167,11 @@ def score_genes(
r_genes = r_genes.to_series().sample(ctrl_size).index
control_genes = control_genes.union(r_genes.difference(gene_list))

X_list = _adata[:, gene_list].X
if issparse(X_list):
X_list = np.array(_sparse_nanmean(X_list, axis=1)).flatten()
else:
X_list = np.nanmean(X_list, axis=1, dtype="float64")

X_control = _adata[:, control_genes].X
if issparse(X_control):
X_control = np.array(_sparse_nanmean(X_control, axis=1)).flatten()
else:
X_control = np.nanmean(X_control, axis=1, dtype="float64")

score = X_list - X_control
means_list, means_control = (
_nan_means(get_subset(genes), axis=1, dtype="float64")
for genes in (gene_list, control_genes)
)
score = means_list - means_control

adata.obs[score_name] = pd.Series(
np.array(score).ravel(), index=adata.obs_names, dtype="float64"
Expand All @@ -200,6 +189,14 @@ def score_genes(
return adata if copy else None


def _nan_means(
x, *, axis: Literal[0, 1], dtype: DTypeLike | None = None
) -> NDArray[np.float64]:
if issparse(x):
return np.array(_sparse_nanmean(x, axis=axis)).flatten()
return np.nanmean(x, axis=axis, dtype=dtype)


@old_positionals("s_genes", "g2m_genes", "copy")
def score_genes_cell_cycle(
adata: AnnData,
Expand Down Expand Up @@ -253,25 +250,15 @@ def score_genes_cell_cycle(

adata = adata.copy() if copy else adata
ctrl_size = min(len(s_genes), len(g2m_genes))
# add s-score
score_genes(
adata, gene_list=s_genes, score_name="S_score", ctrl_size=ctrl_size, **kwargs
)
# add g2m-score
score_genes(
adata,
gene_list=g2m_genes,
score_name="G2M_score",
ctrl_size=ctrl_size,
**kwargs,
)
for genes, name in [(s_genes, "S_score"), (g2m_genes, "G2M_score")]:
score_genes(adata, genes, score_name=name, ctrl_size=ctrl_size, **kwargs)
scores = adata.obs[["S_score", "G2M_score"]]

# default phase is S
phase = pd.Series("S", index=scores.index)

# if G2M is higher than S, it's G2M
phase[scores.G2M_score > scores.S_score] = "G2M"
phase[scores["G2M_score"] > scores["S_score"]] = "G2M"

# if all scores are negative, it's G1...
phase[np.all(scores < 0, axis=1)] = "G1"
Expand Down

0 comments on commit 21aecd9

Please sign in to comment.