From 160d3d81935090504c521a2c47ce568f33516b8d Mon Sep 17 00:00:00 2001 From: Daniel Bolin Date: Wed, 13 Dec 2023 14:08:39 -0500 Subject: [PATCH] Code documentation --- containers/azimuth/context/main.py | 74 +++++++++++++-- containers/celltypist/context/main.py | 33 ++++++- containers/crosswalking/context/main.py | 77 ++++++++++++++-- containers/extract-summary/context/main.py | 39 +++++++- containers/gene-expression/context/main.py | 88 +++++++++++++++--- containers/popv/context/main.py | 101 +++++++++++++++++++-- src/algorithm/algorithm.py | 42 +++++++++ src/algorithm/arguments.py | 8 ++ src/algorithm/organ.py | 41 +++++++++ src/algorithm/report.py | 34 +++++++ steps/js/options-util.js | 18 ++++ 11 files changed, 519 insertions(+), 36 deletions(-) diff --git a/containers/azimuth/context/main.py b/containers/azimuth/context/main.py index d2b66a4..1100a33 100644 --- a/containers/azimuth/context/main.py +++ b/containers/azimuth/context/main.py @@ -20,6 +20,7 @@ def __init__(self): super().__init__(OrganLookup) def do_run(self, matrix: Path, organ: str, options: AzimuthOptions): + """Annotate data using azimuth.""" data = anndata.read_h5ad(matrix) reference_data = self.find_reference_data(organ, options["reference_data_dir"]) annotation_level = self.find_annotation_level( @@ -42,7 +43,15 @@ def do_run(self, matrix: Path, organ: str, options: AzimuthOptions): return data, annotation_level - def create_clean_matrix(self, matrix: anndata.AnnData): + def create_clean_matrix(self, matrix: anndata.AnnData) -> anndata.AnnData: + """Creates a copy of the data with all observation columns removed. + + Args: + matrix (anndata.AnnData): Original data + + Returns: + anndata.AnnData: Cleaned data + """ clean_obs = pandas.DataFrame(index=matrix.obs.index) clean_matrix = matrix.copy() clean_matrix.obs = clean_obs @@ -50,15 +59,43 @@ def create_clean_matrix(self, matrix: anndata.AnnData): def copy_annotations( self, matrix: anndata.AnnData, annotated_matrix: anndata.AnnData - ): + ) -> None: + """Copies annotations from one matrix to another. + + Args: + matrix (anndata.AnnData): Matrix to copy to + annotated_matrix (anndata.AnnData): Matrix to copy from + """ matrix.obs = matrix.obs.join(annotated_matrix.obs, rsuffix="_azimuth") - def run_azimuth_scripts(self, matrix_path: Path, reference_data: Path): + def run_azimuth_scripts(self, matrix_path: Path, reference_data: Path) -> str: + """Creates a subprocess running the Azimuth annotation R script. + + Args: + matrix_path (Path): Path to data file + reference_data (Path): Path to organ reference data directory + + Returns: + str: Path to the output data file + """ script_command = ["Rscript", "/run_azimuth.R", matrix_path, reference_data] subprocess.run(script_command, capture_output=True, check=True, text=True) return "./result.h5ad" - def find_reference_data(self, organ: str, dir: Path): + def find_reference_data(self, organ: str, dir: Path) -> Path: + """Finds the reference data directory for an organ. + + Args: + organ (str): Organ name + dir (Path): Directory to search + + Raises: + ValueError: If no reference data could be found + + Returns: + Path: The data directory + """ + def is_reference_data_candidate(path: Path): return path.is_dir() and organ.lower() in path.name.lower() @@ -71,14 +108,39 @@ def is_reference_data_candidate(path: Path): # idx.annoy and ref.Rds is always located inside an 'azimuth' subdirectory return subdir / "azimuth" - def find_annotation_level(self, organ: str, path: Path): + def find_annotation_level(self, organ: str, path: Path) -> str: + """Finds the column name which contains the predictions. + + Args: + organ (str): Organ name + path (Path): Path to file containing information about column names + + Returns: + str: Column name + """ with open(path) as file: levels_by_organ = json.load(file) return "predicted." + levels_by_organ[organ] def _find_in_dir( self, dir: Path, cond: t.Callable[[Path], bool], error_msg: str, warn_msg: str - ): + ) -> Path: + """Search a directory for a entry which passes the provided test. + + Args: + dir (Path): Directory to search + cond (t.Callable[[Path], bool]): Test used to match sub entries + error_msg (str): Error message used when no entries match + warn_msg (str): Warning message use when multiple entries match + + Raises: + ValueError: If there are no matching sub entries + + Returns: + Path: + The matching entry. + If multiple entries match the one with the shortest name is returned. + """ candidates = list(filter(cond, dir.iterdir())) candidates.sort(key=lambda path: len(path.name)) diff --git a/containers/celltypist/context/main.py b/containers/celltypist/context/main.py index 55dafe4..b72b4e9 100644 --- a/containers/celltypist/context/main.py +++ b/containers/celltypist/context/main.py @@ -19,10 +19,12 @@ def __init__(self, mapping_file: Path): super().__init__(mapping_file) def get_builtin_options(self): + """Get builtin celltypist models.""" models = celltypist.models.get_all_models() return map(lambda model: (model, self.from_raw(model)), models) def from_raw(self, id: str): + """Load a celltypist model.""" return celltypist.models.Model.load(id) @@ -31,6 +33,7 @@ def __init__(self): super().__init__(CelltypistOrganLookup, "predicted_labels") def do_run(self, matrix: Path, organ: celltypist.Model, options: CelltypistOptions): + """Annotate data using celltypist.""" data = scanpy.read_h5ad(matrix) data = self.normalize(data) data, var_names = self.normalize_var_names(data, options) @@ -39,6 +42,17 @@ def do_run(self, matrix: Path, organ: celltypist.Model, options: CelltypistOptio return data def normalize(self, data: scanpy.AnnData) -> scanpy.AnnData: + """Normalizes data according to celltypist requirements. + + Celltypist requires data to be log1p normalized with 10,000 counts per cell. + See https://github.com/Teichlab/celltypist for details. + + Args: + data (scanpy.AnnData): Original data to be normalized + + Returns: + scanpy.AnnData: Normalized data + """ primary_column = "feature_name" alternative_primary_column = "gene_symbol" if primary_column not in data.var.columns: @@ -58,6 +72,15 @@ def normalize(self, data: scanpy.AnnData) -> scanpy.AnnData: def normalize_var_names( self, data: scanpy.AnnData, options: CelltypistOptions ) -> t.Tuple[scanpy.AnnData, pandas.Index]: + """Normalizes variable names, replacing ensemble ids with the corresponding gene name. + + Args: + data (scanpy.AnnData): Data with potentially non-normalized names + options (CelltypistOptions): Options containing the ensemble id mapping file path + + Returns: + t.Tuple[scanpy.AnnData, pandas.Index]: The normalized data along with the original names + """ lookup = self.load_ensemble_lookup(options) names = data.var_names @@ -68,7 +91,15 @@ def getNewName(name: str): data.var_names = t.cast(t.Any, names.map(getNewName)) return data, names - def load_ensemble_lookup(self, options: CelltypistOptions): + def load_ensemble_lookup(self, options: CelltypistOptions) -> t.Dict[str, str]: + """Load a file mapping ensemble id to gene names. + + Args: + options (CelltypistOptions): Options with the mapping file path + + Returns: + t.Dict[str, str]: Loaded mapping + """ with open(options["ensemble_lookup"]) as file: reader = csv.DictReader(file) lookup: t.Dict[str, str] = {} diff --git a/containers/crosswalking/context/main.py b/containers/crosswalking/context/main.py index ae29664..cb22463 100644 --- a/containers/crosswalking/context/main.py +++ b/containers/crosswalking/context/main.py @@ -7,12 +7,28 @@ def filter_crosswalk_table(table: pd.DataFrame, *columns: str) -> pd.DataFrame: - """Filters the table to remove empty rows and keep only necessary columns""" + """Filter the crosswalk table to only include specified columns. + + Also removes empty rows and cast values to string. + + Args: + table (pd.DataFrame): Original full crosswalk table + + Returns: + pd.DataFrame: Filtered table + """ return table[list(columns)].dropna().astype(str).drop_duplicates() -def generate_iri(label: str): - """generate IRIs for labels not found in crosswalk tables""" +def generate_iri(label: str) -> str: + """Create a temporary IRI based on a label. + + Args: + label (str): Label for the row + + Returns: + str: Temporary IRI + """ suffix = label.lower().strip() suffix = re.sub(r"\W+", "-", suffix) suffix = re.sub(r"[^a-z0-9-]+", "", suffix) @@ -29,7 +45,21 @@ def crosswalk( table_clid_column: str, table_match_column: str, ) -> anndata.AnnData: - """Gives each cell a CL ID and Match type using crosswalk table""" + """Crosswalks the data adding CLIDs and match types using a crosswalk table. + + Args: + matrix (anndata.AnnData): Data to crosswalk + data_label_column (str): Column used to match against the table + data_clid_column (str): Column to store CLIDs in + data_match_column (str): Column to store match type in + table (pd.DataFrame): Crosswalk table + table_label_column (str): Column used to match against the data + table_clid_column (str): Column storing CLIDs + table_match_column (str): Column storing match type + + Returns: + anndata.AnnData: Crosswalked data with CLIDs and match type added + """ column_map = { table_clid_column: data_clid_column, table_match_column: data_match_column, @@ -54,16 +84,37 @@ def crosswalk( return result -def _set_default_clid(obs: pd.DataFrame, clid_column: str, label_column: str): +def _set_default_clid(obs: pd.DataFrame, clid_column: str, label_column: str) -> None: + """Adds default CLIDs to rows that did not match against the crosswalk table. + + Args: + obs (pd.DataFrame): Data rows + clid_column (str): Column to check and update with default CLIDs + label_column (str): Column used when generating default CLIDs + """ defaults = obs.apply(lambda row: generate_iri(row[label_column]), axis=1) obs.loc[obs[clid_column].isna(), clid_column] = defaults -def _set_default_match(obs: pd.DataFrame, column: str): +def _set_default_match(obs: pd.DataFrame, column: str) -> None: + """Adds default match type to rows that did not match against the crosswalk table. + + Args: + obs (pd.DataFrame): Data rows + column (str): Column to check and update with default match type + """ obs.loc[obs[column].isna(), column] = "skos:exactMatch" def _get_empty_table(args: argparse.Namespace) -> pd.DataFrame: + """Creates an empty crosswalk table. + + Args: + args (argparse.Namespace): Same arguments as provided to `main` + + Returns: + pd.DataFrame: An empty table + """ return pd.DataFrame( columns=[ args.crosswalk_table_label_column, @@ -74,12 +125,24 @@ def _get_empty_table(args: argparse.Namespace) -> pd.DataFrame: def main(args: argparse.Namespace): + """Crosswalks a h5ad file and saves the result to another h5ad file. + + Args: + args (argparse.Namespace): + CLI arguments, must contain "matrix", + "annotation_column", "clid_column", "match_column", + "crosswalk_table", "crosswalk_table_label_column", + "crosswalk_table_clid_column", "crosswalk_table_match_column", and + "output_matrix" + """ matrix = crosswalk( args.matrix, args.annotation_column, args.clid_column, args.match_column, - args.crosswalk_table if args.crosswalk_table is not None else _get_empty_table(args), + args.crosswalk_table + if args.crosswalk_table is not None + else _get_empty_table(args), args.crosswalk_table_label_column, args.crosswalk_table_clid_column, args.crosswalk_table_match_column, diff --git a/containers/extract-summary/context/main.py b/containers/extract-summary/context/main.py index 38cdb77..fec5b9c 100644 --- a/containers/extract-summary/context/main.py +++ b/containers/extract-summary/context/main.py @@ -1,5 +1,6 @@ import argparse import json +import typing as t import anndata import pandas as pd @@ -8,6 +9,15 @@ def get_unique_rows_with_counts( matrix: anndata.AnnData, clid_column: str ) -> pd.DataFrame: + """Computes unique CLIDs and the total count for each. + + Args: + matrix (anndata.AnnData): Data + clid_column (str): Column with CLIDs + + Returns: + pd.DataFrame: A frame with unique CLIDs and counts added + """ counts = matrix.obs.value_counts(clid_column).reset_index() counts.columns = [clid_column, "count"] obs_with_counts = matrix.obs.merge(counts, how="left") @@ -20,7 +30,19 @@ def unique_rows_to_summary_rows( label_column: str, gene_expr_column: str, counts_column="count", -): +) -> t.List[dict]: + """Converts a data frame with unique CLIDs rows into cell summary rows. + + Args: + unique (pd.DataFrame): Data with unique CLIDs + clid_column (str): Column with CLIDs + label_column (str): Column with labels + gene_expr_column (str): Column with gene expressions + counts_column (str, optional): Column with the total counts. Defaults to "count". + + Returns: + t.List[dict]: A cell summary for each row in the source data + """ columns = [clid_column, label_column, gene_expr_column, counts_column] df = unique[columns].rename( columns={ @@ -33,11 +55,24 @@ def unique_rows_to_summary_rows( df["@type"] = "CellSummaryRow" df["percentage"] = df["count"] / df["count"].sum() - df["gene_expr"] = df["gene_expr"].astype(object).apply(lambda x: [] if pd.isna(x) else json.loads(x)) + df["gene_expr"] = ( + df["gene_expr"] + .astype(object) + .apply(lambda x: [] if pd.isna(x) else json.loads(x)) + ) return df.to_dict("records") def main(args: argparse.Namespace): + """Extract and save a cell summary from annotated data. + + Args: + args (argparse.Namespace): + CLI arguments, must contain "matrix", "annotation_method", + "cell_id_column", "cell_label_column", "gene_expr_column", + "cell_source, "jsonld_context", "output", and + "annotations_output" + """ unique_rows = get_unique_rows_with_counts(args.matrix, args.cell_id_column) summary_rows = unique_rows_to_summary_rows( unique_rows, args.cell_id_column, args.cell_label_column, args.gene_expr_column diff --git a/containers/gene-expression/context/main.py b/containers/gene-expression/context/main.py index 8ab0181..80c5add 100644 --- a/containers/gene-expression/context/main.py +++ b/containers/gene-expression/context/main.py @@ -1,5 +1,6 @@ import argparse import json +import typing as t from pathlib import Path import anndata @@ -10,50 +11,101 @@ MIN_CELLS_PER_CT = 2 # need a filter to reomve CTs with one cell as sc.tl.rank_genes_groups gives an error -def filter_matrix(matrix: anndata.AnnData, clid_column: str): - """filters an anndata matrix to only include cells which are annotated to a cell type with more than MIN_CELLS_PER_CT""" +def filter_matrix(matrix: anndata.AnnData, clid_column: str) -> anndata.AnnData: + """Filters data to only include cells and cell types with at least `MIN_CELLS_PER_CT`. + + Args: + matrix (anndata.AnnData): Data to filter + clid_column (str): Cell type id column + + Returns: + anndata.AnnData: Filtered data + """ ct_counts = matrix.obs[clid_column].value_counts() valid_cts = ct_counts.index[ct_counts >= MIN_CELLS_PER_CT] mask = np.isin(matrix.obs[clid_column], valid_cts) return matrix[mask, :] -def format_marker_genes_df(df: pd.DataFrame, clid_column: str): - """formats the output from sc.tl.rank_genes_groups to a dataframe with celltype as one column and list of marker genes as other""" +def format_marker_genes_df(df: pd.DataFrame, clid_column: str) -> pd.DataFrame: + """Format the output of `scanpy.tl.rank_genes_groups` into a dataframe with + cell type and marker genes as columns. + + Args: + df (pd.DataFrame): Output from `scanpy.tl.rank_genes_groups` + clid_column (str): Cell type id column + + Returns: + pd.DataFrame: A data frame with `clid_column` and marker_genes columns + """ df = df.transpose() df["marker_genes"] = df.apply(lambda row: row.tolist(), axis=1) df = df["marker_genes"].rename_axis(clid_column).reset_index() return df -def get_mean_expr_value(matrix, clid_column, cell_type, gene): - """gets the mean expression value for a cell type and a gene""" +def get_mean_expr_value( + matrix: anndata.AnnData, clid_column: str, cell_type: str, gene: str +) -> float: + """Computes the mean expression for a cell type and gene. + + Args: + matrix (anndata.AnnData): Data + clid_column (str): Cell type id column + cell_type (str): Cell type name + gene (str): Gene name + + Returns: + float: The mean expression + """ cell_indices = [ matrix.obs.index.get_loc(cell_index) for cell_index in matrix.obs[matrix.obs[clid_column] == cell_type].index ] mean_expr = matrix.X[cell_indices, matrix.var.index.get_loc(gene)].mean() - return mean_expr + return float(mean_expr) + +def get_marker_genes_with_expr( + matrix: anndata.AnnData, clid_column: str, cell_type: str, marker_genes: str +) -> t.List[dict]: + """Get the mean expression for all marker genes. -def get_marker_genes_with_expr(matrix, clid_column, cell_type, marker_genes): - """gets the mean expression values for all marker genes""" + Args: + matrix (anndata.AnnData): Data + clid_column (str): Cell type id column + cell_type (str): Cell type name + gene (str): Gene name + + Returns: + t.List[dict]: Mean expressions, the dicts contains "gene_label" and "mean_gene_expr_value" + """ output = [] for gene in marker_genes: output.append( { "gene_label": gene, - "mean_gene_expr_value": float( - get_mean_expr_value(matrix, clid_column, cell_type, gene) + "mean_gene_expr_value": get_mean_expr_value( + matrix, clid_column, cell_type, gene ), } ) return output -def get_gene_expr(matrix: anndata.AnnData, clid_column: str, gene_expr_column: str): - """gets the marker genes and mean expression values for all cells in the anndata matrix""" +def get_gene_expr( + matrix: anndata.AnnData, clid_column: str, gene_expr_column: str +) -> anndata.AnnData: + """Computes and adds gene mean expressions for all cells in the annotated data. + + Args: + matrix (anndata.AnnData): Data + clid_column (str): Cell type id column + gene_expr_column (str): Column to store gene expression on + Returns: + anndata.AnnData: Updated data with gene expressions + """ matrix.raw = matrix # for getting gene names as output for sc.tl.rank_genes_groups filtered_matrix = filter_matrix(matrix, clid_column) sc.tl.rank_genes_groups(filtered_matrix, groupby=clid_column, n_genes=10) @@ -71,7 +123,7 @@ def get_gene_expr(matrix: anndata.AnnData, clid_column: str, gene_expr_column: s merged_obs = matrix.obs.merge( ct_marker_genes_df[[clid_column, gene_expr_column]], how="left" ) - merged_obs.fillna({gene_expr_column: '[]'}, inplace=True) + merged_obs.fillna({gene_expr_column: "[]"}, inplace=True) merged_obs[gene_expr_column] = merged_obs[gene_expr_column].apply(json.dumps) merged_obs.index = matrix.obs.index matrix.obs = merged_obs @@ -79,6 +131,14 @@ def get_gene_expr(matrix: anndata.AnnData, clid_column: str, gene_expr_column: s def main(args: argparse.Namespace): + """Computes gene mean expression for all cells in an annotated h5ad file and + saves the result to another h5ad file. + + Args: + args (argparse.Namespace): + CLI arguments, must contain "matrix", "clid_column", + "gene_expr_column", and "output_matrix" + """ matrix = get_gene_expr(args.matrix, args.clid_column, args.gene_expr_column) matrix.write_h5ad(args.output_matrix) diff --git a/containers/popv/context/main.py b/containers/popv/context/main.py index 1ce8431..6626f08 100644 --- a/containers/popv/context/main.py +++ b/containers/popv/context/main.py @@ -1,14 +1,13 @@ +import csv import typing as t from logging import warn from pathlib import Path +import anndata import numpy import popv import scanpy -import anndata import torch -import pandas -import csv from src.algorithm import Algorithm, OrganLookup, add_common_arguments @@ -33,6 +32,7 @@ def __init__(self): super().__init__(OrganLookup, "popv_prediction") def do_run(self, matrix: Path, organ: str, options: PopvOptions): + """Annotate data using popv.""" data = scanpy.read_h5ad(matrix) data = self.prepare_query(data, organ, options) popv.annotation.annotate_data( @@ -55,6 +55,16 @@ def do_run(self, matrix: Path, organ: str, options: PopvOptions): def prepare_query( self, data: scanpy.AnnData, organ: str, options: PopvOptions ) -> scanpy.AnnData: + """Prepares the data to be annotated by popv. + + Args: + data (scanpy.AnnData): Unprepared data + organ (str): Organ name + options (PopvOptions): Additional options + + Returns: + scanpy.AnnData: Prepared data + """ reference_data_path = self.find_reference_data( options["reference_data_dir"], organ ) @@ -96,6 +106,15 @@ def prepare_query( def get_n_samples_per_label( self, reference_data: scanpy.AnnData, options: PopvOptions ) -> int: + """Computes the number of samples by label in the reference data. + + Args: + reference_data (scanpy.AnnData): Reference data + options (PopvOptions): Additional options + + Returns: + int: Number of samples per label + """ ref_labels_key = options["ref_labels_key"] n_samples_per_label = options["samples_per_label"] if ref_labels_key in reference_data.obs.columns: @@ -104,6 +123,19 @@ def get_n_samples_per_label( return n_samples_per_label def find_reference_data(self, dir: Path, organ: str) -> Path: + """Finds the reference data directory for an organ. + + Args: + dir (Path): Directory to search + organ (str): Organ name + + Raises: + ValueError: If no reference data could be found + + Returns: + Path: The data directory + """ + def is_reference_data_candidate(path: Path): return ( path.is_file() @@ -119,6 +151,19 @@ def is_reference_data_candidate(path: Path): ) def find_model_dir(self, dir: Path, organ: str) -> Path: + """Find the model data directory for an organ. + + Args: + dir (Path): Directory to search + organ (str): Organ name + + Raises: + ValueError: If no model data could be found + + Returns: + Path: The data directory + """ + def is_model_candidate(path: Path): return path.is_dir() and organ.lower() in path.name.lower() @@ -131,7 +176,23 @@ def is_model_candidate(path: Path): def _find_in_dir( self, dir: Path, cond: t.Callable[[Path], bool], error_msg: str, warn_msg: str - ): + ) -> Path: + """Search a directory for a entry which passes the provided test. + + Args: + dir (Path): Directory to search + cond (t.Callable[[Path], bool]): Test used to match sub entries + error_msg (str): Error message used when no entries match + warn_msg (str): Warning message use when multiple entries match + + Raises: + ValueError: If there are no matching sub entries + + Returns: + Path: + The matching entry. + If multiple entries match the one with the shortest name is returned. + """ candidates = list(filter(cond, dir.iterdir())) candidates.sort(key=lambda path: len(path.name)) @@ -144,6 +205,15 @@ def _find_in_dir( def normalize_var_names( self, data: scanpy.AnnData, options: PopvOptions ) -> scanpy.AnnData: + """Normalizes variable names, replacing ensemble ids with the corresponding gene name. + + Args: + data (scanpy.AnnData): Data with potentially non-normalized names + options (PopvOptions): Options containing the ensemble id mapping file path + + Returns: + scanpy.AnnData: The normalized data + """ lookup = self.load_ensemble_lookup(options) names = data.var_names @@ -155,6 +225,14 @@ def getNewName(name: str): return data def load_ensemble_lookup(self, options: PopvOptions): + """Load a file mapping ensemble id to gene names. + + Args: + options (PopvOptions): Options with the mapping file path + + Returns: + t.Dict[str, str]: Loaded mapping + """ with open(options["ensemble_lookup"]) as file: reader = csv.DictReader(file) lookup: t.Dict[str, str] = {} @@ -166,9 +244,20 @@ def add_model_genes( self, data: scanpy.AnnData, model_path: Path, - query_layers_key: str, + query_layers_key: t.Optional[str], ) -> scanpy.AnnData: - """Adds genes from model not present in input data to input data. Needed for preprocessing bug""" + """Adds genes from model not present in input data to input data. + + Solves a preprocessing bug. + + Args: + data (scanpy.AnnData): Data to fix + model_path (Path): Model data directory + query_layers_key (str, optional): Data layer to fix + + Returns: + scanpy.AnnData: The fixed data + """ model_genes = torch.load( Path.joinpath(model_path, "scvi/model.pt"), map_location="cpu" )["var_names"] diff --git a/src/algorithm/algorithm.py b/src/algorithm/algorithm.py index 35e492a..ef30c39 100644 --- a/src/algorithm/algorithm.py +++ b/src/algorithm/algorithm.py @@ -14,6 +14,13 @@ class Algorithm(t.Generic[Organ, Options], abc.ABC): + """An annotation algorithm. + + Attributes: + organ_lookup (t.Callable[[Path], OrganLookup[Organ]]): Callable to create an organ lookup + prediction_column (t.Optional[str]): Column in annotated data with the predictions + """ + def __init__( self, organ_lookup: t.Callable[[Path], OrganLookup[Organ]], @@ -32,6 +39,19 @@ def run( output_report: Path, **options, ) -> AlgorithmReport: + """Runs the algorithm to annotate data. + + Args: + matrix (Path): Path to h5ad data file + organ (str): Raw organ identifier + organ_mapping (Path): Path to json file containing organ mapping information + output_matrix (Path): Path where the annotated h5ad file will be written + output_annotations (Path): Path where the annotation csv will be written + output_report (Path): Path where the algorithm report json will be written + + Returns: + AlgorithmReport: Report containing the status of the run + """ report = AlgorithmReport(output_matrix, output_annotations, output_report) try: lookup = self.organ_lookup(organ_mapping) @@ -45,9 +65,31 @@ def run( @abc.abstractmethod def do_run(self, matrix: Path, organ: Organ, options: Options) -> RunResult: + """Perform a annotation run. Must be overridden in subclasses. + + Args: + matrix (Path): Path to the h5ad data file + organ (Organ): Organ associated with the data + options (Options): Additional algorithm specific options + + Returns: + RunResult: + Annotated data either in-memory or a path to a h5ad. + Can also return a tuple where the first element is + the annotated data and the second element is the name + of the column that stores the predictions. + """ ... def __post_process_result(self, result: RunResult) -> anndata.AnnData: + """Normalize the result of a run. + + Args: + result (RunResult): Non-normalized result value + + Returns: + anndata.AnnData: Loaded h5ad data + """ prediction_column = self.prediction_column if isinstance(result, tuple): result, prediction_column = result diff --git a/src/algorithm/arguments.py b/src/algorithm/arguments.py index 50297a0..915a316 100644 --- a/src/algorithm/arguments.py +++ b/src/algorithm/arguments.py @@ -6,6 +6,14 @@ def add_common_arguments( parser: t.Optional[argparse.ArgumentParser] = None, ) -> argparse.ArgumentParser: + """Add arguments common to all algorithms. + + Args: + parser (argparse.ArgumentParser, optional): An existing parser to add argument to. Defaults to None. + + Returns: + argparse.ArgumentParser: A parser with common arguments added + """ if parser is None: parser = argparse.ArgumentParser(description="Compute annotations") diff --git a/src/algorithm/organ.py b/src/algorithm/organ.py index 6c614da..b8417e4 100644 --- a/src/algorithm/organ.py +++ b/src/algorithm/organ.py @@ -9,21 +9,57 @@ @dataclasses.dataclass class OrganLookup(t.Generic[Organ]): + """Lookup from raw organ name to algorithm specific organ data. + + Attributes: + mapping_file (Path): Path to file mapping raw organ name to data + """ + mapping_file: Path def get(self, id: str) -> Organ: + """Get the algorithm specific data for a raw organ name. + + Args: + id (str): Organ uberon id + + Raises: + ValueError: If the organ is not supported by the algorithm + + Returns: + Organ: Algorithm specific data + """ for key, organ in self.__get_options(): if key.lower() == id.lower(): return organ raise ValueError(f"Organ '{id}' is not supported") def get_builtin_options(self) -> t.Iterable[t.Tuple[str, Organ]]: + """Get builtin organ mapping options. + + Returns: + t.Iterable[t.Tuple[str, Organ]]: Entries mapping organ to data + """ return [] def from_raw(self, raw: t.Any) -> Organ: + """Convert a raw mapping value to algorithm specific data. + Can be overridden in subclasses. + + Args: + raw (t.Any): Raw value from the mapping file + + Returns: + Organ: Converted organ data + """ return raw def __get_options(self) -> t.Iterable[t.Tuple[str, Organ]]: + """Gets all options, builtin and from the mapping file. + + Yields: + Iterator[t.Tuple[str, Organ]]: Each entry from builtin and the mapping file + """ yield from self.get_builtin_options() try: for key, value in self.__load_mapping_file(): @@ -32,6 +68,11 @@ def __get_options(self) -> t.Iterable[t.Tuple[str, Organ]]: logging.warn(f"Invalid format of organ mapping file '{self.mapping_file}'") def __load_mapping_file(self) -> t.Iterable[t.Tuple[str, t.Any]]: + """Load the mapping json file. + + Yields: + Iterator[t.Tuple[str, t.Any]]: Each entry in the mapping + """ if not self.mapping_file.exists() or not self.mapping_file.is_file(): return with open(self.mapping_file) as file: diff --git a/src/algorithm/report.py b/src/algorithm/report.py index c50fa0a..72ccf9c 100644 --- a/src/algorithm/report.py +++ b/src/algorithm/report.py @@ -18,6 +18,17 @@ class Status(enum.Enum): @dataclasses.dataclass class AlgorithmReport: + """Stores information about an algorithm run. + + Attributes: + matrix (Path): Data h5ad output file path + annotations (Path): Annotations csv file path + report (Path): Report json file path + status (Status): Algorithm run status + data (anndata.AnnData, optional): Annotated data for a successful run + failure_cause (t.Any): Data associated with a failed run + """ + matrix: Path annotations: Path report: Path @@ -26,31 +37,49 @@ class AlgorithmReport: failure_cause: t.Any = None def is_success(self) -> bool: + """Gets whether the run was successful. + """ return self.status == Status.SUCCESS def set_success(self, data: anndata.AnnData): + """Set the run to success and add the result data to the report. + + Args: + data (anndata.AnnData): Result data + """ self.status = Status.SUCCESS self.data = data self.failure_cause = None return self def set_failure(self, cause: t.Any): + """Set the run to failure and add the cause to the report. + + Args: + cause (t.Any): Any value but usually the error that caused the failure + """ self.status = Status.FAILURE self.data = anndata.AnnData() self.failure_cause = cause return self def save(self): + """Saves the report and associated data to file. + """ self.save_matrix() self.save_report() def save_matrix(self): + """Saves the matrix and annotations to file. + """ matrix = self.data if matrix is not None: matrix.obs.to_csv(self.annotations) matrix.write_h5ad(self.matrix) def save_report(self): + """Saves the report status to file. + """ result = {"status": self.status.value} if not self.is_success(): self.format_cause(result) @@ -59,6 +88,11 @@ def save_report(self): json.dump(result, file, indent=4) def format_cause(self, result: dict): + """Format a failure cause and add it to a result dict. + + Args: + result (dict): Dict to add formatted cause to + """ cause = self.failure_cause if isinstance(cause, Exception): result["cause"] = repr(cause) diff --git a/steps/js/options-util.js b/steps/js/options-util.js index 25b40a8..7b5982e 100644 --- a/steps/js/options-util.js +++ b/steps/js/options-util.js @@ -1,5 +1,11 @@ var ALGORITHMS = ["azimuth", "celltypist", "popv"]; +/** + * Finds the algorithm selected in the option object + * + * @param {object} obj Options + * @returns {string} Name of algorithm or null if no match was found + */ function _find_algorithm(obj) { for (var index = 0; index < ALGORITHMS.length; ++index) { var name = ALGORITHMS[index]; @@ -11,10 +17,22 @@ function _find_algorithm(obj) { return null; } +/** + * Selects an output directory based on the provided options + * + * @param {object} obj Options + * @returns The selected output directory + */ function selectOutputDirectory(obj) { return obj["directory"] || _find_algorithm(obj) || "."; } +/** + * Creates default summarize step options based on the provided options + * + * @param {object} obj Options + * @returns Default summarize options + */ function getDefaultSummarizeOptions(obj) { return { annotationMethod: _find_algorithm(obj) || "unknown",