Skip to content

Commit

Permalink
[python/c++] Use C++ bindings for SparseNDArray write path
Browse files Browse the repository at this point in the history
* Also modify SOMA typechecks to be case insensitive
  • Loading branch information
nguyenv committed Apr 22, 2024
1 parent c52724a commit 1223191
Show file tree
Hide file tree
Showing 25 changed files with 1,113 additions and 680 deletions.
156 changes: 132 additions & 24 deletions apis/python/src/tiledbsoma/_sparse_nd_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
from somacore.options import PlatformConfig
from typing_extensions import Self

import tiledb

from . import _util

# This package's pybind11 code
Expand All @@ -39,8 +37,12 @@
SparseCOOTensorReadIter,
TableReadIter,
)
from ._tdb_handles import ArrayWrapper, SparseNDArrayWrapper
from ._types import NTuple
from ._tdb_handles import SparseNDArrayWrapper
from ._types import NTuple, OpenTimestamp
from .options._soma_tiledb_context import (
SOMATileDBContext,
_validate_soma_tiledb_context,
)
from .options._tiledb_create_options import TileDBCreateOptions

_UNBATCHED = options.BatchSize()
Expand Down Expand Up @@ -96,13 +98,106 @@ class SparseNDArray(NDArray, somacore.SparseNDArray):

__slots__ = ()

_wrapper_type = ArrayWrapper
_wrapper_type = SparseNDArrayWrapper
_reader_wrapper_type = SparseNDArrayWrapper

# Inherited from somacore
# * ndim accessor
# * is_sparse: Final = True

@classmethod
def create(
cls,
uri: str,
*,
type: pa.DataType,
shape: Sequence[Union[int, None]],
platform_config: Optional[options.PlatformConfig] = None,
context: Optional[SOMATileDBContext] = None,
tiledb_timestamp: Optional[OpenTimestamp] = None,
) -> Self:
context = _validate_soma_tiledb_context(context)

plt_cfg = None
if platform_config:
ops = TileDBCreateOptions.from_platform_config(platform_config)
plt_cfg = clib.PlatformConfig()
plt_cfg.dataframe_dim_zstd_level = ops.dataframe_dim_zstd_level
plt_cfg.sparse_nd_array_dim_zstd_level = ops.sparse_nd_array_dim_zstd_level
plt_cfg.write_X_chunked = ops.write_X_chunked
plt_cfg.goal_chunk_nnz = ops.goal_chunk_nnz
plt_cfg.capacity = ops.capacity
if ops.offsets_filters:
plt_cfg.offsets_filters = [
info["_type"] for info in ops.offsets_filters
]
if ops.validity_filters:
plt_cfg.validity_filters = [
info["_type"] for info in ops.validity_filters
]
plt_cfg.allows_duplicates = ops.allows_duplicates
plt_cfg.tile_order = ops.tile_order
plt_cfg.cell_order = ops.cell_order
plt_cfg.consolidate_and_vacuum = ops.consolidate_and_vacuum

index_column_names = []
domains = []
extents = []
for dim_idx, dim_shape in enumerate(shape):
dim_name = f"soma_dim_{dim_idx}"
dim_capacity, dim_extent = cls._dim_capacity_and_extent(
dim_name,
dim_shape,
TileDBCreateOptions.from_platform_config(platform_config),
)
index_column_names.append(dim_name)
domains.append(pa.array([0, dim_capacity - 1], type=pa.int64()))
extents.append(pa.array([dim_extent], type=pa.int64()))

domains = pa.StructArray.from_arrays(domains, names=index_column_names)
extents = pa.StructArray.from_arrays(extents, names=index_column_names)

to_pa_ctype = {
pa.bool_(): "b",
pa.int8(): "c",
pa.int16(): "s",
pa.int32(): "i",
pa.int64(): "l",
pa.uint8(): "C",
pa.uint16(): "S",
pa.uint32(): "I",
pa.uint64(): "L",
pa.float32(): "f",
pa.float64(): "g",
pa.timestamp("s"): "tss:",
pa.timestamp("ms"): "tsm:",
pa.timestamp("us"): "tsu:",
pa.timestamp("ns"): "tsn:",
}
if type not in to_pa_ctype:
raise TypeError(f"Invalid pyarrow type {type} for SparseNDArray")

timestamp_ms = context._open_timestamp_ms(tiledb_timestamp)
try:
clib.SOMASparseNDArray.create(
uri,
format=to_pa_ctype[type],
index_column_names=index_column_names,
domains=domains,
extents=extents,
ctx=context.native_context,
platform_config=plt_cfg,
timestamp=(0, timestamp_ms),
)
except SOMAError as e:
raise ValueError(e)

handle = cls._wrapper_type.open(uri, "w", context, tiledb_timestamp)
return cls(
handle,
_dont_call_this_use_create_or_open_instead="tiledbsoma-internal-code",
)

@property
def nnz(self) -> int:
"""
Expand Down Expand Up @@ -214,15 +309,23 @@ def write(
Experimental.
"""

arr = self._handle.writer
tiledb_create_options = TileDBCreateOptions.from_platform_config(
platform_config
)

if isinstance(values, pa.SparseCOOTensor):
# Write bulk data
data, coords = values.to_numpy()
arr[tuple(c for c in coords.T)] = data
self._handle._handle.write_ndarray(
[
np.array(
c,
dtype=self.schema.field(f"soma_dim_{i}").type.to_pandas_dtype(),
)
for i, c in enumerate(coords.T)
],
data,
)

# Write bounding-box metadata. Note COO can be N-dimensional.
maxes = [e - 1 for e in values.shape]
Expand All @@ -242,7 +345,16 @@ def write(
# Write bulk data
# TODO: the ``to_scipy`` function is not zero copy. Need to explore zero-copy options.
sp = values.to_scipy().tocoo()
arr[sp.row, sp.col] = sp.data
self._handle._handle.write_ndarray(
[
np.array(
c,
dtype=self.schema.field(f"soma_dim_{i}").type.to_pandas_dtype(),
)
for i, c in enumerate([sp.row, sp.col])
],
sp.data,
)

# Write bounding-box metadata. Note CSR and CSC are necessarily 2-dimensional.
nr, nc = values.shape
Expand All @@ -256,16 +368,12 @@ def write(

if isinstance(values, pa.Table):
# Write bulk data
data = values.column("soma_data").to_numpy()
coord_tbl = values.drop(["soma_data"])
coords = tuple(
coord_tbl.column(f"soma_dim_{n}").to_numpy()
for n in range(coord_tbl.num_columns)
)
arr[coords] = data
for batch in values.to_batches():
self._handle._handle.write(batch)

# Write bounding-box metadata
maxes = []
coord_tbl = values.drop(["soma_data"])
for i in range(coord_tbl.num_columns):
coords = values.column(f"soma_dim_{i}")
if coords:
Expand Down Expand Up @@ -348,15 +456,15 @@ def used_shape(self) -> Tuple[Tuple[int, int], ...]:
Compare this to ``shape`` which returns the available/writable capacity.
"""
retval = []
with tiledb.open(self.uri, ctx=self.context.tiledb_ctx):
for i in range(20):
lower_key = f"soma_dim_{i}_domain_lower"
lower_val = self.metadata.get(lower_key)
upper_key = f"soma_dim_{i}_domain_upper"
upper_val = self.metadata.get(upper_key)
if lower_val is None or upper_val is None:
break
retval.append((lower_val, upper_val))
# with tiledb.open(self.uri, ctx=self.context.tiledb_ctx):
for i in range(20):
lower_key = f"soma_dim_{i}_domain_lower"
lower_val = self.metadata.get(lower_key)
upper_key = f"soma_dim_{i}_domain_upper"
upper_val = self.metadata.get(upper_key)
if lower_val is None or upper_val is None:
break
retval.append((lower_val, upper_val))
if not retval:
raise SOMAError(
f"Array {self.uri} was not written with bounding box support. "
Expand Down
20 changes: 10 additions & 10 deletions apis/python/src/tiledbsoma/_tdb_handles.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,25 +94,25 @@ def open(
if not soma_type:
raise DoesNotExistError(f"{uri!r} does not exist")

if soma_type == "SOMADataFrame":
soma_type = soma_type.lower()

if soma_type == "somadataframe":
return DataFrameWrapper._from_soma_object(soma_object, context)
if open_mode == clib.OpenMode.read and soma_type == "SOMADenseNDArray":
if open_mode == clib.OpenMode.read and soma_type == "somadensendarray":
return DenseNDArrayWrapper._from_soma_object(soma_object, context)
if open_mode == clib.OpenMode.read and soma_type == "SOMASparseNDArray":
if soma_type == "somasparsendarray":
return SparseNDArrayWrapper._from_soma_object(soma_object, context)

if soma_type in (
"SOMADataFrame",
"SOMADenseNDArray",
"SOMASparseNDArray",
"somadensendarray",
"array",
):
return ArrayWrapper.open(uri, mode, context, timestamp)

if soma_type in (
"SOMACollection",
"SOMAExperiment",
"SOMAMeasurement",
"somacollection",
"somaexperiment",
"somameasurement",
"group",
):
return GroupWrapper.open(uri, mode, context, timestamp)
Expand Down Expand Up @@ -589,7 +589,7 @@ def _write(self) -> None:
# There were no changes (e.g., it's a read handle). Do nothing.
return
# Only try to get the writer if there are changes to be made.
if isinstance(self.owner, DataFrameWrapper):
if isinstance(self.owner, (DataFrameWrapper, SparseNDArrayWrapper)):
meta = self.owner.meta
for key, mod in self._mods.items():
if mod in (_DictMod.ADDED, _DictMod.UPDATED):
Expand Down
6 changes: 5 additions & 1 deletion apis/python/src/tiledbsoma/_tiledb_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,11 @@ def _create_internal(
schema: tiledb.ArraySchema,
context: SOMATileDBContext,
tiledb_timestamp: Optional[OpenTimestamp],
) -> Union[_tdb_handles.ArrayWrapper, _tdb_handles.DataFrameWrapper]:
) -> Union[
_tdb_handles.ArrayWrapper,
_tdb_handles.DataFrameWrapper,
_tdb_handles.SparseNDArrayWrapper,
]:
"""Creates the TileDB Array for this type and returns an opened handle.
This does the work of creating a TileDB Array with the provided schema
Expand Down
7 changes: 6 additions & 1 deletion apis/python/src/tiledbsoma/_tiledb_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class TileDBObject(somacore.SOMAObject, Generic[_WrapperType_co]):
_wrapper_type: Union[
Type[_WrapperType_co],
Type[_tdb_handles.DataFrameWrapper],
Type[_tdb_handles.SparseNDArrayWrapper],
]
"""Class variable of the Wrapper class used to open this object type."""

Expand Down Expand Up @@ -109,7 +110,11 @@ def open(

def __init__(
self,
handle: Union[_WrapperType_co, _tdb_handles.DataFrameWrapper],
handle: Union[
_WrapperType_co,
_tdb_handles.DataFrameWrapper,
_tdb_handles.SparseNDArrayWrapper,
],
*,
_dont_call_this_use_create_or_open_instead: str = "unset",
):
Expand Down
22 changes: 22 additions & 0 deletions apis/python/src/tiledbsoma/soma_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,26 @@ void write(SOMAArray& array, py::handle py_batch) {
}
}

void write_ndarray(
SOMAArray& array, std::vector<py::array> coords, py::array data) {
for (uint64_t i = 0; i < coords.size(); ++i) {
py::buffer_info coords_info = coords[i].request();
array.set_column_data(
"soma_dim_" + std::to_string(i),
coords[i].size(),
(const void*)coords_info.ptr);
}

py::buffer_info data_info = data.request();
array.set_column_data("soma_data", data.size(), (const void*)data_info.ptr);

try {
array.write();
} catch (const std::exception& e) {
TPY_ERROR_LOC(e.what());
}
}

py::dict meta(SOMAArray& array) {
py::dict results;

Expand Down Expand Up @@ -575,6 +595,8 @@ void load_soma_array(py::module& m) {

.def("write", write)

.def("write_ndarray", write_ndarray)

.def("nnz", &SOMAArray::nnz, py::call_guard<py::gil_scoped_release>())

.def_property_readonly("shape", &SOMAArray::shape)
Expand Down
21 changes: 15 additions & 6 deletions apis/python/src/tiledbsoma/soma_object.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,31 @@ void load_soma_object(py::module& m) {
auto soma_obj = SOMAObject::open(
uri, mode, context, timestamp, clib_type);
auto soma_obj_type = soma_obj->type();
if (soma_obj_type == "SOMADataFrame")

if (soma_obj_type.has_value()) {
std::transform(
soma_obj_type->begin(),
soma_obj_type->end(),
soma_obj_type->begin(),
[](unsigned char c) { return std::tolower(c); });
}

if (soma_obj_type == "somadataframe")
return py::cast(
dynamic_cast<SOMADataFrame&>(*soma_obj));
else if (soma_obj_type == "SOMASparseNDArray")
else if (soma_obj_type == "somasparsendarray")
return py::cast(
dynamic_cast<SOMASparseNDArray&>(*soma_obj));
else if (soma_obj_type == "SOMADenseNDArray")
else if (soma_obj_type == "somadensendarray")
return py::cast(
dynamic_cast<SOMADenseNDArray&>(*soma_obj));
else if (soma_obj_type == "SOMACollection")
else if (soma_obj_type == "somacollection")
return py::cast(
dynamic_cast<SOMACollection&>(*soma_obj));
else if (soma_obj_type == "SOMAExperiment")
else if (soma_obj_type == "somaexperiment")
return py::cast(
dynamic_cast<SOMAExperiment&>(*soma_obj));
else if (soma_obj_type == "SOMAMeasurement")
else if (soma_obj_type == "somameasurement")
return py::cast(
dynamic_cast<SOMAMeasurement&>(*soma_obj));
return py::none();
Expand Down
Loading

0 comments on commit 1223191

Please sign in to comment.