From 42435376d7d06f07bbd0fbf26bcf7663d5c58e5f Mon Sep 17 00:00:00 2001 From: Vivian Nguyen Date: Mon, 1 Apr 2024 23:06:51 -0500 Subject: [PATCH] WIP update enumeration index values when extending --- apis/python/src/tiledbsoma/_collection.py | 2 +- apis/python/src/tiledbsoma/_dataframe.py | 46 ++++++--- apis/python/src/tiledbsoma/_tdb_handles.py | 4 +- apis/python/src/tiledbsoma/io/ingest.py | 1 + apis/python/src/tiledbsoma/soma_array.cc | 107 +++++++++++++++------ apis/python/tests/test_query_condition.py | 2 +- libtiledbsoma/src/soma/soma_array.cc | 3 +- libtiledbsoma/src/soma/soma_array.h | 3 +- 8 files changed, 120 insertions(+), 48 deletions(-) diff --git a/apis/python/src/tiledbsoma/_collection.py b/apis/python/src/tiledbsoma/_collection.py index e110374d3d..f3886b96d6 100644 --- a/apis/python/src/tiledbsoma/_collection.py +++ b/apis/python/src/tiledbsoma/_collection.py @@ -436,7 +436,7 @@ def __getitem__(self, key: str) -> CollectionElementType: wrapper = _tdb_handles.open(uri, mode, context, timestamp) entry.soma = _factory.reify_handle(wrapper) - + # Since we just opened this object, we own it and should close it. self._close_stack.enter_context(entry.soma) return cast(CollectionElementType, entry.soma) diff --git a/apis/python/src/tiledbsoma/_dataframe.py b/apis/python/src/tiledbsoma/_dataframe.py index 4abe91dc63..6c3f605064 100644 --- a/apis/python/src/tiledbsoma/_dataframe.py +++ b/apis/python/src/tiledbsoma/_dataframe.py @@ -9,6 +9,7 @@ from typing import Any, Optional, Sequence, Tuple, Type, Union, cast import numpy as np +import pandas as pd import pyarrow as pa import somacore import tiledb @@ -243,7 +244,7 @@ def create( domains = pa.StructArray.from_arrays(domains, names=index_column_names) extents = pa.StructArray.from_arrays(extents, names=index_column_names) - + plt_cfg = None if platform_config: ops = TileDBCreateOptions.from_platform_config(platform_config) @@ -254,14 +255,18 @@ def create( plt_cfg.goal_chunk_nnz = ops.goal_chunk_nnz plt_cfg.capacity = ops.capacity if ops.offsets_filters: - plt_cfg.offsets_filters = [info["_type"] for info in ops.offsets_filters] + plt_cfg.offsets_filters = [ + info["_type"] for info in ops.offsets_filters + ] if ops.validity_filters: - plt_cfg.validity_filters = [info["_type"] for info in ops.validity_filters] + plt_cfg.validity_filters = [ + info["_type"] for info in ops.validity_filters + ] plt_cfg.allows_duplicates = ops.allows_duplicates plt_cfg.tile_order = ops.tile_order plt_cfg.cell_order = ops.cell_order plt_cfg.consolidate_and_vacuum = ops.consolidate_and_vacuum - + # TODO add as kw args clib.SOMADataFrame.create( uri, @@ -455,18 +460,31 @@ def write( _util.check_type("values", values, (pa.Table,)) target_schema = [] - for input_field in values.schema: - target_field = self.schema.field(input_field.name) + for i, input_field in enumerate(values.schema): + name = input_field.name + target_field = self.schema.field(name) if pa.types.is_dictionary(target_field.type): if not pa.types.is_dictionary(input_field.type): - raise ValueError(f"{input_field.name} requires dictionary entry") - # extend enums in array schema as necessary - # get evolved enums - col = values.column(input_field.name).combine_chunks() - new_enums = self._handle._handle.extend_enumeration(col) - print(new_enums) - # cast that in table + raise ValueError(f"{name} requires dictionary entry") + col = values.column(name).combine_chunks() + new_enmr = self._handle._handle.extend_enumeration(name, col) + + if pa.types.is_binary( + target_field.type.value_type + ) or pa.types.is_large_binary(target_field.type.value_type): + new_enmr = np.array(new_enmr, "S") + elif pa.types.is_boolean(target_field.type.value_type): + new_enmr = np.array(new_enmr, bool) + + df = pd.Categorical( + col.to_pandas(), + ordered=target_field.type.ordered, + categories=new_enmr, + ) + values = values.set_column( + i, name, pa.DictionaryArray.from_pandas(df, type=target_field.type) + ) if pa.types.is_boolean(input_field.type): target_schema.append(target_field.with_type(pa.uint8())) @@ -476,7 +494,7 @@ def write( for batch in values.to_batches(): self._handle.write(batch) - + tiledb_create_options = TileDBCreateOptions.from_platform_config( platform_config ) diff --git a/apis/python/src/tiledbsoma/_tdb_handles.py b/apis/python/src/tiledbsoma/_tdb_handles.py index fb635edaa3..5968391bea 100644 --- a/apis/python/src/tiledbsoma/_tdb_handles.py +++ b/apis/python/src/tiledbsoma/_tdb_handles.py @@ -47,7 +47,7 @@ def open( uri: str, mode: options.OpenMode, context: SOMATileDBContext, - timestamp: Optional[OpenTimestamp] + timestamp: Optional[OpenTimestamp], ) -> "Wrapper[RawHandle]": """Determine whether the URI is an array or group, and open it.""" open_mode = clib.OpenMode.read if mode == "r" else clib.OpenMode.write @@ -68,7 +68,7 @@ def open( if not obj_type: raise DoesNotExistError(f"{uri!r} does not exist") - + if obj_type == "SOMADataFrame": return DataFrameWrapper._from_soma_object(soma_object, context) if open_mode == clib.OpenMode.read and obj_type == "SOMADenseNDArray": diff --git a/apis/python/src/tiledbsoma/io/ingest.py b/apis/python/src/tiledbsoma/io/ingest.py index a05a9385be..48e2906ac0 100644 --- a/apis/python/src/tiledbsoma/io/ingest.py +++ b/apis/python/src/tiledbsoma/io/ingest.py @@ -1219,6 +1219,7 @@ def _write_arrow_table( ) handle.write(arrow_table) + def _write_dataframe( df_uri: str, df: pd.DataFrame, diff --git a/apis/python/src/tiledbsoma/soma_array.cc b/apis/python/src/tiledbsoma/soma_array.cc index 2de72d0385..fe3d7e8a6f 100644 --- a/apis/python/src/tiledbsoma/soma_array.cc +++ b/apis/python/src/tiledbsoma/soma_array.cc @@ -762,7 +762,9 @@ void load_soma_array(py::module& m) { .def( "extend_enumeration", - [](SOMAArray& array, py::handle py_batch) -> py::object { + [](SOMAArray& array, + std::string attr_name, + py::handle py_batch) -> py::array { ArrowSchema arrow_schema; ArrowArray arrow_array; uintptr_t arrow_schema_ptr = (uintptr_t)(&arrow_schema); @@ -782,38 +784,87 @@ void load_soma_array(py::module& m) { if (dict->length != 0) { auto new_enmr = array.extend_enumeration( - arrow_schema.name, - dict->length, - enmr_data, - enmr_offsets); + attr_name, dict->length, enmr_data, enmr_offsets); auto emdr_format = arrow_schema.dictionary->format; switch (ArrowAdapter::to_tiledb_format(emdr_format)) { case TILEDB_STRING_ASCII: - case TILEDB_STRING_UTF8: case TILEDB_CHAR: - return py::cast(new_enmr.as_vector()); + case TILEDB_STRING_UTF8: { + auto result = new_enmr.as_vector(); + return py::array(py::cast(result)); + } case TILEDB_BOOL: - case TILEDB_INT8: - return py::cast(new_enmr.as_vector()); - case TILEDB_UINT8: - return py::cast(new_enmr.as_vector()); - case TILEDB_INT16: - return py::cast(new_enmr.as_vector()); - case TILEDB_UINT16: - return py::cast(new_enmr.as_vector()); - case TILEDB_INT32: - return py::cast(new_enmr.as_vector()); - case TILEDB_UINT32: - return py::cast(new_enmr.as_vector()); - case TILEDB_INT64: - return py::cast(new_enmr.as_vector()); - case TILEDB_UINT64: - return py::cast(new_enmr.as_vector()); - case TILEDB_FLOAT32: - return py::cast(new_enmr.as_vector()); - case TILEDB_FLOAT64: - return py::cast(new_enmr.as_vector()); + case TILEDB_INT8: { + auto result = new_enmr.as_vector(); + return py::array( + py::dtype("int8"), + result.size(), + result.data()); + } + case TILEDB_UINT8: { + auto result = new_enmr.as_vector(); + return py::array( + py::dtype("uint8"), + result.size(), + result.data()); + } + case TILEDB_INT16: { + auto result = new_enmr.as_vector(); + return py::array( + py::dtype("int16"), + result.size(), + result.data()); + } + case TILEDB_UINT16: { + auto result = new_enmr.as_vector(); + return py::array( + py::dtype("uint16"), + result.size(), + result.data()); + } + case TILEDB_INT32: { + auto result = new_enmr.as_vector(); + return py::array( + py::dtype("int32"), + result.size(), + result.data()); + } + case TILEDB_UINT32: { + auto result = new_enmr.as_vector(); + return py::array( + py::dtype("uint32"), + result.size(), + result.data()); + } + case TILEDB_INT64: { + auto result = new_enmr.as_vector(); + return py::array( + py::dtype("int64"), + result.size(), + result.data()); + } + case TILEDB_UINT64: { + auto result = new_enmr.as_vector(); + return py::array( + py::dtype("uint64"), + result.size(), + result.data()); + } + case TILEDB_FLOAT32: { + auto result = new_enmr.as_vector(); + return py::array( + py::dtype("float32"), + result.size(), + result.data()); + } + case TILEDB_FLOAT64: { + auto result = new_enmr.as_vector(); + return py::array( + py::dtype("float64"), + result.size(), + result.data()); + } default: throw TileDBSOMAError( "extend_enumeration: Unsupported dict " @@ -821,7 +872,7 @@ void load_soma_array(py::module& m) { } } else { - return py::cast(std::vector()); + return py::array(); } }) diff --git a/apis/python/tests/test_query_condition.py b/apis/python/tests/test_query_condition.py index 7504743ecf..53b8c494c6 100644 --- a/apis/python/tests/test_query_condition.py +++ b/apis/python/tests/test_query_condition.py @@ -31,7 +31,7 @@ def soma_query(uri, condition): sr.set_condition(qc, sr.schema) arrow_table = sr.read_next() assert sr.results_complete() - + return arrow_table diff --git a/libtiledbsoma/src/soma/soma_array.cc b/libtiledbsoma/src/soma/soma_array.cc index 5cbdafab4a..f25ebb6710 100644 --- a/libtiledbsoma/src/soma/soma_array.cc +++ b/libtiledbsoma/src/soma/soma_array.cc @@ -356,9 +356,10 @@ Enumeration SOMAArray::extend_enumeration( ArraySchemaEvolution se(*ctx_->tiledb_ctx()); se.extend_enumeration(enmr.extend(extend_values)); se.array_evolve(uri_); + return enmr.extend(extend_values); } - return enmr.extend(extend_values); + return enmr; } case TILEDB_BOOL: case TILEDB_INT8: diff --git a/libtiledbsoma/src/soma/soma_array.h b/libtiledbsoma/src/soma/soma_array.h index 76e914c8c9..2e56a2cf64 100644 --- a/libtiledbsoma/src/soma/soma_array.h +++ b/libtiledbsoma/src/soma/soma_array.h @@ -731,9 +731,10 @@ class SOMAArray : public SOMAObject { ArraySchemaEvolution se(*ctx_->tiledb_ctx()); se.extend_enumeration(enmr.extend(extend_values)); se.array_evolve(uri_); + return enmr.extend(extend_values); } - return enmr.extend(extend_values); + return enmr; } // Fills the metadata cache upon opening the array.