Skip to content

Commit

Permalink
update doctstring and typehint
Browse files Browse the repository at this point in the history
  • Loading branch information
zktuong committed Jan 7, 2025
1 parent 32c4d3a commit d71eb86
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 158 deletions.
68 changes: 34 additions & 34 deletions src/ktplotspy/plot/plot_cpdb.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#!/usr/bin/env python
import numpy as np
import pandas as pd
import re
from typing import Literal

import numpy as np
import pandas as pd
from plotnine import (
aes,
element_blank,
Expand All @@ -24,16 +25,15 @@
theme,
theme_bw,
)
from typing import List, Literal, Optional, Union, Tuple, Dict

from ktplotspy.utils.settings import (
DEFAULT_V5_COL_START,
DEFAULT_COL_START,
DEFAULT_CELLSIGN_ALPHA,
DEFAULT_CLASS_COL,
DEFAULT_COL_START,
DEFAULT_COLUMNS,
DEFAULT_SEP,
DEFAULT_SPEC_PAT,
DEFAULT_CELLSIGN_ALPHA,
DEFAULT_COLUMNS,
DEFAULT_V5_COL_START,
)
from ktplotspy.utils.support import (
ensure_categorical,
Expand All @@ -54,36 +54,36 @@ def plot_cpdb(
means: pd.DataFrame,
pvals: pd.DataFrame,
celltype_key: str,
interaction_scores: Optional[pd.DataFrame] = None,
cellsign: Optional[pd.DataFrame] = None,
interaction_scores: pd.DataFrame | None = None,
cellsign: pd.DataFrame | None = None,
degs_analysis: bool = False,
splitby_key: Optional[str] = None,
splitby_key: str | None = None,
alpha: float = 0.05,
keep_significant_only: bool = True,
genes: Optional[Union[List[str], str]] = None,
gene_family: Optional[Union[List[str], Literal["chemokines", "th1", "th2", "th17", "treg", "costimulatory", "coinhibitory"]]] = None,
interacting_pairs: Optional[Union[List[str], str]] = None,
custom_gene_family: Optional[Dict[str, List[str]]] = None,
genes: list[str] | str | None = None,
gene_family: list[str] | Literal["chemokines", "th1", "th2", "th17", "treg", "costimulatory", "coinhibitory"] | None = None,
interacting_pairs: list[str] | str | None = None,
custom_gene_family: dict[str, list[str]] | None = None,
standard_scale: bool = True,
cluster_rows: bool = True,
cmap_name: str = "viridis",
max_size: int = 8,
max_highlight_size: int = 3,
default_style: bool = True,
highlight_col: str = "#d62728",
highlight_size: Optional[int] = None,
special_character_regex_pattern: Optional[str] = None,
exclude_interactions: Optional[Union[List[str], str]] = None,
highlight_size: int | None = None,
special_character_regex_pattern: str | None = None,
exclude_interactions: list[str] | str | None = None,
title: str = "",
return_table: bool = False,
figsize: Tuple[Union[int, float], Union[int, float]] = (6.4, 4.8),
figsize: tuple[int | float, int | float] = (6.4, 4.8),
min_interaction_score: int = 0,
scale_alpha_by_interaction_scores: bool = False,
scale_alpha_by_cellsign: bool = False,
filter_by_cellsign: bool = False,
keep_id_cp_interaction: bool = False,
result_precision: int = 3,
) -> Union[ggplot, pd.DataFrame]:
) -> ggplot | pd.DataFrame:
"""Plotting CellPhoneDB results as a dot plot.
Parameters
Expand All @@ -102,26 +102,26 @@ def plot_cpdb(
celltype_key : str
Column name in `adata.obs` storing the celltype annotations.
Values in this column should match the second column of the input `meta.txt` used for CellPhoneDB.
interaction_scores : Optional[pd.DataFrame], optional
interaction_scores : pd.DataFrame | None, optional
Data frame corresponding to `interaction_scores.txt` from CellPhoneDB version 5 onwards.
cellsign : Optional[pd.DataFrame], optional
cellsign : pd.DataFrame | None, optional
Data frame corresponding to `CellSign.txt` from CellPhoneDB version 5 onwards.
degs_analysis : bool, optional
Whether CellPhoneDB was run in `deg_analysis` mode.
splitby_key : Optional[str], optional
splitby_key : str | None, optional
If provided, will attempt to split the output plot/table by groups.
In order for this to work, the second column of the input `meta.txt` used for CellPhoneDB MUST be this format: {splitby}_{celltype}.
alpha : float, optional
P value threshold value for significance.
keep_significant_only : bool, optional
Whether or not to trim to significant (p<0.05) hits.
genes : Optional[Union[List[str], str]], optional
genes : list[str] | str | None, optional
If provided, will attempt to plot only interactions containing the specified gene(s).
gene_family : Optional[Union[List[str], Literal["chemokines", "th1", "th2", "th17", "treg", "costimulatory", "coinhibitory"]]], optional
gene_family : list[str] | Literal["chemokines", "th1", "th2", "th17", "treg", "costimulatory", "coinhibitory"] | None, optional
If provided, will attempt to plot a predetermined set of chemokines or genes associated with Th1, Th2, Th17, Treg, costimulatory or coinhibitory molecules.
interacting_pairs : Optional[Union[List[str], str]], optional
interacting_pairs : list[str] | str | None, optional
If provided, will attempt to plot only interactions containing the specified interacting pair(s). Ignores `genes` and `gene_family` if provided.
custom_gene_family : Optional[Dict[str, List[str]]], optional
custom_gene_family : dict[str, list[str]] | None, optional
If provided, will update the gene_family dictionary with this custom dictionary.
Both `gene_family` (name of the custom family) and `custom_gene_family` (dictionary holding this new family)
must be specified for this to work.
Expand All @@ -139,20 +139,20 @@ def plot_cpdb(
Whether or not to plot in default style or inspired from `squidpy`'s plotting style.
highlight_col : str, optional
Colour of highlights marking significant hits.
highlight_size : Optional[int], optional
highlight_size : int | None, optional
Size of highlights marking significant hits.
special_character_regex_pattern : Optional[str], optional
special_character_regex_pattern : str | None, optional
Regex string pattern to perform substitution.
This option should not realy be used unless there is really REALLY special characters that you really REALLY want to keep.
Rather than using this option, the easiest way is to not your celltypes with weird characters.
Just use alpha numeric characters and underscores if necessary.
exclude_interactions : Optional[Union[List, str]], optional
exclude_interactions : Optional[Union[list, str]], optional
If provided, the interactions will be removed from the output.
title : str, optional
Plot title.
return_table : bool, optional
Whether or not to return the results as a dataframe.
figsize : Tuple[Union[int, float], Union[int, float]], optional
figsize : tuple[int | float, int | float], optional
Figure size.
min_interaction_score: int, optional
Filtering the interactions shown by including only those above the given interaction score.
Expand All @@ -166,17 +166,17 @@ def plot_cpdb(
Whether to keep the original `id_cp_interaction` value when plotting.
result_precision: int, optional
Sets integer value for decimal points of p_value, default to 3
Returns
-------
Union[ggplot, pd.DataFrame]
ggplot | pd.DataFrame
Either a plotnine `ggplot` plot or a pandas `Data frame` holding the results.
Raises
------
KeyError
If genes and gene_family are both provided, wrong key for gene family provided, or if interaction_score and cellsign are both provided the error will occur.
"""

if special_character_regex_pattern is None:
special_character_regex_pattern = DEFAULT_SPEC_PAT
# prepare data
Expand Down Expand Up @@ -230,13 +230,13 @@ def plot_cpdb(
for gfg in query_group[gf.lower()]:
query.append(gfg)
else:
raise KeyError("gene_family needs to be one of the following: {}".format(query_group.keys()))
raise KeyError(f"gene_family needs to be one of the following: {query_group.keys()}")
query = list(set(query))
else:
if gene_family.lower() in query_group:
query = query_group[gene_family.lower()]
else:
raise KeyError("gene_family needs to be one of the following: {}".format(query_group.keys()))
raise KeyError(f"gene_family needs to be one of the following: {query_group.keys()}")
else:
query = [i for i in means_mat.interacting_pair if re.search("", i)]
elif genes is not None:
Expand Down
62 changes: 30 additions & 32 deletions src/ktplotspy/plot/plot_cpdb_chord.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
#!/usr/bin/env python
import re
import matplotlib.pyplot as plt
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from collections import defaultdict
from matplotlib.lines import Line2D
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.lines import Line2D
from pycircos import Garc, Gcircle
from typing import Optional, Tuple, Dict, Union

from ktplotspy.utils.settings import DEFAULT_SEP # DEFAULT_PAL
from ktplotspy.utils.support import celltype_fraction, celltype_means, find_complex, flatten, generate_df, present
from ktplotspy.plot import plot_cpdb
from ktplotspy.utils.settings import DEFAULT_SEP # DEFAULT_PAL
from ktplotspy.utils.support import celltype_fraction, celltype_means, find_complex, flatten, generate_df


def plot_cpdb_chord(
Expand All @@ -22,20 +20,20 @@ def plot_cpdb_chord(
pvals: pd.DataFrame,
deconvoluted: pd.DataFrame,
celltype_key: str,
face_col_dict: Optional[Dict[str, str]] = None,
edge_col_dict: Optional[Dict[str, str]] = None,
face_col_dict: dict[str, str] | None = None,
edge_col_dict: dict[str, str] | None = None,
edge_cmap: LinearSegmentedColormap = plt.cm.nipy_spectral,
remove_self: bool = True,
gap: Union[int, float] = 2,
scale_lw: Union[int, float] = 10,
size: Union[int, float] = 50,
interspace: Union[int, float] = 2,
raxis_range: Tuple[int, int] = (950, 1000),
labelposition: Union[int, float] = 80,
gap: int | float = 2,
scale_lw: int | float = 10,
size: int | float = 50,
interspace: int | float = 2,
raxis_range: tuple[int, int] = (950, 1000),
labelposition: int | float = 80,
label_visible: bool = True,
figsize: Tuple[Union[int, float], Union[int, float]] = (8, 8),
legend_params: Dict = {"loc": "center left", "bbox_to_anchor": (1, 1), "frameon": False},
layer: Optional[str] = None,
figsize: tuple[int | float, int | float] = (8, 8),
legend_params: dict = {"loc": "center left", "bbox_to_anchor": (1, 1), "frameon": False},
layer: str | None = None,
**kwargs,
) -> Gcircle:
"""Plotting cellphonedb results as a chord diagram.
Expand All @@ -54,42 +52,42 @@ def plot_cpdb_chord(
celltype_key : str
Column name in `adata.obs` storing the celltype annotations.
Values in this column should match the second column of the input `meta.txt` used for `cellphonedb`.
face_col_dict : Optional[Dict[str, str]], optional
face_col_dict : dict[str, str] | None, optional
dictionary of celltype : face colours.
If not provided, will try and use `.uns` from `adata` if correct slot is present.
edge_col_dict : Optional[Dict[str, str]], optional
edge_col_dict : dict[str, str] | None, optional
Dictionary of interactions : edge colours. Otherwise, will use edge_cmap option.
edge_cmap : LinearSegmentedColormap, optional
a `LinearSegmentedColormap` to generate edge colors.
remove_self : bool, optional
whether to remove self edges.
gap : Union[int, float], optional
gap : int | float, optional
relative size of gaps between edges on arc.
scale_lw : Union[int, float], optional
scale_lw : int | float, optional
numeric value to scale width of lines.
size : Union[int, float], optional
size : int | float, optional
Width of the arc section. If record is provided, the value is
instead set by the sequence length of the record. In reality
the actual arc section width in the resultant circle is determined
by the ratio of size to the combined sum of the size and interspace
values of the Garc class objects in the Gcircle class object.
interspace : Union[int, float], optional
interspace : int | float, optional
Distance angle (deg) to the adjacent arc section in clockwise
sequence. The actual interspace size in the circle is determined by
the actual arc section width in the resultant circle is determined
by the ratio of size to the combined sum of the size and interspace
values of the Garc class objects in the Gcircle class object.
raxis_range : Tuple[int, int], optional
raxis_range : tuple[int, int], optional
Radial axis range where line plot is drawn.
labelposition : Union[int, float], optional
labelposition : int | float, optional
Relative label height from the center of the arc section.
label_visible : bool, optional
Font size of the label. The default is 10.
figsize : Tuple[Union[int, float], Union[int, float]], optional
figsize : tuple[int | float, int | float], optional
size of figure.
legend_params : Dict, optional
legend_params : dict, optional
additional arguments for `plt.legend`.
layer : Optional[str], optional
layer : str | None, optional
slot in `AnnData.layers` to access. If `None`, uses `.X`.
**kwargs
passed to `plot_cpdb`.
Expand Down Expand Up @@ -234,7 +232,7 @@ def plot_cpdb_chord(
celltype_end_dict = {r: k + gap for k, r in enumerate(celltypes)}
interactions = sorted(list(set(tmpdf["interaction_celltype"])))
interaction_start_dict = {r: k * gap for k, r in enumerate(interactions)}
interaction_end_dict = {r: k + gap for k, r in enumerate(interactions)}
# interaction_end_dict = {r: k + gap for k, r in enumerate(interactions)}
tmpdf["from"] = [celltype_start_dict[x] for x in tmpdf.producer]
tmpdf["to"] = [celltype_end_dict[x] for x in tmpdf.receiver]
tmpdf["interaction_value"] = [
Expand All @@ -258,7 +256,7 @@ def plot_cpdb_chord(
face_col_dict = dict(zip(adata.obs[celltype_key].cat.categories, adata.uns[celltype_key + "_colors"]))
else:
face_col_dict = dict(zip(list(set(adata.obs[celltype_key])), adata.uns[celltype_key + "_colors"]))
for i, j in tmpdf.iterrows():
for _, j in tmpdf.iterrows():
name = j["producer"]
if face_col_dict is None:
col = None
Expand All @@ -276,7 +274,7 @@ def plot_cpdb_chord(
)
circle.add_garc(arc)
circle.set_garcs(-180, 180)
for i, j in tmpdf.iterrows():
for _, j in tmpdf.iterrows():
if pd.notnull(j["interaction_value"]):
lr = j["converted_pair"]
start_size = j["start"] + j["interaction_value"] / scale_lw
Expand Down
17 changes: 8 additions & 9 deletions src/ktplotspy/plot/plot_cpdb_heatmap.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
#!/usr/bin/env python
from itertools import product

import numpy as np
import pandas as pd
import seaborn as sns

from itertools import product
from matplotlib.colors import ListedColormap
from typing import Optional, Union, Dict, List

from ktplotspy.utils.settings import DEFAULT_CLASS_COL, DEFAULT_COL_START, DEFAULT_CPDB_SEP, DEFAULT_V5_COL_START
from ktplotspy.utils.support import diverging_palette
from ktplotspy.utils.settings import DEFAULT_V5_COL_START, DEFAULT_COL_START, DEFAULT_CLASS_COL, DEFAULT_CPDB_SEP


def plot_cpdb_heatmap(
pvals: pd.DataFrame,
cell_types: Optional[List[str]] = None,
cell_types: list[str] | None = None,
degs_analysis: bool = False,
log1p_transform: bool = False,
alpha: float = 0.05,
Expand All @@ -23,21 +22,21 @@ def plot_cpdb_heatmap(
low_col: str = "#104e8b",
mid_col: str = "#ffdab9",
high_col: str = "#8b0a50",
cmap: Optional[Union[str, ListedColormap]] = None,
cmap: str | ListedColormap | None = None,
title: str = "",
return_tables: bool = False,
symmetrical: bool = True,
default_sep: str = DEFAULT_CPDB_SEP,
**kwargs,
) -> Union[sns.matrix.ClusterGrid, Dict]:
) -> sns.matrix.ClusterGrid | dict:
"""Plot cellphonedb results as total counts of interactions.
Parameters
----------
adata : AnnData
`AnnData` object with the `.obs` storing the `celltype_key`.
The `.obs_names` must match the first column of the input `meta.txt` used for `cellphonedb`.
cell_types : Optional[List[str]], optional
cell_types : list[str] | None, optional
List of cell types to include in the heatmap. If `None`, all cell types are included.
pvals : pd.DataFrame
Dataframe corresponding to `pvalues.txt` or `relevant_interactions.txt` from cellphonedb.
Expand Down Expand Up @@ -75,7 +74,7 @@ def plot_cpdb_heatmap(
Returns
-------
Union[sns.matrix.ClusterGrid, Dict]
sns.matrix.ClusterGrid | dict
Either heatmap of cellphonedb interactions or dataframe containing the interaction network.
"""
all_intr = pvals.copy()
Expand Down
Loading

0 comments on commit d71eb86

Please sign in to comment.