Skip to content

Commit

Permalink
Merge pull request #19 from icbi-lab/fix-plotting-dense-matrix
Browse files Browse the repository at this point in the history
fix plotting dense matrix
  • Loading branch information
grst authored Sep 6, 2021
2 parents c1ed198 + 546d6b5 commit 7ee25b0
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 16 deletions.
32 changes: 17 additions & 15 deletions infercnvpy/pl/_chromosome_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from scanpy.plotting._utils import savefig_or_show
import pandas as pd
from scipy.sparse import issparse


def chromosome_heatmap(
Expand All @@ -17,7 +18,7 @@ def chromosome_heatmap(
figsize: Tuple[int, int] = (16, 10),
show: Optional[bool] = None,
save: Union[str, bool, None] = None,
**kwargs
**kwargs,
) -> Optional[Dict[str, matplotlib.axes.Axes]]:
"""
Plot a heatmap of smoothed gene expression by chromosome.
Expand Down Expand Up @@ -80,7 +81,7 @@ def chromosome_heatmap(
var_group_labels=list(chr_pos_dict.keys()),
norm=norm,
show=False,
**kwargs
**kwargs,
)

return_ax_dic["heatmap_ax"].vlines(
Expand All @@ -102,7 +103,7 @@ def chromosome_heatmap_summary(
figsize: Tuple[int, int] = (16, 10),
show: Optional[bool] = None,
save: Union[str, bool, None] = None,
**kwargs
**kwargs,
) -> Optional[Dict[str, matplotlib.axes.Axes]]:
"""
Plot a heatmap of average of the smoothed gene expression by chromosome per
Expand Down Expand Up @@ -142,25 +143,26 @@ def chromosome_heatmap_summary(
"'cnv_leiden' is not in `adata.obs`. Did you run `tl.leiden()`?"
)

# TODO this dirty hack repeats reach row 10 times, since scanpy
# TODO this dirty hack repeats each row 10 times, since scanpy
# heatmap cannot really handle it if there's just one observation
# per row. Scanpy matrixplot is not an option, since it plots each
# gene individually.
groups = adata.obs[groupby].unique()
tmp_obs = pd.DataFrame()
tmp_obs[groupby] = np.hstack([np.repeat(x, 10) for x in groups])

def _get_group_mean(group):
group_mean = np.mean(
adata.obsm[f"X_{use_rep}"][adata.obs[groupby] == group, :], axis=0
)
if len(group_mean.shape) == 1:
# derived from an array instead of sparse matrix -> 1 dim instead of 2
group_mean = group_mean[np.newaxis, :]
return group_mean

tmp_adata = sc.AnnData(
X=np.vstack(
[
np.repeat(
np.mean(
adata.obsm["X_cnv"][adata.obs[groupby] == group, :], axis=0
),
10,
axis=0,
)
for group in groups
]
[np.repeat(_get_group_mean(group), 10, axis=0) for group in groups]
),
obs=tmp_obs,
)
Expand All @@ -187,7 +189,7 @@ def chromosome_heatmap_summary(
var_group_labels=list(chr_pos_dict.keys()),
norm=norm,
show=False,
**kwargs
**kwargs,
)

return_ax_dic["heatmap_ax"].vlines(
Expand Down
16 changes: 16 additions & 0 deletions infercnvpy/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,29 @@

@pytest.fixture(params=[np.array, sp.csr_matrix, sp.csc_matrix])
def adata_oligodendroma(request):
"""Adata with raw counts in .X parametrized to be either sparse or dense."""
adata = cnv.datasets.oligodendroglioma()

adata.X = request.param(adata.X.toarray())

return adata


@pytest.fixture(params=[np.array, sp.csr_matrix, sp.csc_matrix])
def adata_infercnv(request):
"""Adata with infercnv computed and results stored in `.obsm["X_cnv"]`.
The matrix in obsm is parametrized to be either sparse or dense."""
adata = cnv.datasets.oligodendroglioma()
cnv.tl.infercnv(adata)
cnv.tl.pca(adata)
cnv.pp.neighbors(adata)
cnv.tl.leiden(adata)

adata.obsm["X_cnv"] = request.param(adata.obsm["X_cnv"].toarray())

return adata


@pytest.fixture(params=[np.array, sp.csr_matrix, sp.csc_matrix])
def adata_mock(request):
obs = pd.DataFrame().assign(cat=["foo", "foo", "bar", "baz", "bar"])
Expand Down
10 changes: 10 additions & 0 deletions infercnvpy/tests/test_plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .fixtures import adata_infercnv
import infercnvpy as cnv


def test_plot_chromosome_heatmap(adata_infercnv):
cnv.pl.chromosome_heatmap(adata_infercnv, show=False)


def test_plot_chromosome_heatmap_summary(adata_infercnv):
cnv.pl.chromosome_heatmap_summary(adata_infercnv, show=False)
2 changes: 1 addition & 1 deletion infercnvpy/tl/_infercnv.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def infercnv(
Layer from adata to use. If `None`, use `X`.
key_added
Key under which the cnv matrix will be stored in adata if `inplace=True`.
Will store the matrix in `adata.obs["X_{key_added}"] and additional information
Will store the matrix in `adata.obsm["X_{key_added}"] and additional information
in `adata.uns[key_added]`.
Returns
Expand Down

0 comments on commit 7ee25b0

Please sign in to comment.