Skip to content

Commit

Permalink
Improve layer selection logic
Browse files Browse the repository at this point in the history
  • Loading branch information
axdanbol committed Mar 4, 2024
1 parent af0f051 commit 0497a6b
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 36 deletions.
18 changes: 2 additions & 16 deletions containers/azimuth/context/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pandas

from src.algorithm import Algorithm, RunResult, add_common_arguments
from src.util.layers import set_data_layer


class AzimuthOrganMetadata(t.TypedDict):
Expand Down Expand Up @@ -41,8 +42,7 @@ def do_run(
temp_index = self.create_temp_obs_index(data)
clean_matrix_path = Path("clean_matrix.h5ad")
clean_matrix = self.create_clean_matrix(data, temp_index)

self.set_data_layer(clean_matrix, options["query_layers_key"])
clean_matrix = set_data_layer(clean_matrix, options["query_layers_key"])
clean_matrix.write_h5ad(clean_matrix_path)

annotated_matrix_path = self.run_azimuth_scripts(
Expand Down Expand Up @@ -90,20 +90,6 @@ def create_clean_matrix(

return clean_matrix

def set_data_layer(
self, matrix: anndata.AnnData, query_layers_key: t.Optional[str]
) -> None:
"""Set the data layer to use for annotating.
Args:
matrix (anndata.AnnData): Matrix to update
query_layers_key (t.Optional[str]): A layer name or 'raw'
"""
if query_layers_key == "raw":
matrix.X = matrix.raw.X
elif query_layers_key is not None:
matrix.X = matrix.layers[query_layers_key].copy()

def copy_annotations(
self,
matrix: anndata.AnnData,
Expand Down
17 changes: 2 additions & 15 deletions containers/celltypist/context/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import scanpy

from src.algorithm import Algorithm, RunResult, add_common_arguments
from src.util.layers import set_data_layer


class CelltypistOrganMetadata(t.TypedDict):
Expand All @@ -32,7 +33,7 @@ def do_run(
) -> RunResult:
"""Annotate data using celltypist."""
data = scanpy.read_h5ad(matrix)
self.set_data_layer(data, options["query_layers_key"])
data = set_data_layer(data, options["query_layers_key"])
data = self.normalize(data)
data, var_names = self.normalize_var_names(data, options)
data = celltypist.annotate(
Expand All @@ -42,20 +43,6 @@ def do_run(

return {"data": data, "organ_level": metadata["model"].replace(".", "_")}

def set_data_layer(
self, matrix: scanpy.AnnData, query_layers_key: t.Optional[str]
) -> None:
"""Set the data layer to use for annotating.
Args:
matrix (anndata.AnnData): Matrix to update
query_layers_key (t.Optional[str]): A layer name or 'raw'
"""
if query_layers_key == "raw":
matrix.X = matrix.raw.X
elif query_layers_key is not None:
matrix.X = matrix.layers[query_layers_key].copy()

def normalize(self, data: scanpy.AnnData) -> scanpy.AnnData:
"""Normalizes data according to celltypist requirements.
Expand Down
8 changes: 3 additions & 5 deletions containers/popv/context/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch

from src.algorithm import Algorithm, RunResult, add_common_arguments
from src.util.layers import set_data_layer


class PopvOrganMetadata(t.TypedDict):
Expand Down Expand Up @@ -84,12 +85,9 @@ def prepare_query(
reference_data = scanpy.read_h5ad(reference_data_path)
n_samples_per_label = self.get_n_samples_per_label(reference_data, options)
data = self.normalize_var_names(data, options)
data = set_data_layer(data, options["query_layers_key"])

if options["query_layers_key"] == "raw":
options["query_layers_key"] = None
data.X = numpy.rint(data.raw.X)

if options["query_layers_key"] == "X":
if options["query_layers_key"] in ('X', 'raw'):
options["query_layers_key"] = None
data.X = numpy.rint(data.X)

Expand Down
26 changes: 26 additions & 0 deletions src/util/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import typing as t

import anndata


def set_data_layer(matrix: anndata.AnnData, layer: t.Optional[str]) -> anndata.AnnData:
"""Sets the active data layer.
If the layer does not exist it is ignored.
Args:
matrix (anndata.AnnData): Original matrix with layers
layer (t.Optional[str]): Name of layer or 'X' or 'raw'
Returns:
anndata.AnnData: A new matrix with the layer set as the X data matrix
"""
if layer in ('X', None):
return matrix

matrix = matrix.copy()
if layer == 'raw' and matrix.raw is not None:
matrix.X = matrix.raw.X
elif layer in matrix.layers:
matrix.X = matrix.layers[layer].copy()

return matrix

0 comments on commit 0497a6b

Please sign in to comment.