Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perturbation space - add/subtract #328

Merged
merged 5 commits into from
Aug 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pertpy/tools/_coda/_base_coda.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,7 @@ def credible_effects(self, data: AnnData | MuData, modality_key: str = "coda", e
model_type = sample_adata.uns["scCODA_params"]["model_type"]

# If other than None for est_fdr is specified, recalculate intercept and effect DataFrames
if type(est_fdr) == float:
if isinstance(est_fdr, float):
if est_fdr < 0 or est_fdr > 1:
raise ValueError("est_fdr must be between 0 and 1!")
else:
Expand Down Expand Up @@ -1353,7 +1353,7 @@ def from_scanpy(
Returns:
AnnData: A data set with cells aggregated to the (sample x cell type) level
"""
if type(sample_identifier) == str:
if isinstance(sample_identifier, str):
sample_identifier = [sample_identifier]

if covariate_obs:
Expand All @@ -1362,7 +1362,7 @@ def from_scanpy(
covariate_obs = sample_identifier # type: ignore

# join sample identifiers
if type(sample_identifier) == list:
if isinstance(sample_identifier, list):
adata.obs["scCODA_sample_id"] = adata.obs[sample_identifier].agg("-".join, axis=1)
sample_identifier = "scCODA_sample_id"

Expand Down
4 changes: 2 additions & 2 deletions pertpy/tools/_coda/_tasccoda.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def prepare(
raise ValueError("Please specify the key in .uns that contains the tree structure!")

# toytree tree - only for legacy reasons, can be removed in the final version
if type(adata.uns[tree_key]) == tt.tree:
if isinstance(adata.uns[tree_key], tt.tree):
# Collapse singularities in the tree
phy_tree = collapse_singularities(adata.uns[tree_key])

Expand Down Expand Up @@ -232,7 +232,7 @@ def prepare(
pen_args["node_leaves"] = np.delete(np.array(node_leaves[:-1]), refs)

# ete tree
elif type(adata.uns[tree_key]) == ete.Tree:
elif isinstance(adata.uns[tree_key], ete.Tree):
# Collapse singularities in the tree
phy_tree = collapse_singularities_2(adata.uns[tree_key])

Expand Down
248 changes: 237 additions & 11 deletions pertpy/tools/_perturbation_space/_perturbation_space.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from abc import ABC, abstractmethod
from __future__ import annotations

from typing import Iterable

import numpy as np
import pandas as pd
from anndata import AnnData
from rich import print


class PerturbationSpace:
Expand All @@ -12,6 +16,9 @@ class PerturbationSpace:
whereas in a perturbation space, data points summarize whole perturbations.
"""

def __init__(self):
self.control_diff_computed = False

def compute_control_diff( # type: ignore
self,
adata: AnnData,
Expand All @@ -21,6 +28,7 @@ def compute_control_diff( # type: ignore
new_layer_key: str = "control_diff",
embedding_key: str = None,
new_embedding_key: str = "control_diff",
all_data: bool = False,
copy: bool = False,
):
"""Subtract mean of the control from the perturbation.
Expand All @@ -33,33 +41,251 @@ def compute_control_diff( # type: ignore
new_layer_key: the results are stored in the given layer. Defaults to 'differential diff'.
embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise.
new_embedding_key: Results are stored in a new embedding in `obsm` with this key. Defaults to 'control diff'.
all_data: if True, do the computation in all data representations (X, all layers and all embeddings)
copy: If True returns a new Anndata of same size with the new column; otherwise it updates the initial AnnData object.
"""
if reference_key not in adata.obs[target_col].unique():
raise ValueError(
f"Reference key {reference_key} not found in {target_col}. {reference_key} must be in obs column {target_col}."
)

if embedding_key is not None and embedding_key not in adata.obsm_keys():
raise ValueError(
f"Reference key {reference_key} not found in {target_col}. {reference_key} must be in obs column {target_col}."
)

if layer_key is not None and layer_key not in adata.layers.keys():
raise ValueError(f"Layer {layer_key!r} does not exist in the anndata.")

if copy:
adata = adata.copy()

control_mask = adata.obs[target_col] == reference_key
num_control = control_mask.sum()

if layer_key:
diff_matrix = adata.layers[layer_key] - np.mean(adata.layers[layer_key][~control_mask, :], axis=0)
adata[new_layer_key] = diff_matrix
if num_control == 1:
control_expression = adata.layers[layer_key][control_mask, :]
else:
control_expression = np.mean(adata.layers[layer_key][control_mask, :], axis=0)
diff_matrix = adata.layers[layer_key] - control_expression
adata.layers[new_layer_key] = diff_matrix

elif embedding_key:
diff_matrix = adata.obsm[embedding_key] - np.mean(adata.obsm[embedding_key][~control_mask, :], axis=0)
if embedding_key:
if num_control == 1:
control_expression = adata.obsm[embedding_key][control_mask, :]
else:
control_expression = np.mean(adata.obsm[embedding_key][control_mask, :], axis=0)
diff_matrix = adata.obsm[embedding_key] - control_expression
adata.obsm[new_embedding_key] = diff_matrix
else:
diff_matrix = adata.X - np.mean(adata.X[~control_mask, :], axis=0)

if (not layer_key and not embedding_key) or all_data:
if num_control == 1:
control_expression = adata.X[control_mask, :]
else:
control_expression = np.mean(adata.X[control_mask, :], axis=0)
diff_matrix = adata.X - control_expression
adata.X = diff_matrix

if all_data:
layers_keys = list(adata.layers.keys())
for local_layer_key in layers_keys:
if local_layer_key != layer_key and local_layer_key != new_layer_key:
diff_matrix = adata.layers[local_layer_key] - np.mean(
adata.layers[local_layer_key][control_mask, :], axis=0
)
adata.layers[local_layer_key + "_control_diff"] = diff_matrix

embedding_keys = list(adata.obsm_keys())
for local_embedding_key in embedding_keys:
if local_embedding_key != embedding_key and local_embedding_key != new_embedding_key:
diff_matrix = adata.obsm[local_embedding_key] - np.mean(
adata.obsm[local_embedding_key][control_mask, :], axis=0
)
adata.obsm[local_embedding_key + "_control_diff"] = diff_matrix

self.control_diff_computed = True

return adata

def add(self):
raise NotImplementedError
def add(
self,
adata: AnnData,
perturbations: Iterable[str],
reference_key: str = "control",
ensure_consistency: bool = False,
):
"""Add perturbations linearly. Assumes input of size n_perts x dimensionality

Args:
adata: Anndata object of size n_perts x dim.
perturbations: Perturbations to add.
reference_key: perturbation source from which the perturbation summation starts.
ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space.
"""
new_pert_name = ""
for perturbation in perturbations:
if perturbation not in adata.obs_names:
raise ValueError(
f"Perturbation {reference_key} not found in adata.obs_names. {reference_key} must be in adata.obs_names."
)
new_pert_name += perturbation + "+"

if not ensure_consistency:
print(
"[bold yellow]Operation might be done in non-consistent space (perturbation - perturbation != control). \n"
"Subtract control perturbation needed for consistency of space in all data representations. \n"
"Run with ensure_consistency=True"
)
else:
adata = self.compute_control_diff(adata, copy=True, all_data=True)

data: dict[str, np.array] = {}

for local_layer_key in adata.layers.keys():
data["layers"] = {}
control_local = adata[reference_key].layers[local_layer_key].copy()
for perturbation in perturbations:
control_local += adata[perturbation].layers[local_layer_key]
original_data = adata.layers[local_layer_key].copy()
new_data = np.concatenate((original_data, control_local))
data["layers"][local_layer_key] = new_data

for local_embedding_key in adata.obsm_keys():
data["embeddings"] = {}
control_local = adata[reference_key].obsm[local_embedding_key].copy()
for perturbation in perturbations:
control_local += adata[perturbation].obsm[local_embedding_key]
original_data = adata.obsm[local_embedding_key].copy()
new_data = np.concatenate((original_data, control_local))
data["embeddings"][local_embedding_key] = new_data

# Operate in X
control = adata[reference_key].X.copy()
for perturbation in perturbations:
control += adata[perturbation].X

# Fill all obs fields with NaNs
new_pert_obs = [np.nan for _ in adata.obs]

original_data = adata.X.copy()
new_data = np.concatenate((original_data, control))
new_perturbation = AnnData(X=new_data)

original_obs_names = adata.obs_names
new_obs_names = original_obs_names.append(pd.Index([new_pert_name[:-1]]))
new_perturbation.obs_names = new_obs_names

new_obs = adata.obs.copy()
new_obs.loc[new_pert_name[:-1]] = new_pert_obs
new_perturbation.obs = new_obs

if "layers" in data.keys():
for key in data["layers"]:
key_name = key
if key.endswith("_control_diff"):
key_name = key.remove_suffix("_control_diff")
new_perturbation.layers[key_name] = data["layers"][key]

if "embeddings" in data.keys():
key_name = key
for key in data["embeddings"]:
if key.endswith("_control_diff"):
key_name = key.remove_suffix("_control_diff")
new_perturbation.obsm[key_name] = data["embeddings"][key]

if ensure_consistency:
return new_perturbation, adata

return new_perturbation

def subtract(
self,
adata: AnnData,
perturbations: Iterable[str],
reference_key: str = "control",
ensure_consistency: bool = False,
):
"""Subtract perturbations linearly. Assumes input of size n_perts x dimensionality

Args:
adata: Anndata object of size n_perts x dim.
perturbations: Perturbations to subtract,
reference_key: Perturbation source from which the perturbation subtraction starts
ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space.
"""
new_pert_name = reference_key + "-"
for perturbation in perturbations:
if perturbation not in adata.obs_names:
raise ValueError(
f"Perturbation {reference_key} not found in adata.obs_names. {reference_key} must be in adata.obs_names."
)
new_pert_name += perturbation + "-"

if not ensure_consistency:
print(
"[bold yellow]Operation might be done in non-consistent space (perturbation - perturbation != control).\n"
"Subtract control perturbation needed for consistency of space in all data representations.\n"
"Run with ensure_consistency=True"
)
else:
adata = self.compute_control_diff(adata, copy=True, all_data=True)

data: dict[str, np.array] = {}

for local_layer_key in adata.layers.keys():
data["layers"] = {}
control_local = adata[reference_key].layers[local_layer_key].copy()
for perturbation in perturbations:
control_local -= adata[perturbation].layers[local_layer_key]
original_data = adata.layers[local_layer_key].copy()
new_data = np.concatenate((original_data, control_local))
data["layers"][local_layer_key] = new_data

for local_embedding_key in adata.obsm_keys():
data["embeddings"] = {}
control_local = adata[reference_key].obsm[local_embedding_key].copy()
for perturbation in perturbations:
control_local -= adata[perturbation].obsm[local_embedding_key]
original_data = adata.obsm[local_embedding_key].copy()
new_data = np.concatenate((original_data, control_local))
data["embeddings"][local_embedding_key] = new_data

# Operate in X
control = adata[reference_key].X.copy()
for perturbation in perturbations:
control -= adata[perturbation].X

# Fill all obs fields with NaNs
new_pert_obs = [np.nan for _ in adata.obs]

original_data = adata.X.copy()
new_data = np.concatenate((original_data, control))
new_perturbation = AnnData(X=new_data)

original_obs_names = adata.obs_names
new_obs_names = original_obs_names.append(pd.Index([new_pert_name[:-1]]))
new_perturbation.obs_names = new_obs_names

new_obs = adata.obs.copy()
new_obs.loc[new_pert_name[:-1]] = new_pert_obs
new_perturbation.obs = new_obs

if "layers" in data.keys():
for key in data["layers"]:
key_name = key
if key.endswith("_control_diff"):
key_name = key.remove_suffix("_control_diff")
new_perturbation.layers[key_name] = data["layers"][key]

if "embeddings" in data.keys():
key_name = key
for key in data["embeddings"]:
if key.endswith("_control_diff"):
key_name = key.remove_suffix("_control_diff")
new_perturbation.obsm[key_name] = data["embeddings"][key]

if ensure_consistency:
return new_perturbation, adata

def subtract(self):
raise NotImplementedError
return new_perturbation
2 changes: 1 addition & 1 deletion pertpy/tools/_perturbation_space/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def compute( # type: ignore
self.X = adata.obsm[embedding_key]

elif layer_key is not None:
if layer_key not in adata.obsm_keys():
if layer_key not in adata.layers.keys():
raise ValueError(f"Layer {layer_key!r} does not exist in the anndata.")
else:
self.X = adata.layers[layer_key]
Expand Down
2 changes: 1 addition & 1 deletion tests/tools/_distances/test_distance_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_distancetest(self, adata, distance):
tab = etest(adata, groupby="perturbation", contrast="control")
# Well-defined output
assert tab.shape[1] == 5
assert type(tab) == DataFrame
assert isinstance(tab, DataFrame)
# p-values are in [0,1]
assert tab["pvalue"].min() >= 0
assert tab["pvalue"].max() <= 1
Expand Down
Loading