Skip to content

Commit

Permalink
ExperimentAxisQuery - to_anndata obsp/varp as sparse matrices (#3387) (
Browse files Browse the repository at this point in the history
…#3413)

* small performance improvement

* obsp/varp as sparse matrices

Co-authored-by: Bruce Martin <[email protected]>
  • Loading branch information
github-actions[bot] and bkmartinjr authored Dec 9, 2024
1 parent f48de86 commit 7df7343
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 75 deletions.
152 changes: 83 additions & 69 deletions apis/python/src/tiledbsoma/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""Implementation of a SOMA Experiment.
"""
import enum
from concurrent.futures import Future, ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -164,12 +164,12 @@ class AxisQueryResult:
"""Any additional X layers requested, as SciPy sparse.csr_matrix(s)"""
obsm: Dict[str, npt.NDArray[Any]] = attrs.field(factory=dict)
"""Experiment.obsm query slice, as a numpy ndarray"""
obsp: Dict[str, npt.NDArray[Any]] = attrs.field(factory=dict)
"""Experiment.obsp query slice, as a numpy ndarray"""
obsp: Dict[str, sp.csr_matrix] = attrs.field(factory=dict)
"""Experiment.obsp query slice, as SciPy sparse.csr_matrix(s)"""
varm: Dict[str, npt.NDArray[Any]] = attrs.field(factory=dict)
"""Experiment.varm query slice, as a numpy ndarray"""
varp: Dict[str, npt.NDArray[Any]] = attrs.field(factory=dict)
"""Experiment.varp query slice, as a numpy ndarray"""
varp: Dict[str, sp.csr_matrix] = attrs.field(factory=dict)
"""Experiment.varp query slice, as SciPy sparse.csr_matrix(s)"""

def to_anndata(self) -> AnnData:
return AnnData(
Expand Down Expand Up @@ -371,26 +371,32 @@ def obsp(self, layer: str) -> SparseRead:
Lifecycle: maturing
"""
return self._axisp_inner(Axis.OBS, layer)
joinids = self._joinids.obs
return self._axisp_get_array(Axis.OBS, layer).read((joinids, joinids))

def varp(self, layer: str) -> SparseRead:
"""Returns a ``varp`` layer as a sparse read.
Lifecycle: maturing
"""
return self._axisp_inner(Axis.VAR, layer)
joinids = self._joinids.var
return self._axisp_get_array(Axis.VAR, layer).read((joinids, joinids))

def obsm(self, layer: str) -> SparseRead:
"""Returns an ``obsm`` layer as a sparse read.
Lifecycle: maturing
"""
return self._axism_inner(Axis.OBS, layer)
return self._axism_get_array(Axis.OBS, layer).read(
(self._joinids.obs, slice(None))
)

def varm(self, layer: str) -> SparseRead:
"""Returns a ``varm`` layer as a sparse read.
Lifecycle: maturing
"""
return self._axism_inner(Axis.VAR, layer)
return self._axism_get_array(Axis.VAR, layer).read(
(self._joinids.var, slice(None))
)

def obs_scene_ids(self) -> pa.Array:
"""Returns a pyarrow array with scene ids that contain obs from this
Expand Down Expand Up @@ -509,6 +515,7 @@ def _read(
varp_layers:
Additional varp layers to read and return in the varp slot.
"""
tp = self._threadpool
x_collection = self._ms.X
all_x_names = [X_name] + list(X_layers)
all_x_arrays: Dict[str, SparseNDArray] = {}
Expand All @@ -522,7 +529,7 @@ def _read(
raise NotImplementedError("Dense array unsupported")
all_x_arrays[_xname] = x_array

obs_table, var_table = self._threadpool.map(
obs_table, var_table = tp.map(
self._read_axis_dataframe,
(Axis.OBS, Axis.VAR),
(column_names, column_names),
Expand All @@ -531,32 +538,34 @@ def _read(
var_joinids = self.var_joinids()

x_matrices = {
_xname: self._threadpool.submit(
_read_as_csr, layer, obs_joinids, var_joinids, self._indexer
_xname: tp.submit(
_read_as_csr,
layer,
obs_joinids,
var_joinids,
self._indexer.by_obs,
self._indexer.by_var,
)
for _xname, layer in all_x_arrays.items()
}
x_future = x_matrices.pop(X_name)

def _read_axis_mappings(
fn: Callable[[Axis, str], npt.NDArray[Any]],
axis: Axis,
keys: Sequence[str],
) -> Dict[str, Future[npt.NDArray[Any]]]:
return {key: self._threadpool.submit(fn, axis, key) for key in keys}

obsm_future = _read_axis_mappings(
self._axism_inner_ndarray, Axis.OBS, obsm_layers
)
obsp_future = _read_axis_mappings(
self._axisp_inner_ndarray, Axis.OBS, obsp_layers
)
varm_future = _read_axis_mappings(
self._axism_inner_ndarray, Axis.VAR, varm_layers
)
varp_future = _read_axis_mappings(
self._axisp_inner_ndarray, Axis.VAR, varp_layers
)
obsm_future = {
key: tp.submit(self._axism_inner_ndarray, Axis.OBS, key)
for key in obsm_layers
}
varm_future = {
key: tp.submit(self._axism_inner_ndarray, Axis.VAR, key)
for key in varm_layers
}
obsp_future = {
key: tp.submit(self._axisp_inner_sparray, Axis.OBS, key)
for key in obsp_layers
}
varp_future = {
key: tp.submit(self._axisp_inner_sparray, Axis.VAR, key)
for key in varp_layers
}

obs = obs_table.to_pandas()
obs.index = obs.index.astype(str)
Expand Down Expand Up @@ -625,11 +634,11 @@ def _read_axis_dataframe(
arrow_table = arrow_table.drop(["soma_joinid"])
return arrow_table

def _axisp_inner(
def _axisp_get_array(
self,
axis: Axis,
layer: str,
) -> SparseRead:
) -> SparseNDArray:
p_name = f"{axis.value}p"
try:
ms = self._ms
Expand All @@ -638,23 +647,22 @@ def _axisp_inner(
raise ValueError(f"Measurement does not contain {p_name} data")

try:
ap_layer = axisp[layer]
axisp_layer = axisp[layer]
except KeyError:
raise ValueError(f"layer {layer!r} is not available in {p_name}")
if not isinstance(ap_layer, SparseNDArray):
if not isinstance(axisp_layer, SparseNDArray):
raise TypeError(
f"Unexpected SOMA type {type(ap_layer).__name__}"
f"Unexpected SOMA type {type(axisp_layer).__name__}"
f" stored in {p_name} layer {layer!r}"
)

joinids = axis.getattr_from(self._joinids)
return ap_layer.read((joinids, joinids))
return axisp_layer

def _axism_inner(
def _axism_get_array(
self,
axis: Axis,
layer: str,
) -> SparseRead:
) -> SparseNDArray:
m_name = f"{axis.value}m"

try:
Expand All @@ -671,8 +679,7 @@ def _axism_inner(
if not isinstance(axism_layer, SparseNDArray):
raise TypeError(f"Unexpected SOMA type stored in '{m_name}' layer")

joinids = axis.getattr_from(self._joinids)
return axism_layer.read((joinids, slice(None)))
return axism_layer

def _convert_to_ndarray(
self, axis: Axis, table: pa.Table, n_row: int, n_col: int
Expand All @@ -686,24 +693,34 @@ def _convert_to_ndarray(
np.put(z, idx * n_col + table["soma_dim_1"], table["soma_data"])
return z.reshape(n_row, n_col)

def _axisp_inner_ndarray(
def _axisp_inner_sparray(
self,
axis: Axis,
layer: str,
) -> npt.NDArray[np.float32]:
n_row = n_col = len(axis.getattr_from(self._joinids))

table = self._axisp_inner(axis, layer).tables().concat()
return self._convert_to_ndarray(axis, table, n_row, n_col)
) -> sp.csr_matrix:
joinids = axis.getattr_from(self._joinids)
indexer = cast(
Callable[[Numpyable], npt.NDArray[np.intp]],
axis.getattr_from(self.indexer, pre="by_"),
)
return _read_as_csr(
self._axisp_get_array(axis, layer), joinids, joinids, indexer, indexer
)

def _axism_inner_ndarray(
self,
axis: Axis,
layer: str,
) -> npt.NDArray[np.float32]:
table = self._axism_inner(axis, layer).tables().concat()
joinids = axis.getattr_from(self._joinids)
table = (
self._axism_get_array(axis, layer)
.read((joinids, slice(None)))
.tables()
.concat()
)

n_row = len(axis.getattr_from(self._joinids))
n_row = len(joinids)
n_col = len(table["soma_dim_1"].unique())

return self._convert_to_ndarray(axis, table, n_row, n_col)
Expand Down Expand Up @@ -796,19 +813,20 @@ def load_joinids(df: DataFrame, axq: AxisQuery) -> pa.IntegerArray:

def _read_as_csr(
matrix: SparseNDArray,
obs_joinids_arr: pa.IntegerArray,
var_joinids_arr: pa.IntegerArray,
indexer: AxisIndexer,
d0_joinids_arr: pa.IntegerArray,
d1_joinids_arr: pa.IntegerArray,
d0_indexer: Callable[[Numpyable], npt.NDArray[np.intp]],
d1_indexer: Callable[[Numpyable], npt.NDArray[np.intp]],
) -> sp.csr_matrix:

obs_joinids = obs_joinids_arr.to_numpy()
var_joinids = var_joinids_arr.to_numpy()
d0_joinids = d0_joinids_arr.to_numpy()
d1_joinids = d1_joinids_arr.to_numpy()
nnz = matrix.nnz

# if able, downcast from int64 - reduces working memory
index_dtype = (
np.int32
if max(len(obs_joinids), len(var_joinids)) < np.iinfo(np.int32).max
if max(len(d0_joinids), len(d1_joinids)) < np.iinfo(np.int32).max
else np.int64
)
pa_schema = pa.schema(
Expand All @@ -825,12 +843,8 @@ def _read_and_reindex(
def _reindex(batch: pa.RecordBatch) -> pa.RecordBatch:
return pa.RecordBatch.from_pydict(
{
"soma_dim_0": indexer.by_obs(batch["soma_dim_0"]).astype(
index_dtype
),
"soma_dim_1": indexer.by_var(batch["soma_dim_1"]).astype(
index_dtype
),
"soma_dim_0": d0_indexer(batch["soma_dim_0"]).astype(index_dtype),
"soma_dim_1": d1_indexer(batch["soma_dim_1"]).astype(index_dtype),
"soma_data": batch["soma_data"],
},
schema=pa_schema,
Expand Down Expand Up @@ -858,25 +872,25 @@ def _reindex(batch: pa.RecordBatch) -> pa.RecordBatch:
splits = list(
range(
partition_size,
len(obs_joinids) - partition_size + 1,
len(d0_joinids) - partition_size + 1,
partition_size,
)
)
if len(splits) > 0:
obs_joinids_splits = np.array_split(np.partition(obs_joinids, splits), splits)
d0_joinids_splits = np.array_split(np.partition(d0_joinids, splits), splits)
tp = matrix.context.threadpool
tbl = pa.concat_tables(
tp.map(
_read_and_reindex,
(matrix,) * len(obs_joinids_splits),
obs_joinids_splits,
(var_joinids,) * len(obs_joinids_splits),
(matrix,) * len(d0_joinids_splits),
d0_joinids_splits,
(d1_joinids,) * len(d0_joinids_splits),
)
)

else:
tbl = _read_and_reindex(matrix, obs_joinids, var_joinids)
tbl = _read_and_reindex(matrix, d0_joinids, d1_joinids)

return CompressedMatrix.from_soma(
tbl, (len(obs_joinids), len(var_joinids)), "csr", True, matrix.context
tbl, (len(d0_joinids), len(d1_joinids)), "csr", True, matrix.context
).to_scipy()
5 changes: 3 additions & 2 deletions apis/python/src/tiledbsoma/io/outgest.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,10 @@ def _read_partitioned_sparse(X: SparseNDArray, d0_size: int) -> pa.Table:
# density of matrix. Magic number determined empirically, as a tradeoff
# between concurrency and fixed query overhead.
tgt_point_count = 96 * 1024**2
nnz = X.nnz
partition_sz = (
max(1024 * round(d0_size * tgt_point_count / X.nnz / 1024), 1024)
if X.nnz > 0
max(1024 * round(d0_size * tgt_point_count / nnz / 1024), 1024)
if nnz > 0
else d0_size
)
partitions = [
Expand Down
12 changes: 8 additions & 4 deletions apis/python/tests/test_experiment_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,19 +637,23 @@ def test_experiment_query_to_anndata_obsp_varp(soma_experiment):
ad = query.to_anndata("raw", obsp_layers=["foo"], varp_layers=["bar"])
assert set(ad.obsp.keys()) == {"foo"}
obsp = ad.obsp["foo"]
assert isinstance(obsp, np.ndarray)
assert isinstance(obsp, sparse.spmatrix)
assert sparse.isspmatrix_csr(obsp)
assert obsp.shape == (query.n_obs, query.n_obs)

assert (query.obsp("foo").coos().concat().to_scipy() != obsp).nnz == 0
assert np.array_equal(
query.obsp("foo").coos().concat().to_scipy().todense(), obsp
query.obsp("foo").coos().concat().to_scipy().todense(), obsp.todense()
)

assert set(ad.varp.keys()) == {"bar"}
varp = ad.varp["bar"]
assert isinstance(varp, np.ndarray)
assert isinstance(varp, sparse.spmatrix)
assert sparse.isspmatrix_csr(varp)
assert varp.shape == (query.n_vars, query.n_vars)
assert (query.varp("bar").coos().concat().to_scipy() != varp).nnz == 0
assert np.array_equal(
query.varp("bar").coos().concat().to_scipy().todense(), varp
query.varp("bar").coos().concat().to_scipy().todense(), varp.todense()
)


Expand Down

0 comments on commit 7df7343

Please sign in to comment.