Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add two distance metrics, three-way comparison and bootstrapping #608

Merged
merged 36 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
fd21dc1
add two distance metrics
wxicu May 26, 2024
3af7d89
add obsm_key param to distance test
wxicu May 26, 2024
3fe911b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 26, 2024
6d419c3
add agg fct
wxicu Jun 2, 2024
ca86025
speed up tests
wxicu Jun 3, 2024
0830535
Merge branch 'main' into distance
wxicu Jun 3, 2024
9fd4c2b
add type
wxicu Jun 3, 2024
fc71eae
add description
wxicu Jun 3, 2024
09e5fea
Update pertpy/tools/_distances/_distances.py
wxicu Jun 5, 2024
ad23ca6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 5, 2024
dc74884
Update pertpy/tools/_distances/_distances.py
wxicu Jun 5, 2024
c774cd2
Update pertpy/tools/_distances/_distances.py
wxicu Jun 5, 2024
d413d67
Update pertpy/tools/_distances/_distances.py
wxicu Jun 5, 2024
b7f2cf7
Update pertpy/tools/_distances/_distances.py
wxicu Jun 5, 2024
e71f81c
Update pertpy/tools/_distances/_distances.py
wxicu Jun 5, 2024
edaa6e6
Update pertpy/tools/_distances/_distances.py
wxicu Jun 5, 2024
317cfd5
update code
wxicu Jun 6, 2024
47b4134
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 6, 2024
d261410
fix drug
wxicu Jun 6, 2024
4fd29a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 6, 2024
30baefe
add bootstrapping and metrics_3g
wxicu Jun 7, 2024
2c8127c
speed up tests,
wxicu Jun 7, 2024
57b14c3
remove test classes
wxicu Jun 10, 2024
78d00fa
drop test classes
wxicu Jun 10, 2024
052fd00
update compare_de
wxicu Jun 12, 2024
3a8eac6
correct the comments
wxicu Jun 12, 2024
63ed17a
speed tests
wxicu Jun 13, 2024
f9e0d36
speed up tests
wxicu Jun 13, 2024
2e65f9d
split metrics_3g
wxicu Jun 18, 2024
2e7acf3
fix pre-commit
wxicu Jun 18, 2024
69163ff
pin numpy <2
wxicu Jun 19, 2024
67c54be
unpin numpy
wxicu Jun 20, 2024
6e32f37
speed up mahalanobis distance
wxicu Jun 20, 2024
620e645
use scipy to calculate mahalanobis distance
wxicu Jun 20, 2024
10d3483
rename DGE to DGEEVAL
wxicu Jun 23, 2024
4a07252
Merge branch 'main' into distance
wxicu Jun 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions pertpy/tools/_distances/_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,15 +1158,8 @@ def __init__(self, aggregation_func: Callable = np.mean) -> None:
self.aggregation_func = aggregation_func

def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
cov = np.cov(X.T)
if cov.shape[0] == cov.shape[1] and np.linalg.matrix_rank(cov) == cov.shape[0]: # check invertiblity
inverse_cov = np.linalg.inv(cov)
else: # if not invertible, compute the (Moore-Penrose) pseudo-inverse of a matrix
inverse_cov = np.linalg.pinv(cov)
return mahalanobis(
self.aggregation_func(X, axis=0),
self.aggregation_func(Y, axis=0),
inverse_cov,
self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0), np.linalg.inv(np.cov(X.T))
)

def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
Expand Down
46 changes: 14 additions & 32 deletions pertpy/tools/_metrics_3g.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
from numpy.typing import NDArray


def compare_de(
X: np.ndarray, Y: np.ndarray, C: np.ndarray, shared_top: int = 100, **kwargs
) -> dict:
def compare_de(X: np.ndarray, Y: np.ndarray, C: np.ndarray, shared_top: int = 100, **kwargs) -> dict:
"""Compare DEG across real and simulated perturbations.

Computes DEG for real and simulated perturbations vs. control and calculates
Expand Down Expand Up @@ -50,39 +48,25 @@ def compare_de(
for group in ("x", "y"):
adata_joint = ad.concat((adatas_xy[group], adata_c), index_unique="-")

sc.tl.rank_genes_groups(
adata_joint, groupby="label", reference="ctrl", key_added="de", **kwargs
)
sc.tl.rank_genes_groups(adata_joint, groupby="label", reference="ctrl", key_added="de", **kwargs)

srt_idx = np.argsort(adata_joint.uns["de"]["names"]["comp"])
results[f"scores_{group}"] = adata_joint.uns["de"]["scores"]["comp"][srt_idx]
results[f"pvals_adj_{group}"] = adata_joint.uns["de"]["pvals_adj"]["comp"][
srt_idx
]
results[f"pvals_adj_{group}"] = adata_joint.uns["de"]["pvals_adj"]["comp"][srt_idx]
results[f"ranks_{group}"] = vars_ranks[srt_idx]

top_names.append(adata_joint.uns["de"]["names"]["comp"][:shared_top])

metrics = {}
metrics["shared_top_genes"] = (
len(set(top_names[0]).intersection(top_names[1])) / shared_top
)
metrics["scores_corr"] = results["scores_x"].corr(
results["scores_y"], method="pearson"
)
metrics["pvals_adj_corr"] = results["pvals_adj_x"].corr(
results["pvals_adj_y"], method="pearson"
)
metrics["scores_ranks_corr"] = results["ranks_x"].corr(
results["ranks_y"], method="spearman"
)
metrics["shared_top_genes"] = len(set(top_names[0]).intersection(top_names[1])) / shared_top
metrics["scores_corr"] = results["scores_x"].corr(results["scores_y"], method="pearson")
metrics["pvals_adj_corr"] = results["pvals_adj_x"].corr(results["pvals_adj_y"], method="pearson")
metrics["scores_ranks_corr"] = results["ranks_x"].corr(results["ranks_y"], method="spearman")

return metrics


def compare_class(
X: np.ndarray, Y: np.ndarray, C: np.ndarray, clf: Optional[ClassifierMixin] = None
) -> float:
def compare_class(X: np.ndarray, Y: np.ndarray, C: np.ndarray, clf: ClassifierMixin | None = None) -> float:
"""Compare classification accuracy between real and simulated perturbations.

Trains a classifier on the real perturbation data + the control data and reports a normalized
Expand Down Expand Up @@ -117,7 +101,7 @@ def compare_class(
def compare_knn(
X: np.ndarray,
Y: np.ndarray,
C: Optional[np.ndarray] = None,
C: np.ndarray | None = None,
n_neighbors: int = 20,
use_Y_knn: bool = False,
random_state: int = 0,
Expand Down Expand Up @@ -169,8 +153,8 @@ def compare_knn(

uq, uq_counts = np.unique(labels[indices], return_counts=True)
uq_counts_norm = uq_counts / uq_counts.sum()
counts = dict(zip(label_groups, [0.0] * len(label_groups)))
for group, count_norm in zip(uq, uq_counts_norm):
counts = dict(zip(label_groups, [0.0] * len(label_groups), strict=False))
for group, count_norm in zip(uq, uq_counts_norm, strict=False):
counts[group] = count_norm

return counts
Expand All @@ -192,7 +176,7 @@ def compare_dist(
pert: Real perturbed data.
pred: Simulated perturbed data.
ctrl: Control data
kind: Kind of metric to use.
mode: Mode to use.
"""
metric_fct = partial(Distance(metric).metric_fct, **metric_kwds)

Expand All @@ -201,13 +185,11 @@ def compare_dist(
elif mode == "scaled":
from sklearn.preprocessing import MinMaxScaler

scaler = MinMaxScaler().fit(
np.vstack((pert, ctrl)) if _fit_to_pert_and_ctrl else ctrl
)
scaler = MinMaxScaler().fit(np.vstack((pert, ctrl)) if _fit_to_pert_and_ctrl else ctrl)
pred = scaler.transform(pred)
pert = scaler.transform(pert)
else:
raise ValueError(f"Unknown mode {mod}. Please choose simple or scaled.")
raise ValueError(f"Unknown mode {mode}. Please choose simple or scaled.")

d1 = metric_fct(pert, pred)
d2 = metric_fct(ctrl, pred)
Expand Down
65 changes: 25 additions & 40 deletions tests/tools/_distances/test_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,56 +78,50 @@ def distance_obj(self, request):
Distance = pt.tl.Distance(distance, obsm_key="X_pca")
return Distance

@mark.parametrize("distance", actual_distances + semi_distances + non_distances)
def test_distance(self, adata, distance):
Distance = pt.tl.Distance(distance, obsm_key="X_pca")
df = Distance.pairwise(adata, groupby="perturbation", show_progressbar=True)

assert isinstance(df, DataFrame)
@fixture
@mark.parametrize("distance", all_distances)
def pairwise_distance(self, adata, distance_obj, distance):
return distance_obj.pairwise(adata, groupby="perturbation", show_progressbar=True)

@mark.parametrize("distance", actual_distances + semi_distances)
def test_distance_axioms(self, adata, distance):
def test_distance_axioms(self, pairwise_distance, distance):
# This is equivalent to testing for a semimetric, defined as fulfilling all axioms except triangle inequality.
Distance = pt.tl.Distance(distance, obsm_key="X_pca")
df = Distance.pairwise(adata, groupby="perturbation", show_progressbar=True)

# (M1) Definiteness
assert all(np.diag(df.values) == 0) # distance to self is 0
assert all(np.diag(pairwise_distance.values) == 0) # distance to self is 0

# (M2) Positivity
assert len(df) == np.sum(df.values == 0) # distance to other is not 0 (TODO)
assert all(df.values.flatten() >= 0) # distance is non-negative
assert len(pairwise_distance) == np.sum(pairwise_distance.values == 0) # distance to other is not 0
assert all(pairwise_distance.values.flatten() >= 0) # distance is non-negative

# (M3) Symmetry
assert np.sum(df.values - df.values.T) == 0
assert np.sum(pairwise_distance.values - pairwise_distance.values.T) == 0

@mark.parametrize("distance", actual_distances)
def test_triangle_inequality(self, adata, distance):
def test_triangle_inequality(self, pairwise_distance, distance):
# Test if distances are well-defined in accordance with metric axioms
Distance = pt.tl.Distance(distance, obsm_key="X_pca")
df = Distance.pairwise(adata, groupby="perturbation", show_progressbar=True)

# (M4) Triangle inequality (we just probe this for a few random triplets)
for _i in range(10):
rng = np.random.default_rng()
triplet = rng.choice(df.index, size=3, replace=False)
assert df.loc[triplet[0], triplet[1]] + df.loc[triplet[1], triplet[2]] >= df.loc[triplet[0], triplet[2]]
triplet = rng.choice(pairwise_distance.index, size=3, replace=False)
assert (
pairwise_distance.loc[triplet[0], triplet[1]] + pairwise_distance.loc[triplet[1], triplet[2]]
>= pairwise_distance.loc[triplet[0], triplet[2]]
)

@mark.parametrize("distance", all_distances)
def test_distance_layers(self, adata, distance_obj, distance):
df = distance_obj.pairwise(adata, groupby="perturbation")

assert isinstance(df, DataFrame)
assert df.columns.equals(df.index)
assert np.sum(df.values - df.values.T) == 0 # symmetry
def test_distance_layers(self, pairwise_distance, distance):
assert isinstance(pairwise_distance, DataFrame)
assert pairwise_distance.columns.equals(pairwise_distance.index)
assert np.sum(pairwise_distance.values - pairwise_distance.values.T) == 0 # symmetry

@mark.parametrize("distance", actual_distances + pseudo_counts_distances)
def test_distance_counts(self, adata, distance):
Distance = pt.tl.Distance(distance, layer_key="counts")
df = Distance.pairwise(adata, groupby="perturbation")
assert isinstance(df, DataFrame)
assert df.columns.equals(df.index)
assert np.sum(df.values - df.values.T) == 0
if distance != "mahalanobis": # doesn't work, covariance matrix is a singular matrix, not invertible
Distance = pt.tl.Distance(distance, layer_key="counts")
df = Distance.pairwise(adata, groupby="perturbation")
assert isinstance(df, DataFrame)
assert df.columns.equals(df.index)
assert np.sum(df.values - df.values.T) == 0

@mark.parametrize("distance", all_distances)
def test_mutually_exclusive_keys(self, distance):
Expand All @@ -144,15 +138,6 @@ def test_distance_output_type(self, distance):
d = Distance(X, Y)
assert isinstance(d, float)

@mark.parametrize("distance", all_distances)
def test_distance_pairwise(self, adata, distance_obj, distance):
# Test consistency of pairwise distance results
df = distance_obj.pairwise(adata, groupby="perturbation")

assert isinstance(df, DataFrame)
assert df.columns.equals(df.index)
assert np.sum(df.values - df.values.T) == 0 # symmetry

@mark.parametrize("distance", all_distances + onesided_only)
def test_distance_onesided(self, adata, distance_obj, distance):
# Test consistency of one-sided distance results
Expand Down
55 changes: 55 additions & 0 deletions tests/tools/test_metrics_3g.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pertpy as pt
import numpy as np
import pytest
from pertpy.tools._metrics_3g import (
compare_de,
compare_class,
compare_dist,
compare_knn,
)


class TestMetrics3G:
wxicu marked this conversation as resolved.
Show resolved Hide resolved
@pytest.fixture
def test_data(self):
rng = np.random.default_rng()
X = rng.normal(size=(100, 10))
Y = rng.normal(size=(100, 10))
C = rng.normal(size=(100, 10))
return X, Y, C

def test_compare_de(self, test_data):
X, Y, C = test_data
result = compare_de(X, Y, C, shared_top=5)
assert isinstance(result, dict)
required_keys = {
"shared_top_genes",
"scores_corr",
"pvals_adj_corr",
"scores_ranks_corr",
}
assert all(key in result for key in required_keys)

def test_compare_class(self, test_data):
X, Y, C = test_data
result = compare_class(X, Y, C)
assert result <= 1

def test_compare_knn(self, test_data):
X, Y, C = test_data
result = compare_knn(X, Y, C)
assert isinstance(result, dict)
assert "comp" in result
assert isinstance(result["comp"], float)

result_no_ctrl = compare_knn(X, Y)
assert isinstance(result_no_ctrl, dict)

def test_compare_dist(self, test_data):
X, Y, C = test_data
res_simple = compare_dist(X, Y, C, mode="simple")
assert isinstance(res_simple, float)
res_scaled = compare_dist(X, Y, C, mode="scaled")
assert isinstance(res_scaled, float)
with pytest.raises(ValueError):
compare_dist(X, Y, C, mode="new_mode")
Loading