Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(c/driver/postgresql): Factory func for CopyWriter construction #1998

Merged
merged 1 commit into from
Jul 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 95 additions & 90 deletions c/driver/postgresql/copy/writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,20 @@ class PostgresCopyFieldWriter {
public:
virtual ~PostgresCopyFieldWriter() {}

void Init(struct ArrowArrayView* array_view) { array_view_ = array_view; };
template <class T, typename... Params>
static std::unique_ptr<T> Create(struct ArrowArrayView* array_view, Params&&... args) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this even needs to be a static method? It could just be a free function and you can then just invoke

*out = MakeWriter<T>(array_view);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think would have to make the Init method public to do that - do you think that is better?

../driver/postgresql/copy/writer.h:564:15: error: ‘virtual void adbcpq::PostgresCopyFieldWriter::Init(ArrowArrayView*)’ is protected within this context
  564 |   writer->Init(array_view);
      |   ~~~~~~~~~~~~^~~~~~~~~~~~
../driver/postgresql/copy/writer.h:100:16: note: declared protected here
  100 |   virtual void Init(struct ArrowArrayView* array_view) { array_view_ = array_view; };

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah...

This is fine then. The T::Create<T> bugged me a little bit but it's not actually a problem.

auto writer = std::make_unique<T>(std::forward<Params>(args)...);
writer->Init(array_view);
return writer;
}

virtual ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) {
return ENOTSUP;
}

protected:
virtual void Init(struct ArrowArrayView* array_view) { array_view_ = array_view; };

struct ArrowArrayView* array_view_;
std::vector<std::unique_ptr<PostgresCopyFieldWriter>> children_;
};
Expand Down Expand Up @@ -439,11 +446,9 @@ class PostgresCopyBinaryDictFieldWriter : public PostgresCopyFieldWriter {
template <bool IsFixedSize>
class PostgresCopyListFieldWriter : public PostgresCopyFieldWriter {
public:
explicit PostgresCopyListFieldWriter(uint32_t child_oid) : child_oid_{child_oid} {}

void InitChild(std::unique_ptr<PostgresCopyFieldWriter> child) {
child_ = std::move(child);
}
explicit PostgresCopyListFieldWriter(uint32_t child_oid,
std::unique_ptr<PostgresCopyFieldWriter> child)
: child_oid_{child_oid}, child_{std::move(child)} {}

ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) override {
if (index >= array_view_->length) {
Expand Down Expand Up @@ -499,8 +504,8 @@ class PostgresCopyListFieldWriter : public PostgresCopyFieldWriter {
}

private:
std::unique_ptr<PostgresCopyFieldWriter> child_;
const uint32_t child_oid_;
std::unique_ptr<PostgresCopyFieldWriter> child_;
};

template <enum ArrowTimeUnit TU>
Expand Down Expand Up @@ -569,123 +574,127 @@ static inline ArrowErrorCode MakeCopyFieldWriter(

switch (schema_view.type) {
case NANOARROW_TYPE_BOOL:
*out = std::make_unique<PostgresCopyBooleanFieldWriter>();
out->get()->Init(array_view);
using T = PostgresCopyBooleanFieldWriter;
*out = T::Create<T>(array_view);
return NANOARROW_OK;
case NANOARROW_TYPE_INT8:
case NANOARROW_TYPE_INT16:
case NANOARROW_TYPE_UINT8:
*out = std::make_unique<PostgresCopyNetworkEndianFieldWriter<int16_t>>();
out->get()->Init(array_view);
case NANOARROW_TYPE_UINT8: {
using T = PostgresCopyNetworkEndianFieldWriter<int16_t>;
*out = T::Create<T>(array_view);
return NANOARROW_OK;
}
case NANOARROW_TYPE_INT32:
case NANOARROW_TYPE_UINT16:
*out = std::make_unique<PostgresCopyNetworkEndianFieldWriter<int32_t>>();
out->get()->Init(array_view);
case NANOARROW_TYPE_UINT16: {
using T = PostgresCopyNetworkEndianFieldWriter<int32_t>;
*out = T::Create<T>(array_view);
return NANOARROW_OK;
}
case NANOARROW_TYPE_INT64:
case NANOARROW_TYPE_UINT32:
*out = std::make_unique<PostgresCopyNetworkEndianFieldWriter<int64_t>>();
out->get()->Init(array_view);
case NANOARROW_TYPE_UINT32: {
using T = PostgresCopyNetworkEndianFieldWriter<int64_t>;
*out = T::Create<T>(array_view);
return NANOARROW_OK;
}
case NANOARROW_TYPE_DATE32: {
constexpr int32_t kPostgresDateEpoch = 10957;
*out = std::make_unique<
PostgresCopyNetworkEndianFieldWriter<int32_t, kPostgresDateEpoch>>();
out->get()->Init(array_view);
using T = PostgresCopyNetworkEndianFieldWriter<int32_t, kPostgresDateEpoch>;
*out = T::Create<T>(array_view);
return NANOARROW_OK;
}
case NANOARROW_TYPE_TIME64: {
switch (schema_view.time_unit) {
case NANOARROW_TIME_UNIT_MICRO:
*out = std::make_unique<PostgresCopyNetworkEndianFieldWriter<int64_t>>();
out->get()->Init(array_view);
using T = PostgresCopyNetworkEndianFieldWriter<int64_t>;
*out = T::Create<T>(array_view);
return NANOARROW_OK;
default:
return ADBC_STATUS_NOT_IMPLEMENTED;
}
}
case NANOARROW_TYPE_FLOAT:
*out = std::make_unique<PostgresCopyFloatFieldWriter>();
out->get()->Init(array_view);
case NANOARROW_TYPE_FLOAT: {
using T = PostgresCopyFloatFieldWriter;
*out = T::Create<T>(array_view);
return NANOARROW_OK;
case NANOARROW_TYPE_DOUBLE:
*out = std::make_unique<PostgresCopyDoubleFieldWriter>();
out->get()->Init(array_view);
}
case NANOARROW_TYPE_DOUBLE: {
using T = PostgresCopyDoubleFieldWriter;
*out = T::Create<T>(array_view);
return NANOARROW_OK;
}
case NANOARROW_TYPE_DECIMAL128: {
using T = PostgresCopyNumericFieldWriter<NANOARROW_TYPE_DECIMAL128>;
const auto precision = schema_view.decimal_precision;
const auto scale = schema_view.decimal_scale;
*out = std::make_unique<PostgresCopyNumericFieldWriter<NANOARROW_TYPE_DECIMAL128>>(
precision, scale);
out->get()->Init(array_view);
*out = T::Create<T>(array_view, precision, scale);
return NANOARROW_OK;
}
case NANOARROW_TYPE_DECIMAL256: {
using T = PostgresCopyNumericFieldWriter<NANOARROW_TYPE_DECIMAL256>;
const auto precision = schema_view.decimal_precision;
const auto scale = schema_view.decimal_scale;
*out = std::make_unique<PostgresCopyNumericFieldWriter<NANOARROW_TYPE_DECIMAL256>>(
precision, scale);
out->get()->Init(array_view);
*out = T::Create<T>(array_view, precision, scale);
return NANOARROW_OK;
}
case NANOARROW_TYPE_BINARY:
case NANOARROW_TYPE_STRING:
case NANOARROW_TYPE_LARGE_STRING:
*out = std::make_unique<PostgresCopyBinaryFieldWriter>();
out->get()->Init(array_view);
case NANOARROW_TYPE_LARGE_STRING: {
using T = PostgresCopyBinaryFieldWriter;
*out = T::Create<T>(array_view);
return NANOARROW_OK;
}
case NANOARROW_TYPE_TIMESTAMP: {
switch (schema_view.time_unit) {
case NANOARROW_TIME_UNIT_NANO:
*out = std::make_unique<
PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_NANO>>();
out->get()->Init(array_view);
case NANOARROW_TIME_UNIT_NANO: {
using T = PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_NANO>;
*out = T::Create<T>(array_view);
break;
case NANOARROW_TIME_UNIT_MILLI:
*out = std::make_unique<
PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_MILLI>>();
out->get()->Init(array_view);
}
case NANOARROW_TIME_UNIT_MILLI: {
using T = PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_MILLI>;
*out = T::Create<T>(array_view);
break;
case NANOARROW_TIME_UNIT_MICRO:
*out = std::make_unique<
PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_MICRO>>();
out->get()->Init(array_view);
}
case NANOARROW_TIME_UNIT_MICRO: {
using T = PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_MICRO>;
*out = T::Create<T>(array_view);
break;
case NANOARROW_TIME_UNIT_SECOND:
*out = std::make_unique<
PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_SECOND>>();
out->get()->Init(array_view);
}
case NANOARROW_TIME_UNIT_SECOND: {
using T = PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_SECOND>;
*out = T::Create<T>(array_view);
break;
}
}
return NANOARROW_OK;
}
case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO:
*out = std::make_unique<PostgresCopyIntervalFieldWriter>();
out->get()->Init(array_view);
case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: {
using T = PostgresCopyIntervalFieldWriter;
*out = T::Create<T>(array_view);
return NANOARROW_OK;
}
case NANOARROW_TYPE_DURATION: {
switch (schema_view.time_unit) {
case NANOARROW_TIME_UNIT_SECOND:
*out = std::make_unique<
PostgresCopyDurationFieldWriter<NANOARROW_TIME_UNIT_SECOND>>();
out->get()->Init(array_view);
case NANOARROW_TIME_UNIT_SECOND: {
using T = PostgresCopyDurationFieldWriter<NANOARROW_TIME_UNIT_SECOND>;
*out = T::Create<T>(array_view);
break;
case NANOARROW_TIME_UNIT_MILLI:
*out = std::make_unique<
PostgresCopyDurationFieldWriter<NANOARROW_TIME_UNIT_MILLI>>();
out->get()->Init(array_view);
}
case NANOARROW_TIME_UNIT_MILLI: {
using T = PostgresCopyDurationFieldWriter<NANOARROW_TIME_UNIT_MILLI>;
*out = T::Create<T>(array_view);
break;
case NANOARROW_TIME_UNIT_MICRO:
*out = std::make_unique<
PostgresCopyDurationFieldWriter<NANOARROW_TIME_UNIT_MICRO>>();
out->get()->Init(array_view);
}
case NANOARROW_TIME_UNIT_MICRO: {
using T = PostgresCopyDurationFieldWriter<NANOARROW_TIME_UNIT_MICRO>;
*out = T::Create<T>(array_view);
break;
case NANOARROW_TIME_UNIT_NANO:
*out = std::make_unique<
PostgresCopyDurationFieldWriter<NANOARROW_TIME_UNIT_NANO>>();
out->get()->Init(array_view);
}
case NANOARROW_TIME_UNIT_NANO: {
using T = PostgresCopyDurationFieldWriter<NANOARROW_TIME_UNIT_NANO>;
*out = T::Create<T>(array_view);
break;
}
}
return NANOARROW_OK;
}
Expand All @@ -697,10 +706,11 @@ static inline ArrowErrorCode MakeCopyFieldWriter(
case NANOARROW_TYPE_BINARY:
case NANOARROW_TYPE_STRING:
case NANOARROW_TYPE_LARGE_BINARY:
case NANOARROW_TYPE_LARGE_STRING:
*out = std::make_unique<PostgresCopyBinaryDictFieldWriter>();
out->get()->Init(array_view);
case NANOARROW_TYPE_LARGE_STRING: {
using T = PostgresCopyBinaryDictFieldWriter;
*out = T::Create<T>(array_view);
return NANOARROW_OK;
}
default:
break;
}
Expand All @@ -724,17 +734,11 @@ static inline ArrowErrorCode MakeCopyFieldWriter(
&child_writer, error));

if (schema_view.type == NANOARROW_TYPE_FIXED_SIZE_LIST) {
auto list_writer =
std::make_unique<PostgresCopyListFieldWriter<true>>(child_type.oid());
list_writer->Init(array_view);
list_writer->InitChild(std::move(child_writer));
*out = std::move(list_writer);
using T = PostgresCopyListFieldWriter<true>;
*out = T::Create<T>(array_view, child_type.oid(), std::move(child_writer));
} else {
auto list_writer =
std::make_unique<PostgresCopyListFieldWriter<false>>(child_type.oid());
list_writer->Init(array_view);
list_writer->InitChild(std::move(child_writer));
*out = std::move(list_writer);
using T = PostgresCopyListFieldWriter<false>;
*out = T::Create<T>(array_view, child_type.oid(), std::move(child_writer));
}
return NANOARROW_OK;
}
Expand All @@ -752,7 +756,8 @@ class PostgresCopyStreamWriter {
schema_ = schema;
NANOARROW_RETURN_NOT_OK(
ArrowArrayViewInitFromSchema(&array_view_.value, schema, nullptr));
root_writer_.Init(&array_view_.value);
root_writer_ = PostgresCopyFieldTupleWriter::Create<PostgresCopyFieldTupleWriter>(
&array_view_.value);
ArrowBufferInit(&buffer_.value);
return NANOARROW_OK;
}
Expand All @@ -778,7 +783,7 @@ class PostgresCopyStreamWriter {
}

ArrowErrorCode WriteRecord(ArrowError* error) {
NANOARROW_RETURN_NOT_OK(root_writer_.Write(&buffer_.value, records_written_, error));
NANOARROW_RETURN_NOT_OK(root_writer_->Write(&buffer_.value, records_written_, error));
records_written_++;
return NANOARROW_OK;
}
Expand All @@ -794,7 +799,7 @@ class PostgresCopyStreamWriter {
NANOARROW_RETURN_NOT_OK(MakeCopyFieldWriter(schema_->children[i],
array_view_->children[i], type_resolver,
&child_writer, error));
root_writer_.AppendChild(std::move(child_writer));
root_writer_->AppendChild(std::move(child_writer));
}

return NANOARROW_OK;
Expand All @@ -808,7 +813,7 @@ class PostgresCopyStreamWriter {
}

private:
PostgresCopyFieldTupleWriter root_writer_;
std::unique_ptr<PostgresCopyFieldTupleWriter> root_writer_;
struct ArrowSchema* schema_;
Handle<struct ArrowArrayView> array_view_;
Handle<struct ArrowBuffer> buffer_;
Expand Down
Loading