diff --git a/c/driver/postgresql/copy/writer.h b/c/driver/postgresql/copy/writer.h index 3d74b81681..b97628f349 100644 --- a/c/driver/postgresql/copy/writer.h +++ b/c/driver/postgresql/copy/writer.h @@ -92,13 +92,20 @@ class PostgresCopyFieldWriter { public: virtual ~PostgresCopyFieldWriter() {} - void Init(struct ArrowArrayView* array_view) { array_view_ = array_view; }; + template + static std::unique_ptr Create(struct ArrowArrayView* array_view, Params&&... args) { + auto writer = std::make_unique(std::forward(args)...); + writer->Init(array_view); + return writer; + } virtual ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) { return ENOTSUP; } protected: + virtual void Init(struct ArrowArrayView* array_view) { array_view_ = array_view; }; + struct ArrowArrayView* array_view_; std::vector> children_; }; @@ -439,11 +446,9 @@ class PostgresCopyBinaryDictFieldWriter : public PostgresCopyFieldWriter { template class PostgresCopyListFieldWriter : public PostgresCopyFieldWriter { public: - explicit PostgresCopyListFieldWriter(uint32_t child_oid) : child_oid_{child_oid} {} - - void InitChild(std::unique_ptr child) { - child_ = std::move(child); - } + explicit PostgresCopyListFieldWriter(uint32_t child_oid, + std::unique_ptr child) + : child_oid_{child_oid}, child_{std::move(child)} {} ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) override { if (index >= array_view_->length) { @@ -499,8 +504,8 @@ class PostgresCopyListFieldWriter : public PostgresCopyFieldWriter { } private: - std::unique_ptr child_; const uint32_t child_oid_; + std::unique_ptr child_; }; template @@ -569,123 +574,127 @@ static inline ArrowErrorCode MakeCopyFieldWriter( switch (schema_view.type) { case NANOARROW_TYPE_BOOL: - *out = std::make_unique(); - out->get()->Init(array_view); + using T = PostgresCopyBooleanFieldWriter; + *out = T::Create(array_view); return NANOARROW_OK; case NANOARROW_TYPE_INT8: case NANOARROW_TYPE_INT16: - case NANOARROW_TYPE_UINT8: - *out = std::make_unique>(); - out->get()->Init(array_view); + case NANOARROW_TYPE_UINT8: { + using T = PostgresCopyNetworkEndianFieldWriter; + *out = T::Create(array_view); return NANOARROW_OK; + } case NANOARROW_TYPE_INT32: - case NANOARROW_TYPE_UINT16: - *out = std::make_unique>(); - out->get()->Init(array_view); + case NANOARROW_TYPE_UINT16: { + using T = PostgresCopyNetworkEndianFieldWriter; + *out = T::Create(array_view); return NANOARROW_OK; + } case NANOARROW_TYPE_INT64: - case NANOARROW_TYPE_UINT32: - *out = std::make_unique>(); - out->get()->Init(array_view); + case NANOARROW_TYPE_UINT32: { + using T = PostgresCopyNetworkEndianFieldWriter; + *out = T::Create(array_view); return NANOARROW_OK; + } case NANOARROW_TYPE_DATE32: { constexpr int32_t kPostgresDateEpoch = 10957; - *out = std::make_unique< - PostgresCopyNetworkEndianFieldWriter>(); - out->get()->Init(array_view); + using T = PostgresCopyNetworkEndianFieldWriter; + *out = T::Create(array_view); return NANOARROW_OK; } case NANOARROW_TYPE_TIME64: { switch (schema_view.time_unit) { case NANOARROW_TIME_UNIT_MICRO: - *out = std::make_unique>(); - out->get()->Init(array_view); + using T = PostgresCopyNetworkEndianFieldWriter; + *out = T::Create(array_view); return NANOARROW_OK; default: return ADBC_STATUS_NOT_IMPLEMENTED; } } - case NANOARROW_TYPE_FLOAT: - *out = std::make_unique(); - out->get()->Init(array_view); + case NANOARROW_TYPE_FLOAT: { + using T = PostgresCopyFloatFieldWriter; + *out = T::Create(array_view); return NANOARROW_OK; - case NANOARROW_TYPE_DOUBLE: - *out = std::make_unique(); - out->get()->Init(array_view); + } + case NANOARROW_TYPE_DOUBLE: { + using T = PostgresCopyDoubleFieldWriter; + *out = T::Create(array_view); return NANOARROW_OK; + } case NANOARROW_TYPE_DECIMAL128: { + using T = PostgresCopyNumericFieldWriter; const auto precision = schema_view.decimal_precision; const auto scale = schema_view.decimal_scale; - *out = std::make_unique>( - precision, scale); - out->get()->Init(array_view); + *out = T::Create(array_view, precision, scale); return NANOARROW_OK; } case NANOARROW_TYPE_DECIMAL256: { + using T = PostgresCopyNumericFieldWriter; const auto precision = schema_view.decimal_precision; const auto scale = schema_view.decimal_scale; - *out = std::make_unique>( - precision, scale); - out->get()->Init(array_view); + *out = T::Create(array_view, precision, scale); return NANOARROW_OK; } case NANOARROW_TYPE_BINARY: case NANOARROW_TYPE_STRING: - case NANOARROW_TYPE_LARGE_STRING: - *out = std::make_unique(); - out->get()->Init(array_view); + case NANOARROW_TYPE_LARGE_STRING: { + using T = PostgresCopyBinaryFieldWriter; + *out = T::Create(array_view); return NANOARROW_OK; + } case NANOARROW_TYPE_TIMESTAMP: { switch (schema_view.time_unit) { - case NANOARROW_TIME_UNIT_NANO: - *out = std::make_unique< - PostgresCopyTimestampFieldWriter>(); - out->get()->Init(array_view); + case NANOARROW_TIME_UNIT_NANO: { + using T = PostgresCopyTimestampFieldWriter; + *out = T::Create(array_view); break; - case NANOARROW_TIME_UNIT_MILLI: - *out = std::make_unique< - PostgresCopyTimestampFieldWriter>(); - out->get()->Init(array_view); + } + case NANOARROW_TIME_UNIT_MILLI: { + using T = PostgresCopyTimestampFieldWriter; + *out = T::Create(array_view); break; - case NANOARROW_TIME_UNIT_MICRO: - *out = std::make_unique< - PostgresCopyTimestampFieldWriter>(); - out->get()->Init(array_view); + } + case NANOARROW_TIME_UNIT_MICRO: { + using T = PostgresCopyTimestampFieldWriter; + *out = T::Create(array_view); break; - case NANOARROW_TIME_UNIT_SECOND: - *out = std::make_unique< - PostgresCopyTimestampFieldWriter>(); - out->get()->Init(array_view); + } + case NANOARROW_TIME_UNIT_SECOND: { + using T = PostgresCopyTimestampFieldWriter; + *out = T::Create(array_view); break; + } } return NANOARROW_OK; } - case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: - *out = std::make_unique(); - out->get()->Init(array_view); + case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: { + using T = PostgresCopyIntervalFieldWriter; + *out = T::Create(array_view); return NANOARROW_OK; + } case NANOARROW_TYPE_DURATION: { switch (schema_view.time_unit) { - case NANOARROW_TIME_UNIT_SECOND: - *out = std::make_unique< - PostgresCopyDurationFieldWriter>(); - out->get()->Init(array_view); + case NANOARROW_TIME_UNIT_SECOND: { + using T = PostgresCopyDurationFieldWriter; + *out = T::Create(array_view); break; - case NANOARROW_TIME_UNIT_MILLI: - *out = std::make_unique< - PostgresCopyDurationFieldWriter>(); - out->get()->Init(array_view); + } + case NANOARROW_TIME_UNIT_MILLI: { + using T = PostgresCopyDurationFieldWriter; + *out = T::Create(array_view); break; - case NANOARROW_TIME_UNIT_MICRO: - *out = std::make_unique< - PostgresCopyDurationFieldWriter>(); - out->get()->Init(array_view); + } + case NANOARROW_TIME_UNIT_MICRO: { + using T = PostgresCopyDurationFieldWriter; + *out = T::Create(array_view); break; - case NANOARROW_TIME_UNIT_NANO: - *out = std::make_unique< - PostgresCopyDurationFieldWriter>(); - out->get()->Init(array_view); + } + case NANOARROW_TIME_UNIT_NANO: { + using T = PostgresCopyDurationFieldWriter; + *out = T::Create(array_view); break; + } } return NANOARROW_OK; } @@ -697,10 +706,11 @@ static inline ArrowErrorCode MakeCopyFieldWriter( case NANOARROW_TYPE_BINARY: case NANOARROW_TYPE_STRING: case NANOARROW_TYPE_LARGE_BINARY: - case NANOARROW_TYPE_LARGE_STRING: - *out = std::make_unique(); - out->get()->Init(array_view); + case NANOARROW_TYPE_LARGE_STRING: { + using T = PostgresCopyBinaryDictFieldWriter; + *out = T::Create(array_view); return NANOARROW_OK; + } default: break; } @@ -724,17 +734,11 @@ static inline ArrowErrorCode MakeCopyFieldWriter( &child_writer, error)); if (schema_view.type == NANOARROW_TYPE_FIXED_SIZE_LIST) { - auto list_writer = - std::make_unique>(child_type.oid()); - list_writer->Init(array_view); - list_writer->InitChild(std::move(child_writer)); - *out = std::move(list_writer); + using T = PostgresCopyListFieldWriter; + *out = T::Create(array_view, child_type.oid(), std::move(child_writer)); } else { - auto list_writer = - std::make_unique>(child_type.oid()); - list_writer->Init(array_view); - list_writer->InitChild(std::move(child_writer)); - *out = std::move(list_writer); + using T = PostgresCopyListFieldWriter; + *out = T::Create(array_view, child_type.oid(), std::move(child_writer)); } return NANOARROW_OK; } @@ -752,7 +756,8 @@ class PostgresCopyStreamWriter { schema_ = schema; NANOARROW_RETURN_NOT_OK( ArrowArrayViewInitFromSchema(&array_view_.value, schema, nullptr)); - root_writer_.Init(&array_view_.value); + root_writer_ = PostgresCopyFieldTupleWriter::Create( + &array_view_.value); ArrowBufferInit(&buffer_.value); return NANOARROW_OK; } @@ -778,7 +783,7 @@ class PostgresCopyStreamWriter { } ArrowErrorCode WriteRecord(ArrowError* error) { - NANOARROW_RETURN_NOT_OK(root_writer_.Write(&buffer_.value, records_written_, error)); + NANOARROW_RETURN_NOT_OK(root_writer_->Write(&buffer_.value, records_written_, error)); records_written_++; return NANOARROW_OK; } @@ -794,7 +799,7 @@ class PostgresCopyStreamWriter { NANOARROW_RETURN_NOT_OK(MakeCopyFieldWriter(schema_->children[i], array_view_->children[i], type_resolver, &child_writer, error)); - root_writer_.AppendChild(std::move(child_writer)); + root_writer_->AppendChild(std::move(child_writer)); } return NANOARROW_OK; @@ -808,7 +813,7 @@ class PostgresCopyStreamWriter { } private: - PostgresCopyFieldTupleWriter root_writer_; + std::unique_ptr root_writer_; struct ArrowSchema* schema_; Handle array_view_; Handle buffer_;