Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add obsm, etc to ExperimentAxisQuery #179

Merged
merged 22 commits into from
Dec 4, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 153 additions & 16 deletions python-spec/src/somacore/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,18 +236,34 @@ def obsp(self, layer: str) -> data.SparseRead:
return self._axisp_inner(_Axis.OBS, layer)

def varp(self, layer: str) -> data.SparseRead:
"""Returns an ``varp`` layer as a sparse read.
"""Returns a ``varp`` layer as a sparse read.

Lifecycle: maturing
"""
return self._axisp_inner(_Axis.VAR, layer)

def obsm(self, layer: str) -> data.SparseRead:
bkmartinjr marked this conversation as resolved.
Show resolved Hide resolved
"""Returns an ``obsm`` layer as a sparse read.
Lifecycle: experimental
"""
return self._axism_inner(_Axis.OBS, layer)

def varm(self, layer: str) -> data.SparseRead:
"""Returns a ``varm`` layer as a sparse read.
Lifecycle: experimental
"""
return self._axism_inner(_Axis.VAR, layer)

def to_anndata(
self,
X_name: str,
*,
column_names: Optional[AxisColumnNames] = None,
X_layers: Sequence[str] = (),
obsm_layers: Sequence[str] = (),
obsp_layers: Sequence[str] = (),
varm_layers: Sequence[str] = (),
varp_layers: Sequence[str] = (),
) -> anndata.AnnData:
"""
Executes the query and return result as an ``AnnData`` in-memory object.
Expand All @@ -258,13 +274,25 @@ def to_anndata(
to read.
X_layers: Additional X layers to read and return
in the ``layers`` slot.
obsm_layers:
Additional obsm layers to read and return in the obsm slot.
obsp_layers:
Additional obsp layers to read and return in the obsp slot.
varm_layers:
Additional varm layers to read and return in the varm slot.
varp_layers:
Additional varp layers to read and return in the varp slot.

Lifecycle: maturing
"""
return self._read(
X_name,
column_names=column_names or AxisColumnNames(obs=None, var=None),
X_layers=X_layers,
obsm_layers=obsm_layers,
obsp_layers=obsp_layers,
varm_layers=varm_layers,
varp_layers=varp_layers,
).to_anndata()

# Context management
Expand Down Expand Up @@ -306,19 +334,32 @@ def _read(
*,
column_names: AxisColumnNames,
X_layers: Sequence[str],
obsm_layers: Sequence[str] = (),
obsp_layers: Sequence[str] = (),
varm_layers: Sequence[str] = (),
varp_layers: Sequence[str] = (),
) -> "_AxisQueryResult":
"""Reads the entire query result into in-memory Arrow tables.
"""Reads the entire query result in memory.
bkmartinjr marked this conversation as resolved.
Show resolved Hide resolved


This is a low-level routine intended to be used by loaders for other
in-core formats, such as AnnData, which can be created from the
resulting Tables.
resulting objects.

Args:
X_name: The X layer to read and return in the ``X`` slot.
column_names: The columns in the ``var`` and ``obs`` dataframes
to read.
X_layers: Additional X layers to read and return
in the ``layers`` slot.
obsm_layers:
Additional obsm layers to read and return in the obsm slot.
obsp_layers:
Additional obsp layers to read and return in the obsp slot.
varm_layers:
Additional varm layers to read and return in the varm slot.
varp_layers:
Additional varp layers to read and return in the varp slot.
"""
x_collection = self._ms.X
all_x_names = [X_name] + list(X_layers)
Expand All @@ -333,6 +374,22 @@ def _read(
raise NotImplementedError("Dense array unsupported")
all_x_arrays[_xname] = x_array

def _read_axis_mappings(fn, axis, keys: Sequence[str]) -> Dict[str, np.ndarray]:
return {key: fn(axis, key) for key in keys}

obsm_ft = self._threadpool.submit(
bkmartinjr marked this conversation as resolved.
Show resolved Hide resolved
_read_axis_mappings, self._axism_inner_ndarray, _Axis.OBS, obsm_layers
)
obsp_ft = self._threadpool.submit(
_read_axis_mappings, self._axisp_inner_ndarray, _Axis.OBS, obsp_layers
)
varm_ft = self._threadpool.submit(
_read_axis_mappings, self._axism_inner_ndarray, _Axis.VAR, varm_layers
)
varp_ft = self._threadpool.submit(
_read_axis_mappings, self._axisp_inner_ndarray, _Axis.VAR, varp_layers
)

obs_table, var_table = self._read_both_axes(column_names)

x_matrices = {
Expand All @@ -343,7 +400,23 @@ def _read(
}

x = x_matrices.pop(X_name)
return _AxisQueryResult(obs=obs_table, var=var_table, X=x, X_layers=x_matrices)

obs = obs_table.to_pandas()
obs.index = obs.index.astype(str)

var = var_table.to_pandas()
var.index = var.index.astype(str)

return _AxisQueryResult(
obs=obs,
var=var,
X=x,
obsm=obsm_ft.result(),
obsp=obsp_ft.result(),
varm=varm_ft.result(),
varp=varp_ft.result(),
X_layers=x_matrices,
)

def _read_both_axes(
self,
Expand Down Expand Up @@ -433,9 +506,64 @@ def _axisp_inner(
f" stored in {p_name} layer {layer!r}"
)

joinids = getattr(self._joinids, axis.value)
joinids = axis.getattr_from(self._joinids)
return ap_layer.read((joinids, joinids))

def _axism_inner(
ebezzi marked this conversation as resolved.
Show resolved Hide resolved
self,
axis: "_Axis",
layer: str,
) -> data.SparseRead:
m_name = f"{axis.value}m"

try:
axism = axis.getitem_from(self._ms, suf="m")
except KeyError:
raise ValueError(f"Measurement does not contain {m_name} data") from None

try:
axism_layer = axism[layer]
except KeyError as ke:
raise ValueError(f"layer {layer!r} is not available in {m_name}") from ke

bkmartinjr marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(axism_layer, data.SparseNDArray):
raise TypeError(f"Unexpected SOMA type stored in '{m_name}' layer")

ebezzi marked this conversation as resolved.
Show resolved Hide resolved
joinids = axis.getattr_from(self._joinids)
return axism_layer.read((joinids, slice(None)))

def _convert_to_ndarray(
self, axis: "_Axis", table: pa.Table, n_row: int, n_col: int
) -> np.ndarray:
indexer: pd.Index = axis.getattr_from(self.indexer, pre="by_")
idx = indexer(table["soma_dim_0"])
z = np.zeros(n_row * n_col, dtype=np.float32)
np.put(z, idx * n_col + table["soma_dim_1"], table["soma_data"])
return z.reshape(n_row, n_col)

def _axisp_inner_ndarray(
bkmartinjr marked this conversation as resolved.
Show resolved Hide resolved
self,
axis: "_Axis",
layer: str,
) -> np.ndarray:
n_row = n_col = len(axis.getattr_from(self._joinids))

table = self._axisp_inner(axis, layer).tables().concat()
return self._convert_to_ndarray(axis, table, n_row, n_col)

def _axism_inner_ndarray(
self,
axis: "_Axis",
layer: str,
) -> np.ndarray:
axism = axis.getitem_from(self._ms, suf="m")

_, n_col = axism[layer].shape
n_row = len(axis.getattr_from(self._joinids))

table = self._axism_inner(axis, layer).tables().concat()
return self._convert_to_ndarray(axis, table, n_row, n_col)

@property
def _obs_df(self) -> data.DataFrame:
return self.experiment.obs
Expand Down Expand Up @@ -466,24 +594,33 @@ def _threadpool(self) -> futures.ThreadPoolExecutor:
class _AxisQueryResult:
"""The result of running :meth:`ExperimentAxisQuery.read`. Private."""

obs: pa.Table
"""Experiment.obs query slice, as an Arrow Table"""
var: pa.Table
"""Experiment.ms[...].var query slice, as an Arrow Table"""
obs: pd.DataFrame
"""Experiment.obs query slice, as a pandas DataFrame"""
var: pd.DataFrame
"""Experiment.ms[...].var query slice, as a pandas DataFrame"""
X: sparse.csr_matrix
"""Experiment.ms[...].X[...] query slice, as an SciPy sparse.csr_matrix """
X_layers: Dict[str, sparse.csr_matrix] = attrs.field(factory=dict)
"""Any additional X layers requested, as SciPy sparse.csr_matrix(s)"""
obsm: Dict[str, np.ndarray] = attrs.field(factory=dict)
"""Experiment.obsm query slice, as a numpy ndarray"""
obsp: Dict[str, np.ndarray] = attrs.field(factory=dict)
"""Experiment.obsp query slice, as a numpy ndarray"""
varm: Dict[str, np.ndarray] = attrs.field(factory=dict)
"""Experiment.varm query slice, as a numpy ndarray"""
varp: Dict[str, np.ndarray] = attrs.field(factory=dict)
"""Experiment.varp query slice, as a numpy ndarray"""

def to_anndata(self) -> anndata.AnnData:
obs = self.obs.to_pandas()
obs.index = obs.index.astype(str)

var = self.var.to_pandas()
var.index = var.index.astype(str)

return anndata.AnnData(
X=self.X, obs=obs, var=var, layers=(self.X_layers or None)
X=self.X,
obs=self.obs,
var=self.var,
obsm=(self.obsm or None),
obsp=(self.obsp or None),
varm=(self.varm or None),
varp=(self.varp or None),
layers=(self.X_layers or None),
)


Expand Down