Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Simplify plotting #1109

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 45 additions & 13 deletions scanpy/get.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""This module contains helper functions for accessing data."""
from typing import Optional, Iterable, Tuple
import warnings

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -134,28 +135,44 @@ def obs_df(
gene_names = pd.Series(adata.var_names, index=adata.var_names)
lookup_keys = []
not_found = []
found_twice = []
for key in keys:
in_obs, in_var_index = False, False
if key in adata.obs.columns:
lookup_keys.append(key)
elif key in gene_names.index:
lookup_keys.append(gene_names[key])
else:
in_obs = True
if key in gene_names.index:
in_var_index = True
if not in_obs:
lookup_keys.append(gene_names[key])
# Test failure cases
if not (in_obs or in_var_index):
not_found.append(key)
if len(not_found) > 0:
elif in_obs and in_var_index:
found_twice.append(key)
if len(not_found) > 0 or len(found_twice) > 0:
if use_raw:
if gene_symbols is None:
gene_error = "`adata.raw.var_names`"
else:
gene_error = "gene_symbols column `adata.raw.var[{}].values`".format(gene_symbols)
gene_error = f"gene_symbols column `adata.raw.var['{gene_symbols}']`"
else:
if gene_symbols is None:
gene_error = "`adata.var_names`"
else:
gene_error = "gene_symbols column `adata.var[{}].values`".format(gene_symbols)
raise KeyError(
f"Could not find keys '{not_found}' in columns of `adata.obs` or in"
f" {gene_error}."
)
gene_error = f"gene_symbols column `adata.var['{gene_symbols}']`"
if len(found_twice) > 0:
warnings.warn(
f"Found keys {found_twice} in columns of `obs` and in `{gene_error}`. \n\n"
"This will be an error in a future version of scanpy, but interpreting"
" as a variable name for now.",
FutureWarning
)
else:
raise KeyError(
f"Could not find keys '{not_found}' in columns of `adata.obs` or in"
f" {gene_error}."
)

# Make df
df = pd.DataFrame(index=adata.obs_names)
Expand Down Expand Up @@ -205,13 +222,28 @@ def var_df(
# Argument handling
lookup_keys = []
not_found = []
found_twice = []
for key in keys:
in_var, in_obs_index = False, False
if key in adata.var.columns:
in_var = True
lookup_keys.append(key)
elif key in adata.obs_names:
lookup_keys.append(key)
else:
if key in adata.obs_names:
in_obs_index = True
if not in_var:
lookup_keys.append(key)
# Test failure cases
if not (in_var or in_obs_index):
not_found.append(key)
elif in_var and in_obs_index:
found_twice.append(key)
if len(found_twice) > 0:
warnings.warn(
f"Found keys {found_twice} in columns of `var` and in `adata.obs_names`. \n\n"
"This will be an error in a future version of scanpy, but interpreting"
" as a observation name for now.",
FutureWarning
)
if len(not_found) > 0:
raise KeyError(
f"Could not find keys '{not_found}' in columns of `adata.var` or"
Expand Down
126 changes: 24 additions & 102 deletions scanpy/plotting/_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2828,115 +2828,37 @@ def correlation_matrix(


def _prepare_dataframe(
adata: AnnData,
var_names: Union[_VarNames, Mapping[str, _VarNames]],
groupby: Optional[str] = None,
use_raw: Optional[bool] = None,
log: bool = False,
num_categories: int = 7,
adata,
var_names,
groupby=None,
use_raw=None,
log=False,
num_categories=7,
gene_symbols=None,
layer=None,
gene_symbols: Optional[str] = None,
):
"""
Given the anndata object, prepares a data frame in which the row index are the categories
defined by group by and the columns correspond to var_names.

Parameters
----------
adata
Annotated data matrix.
var_names
`var_names` should be a valid subset of `adata.var_names`.
groupby
The key of the observation grouping to consider. It is expected that
groupby is a categorical. If groupby is not a categorical observation,
it would be subdivided into `num_categories`.
use_raw
Use `raw` attribute of `adata` if present.
log
Use the log of the values
num_categories
Only used if groupby observation is not categorical. This value
determines the number of groups into which the groupby observation
should be subdivided.
gene_symbols
Key for field in .var that stores gene symbols.

Returns
-------
Tuple of `pandas.DataFrame` and list of categories.
"""
from scipy.sparse import issparse

sanitize_anndata(adata)
# Backwards compat
if use_raw is None and adata.raw is not None:
use_raw = True
if isinstance(var_names, str):
var_names = [var_names]

# sanitize_anndata(adata)
tidy_df = get.obs_df(
adata,
var_names if groupby is None else [groupby] + list(var_names),
gene_symbols=gene_symbols,
layer=layer,
use_raw=use_raw,
)
if groupby is not None:
if groupby not in adata.obs_keys():
raise ValueError(
'groupby has to be a valid observation. '
f'Given {groupby}, valid observations: {adata.obs_keys()}'
)

if gene_symbols is not None and gene_symbols in adata.var.columns:
# translate gene_symbols to var_names
# slow method but gives a meaningful error if no gene symbol is found:
translated_var_names = []
for symbol in var_names:
if symbol not in adata.var[gene_symbols].values:
logg.error(
f"Gene symbol {symbol!r} not found in given "
f"gene_symbols column: {gene_symbols!r}"
)
return
translated_var_names.append(
adata.var[adata.var[gene_symbols] == symbol].index[0]
)
symbols = var_names
var_names = translated_var_names
if layer is not None:
if layer not in adata.layers.keys():
raise KeyError(
f'Selected layer: {layer} is not in the layers list. '
f'The list of valid layers is: {adata.layers.keys()}'
)
matrix = adata[:, var_names].layers[layer]
elif use_raw:
matrix = adata.raw[:, var_names].X
if pd.api.types.is_categorical(adata.obs[groupby]):
tidy_df.set_index(groupby, inplace=True)
else:
tidy_df.index = pd.cut(tidy_df[groupby], num_categories)
# tidy_df.drop(columns=groupby, inplace=True)
else:
matrix = adata[:, var_names].X

if issparse(matrix):
matrix = matrix.toarray()
tidydf.index = pd.Series(np.repeat("", len(tidy_df))).astype("category")
if log:
matrix = np.log1p(matrix)

obs_tidy = pd.DataFrame(matrix, columns=var_names)
if groupby is None:
groupby = ''
categorical = pd.Series(np.repeat('', len(obs_tidy))).astype('category')
else:
if not is_categorical_dtype(adata.obs[groupby]):
# if the groupby column is not categorical, turn it into one
# by subdividing into `num_categories` categories
categorical = pd.cut(adata.obs[groupby], num_categories)
else:
categorical = adata.obs[groupby]

obs_tidy.set_index(categorical, groupby, inplace=True)
if gene_symbols is not None:
# translate the column names to the symbol names
obs_tidy.rename(
columns=dict([(var_names[x], symbols[x]) for x in range(len(var_names))]),
inplace=True,
)
categories = obs_tidy.index.categories

return categories, obs_tidy

tidy_df = np.log1p(tidy_df)
return (tidy_df.index.categories, tidy_df)

def _plot_gene_groups_brackets(
gene_groups_ax: Axes,
Expand Down
56 changes: 56 additions & 0 deletions scanpy/tests/test_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,42 @@ def test_obs_df(adata):
assert all(badkey_err.match(k) for k in badkeys)


def test_obs_df_key_collision(adata):
# Test that we warn on key collisions
orig = adata.copy()
adata.obs = adata.obs.join(
pd.DataFrame(
np.zeros(adata.shape),
columns=adata.var_names,
index=adata.obs_names,
)
)
# adata.obs = adata.obs.join(sc.get.obs_df(adata, ["gene1", "gene2"]))
with pytest.warns(FutureWarning, match=r"gene1.*gene2.*obs.*adata\.var_names"):
df = sc.get.obs_df(adata, ["gene1", "gene2"])
# TODO: Make this true:
# Until this throws an error, it should favor returning values from X for backwards compat.
pd.testing.assert_frame_equal(
df,
sc.get.obs_df(orig, ["gene1", "gene2"]),
check_dtype=False
)

# Test for gene_symbols
adata.obs = adata.obs.join(
sc.get.obs_df(
adata, ["genesymbol1", "genesymbol2"], gene_symbols="gene_symbols"
)
)
with pytest.warns(
FutureWarning,
match=r"genesymbol1.*genesymbol2.*obs.*adata\.var\['gene_symbols'\]",
):
sc.get.obs_df(
adata, ["genesymbol1", "genesymbol2"], gene_symbols="gene_symbols"
)


def test_var_df(adata):
adata.varm["eye"] = np.eye(2)
adata.varm["sparse"] = sparse.csr_matrix(np.eye(2))
Expand All @@ -69,6 +105,26 @@ def test_var_df(adata):
assert all(badkey_err.match(k) for k in badkeys)


def test_var_df_key_collision(adata):
# Test that we warn on key collisions
orig = adata.copy()
adata.var = adata.var.join(
pd.DataFrame(
np.zeros(adata.shape[::-1]),
columns=adata.obs_names,
index=adata.var_names,
)
)
with pytest.warns(FutureWarning, match=r"cell1.*cell2.*var.*adata\.obs_names"):
df = sc.get.var_df(adata, ["cell1", "cell2"])
# TODO: Make this true
pd.testing.assert_frame_equal(
df,
sc.get.var_df(orig, ["cell1", "cell2"]),
check_dtype=False
)


def test_rank_genes_groups_df():
a = np.zeros((20, 3))
a[:10, 0] = 5
Expand Down