Skip to content

Commit

Permalink
[c++] Fix bug with sliced Arrow-table writes, with string columns (#3433
Browse files Browse the repository at this point in the history
)
  • Loading branch information
johnkerl authored Dec 13, 2024
1 parent 8d92cb5 commit e8ef0ac
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 23 deletions.
32 changes: 32 additions & 0 deletions apis/python/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1968,3 +1968,35 @@ def test_pass_configs(tmp_path, arrow_schema):
}
)
)


def test_arrow_table_sliced_writer(tmp_path):
"""Tests writes of sliced Arrow tables, with fixed-length and variable-length attributes"""
uri = tmp_path.as_posix()
num_rows = 50

schema = pa.schema(
[
("myint", pa.int32()),
("mystring", pa.large_string()),
]
)

pydict = {
"soma_joinid": list(range(num_rows)),
"myint": [(e + 1) * 10 for e in range(num_rows)],
"mystring": ["s_%08d" % e for e in range(num_rows)],
}
table = pa.Table.from_pydict(pydict)

domain = [[0, len(table) - 1]]

with soma.DataFrame.create(uri, schema=schema, domain=domain) as sdf:
mid = num_rows // 2
sdf.write(table[:mid])
sdf.write(table[mid:])

with soma.DataFrame.open(uri) as sdf:
pdf = sdf.read().concat().to_pandas()
assert list(pdf["myint"]) == pydict["myint"]
assert list(pdf["mystring"]) == pydict["mystring"]
54 changes: 32 additions & 22 deletions libtiledbsoma/src/soma/managed_query.cc
Original file line number Diff line number Diff line change
Expand Up @@ -966,35 +966,45 @@ bool ManagedQuery::_cast_column_aux<std::string>(
ArrowSchema* schema, ArrowArray* array, ArraySchemaEvolution se) {
(void)se; // se is unused in std::string specialization

const void* data = nullptr;
const void* offset = nullptr;
const void* validity = nullptr;

if (array->n_buffers == 3) {
data = array->buffers[2];
offset = array->buffers[1];
validity = array->buffers[0];
} else {
data = array->buffers[1];
offset = nullptr;
validity = array->buffers[0];
// A few things in play here:
// * Whether the column (array) has 3 buffers (validity, offset, data)
// or 2 (validity, data).
// * The data is always char* and the validity is always uint8*
// but the offsets are 32-bit or 64-bit.
// * The array-level offset might not be zero. (This happens
// when people pass of things like arrow_table[:n] or arrow_table[n:]
// from Python/R.)

if (array->n_buffers != 3) {
throw TileDBSOMAError(std::format(
"[ManagedQuery] internal error: Arrow-table string column should "
"have 3 buffers; got {}",
array->n_buffers));
}

const char* data = (const char*)array->buffers[2];
uint8_t* validity = (uint8_t*)array->buffers[0];

// If this is a table-slice, slice into the validity buffer.
if (validity != nullptr) {
validity += array->offset;
}
// If this is a table-slice, do *not* slice into the data
// buffer since it is indexed via offsets, which we slice
// into below.

if ((strcmp(schema->format, "U") == 0) ||
(strcmp(schema->format, "Z") == 0)) {
// If this is a table-slice, slice into the offsets buffer.
uint64_t* offset = (uint64_t*)array->buffers[1] + array->offset;
setup_write_column(
schema->name,
array->length,
(const void*)data,
(uint64_t*)offset,
(uint8_t*)validity);
schema->name, array->length, (const void*)data, offset, validity);

} else {
// If this is a table-slice, slice into the offsets buffer.
uint32_t* offset = (uint32_t*)array->buffers[1] + array->offset;
setup_write_column(
schema->name,
array->length,
(const void*)data,
(uint32_t*)offset,
(uint8_t*)validity);
schema->name, array->length, (const void*)data, offset, validity);
}
return false;
}
Expand Down
8 changes: 7 additions & 1 deletion libtiledbsoma/src/soma/managed_query.h
Original file line number Diff line number Diff line change
Expand Up @@ -614,11 +614,17 @@ class ManagedQuery {
// additional processing steps

UserType* buf;
// The array->offset is non-zero when we are passed sliced
// Arrow tables like arrow_table[:m] or arrow_table[m:].
if (array->n_buffers == 3) {
buf = (UserType*)array->buffers[2] + array->offset;
} else {
buf = (UserType*)array->buffers[1] + array->offset;
}
uint8_t* validity = (uint8_t*)array->buffers[0];
if (validity != nullptr) {
validity += array->offset;
}

bool has_attr = schema_->has_attribute(schema->name);
if (has_attr && attr_has_enum(schema->name)) {
Expand Down Expand Up @@ -646,7 +652,7 @@ class ManagedQuery {
casted_values.size(),
(const void*)casted_values.data(),
(uint64_t*)nullptr,
(uint8_t*)array->buffers[0]);
validity);

// Return false because we do not extend the enumeration
return false;
Expand Down

0 comments on commit e8ef0ac

Please sign in to comment.