Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport release-1.15] [python] Re-enable tiledbsoma.ExperimentAxisQuery #3479

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions apis/python/src/tiledbsoma/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,17 @@
except OSError:
# Otherwise try loading by name only.
ctypes.CDLL(libtiledbsoma_name)

from somacore import (
AffineTransform,
Axis,
AxisColumnNames,
AxisQuery,
CoordinateSpace,
AffineTransform,
ScaleTransform,
IdentityTransform,
ScaleTransform,
UniformScaleTransform,
AxisColumnNames,
AxisQuery,
)
from ._query import (
ExperimentAxisQuery,
)
from somacore.options import ResultOrder
Expand Down
38 changes: 19 additions & 19 deletions apis/python/src/tiledbsoma/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def obs(self) -> _T_co: ...
def var(self) -> _T_co: ...


class Axis(enum.Enum):
class AxisName(enum.Enum):
OBS = "obs"
VAR = "var"

Expand Down Expand Up @@ -376,29 +376,29 @@ def obsp(self, layer: str) -> SparseRead:
Lifecycle: maturing
"""
joinids = self._joinids.obs
return self._axisp_get_array(Axis.OBS, layer).read((joinids, joinids))
return self._axisp_get_array(AxisName.OBS, layer).read((joinids, joinids))

def varp(self, layer: str) -> SparseRead:
"""Returns a ``varp`` layer as a sparse read.

Lifecycle: maturing
"""
joinids = self._joinids.var
return self._axisp_get_array(Axis.VAR, layer).read((joinids, joinids))
return self._axisp_get_array(AxisName.VAR, layer).read((joinids, joinids))

def obsm(self, layer: str) -> SparseRead:
"""Returns an ``obsm`` layer as a sparse read.
Lifecycle: maturing
"""
return self._axism_get_array(Axis.OBS, layer).read(
return self._axism_get_array(AxisName.OBS, layer).read(
(self._joinids.obs, slice(None))
)

def varm(self, layer: str) -> SparseRead:
"""Returns a ``varm`` layer as a sparse read.
Lifecycle: maturing
"""
return self._axism_get_array(Axis.VAR, layer).read(
return self._axism_get_array(AxisName.VAR, layer).read(
(self._joinids.var, slice(None))
)

Expand All @@ -421,7 +421,7 @@ def obs_scene_ids(self) -> pa.Array:
)

full_table = obs_scene.read(
coords=((Axis.OBS.getattr_from(self._joinids), slice(None))),
coords=((AxisName.OBS.getattr_from(self._joinids), slice(None))),
result_order=ResultOrder.COLUMN_MAJOR,
value_filter="data != 0",
).concat()
Expand All @@ -448,7 +448,7 @@ def var_scene_ids(self) -> pa.Array:
)

full_table = var_scene.read(
coords=((Axis.VAR.getattr_from(self._joinids), slice(None))),
coords=((AxisName.VAR.getattr_from(self._joinids), slice(None))),
result_order=ResultOrder.COLUMN_MAJOR,
value_filter="data != 0",
).concat()
Expand Down Expand Up @@ -625,7 +625,7 @@ def _read(

obs_table, var_table = tp.map(
self._read_axis_dataframe,
(Axis.OBS, Axis.VAR),
(AxisName.OBS, AxisName.VAR),
(column_names, column_names),
)
obs_joinids = self.obs_joinids()
Expand All @@ -645,19 +645,19 @@ def _read(
x_future = x_matrices.pop(X_name)

obsm_future = {
key: tp.submit(self._axism_inner_ndarray, Axis.OBS, key)
key: tp.submit(self._axism_inner_ndarray, AxisName.OBS, key)
for key in obsm_layers
}
varm_future = {
key: tp.submit(self._axism_inner_ndarray, Axis.VAR, key)
key: tp.submit(self._axism_inner_ndarray, AxisName.VAR, key)
for key in varm_layers
}
obsp_future = {
key: tp.submit(self._axisp_inner_sparray, Axis.OBS, key)
key: tp.submit(self._axisp_inner_sparray, AxisName.OBS, key)
for key in obsp_layers
}
varp_future = {
key: tp.submit(self._axisp_inner_sparray, Axis.VAR, key)
key: tp.submit(self._axisp_inner_sparray, AxisName.VAR, key)
for key in varp_layers
}

Expand All @@ -680,7 +680,7 @@ def _read(

def _read_axis_dataframe(
self,
axis: Axis,
axis: AxisName,
axis_column_names: AxisColumnNames,
) -> pa.Table:
"""Reads the specified axis. Will cache join IDs if not present."""
Expand Down Expand Up @@ -730,7 +730,7 @@ def _read_axis_dataframe(

def _axisp_get_array(
self,
axis: Axis,
axis: AxisName,
layer: str,
) -> SparseNDArray:
p_name = f"{axis.value}p"
Expand All @@ -754,7 +754,7 @@ def _axisp_get_array(

def _axism_get_array(
self,
axis: Axis,
axis: AxisName,
layer: str,
) -> SparseNDArray:
m_name = f"{axis.value}m"
Expand All @@ -776,7 +776,7 @@ def _axism_get_array(
return axism_layer

def _convert_to_ndarray(
self, axis: Axis, table: pa.Table, n_row: int, n_col: int
self, axis: AxisName, table: pa.Table, n_row: int, n_col: int
) -> npt.NDArray[np.float32]:
indexer = cast(
Callable[[Numpyable], npt.NDArray[np.intp]],
Expand All @@ -789,7 +789,7 @@ def _convert_to_ndarray(

def _axisp_inner_sparray(
self,
axis: Axis,
axis: AxisName,
layer: str,
) -> sp.csr_matrix:
joinids = axis.getattr_from(self._joinids)
Expand All @@ -803,7 +803,7 @@ def _axisp_inner_sparray(

def _axism_inner_ndarray(
self,
axis: Axis,
axis: AxisName,
layer: str,
) -> npt.NDArray[np.float32]:
joinids = axis.getattr_from(self._joinids)
Expand Down Expand Up @@ -856,7 +856,7 @@ class JoinIDCache:
_cached_obs: pa.IntegerArray | None = None
_cached_var: pa.IntegerArray | None = None

def _is_cached(self, axis: Axis) -> bool:
def _is_cached(self, axis: AxisName) -> bool:
field = "_cached_" + axis.value
return getattr(self, field) is not None

Expand Down
25 changes: 15 additions & 10 deletions apis/python/tests/test_experiment_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@
from somacore import AxisQuery, options

import tiledbsoma as soma
from tiledbsoma import SOMATileDBContext, _factory, pytiledbsoma
from tiledbsoma import (
ExperimentAxisQuery,
SOMATileDBContext,
_factory,
pytiledbsoma,
)
from tiledbsoma._collection import CollectionBase
from tiledbsoma._experiment import Experiment
from tiledbsoma._query import Axis, ExperimentAxisQuery
from tiledbsoma._query import AxisName
from tiledbsoma.experiment_query import X_as_series

from tests._util import raises_no_typeguard
Expand Down Expand Up @@ -944,12 +949,12 @@ class IHaveObsVarStuff:

def test_axis_helpers() -> None:
thing = IHaveObsVarStuff(obs=1, var=2, the_obs_suf="observe", the_var_suf="vary")
assert 1 == Axis.OBS.getattr_from(thing)
assert 2 == Axis.VAR.getattr_from(thing)
assert "observe" == Axis.OBS.getattr_from(thing, pre="the_", suf="_suf")
assert "vary" == Axis.VAR.getattr_from(thing, pre="the_", suf="_suf")
assert 1 == AxisName.OBS.getattr_from(thing)
assert 2 == AxisName.VAR.getattr_from(thing)
assert "observe" == AxisName.OBS.getattr_from(thing, pre="the_", suf="_suf")
assert "vary" == AxisName.VAR.getattr_from(thing, pre="the_", suf="_suf")
ovdict = {"obs": "erve", "var": "y", "i_obscure": "hide", "i_varcure": "???"}
assert "erve" == Axis.OBS.getitem_from(ovdict)
assert "y" == Axis.VAR.getitem_from(ovdict)
assert "hide" == Axis.OBS.getitem_from(ovdict, pre="i_", suf="cure")
assert "???" == Axis.VAR.getitem_from(ovdict, pre="i_", suf="cure")
assert "erve" == AxisName.OBS.getitem_from(ovdict)
assert "y" == AxisName.VAR.getitem_from(ovdict)
assert "hide" == AxisName.OBS.getitem_from(ovdict, pre="i_", suf="cure")
assert "???" == AxisName.VAR.getitem_from(ovdict, pre="i_", suf="cure")