Skip to content

Commit

Permalink
WIP update enumeration index values when extending
Browse files Browse the repository at this point in the history
  • Loading branch information
nguyenv committed Apr 2, 2024
1 parent 6f1f07c commit 4243537
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 48 deletions.
2 changes: 1 addition & 1 deletion apis/python/src/tiledbsoma/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def __getitem__(self, key: str) -> CollectionElementType:

wrapper = _tdb_handles.open(uri, mode, context, timestamp)
entry.soma = _factory.reify_handle(wrapper)

# Since we just opened this object, we own it and should close it.
self._close_stack.enter_context(entry.soma)
return cast(CollectionElementType, entry.soma)
Expand Down
46 changes: 32 additions & 14 deletions apis/python/src/tiledbsoma/_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Optional, Sequence, Tuple, Type, Union, cast

import numpy as np
import pandas as pd
import pyarrow as pa
import somacore
import tiledb
Expand Down Expand Up @@ -243,7 +244,7 @@ def create(

domains = pa.StructArray.from_arrays(domains, names=index_column_names)
extents = pa.StructArray.from_arrays(extents, names=index_column_names)

plt_cfg = None
if platform_config:
ops = TileDBCreateOptions.from_platform_config(platform_config)
Expand All @@ -254,14 +255,18 @@ def create(
plt_cfg.goal_chunk_nnz = ops.goal_chunk_nnz
plt_cfg.capacity = ops.capacity
if ops.offsets_filters:
plt_cfg.offsets_filters = [info["_type"] for info in ops.offsets_filters]
plt_cfg.offsets_filters = [
info["_type"] for info in ops.offsets_filters
]
if ops.validity_filters:
plt_cfg.validity_filters = [info["_type"] for info in ops.validity_filters]
plt_cfg.validity_filters = [
info["_type"] for info in ops.validity_filters
]
plt_cfg.allows_duplicates = ops.allows_duplicates
plt_cfg.tile_order = ops.tile_order
plt_cfg.cell_order = ops.cell_order
plt_cfg.consolidate_and_vacuum = ops.consolidate_and_vacuum

# TODO add as kw args
clib.SOMADataFrame.create(
uri,
Expand Down Expand Up @@ -455,18 +460,31 @@ def write(
_util.check_type("values", values, (pa.Table,))

target_schema = []
for input_field in values.schema:
target_field = self.schema.field(input_field.name)
for i, input_field in enumerate(values.schema):
name = input_field.name
target_field = self.schema.field(name)

if pa.types.is_dictionary(target_field.type):
if not pa.types.is_dictionary(input_field.type):
raise ValueError(f"{input_field.name} requires dictionary entry")
# extend enums in array schema as necessary
# get evolved enums
col = values.column(input_field.name).combine_chunks()
new_enums = self._handle._handle.extend_enumeration(col)
print(new_enums)
# cast that in table
raise ValueError(f"{name} requires dictionary entry")
col = values.column(name).combine_chunks()
new_enmr = self._handle._handle.extend_enumeration(name, col)

if pa.types.is_binary(
target_field.type.value_type
) or pa.types.is_large_binary(target_field.type.value_type):
new_enmr = np.array(new_enmr, "S")
elif pa.types.is_boolean(target_field.type.value_type):
new_enmr = np.array(new_enmr, bool)

df = pd.Categorical(
col.to_pandas(),
ordered=target_field.type.ordered,
categories=new_enmr,
)
values = values.set_column(
i, name, pa.DictionaryArray.from_pandas(df, type=target_field.type)
)

if pa.types.is_boolean(input_field.type):
target_schema.append(target_field.with_type(pa.uint8()))
Expand All @@ -476,7 +494,7 @@ def write(

for batch in values.to_batches():
self._handle.write(batch)

tiledb_create_options = TileDBCreateOptions.from_platform_config(
platform_config
)
Expand Down
4 changes: 2 additions & 2 deletions apis/python/src/tiledbsoma/_tdb_handles.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def open(
uri: str,
mode: options.OpenMode,
context: SOMATileDBContext,
timestamp: Optional[OpenTimestamp]
timestamp: Optional[OpenTimestamp],
) -> "Wrapper[RawHandle]":
"""Determine whether the URI is an array or group, and open it."""
open_mode = clib.OpenMode.read if mode == "r" else clib.OpenMode.write
Expand All @@ -68,7 +68,7 @@ def open(

if not obj_type:
raise DoesNotExistError(f"{uri!r} does not exist")

if obj_type == "SOMADataFrame":
return DataFrameWrapper._from_soma_object(soma_object, context)
if open_mode == clib.OpenMode.read and obj_type == "SOMADenseNDArray":
Expand Down
1 change: 1 addition & 0 deletions apis/python/src/tiledbsoma/io/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,7 @@ def _write_arrow_table(
)
handle.write(arrow_table)


def _write_dataframe(
df_uri: str,
df: pd.DataFrame,
Expand Down
107 changes: 79 additions & 28 deletions apis/python/src/tiledbsoma/soma_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,9 @@ void load_soma_array(py::module& m) {

.def(
"extend_enumeration",
[](SOMAArray& array, py::handle py_batch) -> py::object {
[](SOMAArray& array,
std::string attr_name,
py::handle py_batch) -> py::array {
ArrowSchema arrow_schema;
ArrowArray arrow_array;
uintptr_t arrow_schema_ptr = (uintptr_t)(&arrow_schema);
Expand All @@ -782,46 +784,95 @@ void load_soma_array(py::module& m) {

if (dict->length != 0) {
auto new_enmr = array.extend_enumeration(
arrow_schema.name,
dict->length,
enmr_data,
enmr_offsets);
attr_name, dict->length, enmr_data, enmr_offsets);

auto emdr_format = arrow_schema.dictionary->format;
switch (ArrowAdapter::to_tiledb_format(emdr_format)) {
case TILEDB_STRING_ASCII:
case TILEDB_STRING_UTF8:
case TILEDB_CHAR:
return py::cast(new_enmr.as_vector<std::string>());
case TILEDB_STRING_UTF8: {
auto result = new_enmr.as_vector<std::string>();
return py::array(py::cast(result));
}
case TILEDB_BOOL:
case TILEDB_INT8:
return py::cast(new_enmr.as_vector<int8_t>());
case TILEDB_UINT8:
return py::cast(new_enmr.as_vector<uint8_t>());
case TILEDB_INT16:
return py::cast(new_enmr.as_vector<int16_t>());
case TILEDB_UINT16:
return py::cast(new_enmr.as_vector<uint16_t>());
case TILEDB_INT32:
return py::cast(new_enmr.as_vector<int32_t>());
case TILEDB_UINT32:
return py::cast(new_enmr.as_vector<uint32_t>());
case TILEDB_INT64:
return py::cast(new_enmr.as_vector<int64_t>());
case TILEDB_UINT64:
return py::cast(new_enmr.as_vector<uint64_t>());
case TILEDB_FLOAT32:
return py::cast(new_enmr.as_vector<float>());
case TILEDB_FLOAT64:
return py::cast(new_enmr.as_vector<double>());
case TILEDB_INT8: {
auto result = new_enmr.as_vector<int8_t>();
return py::array(
py::dtype("int8"),
result.size(),
result.data());
}
case TILEDB_UINT8: {
auto result = new_enmr.as_vector<uint8_t>();
return py::array(
py::dtype("uint8"),
result.size(),
result.data());
}
case TILEDB_INT16: {
auto result = new_enmr.as_vector<int16_t>();
return py::array(
py::dtype("int16"),
result.size(),
result.data());
}
case TILEDB_UINT16: {
auto result = new_enmr.as_vector<uint16_t>();
return py::array(
py::dtype("uint16"),
result.size(),
result.data());
}
case TILEDB_INT32: {
auto result = new_enmr.as_vector<int32_t>();
return py::array(
py::dtype("int32"),
result.size(),
result.data());
}
case TILEDB_UINT32: {
auto result = new_enmr.as_vector<uint32_t>();
return py::array(
py::dtype("uint32"),
result.size(),
result.data());
}
case TILEDB_INT64: {
auto result = new_enmr.as_vector<int64_t>();
return py::array(
py::dtype("int64"),
result.size(),
result.data());
}
case TILEDB_UINT64: {
auto result = new_enmr.as_vector<uint64_t>();
return py::array(
py::dtype("uint64"),
result.size(),
result.data());
}
case TILEDB_FLOAT32: {
auto result = new_enmr.as_vector<float>();
return py::array(
py::dtype("float32"),
result.size(),
result.data());
}
case TILEDB_FLOAT64: {
auto result = new_enmr.as_vector<double>();
return py::array(
py::dtype("float64"),
result.size(),
result.data());
}
default:
throw TileDBSOMAError(
"extend_enumeration: Unsupported dict "
"datatype");
}

} else {
return py::cast(std::vector<std::string>());
return py::array();
}
})

Expand Down
2 changes: 1 addition & 1 deletion apis/python/tests/test_query_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def soma_query(uri, condition):
sr.set_condition(qc, sr.schema)
arrow_table = sr.read_next()
assert sr.results_complete()

return arrow_table


Expand Down
3 changes: 2 additions & 1 deletion libtiledbsoma/src/soma/soma_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -356,9 +356,10 @@ Enumeration SOMAArray::extend_enumeration(
ArraySchemaEvolution se(*ctx_->tiledb_ctx());
se.extend_enumeration(enmr.extend(extend_values));
se.array_evolve(uri_);
return enmr.extend(extend_values);
}

return enmr.extend(extend_values);
return enmr;
}
case TILEDB_BOOL:
case TILEDB_INT8:
Expand Down
3 changes: 2 additions & 1 deletion libtiledbsoma/src/soma/soma_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -731,9 +731,10 @@ class SOMAArray : public SOMAObject {
ArraySchemaEvolution se(*ctx_->tiledb_ctx());
se.extend_enumeration(enmr.extend(extend_values));
se.array_evolve(uri_);
return enmr.extend(extend_values);
}

return enmr.extend(extend_values);
return enmr;
}

// Fills the metadata cache upon opening the array.
Expand Down

0 comments on commit 4243537

Please sign in to comment.