From 0497a6b444d7a0bcfe8f8d33396344cf8e96d7eb Mon Sep 17 00:00:00 2001 From: Daniel Bolin Date: Mon, 4 Mar 2024 15:44:02 -0500 Subject: [PATCH] Improve layer selection logic --- containers/azimuth/context/main.py | 18 ++---------------- containers/celltypist/context/main.py | 17 ++--------------- containers/popv/context/main.py | 8 +++----- src/util/layers.py | 26 ++++++++++++++++++++++++++ 4 files changed, 33 insertions(+), 36 deletions(-) create mode 100644 src/util/layers.py diff --git a/containers/azimuth/context/main.py b/containers/azimuth/context/main.py index d805099..df51afd 100644 --- a/containers/azimuth/context/main.py +++ b/containers/azimuth/context/main.py @@ -7,6 +7,7 @@ import pandas from src.algorithm import Algorithm, RunResult, add_common_arguments +from src.util.layers import set_data_layer class AzimuthOrganMetadata(t.TypedDict): @@ -41,8 +42,7 @@ def do_run( temp_index = self.create_temp_obs_index(data) clean_matrix_path = Path("clean_matrix.h5ad") clean_matrix = self.create_clean_matrix(data, temp_index) - - self.set_data_layer(clean_matrix, options["query_layers_key"]) + clean_matrix = set_data_layer(clean_matrix, options["query_layers_key"]) clean_matrix.write_h5ad(clean_matrix_path) annotated_matrix_path = self.run_azimuth_scripts( @@ -90,20 +90,6 @@ def create_clean_matrix( return clean_matrix - def set_data_layer( - self, matrix: anndata.AnnData, query_layers_key: t.Optional[str] - ) -> None: - """Set the data layer to use for annotating. - - Args: - matrix (anndata.AnnData): Matrix to update - query_layers_key (t.Optional[str]): A layer name or 'raw' - """ - if query_layers_key == "raw": - matrix.X = matrix.raw.X - elif query_layers_key is not None: - matrix.X = matrix.layers[query_layers_key].copy() - def copy_annotations( self, matrix: anndata.AnnData, diff --git a/containers/celltypist/context/main.py b/containers/celltypist/context/main.py index 3a23387..5084374 100644 --- a/containers/celltypist/context/main.py +++ b/containers/celltypist/context/main.py @@ -8,6 +8,7 @@ import scanpy from src.algorithm import Algorithm, RunResult, add_common_arguments +from src.util.layers import set_data_layer class CelltypistOrganMetadata(t.TypedDict): @@ -32,7 +33,7 @@ def do_run( ) -> RunResult: """Annotate data using celltypist.""" data = scanpy.read_h5ad(matrix) - self.set_data_layer(data, options["query_layers_key"]) + data = set_data_layer(data, options["query_layers_key"]) data = self.normalize(data) data, var_names = self.normalize_var_names(data, options) data = celltypist.annotate( @@ -42,20 +43,6 @@ def do_run( return {"data": data, "organ_level": metadata["model"].replace(".", "_")} - def set_data_layer( - self, matrix: scanpy.AnnData, query_layers_key: t.Optional[str] - ) -> None: - """Set the data layer to use for annotating. - - Args: - matrix (anndata.AnnData): Matrix to update - query_layers_key (t.Optional[str]): A layer name or 'raw' - """ - if query_layers_key == "raw": - matrix.X = matrix.raw.X - elif query_layers_key is not None: - matrix.X = matrix.layers[query_layers_key].copy() - def normalize(self, data: scanpy.AnnData) -> scanpy.AnnData: """Normalizes data according to celltypist requirements. diff --git a/containers/popv/context/main.py b/containers/popv/context/main.py index 12edabb..1e88189 100644 --- a/containers/popv/context/main.py +++ b/containers/popv/context/main.py @@ -10,6 +10,7 @@ import torch from src.algorithm import Algorithm, RunResult, add_common_arguments +from src.util.layers import set_data_layer class PopvOrganMetadata(t.TypedDict): @@ -84,12 +85,9 @@ def prepare_query( reference_data = scanpy.read_h5ad(reference_data_path) n_samples_per_label = self.get_n_samples_per_label(reference_data, options) data = self.normalize_var_names(data, options) + data = set_data_layer(data, options["query_layers_key"]) - if options["query_layers_key"] == "raw": - options["query_layers_key"] = None - data.X = numpy.rint(data.raw.X) - - if options["query_layers_key"] == "X": + if options["query_layers_key"] in ('X', 'raw'): options["query_layers_key"] = None data.X = numpy.rint(data.X) diff --git a/src/util/layers.py b/src/util/layers.py new file mode 100644 index 0000000..03a6b54 --- /dev/null +++ b/src/util/layers.py @@ -0,0 +1,26 @@ +import typing as t + +import anndata + + +def set_data_layer(matrix: anndata.AnnData, layer: t.Optional[str]) -> anndata.AnnData: + """Sets the active data layer. + If the layer does not exist it is ignored. + + Args: + matrix (anndata.AnnData): Original matrix with layers + layer (t.Optional[str]): Name of layer or 'X' or 'raw' + + Returns: + anndata.AnnData: A new matrix with the layer set as the X data matrix + """ + if layer in ('X', None): + return matrix + + matrix = matrix.copy() + if layer == 'raw' and matrix.raw is not None: + matrix.X = matrix.raw.X + elif layer in matrix.layers: + matrix.X = matrix.layers[layer].copy() + + return matrix