Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add obsm, etc to ExperimentAxisQuery #179

Merged
merged 22 commits into from
Dec 4, 2023
Merged
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 163 additions & 20 deletions python-spec/src/somacore/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,26 +227,42 @@ def X(
platform_config=platform_config,
)

def obsp(self, layer: str) -> data.SparseRead:
"""Returns an ``obsp`` layer as a sparse read.
def obsp(self, key: str) -> data.SparseRead:
"""Returns an ``obsp`` key as a sparse read.

Lifecycle: maturing
"""
return self._axisp_inner(_Axis.OBS, layer)
return self._axisp_inner(_Axis.OBS, key)

def varp(self, layer: str) -> data.SparseRead:
"""Returns an ``varp`` layer as a sparse read.
def varp(self, key: str) -> data.SparseRead:
"""Returns a ``varp`` key as a sparse read.

Lifecycle: maturing
"""
return self._axisp_inner(_Axis.VAR, layer)
return self._axisp_inner(_Axis.VAR, key)

def obsm(self, key: str) -> data.SparseRead:
"""Returns an ``obsm`` key as a sparse read.
Lifecycle: experimental
"""
return self._axism_inner(_Axis.OBS, key)

def varm(self, key: str) -> data.SparseRead:
"""Returns a ``varm`` key as a sparse read.
Lifecycle: experimental
"""
return self._axism_inner(_Axis.VAR, key)

def to_anndata(
self,
X_name: str,
*,
column_names: Optional[AxisColumnNames] = None,
X_layers: Sequence[str] = (),
obsm_keys: Sequence[str] = (),
obsp_keys: Sequence[str] = (),
varm_keys: Sequence[str] = (),
varp_keys: Sequence[str] = (),
) -> anndata.AnnData:
"""
Executes the query and return result as an ``AnnData`` in-memory object.
Expand All @@ -257,13 +273,25 @@ def to_anndata(
to read.
X_layers: Additional X layers to read and return
in the ``layers`` slot.
obsm_keys:
Additional obsm keys to read and return in the obsm slot.
obsp_keys:
Additional obsp keys to read and return in the obsp slot.
varm_keys:
Additional varm keys to read and return in the varm slot.
varp_keys:
Additional varp keys to read and return in the varp slot.

Lifecycle: maturing
"""
return self._read(
X_name,
column_names=column_names or AxisColumnNames(obs=None, var=None),
X_layers=X_layers,
obsm_keys=obsm_keys,
obsp_keys=obsp_keys,
varm_keys=varm_keys,
varp_keys=varp_keys,
).to_anndata()

# Context management
Expand Down Expand Up @@ -305,19 +333,32 @@ def _read(
*,
column_names: AxisColumnNames,
X_layers: Sequence[str],
obsm_keys: Sequence[str] = (),
obsp_keys: Sequence[str] = (),
varm_keys: Sequence[str] = (),
varp_keys: Sequence[str] = (),
) -> "_AxisQueryResult":
"""Reads the entire query result into in-memory Arrow tables.
"""Reads the entire query result in memory.
bkmartinjr marked this conversation as resolved.
Show resolved Hide resolved


This is a low-level routine intended to be used by loaders for other
in-core formats, such as AnnData, which can be created from the
resulting Tables.
resulting objects.

Args:
X_name: The X layer to read and return in the ``X`` slot.
column_names: The columns in the ``var`` and ``obs`` dataframes
to read.
X_layers: Additional X layers to read and return
in the ``layers`` slot.
obsm_keys:
Additional obsm keys to read and return in the obsm slot.
obsp_keys:
Additional obsp keys to read and return in the obsp slot.
varm_keys:
Additional varm keys to read and return in the varm slot.
varp_keys:
Additional varp keys to read and return in the varp slot.
"""
x_collection = self._ms.X
all_x_names = [X_name] + list(X_layers)
Expand All @@ -341,8 +382,45 @@ def _read(
for _xname in all_x_arrays
}

def _read_axis_mappings(fn, axis, keys: Sequence[str]) -> Dict[str, np.ndarray]:
return {key: fn(axis, key) for key in keys}

obsm_ft = self._threadpool.submit(
bkmartinjr marked this conversation as resolved.
Show resolved Hide resolved
_read_axis_mappings, self._axism_inner_ndarray, _Axis.OBS, obsm_keys
)
obsp_ft = self._threadpool.submit(
_read_axis_mappings, self._axisp_inner_ndarray, _Axis.OBS, obsp_keys
Copy link
Member

@pablo-gar pablo-gar Dec 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bkmartinjr and @ebezzi I believe you had a conversation about this offline (or maybe here but I don't see it). We know that the most common uses case of axisp arrays are numerically sparse, and even though anndata's schema says they should be numpy dense arrays, scanpy's methods fill it in with scipy sparse matrices.

Ultimately we would like to move to a world where either AnnData schema takes both (dense and sparse) or only sparse. Given that the ecosystem already violates the schema in favor of better numerical representation I lean towards tiledbsoma also violating the schema, and then we request AnnData's schema to be relaxed.

What are your thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle I have no issues. You might want to ping Isaac V. and see what he thinks

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks I wanted to make sure I didn't miss anything important. I don't see it as a blocker for now, I will file a few issues (here and in AnnData) to move towards the support of sparse matrices in the axisp arrays.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also look at this page, where they claim that obsm etc can be sparse. This is relative to the on-disk format, but I don't believe there is anything that converts them when loading in memory.

)
varm_ft = self._threadpool.submit(
_read_axis_mappings, self._axism_inner_ndarray, _Axis.VAR, varm_keys
)
varp_ft = self._threadpool.submit(
_read_axis_mappings, self._axisp_inner_ndarray, _Axis.VAR, varp_keys
)

obsm = obsm_ft.result()
obsp = obsp_ft.result()
varm = varm_ft.result()
varp = varp_ft.result()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like these could be passed as the named arguments to _AxisQueryResult without the need for a temporary variable.


x = x_matrices.pop(X_name)
return _AxisQueryResult(obs=obs_table, var=var_table, X=x, X_layers=x_matrices)

obs = obs_table.to_pandas()
obs.index = obs.index.astype(str)

var = var_table.to_pandas()
var.index = var.index.astype(str)

return _AxisQueryResult(
obs=obs,
var=var,
X=x,
obsm=obsm,
obsp=obsp,
varm=varm,
varp=varp,
X_layers=x_matrices,
)

def _read_both_axes(
self,
Expand Down Expand Up @@ -434,6 +512,62 @@ def _axisp_inner(
joinids = getattr(self._joinids, axis.value)
return axisp[layer].read((joinids, joinids))

def _axism_inner(
ebezzi marked this conversation as resolved.
Show resolved Hide resolved
self,
axis: "_Axis",
layer: str,
) -> data.SparseRead:
key = axis.value + "m"

if key not in self._ms:
raise ValueError(f"Measurement does not contain {key} data")

axism = self._ms.obsm if axis is _Axis.OBS else self._ms.varm
if not (layer and layer in axism):
raise ValueError(f"Must specify '{key}' layer")

ebezzi marked this conversation as resolved.
Show resolved Hide resolved
bkmartinjr marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(axism[layer], data.SparseNDArray):
raise TypeError(f"Unexpected SOMA type stored in '{key}' layer")

joinids = getattr(self._joinids, axis.value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

axis.getitem_from(self._joinids)

return axism[layer].read((joinids, slice(None)))

def _convert_to_ndarray(
self, axis: "_Axis", table: pa.Table, n_row: int, n_col: int
) -> np.ndarray:
idx = (self.indexer.by_obs if (axis is _Axis.OBS) else self.indexer.by_var)(
table["soma_dim_0"]
)
Z = np.zeros(n_row * n_col, dtype=np.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: recommend making z lowercase

np.put(Z, idx * n_col + table["soma_dim_1"], table["soma_data"])
return Z.reshape(n_row, n_col)

def _axisp_inner_ndarray(
bkmartinjr marked this conversation as resolved.
Show resolved Hide resolved
self,
axis: "_Axis",
layer: str,
) -> np.ndarray:
is_obs = axis is _Axis.OBS
n_row = n_col = len(self._joinids.obs) if is_obs else len(self._joinids.var)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not something for you to do; just thinking aloud here:

it might make sense to add a function to _Axis that gets the given attribute, so that this could look something like:

n_row = n_col = axis.getattr_from(self._joinids)

so it does self._joinids.obs and self._joinids.var by itself without you having to specify (and thus avoids potential obs/var switcheroos)

(maybe also have it take a suffix, so you could say axis.getattr_from(something, "m") to get something.obsm/something.varm)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decided to see what this would look like here #183

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be great - I approved #183. If you want to merge it, I can accommodate the changes here, otherwise we can do it later.


table = self._axisp_inner(axis, layer).tables().concat()
return self._convert_to_ndarray(axis, table, n_row, n_col)

def _axism_inner_ndarray(
self,
axis: "_Axis",
layer: str,
) -> np.ndarray:
is_obs = axis is _Axis.OBS

axism = self._ms.obsm if is_obs else self._ms.varm

_, n_col = axism[layer].shape
n_row = len(self._joinids.obs) if is_obs else len(self._joinids.var)

table = self._axism_inner(axis, layer).tables().concat()
return self._convert_to_ndarray(axis, table, n_row, n_col)

@property
def _obs_df(self) -> data.DataFrame:
return self.experiment.obs
Expand Down Expand Up @@ -464,24 +598,33 @@ def _threadpool(self) -> futures.ThreadPoolExecutor:
class _AxisQueryResult:
"""The result of running :meth:`ExperimentAxisQuery.read`. Private."""

obs: pa.Table
"""Experiment.obs query slice, as an Arrow Table"""
var: pa.Table
"""Experiment.ms[...].var query slice, as an Arrow Table"""
obs: pd.DataFrame
"""Experiment.obs query slice, as a pandas DataFrame"""
var: pd.DataFrame
"""Experiment.ms[...].var query slice, as a pandas DataFrame"""
X: sparse.csr_matrix
"""Experiment.ms[...].X[...] query slice, as an SciPy sparse.csr_matrix """
X_layers: Dict[str, sparse.csr_matrix] = attrs.field(factory=dict)
"""Any additional X layers requested, as SciPy sparse.csr_matrix(s)"""
obsm: Dict[str, np.ndarray] = attrs.field(factory=dict)
"""Experiment.obsm query slice, as a numpy ndarray"""
obsp: Dict[str, np.ndarray] = attrs.field(factory=dict)
"""Experiment.obsp query slice, as a numpy ndarray"""
varm: Dict[str, np.ndarray] = attrs.field(factory=dict)
"""Experiment.varm query slice, as a numpy ndarray"""
varp: Dict[str, np.ndarray] = attrs.field(factory=dict)
"""Experiment.varp query slice, as a numpy ndarray"""

def to_anndata(self) -> anndata.AnnData:
obs = self.obs.to_pandas()
obs.index = obs.index.astype(str)

var = self.var.to_pandas()
var.index = var.index.astype(str)

return anndata.AnnData(
X=self.X, obs=obs, var=var, layers=(self.X_layers or None)
X=self.X,
obs=self.obs,
var=self.var,
obsm=(self.obsm or None),
obsp=(self.obsp or None),
varm=(self.varm or None),
varp=(self.varp or None),
layers=(self.X_layers or None),
)


Expand Down