diff --git a/apis/python/tests/test_dataframe.py b/apis/python/tests/test_dataframe.py index c7b956a583..eb654808ae 100644 --- a/apis/python/tests/test_dataframe.py +++ b/apis/python/tests/test_dataframe.py @@ -1968,3 +1968,35 @@ def test_pass_configs(tmp_path, arrow_schema): } ) ) + + +def test_arrow_table_sliced_writer(tmp_path): + """Tests writes of sliced Arrow tables, with fixed-length and variable-length attributes""" + uri = tmp_path.as_posix() + num_rows = 50 + + schema = pa.schema( + [ + ("myint", pa.int32()), + ("mystring", pa.large_string()), + ] + ) + + pydict = { + "soma_joinid": list(range(num_rows)), + "myint": [(e + 1) * 10 for e in range(num_rows)], + "mystring": ["s_%08d" % e for e in range(num_rows)], + } + table = pa.Table.from_pydict(pydict) + + domain = [[0, len(table) - 1]] + + with soma.DataFrame.create(uri, schema=schema, domain=domain) as sdf: + mid = num_rows // 2 + sdf.write(table[:mid]) + sdf.write(table[mid:]) + + with soma.DataFrame.open(uri) as sdf: + pdf = sdf.read().concat().to_pandas() + assert list(pdf["myint"]) == pydict["myint"] + assert list(pdf["mystring"]) == pydict["mystring"] diff --git a/libtiledbsoma/src/soma/managed_query.cc b/libtiledbsoma/src/soma/managed_query.cc index eaff167fd9..52f2d8213e 100644 --- a/libtiledbsoma/src/soma/managed_query.cc +++ b/libtiledbsoma/src/soma/managed_query.cc @@ -966,35 +966,45 @@ bool ManagedQuery::_cast_column_aux( ArrowSchema* schema, ArrowArray* array, ArraySchemaEvolution se) { (void)se; // se is unused in std::string specialization - const void* data = nullptr; - const void* offset = nullptr; - const void* validity = nullptr; - - if (array->n_buffers == 3) { - data = array->buffers[2]; - offset = array->buffers[1]; - validity = array->buffers[0]; - } else { - data = array->buffers[1]; - offset = nullptr; - validity = array->buffers[0]; + // A few things in play here: + // * Whether the column (array) has 3 buffers (validity, offset, data) + // or 2 (validity, data). + // * The data is always char* and the validity is always uint8* + // but the offsets are 32-bit or 64-bit. + // * The array-level offset might not be zero. (This happens + // when people pass of things like arrow_table[:n] or arrow_table[n:] + // from Python/R.) + + if (array->n_buffers != 3) { + throw TileDBSOMAError(std::format( + "[ManagedQuery] internal error: Arrow-table string column should " + "have 3 buffers; got {}", + array->n_buffers)); + } + + const char* data = (const char*)array->buffers[2]; + uint8_t* validity = (uint8_t*)array->buffers[0]; + + // If this is a table-slice, slice into the validity buffer. + if (validity != nullptr) { + validity += array->offset; } + // If this is a table-slice, do *not* slice into the data + // buffer since it is indexed via offsets, which we slice + // into below. if ((strcmp(schema->format, "U") == 0) || (strcmp(schema->format, "Z") == 0)) { + // If this is a table-slice, slice into the offsets buffer. + uint64_t* offset = (uint64_t*)array->buffers[1] + array->offset; setup_write_column( - schema->name, - array->length, - (const void*)data, - (uint64_t*)offset, - (uint8_t*)validity); + schema->name, array->length, (const void*)data, offset, validity); + } else { + // If this is a table-slice, slice into the offsets buffer. + uint32_t* offset = (uint32_t*)array->buffers[1] + array->offset; setup_write_column( - schema->name, - array->length, - (const void*)data, - (uint32_t*)offset, - (uint8_t*)validity); + schema->name, array->length, (const void*)data, offset, validity); } return false; } diff --git a/libtiledbsoma/src/soma/managed_query.h b/libtiledbsoma/src/soma/managed_query.h index 443642f9cb..be4f2332aa 100644 --- a/libtiledbsoma/src/soma/managed_query.h +++ b/libtiledbsoma/src/soma/managed_query.h @@ -614,11 +614,17 @@ class ManagedQuery { // additional processing steps UserType* buf; + // The array->offset is non-zero when we are passed sliced + // Arrow tables like arrow_table[:m] or arrow_table[m:]. if (array->n_buffers == 3) { buf = (UserType*)array->buffers[2] + array->offset; } else { buf = (UserType*)array->buffers[1] + array->offset; } + uint8_t* validity = (uint8_t*)array->buffers[0]; + if (validity != nullptr) { + validity += array->offset; + } bool has_attr = schema_->has_attribute(schema->name); if (has_attr && attr_has_enum(schema->name)) { @@ -646,7 +652,7 @@ class ManagedQuery { casted_values.size(), (const void*)casted_values.data(), (uint64_t*)nullptr, - (uint8_t*)array->buffers[0]); + validity); // Return false because we do not extend the enumeration return false;