Skip to content

Commit

Permalink
feat(c/driver/postgresql): Accept bulk ingest of dictionary-encoded s…
Browse files Browse the repository at this point in the history
…trings/binary (apache#1275)

This PR adds the ability for the Postgres driver to ingest
dictionary-encoded arrays. This shows up in R because factors are
relatively common and encode by default in Arrow to dictionary-encoded
string for performance reasons.

Reprex in R:

``` r
library(adbcdrivermanager)

con <- adbcpostgresql::adbcpostgresql() |> 
  adbc_database_init(uri = Sys.getenv("ADBC_POSTGRESQL_TEST_URI")) |> 
  adbc_connection_init()

df <- data.frame(x = letters, y = factor(letters))
write_adbc(df, con, "some_table")
#> Error in adbc_statement_execute_query(stmt): [libpq] Failed to create table: ERROR:  relation "some_table" already exists
#> 
#> Query was: CREATE TABLE "public" . "some_table" ("x" TEXT, "y" TEXT)
read_adbc(con, "SELECT * from some_table") |> 
  as.data.frame() |> 
  str()
#> 'data.frame':    26 obs. of  2 variables:
#>  $ x: chr  "a" "b" "c" "d" ...
#>  $ y: chr  "a" "b" "c" "d" ...
```

<sup>Created on 2023-11-09 with [reprex
v2.0.2](https://reprex.tidyverse.org)</sup>

There is probably some opportunity to consolidate some of the code that
currently lives in the `BindStream` into the `PostgresType` and/or
`PostgresTypeResolver`...I'm happy to poke away at that at some point
but in the meantime it seemed like it wasn't too onerous to tack on
dictionary support here.
  • Loading branch information
paleolimbot authored and vleslief-ms committed Nov 9, 2023
1 parent c0f47f9 commit 5a76d63
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 46 deletions.
116 changes: 73 additions & 43 deletions c/driver/postgresql/postgres_copy_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -1231,13 +1231,13 @@ class PostgresCopyDurationFieldWriter : public PostgresCopyFieldWriter {
switch (TU) {
case NANOARROW_TIME_UNIT_SECOND:
if ((overflow_safe = raw_value <= kMaxSafeSecondsToMicros &&
raw_value >= kMinSafeSecondsToMicros)) {
raw_value >= kMinSafeSecondsToMicros)) {
value = raw_value * 1000000;
}
break;
case NANOARROW_TIME_UNIT_MILLI:
if ((overflow_safe = raw_value <= kMaxSafeMillisToMicros &&
raw_value >= kMinSafeMillisToMicros)) {
raw_value >= kMinSafeMillisToMicros)) {
value = raw_value * 1000;
}
break;
Expand All @@ -1251,11 +1251,8 @@ class PostgresCopyDurationFieldWriter : public PostgresCopyFieldWriter {

if (!overflow_safe) {
ArrowErrorSet(
error,
"Row %" PRId64 " duration value %" PRId64 " with unit %d would overflow",
index,
raw_value,
TU);
error, "Row %" PRId64 " duration value %" PRId64 " with unit %d would overflow",
index, raw_value, TU);
return ADBC_STATUS_INVALID_ARGUMENT;
}

Expand All @@ -1273,8 +1270,7 @@ class PostgresCopyDurationFieldWriter : public PostgresCopyFieldWriter {
class PostgresCopyBinaryFieldWriter : public PostgresCopyFieldWriter {
public:
ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) override {
struct ArrowBufferView buffer_view =
ArrowArrayViewGetBytesUnsafe(array_view_, index);
struct ArrowBufferView buffer_view = ArrowArrayViewGetBytesUnsafe(array_view_, index);
NANOARROW_RETURN_NOT_OK(WriteChecked<int32_t>(buffer, buffer_view.size_bytes, error));
NANOARROW_RETURN_NOT_OK(
ArrowBufferAppend(buffer, buffer_view.data.as_uint8, buffer_view.size_bytes));
Expand All @@ -1283,6 +1279,26 @@ class PostgresCopyBinaryFieldWriter : public PostgresCopyFieldWriter {
}
};

class PostgresCopyBinaryDictFieldWriter : public PostgresCopyFieldWriter {
public:
ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) override {
int64_t dict_index = ArrowArrayViewGetIntUnsafe(array_view_, index);
if (ArrowArrayViewIsNull(array_view_->dictionary, dict_index)) {
constexpr int32_t field_size_bytes = -1;
NANOARROW_RETURN_NOT_OK(WriteChecked<int32_t>(buffer, field_size_bytes, error));
} else {
struct ArrowBufferView buffer_view =
ArrowArrayViewGetBytesUnsafe(array_view_->dictionary, dict_index);
NANOARROW_RETURN_NOT_OK(
WriteChecked<int32_t>(buffer, buffer_view.size_bytes, error));
NANOARROW_RETURN_NOT_OK(
ArrowBufferAppend(buffer, buffer_view.data.as_uint8, buffer_view.size_bytes));
}

return ADBC_STATUS_OK;
}
};

template <enum ArrowTimeUnit TU>
class PostgresCopyTimestampFieldWriter : public PostgresCopyFieldWriter {
public:
Expand All @@ -1297,13 +1313,13 @@ class PostgresCopyTimestampFieldWriter : public PostgresCopyFieldWriter {
switch (TU) {
case NANOARROW_TIME_UNIT_SECOND:
if ((overflow_safe = raw_value <= kMaxSafeSecondsToMicros &&
raw_value >= kMinSafeSecondsToMicros)) {
raw_value >= kMinSafeSecondsToMicros)) {
value = raw_value * 1000000;
}
break;
case NANOARROW_TIME_UNIT_MILLI:
if ((overflow_safe = raw_value <= kMaxSafeMillisToMicros &&
raw_value >= kMinSafeMillisToMicros)) {
raw_value >= kMinSafeMillisToMicros)) {
value = raw_value * 1000;
}
break;
Expand All @@ -1316,12 +1332,10 @@ class PostgresCopyTimestampFieldWriter : public PostgresCopyFieldWriter {
}

if (!overflow_safe) {
ArrowErrorSet(
error,
"Row %" PRId64 " timestamp value %" PRId64 " with unit %d would overflow",
index,
raw_value,
TU);
ArrowErrorSet(error,
"Row %" PRId64 " timestamp value %" PRId64
" with unit %d would overflow",
index, raw_value, TU);
return ADBC_STATUS_INVALID_ARGUMENT;
}

Expand All @@ -1334,9 +1348,12 @@ class PostgresCopyTimestampFieldWriter : public PostgresCopyFieldWriter {
}
};

static inline ArrowErrorCode MakeCopyFieldWriter(
const struct ArrowSchemaView& schema_view, PostgresCopyFieldWriter** out,
ArrowError* error) {
static inline ArrowErrorCode MakeCopyFieldWriter(struct ArrowSchema* schema,
PostgresCopyFieldWriter** out,
ArrowError* error) {
struct ArrowSchemaView schema_view;
NANOARROW_RETURN_NOT_OK(ArrowSchemaViewInit(&schema_view, schema, error));

switch (schema_view.type) {
case NANOARROW_TYPE_BOOL:
*out = new PostgresCopyBooleanFieldWriter();
Expand Down Expand Up @@ -1368,21 +1385,21 @@ static inline ArrowErrorCode MakeCopyFieldWriter(
*out = new PostgresCopyBinaryFieldWriter();
return NANOARROW_OK;
case NANOARROW_TYPE_TIMESTAMP: {
switch (schema_view.time_unit) {
case NANOARROW_TIME_UNIT_NANO:
*out = new PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_NANO>();
break;
case NANOARROW_TIME_UNIT_MILLI:
*out = new PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_MILLI>();
break;
case NANOARROW_TIME_UNIT_MICRO:
*out = new PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_MICRO>();
break;
case NANOARROW_TIME_UNIT_SECOND:
*out = new PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_SECOND>();
break;
}
return NANOARROW_OK;
switch (schema_view.time_unit) {
case NANOARROW_TIME_UNIT_NANO:
*out = new PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_NANO>();
break;
case NANOARROW_TIME_UNIT_MILLI:
*out = new PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_MILLI>();
break;
case NANOARROW_TIME_UNIT_MICRO:
*out = new PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_MICRO>();
break;
case NANOARROW_TIME_UNIT_SECOND:
*out = new PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_SECOND>();
break;
}
return NANOARROW_OK;
}
case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO:
*out = new PostgresCopyIntervalFieldWriter();
Expand All @@ -1405,10 +1422,27 @@ static inline ArrowErrorCode MakeCopyFieldWriter(
}
return NANOARROW_OK;
}
case NANOARROW_TYPE_DICTIONARY: {
struct ArrowSchemaView value_view;
NANOARROW_RETURN_NOT_OK(
ArrowSchemaViewInit(&value_view, schema->dictionary, error));
switch (value_view.type) {
case NANOARROW_TYPE_BINARY:
case NANOARROW_TYPE_STRING:
case NANOARROW_TYPE_LARGE_BINARY:
case NANOARROW_TYPE_LARGE_STRING:
*out = new PostgresCopyBinaryDictFieldWriter();
return NANOARROW_OK;
default:
break;
}
}
default:
ArrowErrorSet(error, "COPY Writer not implemented for type %d", schema_view.type);
return EINVAL;
break;
}

ArrowErrorSet(error, "COPY Writer not implemented for type %d", schema_view.type);
return EINVAL;
}

class PostgresCopyStreamWriter {
Expand Down Expand Up @@ -1450,13 +1484,9 @@ class PostgresCopyStreamWriter {
}

for (int64_t i = 0; i < schema_->n_children; i++) {
struct ArrowSchemaView schema_view;
if (ArrowSchemaViewInit(&schema_view, schema_->children[i], error) !=
NANOARROW_OK) {
return ADBC_STATUS_INTERNAL;
}
PostgresCopyFieldWriter* child_writer = nullptr;
NANOARROW_RETURN_NOT_OK(MakeCopyFieldWriter(schema_view, &child_writer, error));
NANOARROW_RETURN_NOT_OK(
MakeCopyFieldWriter(schema_->children[i], &child_writer, error));
root_writer_.AppendChild(std::unique_ptr<PostgresCopyFieldWriter>(child_writer));
}

Expand Down
3 changes: 3 additions & 0 deletions c/driver/postgresql/postgres_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,9 @@ inline ArrowErrorCode PostgresType::FromSchema(const PostgresTypeResolver& resol
PostgresType::FromSchema(resolver, schema->children[0], &child, error));
return resolver.FindArray(child.oid(), out, error);
}
case NANOARROW_TYPE_DICTIONARY:
// Dictionary arrays always resolve to the dictionary type when binding or ingesting
return PostgresType::FromSchema(resolver, schema->dictionary, out, error);

default:
ArrowErrorSet(error, "Can't map Arrow type '%s' to Postgres type",
Expand Down
9 changes: 9 additions & 0 deletions c/driver/postgresql/postgres_type_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,15 @@ TEST(PostgresTypeTest, PostgresTypeFromSchema) {
EXPECT_EQ(type.child(0).type_id(), PostgresTypeId::kBool);
schema.reset();

ASSERT_EQ(ArrowSchemaInitFromType(schema.get(), NANOARROW_TYPE_INT64), NANOARROW_OK);
ASSERT_EQ(ArrowSchemaAllocateDictionary(schema.get()), NANOARROW_OK);
ASSERT_EQ(ArrowSchemaInitFromType(schema->dictionary, NANOARROW_TYPE_STRING),
NANOARROW_OK);
EXPECT_EQ(PostgresType::FromSchema(resolver, schema.get(), &type, nullptr),
NANOARROW_OK);
EXPECT_EQ(type.type_id(), PostgresTypeId::kText);
schema.reset();

ArrowError error;
ASSERT_EQ(ArrowSchemaInitFromType(schema.get(), NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO),
NANOARROW_OK);
Expand Down
1 change: 0 additions & 1 deletion c/driver/postgresql/postgresql_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,6 @@ class PostgresStatementTest : public ::testing::Test,
void TestSqlIngestUInt16() { GTEST_SKIP() << "Not implemented"; }
void TestSqlIngestUInt32() { GTEST_SKIP() << "Not implemented"; }
void TestSqlIngestUInt64() { GTEST_SKIP() << "Not implemented"; }
void TestSqlIngestStringDictionary() { GTEST_SKIP() << "Not implemented"; }

void TestSqlPrepareErrorParamCountMismatch() { GTEST_SKIP() << "Not yet implemented"; }
void TestSqlPrepareGetParameterSchema() { GTEST_SKIP() << "Not yet implemented"; }
Expand Down
55 changes: 53 additions & 2 deletions c/driver/postgresql/statement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,33 @@ struct BindStream {
type_id = PostgresTypeId::kInterval;
param_lengths[i] = 16;
break;
case ArrowType::NANOARROW_TYPE_DICTIONARY: {
struct ArrowSchemaView value_view;
CHECK_NA(INTERNAL,
ArrowSchemaViewInit(&value_view, bind_schema->children[i]->dictionary,
nullptr),
error);
switch (value_view.type) {
case NANOARROW_TYPE_BINARY:
case NANOARROW_TYPE_LARGE_BINARY:
type_id = PostgresTypeId::kBytea;
param_lengths[i] = 0;
break;
case NANOARROW_TYPE_STRING:
case NANOARROW_TYPE_LARGE_STRING:
type_id = PostgresTypeId::kText;
param_lengths[i] = 0;
break;
default:
SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #",
static_cast<uint64_t>(i + 1), " ('",
bind_schema->children[i]->name,
"') has unsupported dictionary value parameter type ",
ArrowTypeString(value_view.type));
return ADBC_STATUS_NOT_IMPLEMENTED;
}
break;
}
default:
SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #",
static_cast<uint64_t>(i + 1), " ('", bind_schema->children[i]->name,
Expand Down Expand Up @@ -567,8 +594,8 @@ struct BindStream {
}

ArrowBuffer buffer = writer.WriteBuffer();
if (PQputCopyData(conn, reinterpret_cast<char*>(buffer.data),
buffer.size_bytes) <= 0) {
if (PQputCopyData(conn, reinterpret_cast<char*>(buffer.data), buffer.size_bytes) <=
0) {
SetError(error, "Error writing tuple field data: %s", PQerrorMessage(conn));
return ADBC_STATUS_IO;
}
Expand Down Expand Up @@ -1029,6 +1056,30 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO:
create += " INTERVAL";
break;
case ArrowType::NANOARROW_TYPE_DICTIONARY: {
struct ArrowSchemaView value_view;
CHECK_NA(INTERNAL,
ArrowSchemaViewInit(&value_view, source_schema.children[i]->dictionary,
nullptr),
error);
switch (value_view.type) {
case NANOARROW_TYPE_BINARY:
case NANOARROW_TYPE_LARGE_BINARY:
create += " BYTEA";
break;
case NANOARROW_TYPE_STRING:
case NANOARROW_TYPE_LARGE_STRING:
create += " TEXT";
break;
default:
SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #",
static_cast<uint64_t>(i + 1), " ('", source_schema.children[i]->name,
"') has unsupported dictionary value type for ingestion ",
ArrowTypeString(value_view.type));
return ADBC_STATUS_NOT_IMPLEMENTED;
}
break;
}
default:
SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #",
static_cast<uint64_t>(i + 1), " ('", source_schema.children[i]->name,
Expand Down

0 comments on commit 5a76d63

Please sign in to comment.