diff --git a/.env b/.env index be35921f94c3a..1358aafe824a6 100644 --- a/.env +++ b/.env @@ -61,7 +61,7 @@ GCC_VERSION="" GO=1.21.8 STATICCHECK=v0.4.7 HDFS=3.2.1 -JDK=8 +JDK=11 KARTOTHEK=latest # LLVM 12 and GCC 11 reports -Wmismatched-new-delete. LLVM=14 diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index 8eb2682dc077d..d4211c2c81cb5 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -58,7 +58,7 @@ jobs: strategy: fail-fast: false matrix: - jdk: [8, 11, 17, 21, 22] + jdk: [11, 17, 21, 22] maven: [3.9.6] image: [java] env: diff --git a/ci/docker/conda-integration.dockerfile b/ci/docker/conda-integration.dockerfile index 78d2503b23df7..c602490d6b729 100644 --- a/ci/docker/conda-integration.dockerfile +++ b/ci/docker/conda-integration.dockerfile @@ -23,7 +23,7 @@ ARG arch=amd64 ARG maven=3.8.7 ARG node=16 ARG yarn=1.22 -ARG jdk=8 +ARG jdk=11 ARG go=1.21.8 # Install Archery and integration dependencies diff --git a/ci/docker/conda-python-hdfs.dockerfile b/ci/docker/conda-python-hdfs.dockerfile index fa4fa0d1fb772..4e5e1a402e282 100644 --- a/ci/docker/conda-python-hdfs.dockerfile +++ b/ci/docker/conda-python-hdfs.dockerfile @@ -20,7 +20,7 @@ ARG arch=amd64 ARG python=3.8 FROM ${repo}:${arch}-conda-python-${python} -ARG jdk=8 +ARG jdk=11 ARG maven=3.8.7 RUN mamba install -q -y \ maven=${maven} \ diff --git a/ci/docker/conda-python-spark.dockerfile b/ci/docker/conda-python-spark.dockerfile index 866f6f37f8bd9..d95fe58b529f6 100644 --- a/ci/docker/conda-python-spark.dockerfile +++ b/ci/docker/conda-python-spark.dockerfile @@ -20,7 +20,7 @@ ARG arch=amd64 ARG python=3.8 FROM ${repo}:${arch}-conda-python-${python} -ARG jdk=8 +ARG jdk=11 ARG maven=3.8.7 ARG numpy=latest diff --git a/ci/docker/java-jni-manylinux-201x.dockerfile b/ci/docker/java-jni-manylinux-201x.dockerfile index 8b73c73c1d240..479f4aa598b18 100644 --- a/ci/docker/java-jni-manylinux-201x.dockerfile +++ b/ci/docker/java-jni-manylinux-201x.dockerfile @@ -33,7 +33,7 @@ RUN vcpkg install \ --x-feature=s3 # Install Java -ARG java=1.8.0 +ARG java=11 ARG maven=3.9.3 RUN yum install -y java-$java-openjdk-devel && \ yum clean all && \ diff --git a/ci/docker/linux-apt-docs.dockerfile b/ci/docker/linux-apt-docs.dockerfile index 1c916840e071b..0804f3543c283 100644 --- a/ci/docker/linux-apt-docs.dockerfile +++ b/ci/docker/linux-apt-docs.dockerfile @@ -19,7 +19,7 @@ ARG base FROM ${base} ARG r=4.4 -ARG jdk=8 +ARG jdk=11 ENV PUPPETEER_EXECUTABLE_PATH=/usr/bin/chromium diff --git a/ci/docker/ubuntu-24.04-cpp-minimal.dockerfile b/ci/docker/ubuntu-24.04-cpp-minimal.dockerfile new file mode 100644 index 0000000000000..a995ab2a8bc2d --- /dev/null +++ b/ci/docker/ubuntu-24.04-cpp-minimal.dockerfile @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +ARG base=amd64/ubuntu:24.04 +FROM ${base} + +SHELL ["/bin/bash", "-o", "pipefail", "-c"] + +RUN echo "debconf debconf/frontend select Noninteractive" | \ + debconf-set-selections + +RUN apt-get update -y -q && \ + apt-get install -y -q \ + build-essential \ + ccache \ + cmake \ + curl \ + git \ + libssl-dev \ + libcurl4-openssl-dev \ + python3-pip \ + tzdata \ + tzdata-legacy \ + wget && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists* + +# Installs LLVM toolchain, for Gandiva and testing other compilers +# +# Note that this is installed before the base packages to improve iteration +# while debugging package list with docker build. +ARG llvm +RUN latest_system_llvm=14 && \ + if [ ${llvm} -gt ${latest_system_llvm} ]; then \ + apt-get update -y -q && \ + apt-get install -y -q --no-install-recommends \ + apt-transport-https \ + ca-certificates \ + gnupg \ + lsb-release \ + wget && \ + wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | apt-key add - && \ + code_name=$(lsb_release --codename --short) && \ + if [ ${llvm} -gt 10 ]; then \ + echo "deb https://apt.llvm.org/${code_name}/ llvm-toolchain-${code_name}-${llvm} main" > \ + /etc/apt/sources.list.d/llvm.list; \ + fi; \ + fi && \ + apt-get update -y -q && \ + apt-get install -y -q --no-install-recommends \ + clang-${llvm} \ + llvm-${llvm}-dev && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists* + +COPY ci/scripts/install_minio.sh /arrow/ci/scripts/ +RUN /arrow/ci/scripts/install_minio.sh latest /usr/local + +COPY ci/scripts/install_gcs_testbench.sh /arrow/ci/scripts/ +RUN /arrow/ci/scripts/install_gcs_testbench.sh default + +COPY ci/scripts/install_sccache.sh /arrow/ci/scripts/ +RUN /arrow/ci/scripts/install_sccache.sh unknown-linux-musl /usr/local/bin + +ENV ARROW_ACERO=ON \ + ARROW_AZURE=OFF \ + ARROW_BUILD_TESTS=ON \ + ARROW_DATASET=ON \ + ARROW_FLIGHT=ON \ + ARROW_GANDIVA=ON \ + ARROW_GCS=ON \ + ARROW_HDFS=ON \ + ARROW_HOME=/usr/local \ + ARROW_INSTALL_NAME_RPATH=OFF \ + ARROW_ORC=ON \ + ARROW_PARQUET=ON \ + ARROW_S3=ON \ + ARROW_USE_CCACHE=ON \ + ARROW_WITH_BROTLI=ON \ + ARROW_WITH_BZ2=ON \ + ARROW_WITH_LZ4=ON \ + ARROW_WITH_OPENTELEMETRY=OFF \ + ARROW_WITH_SNAPPY=ON \ + ARROW_WITH_ZLIB=ON \ + ARROW_WITH_ZSTD=ON \ + CMAKE_GENERATOR="Unix Makefiles" \ + PARQUET_BUILD_EXAMPLES=ON \ + PARQUET_BUILD_EXECUTABLES=ON \ + PATH=/usr/lib/ccache/:$PATH \ + PYTHON=python3 diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index 5b89a831ff7fe..1c8c40d6f9c52 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -2882,6 +2882,10 @@ macro(build_absl) set(ABSL_INCLUDE_DIR "${ABSL_PREFIX}/include") set(ABSL_CMAKE_ARGS "${EP_COMMON_CMAKE_ARGS}" -DABSL_RUN_TESTS=OFF "-DCMAKE_INSTALL_PREFIX=${ABSL_PREFIX}") + if(CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0) + set(ABSL_CXX_FLAGS "${EP_CXX_FLAGS} -include stdint.h") + list(APPEND ABSL_CMAKE_ARGS "-DCMAKE_CXX_FLAGS=${ABSL_CXX_FLAGS}") + endif() set(ABSL_BUILD_BYPRODUCTS) set(ABSL_LIBRARIES) diff --git a/cpp/src/arrow/array/concatenate.cc b/cpp/src/arrow/array/concatenate.cc index 87e55246c78fe..b4638dd6593d8 100644 --- a/cpp/src/arrow/array/concatenate.cc +++ b/cpp/src/arrow/array/concatenate.cc @@ -75,6 +75,31 @@ struct Bitmap { bool AllSet() const { return data == nullptr; } }; +enum class OffsetBufferOpOutcome { + kOk, + kOffsetOverflow, +}; + +Status OffsetOverflowStatus() { + return Status::Invalid("offset overflow while concatenating arrays"); +} + +#define RETURN_IF_NOT_OK_OUTCOME(outcome) \ + switch (outcome) { \ + case OffsetBufferOpOutcome::kOk: \ + break; \ + case OffsetBufferOpOutcome::kOffsetOverflow: \ + return OffsetOverflowStatus(); \ + } + +struct ErrorHints { + /// \brief Suggested cast to avoid overflow during concatenation. + /// + /// If the concatenation of offsets overflows, this field might be set to the + /// a type that uses larger offsets (e.g. large_utf8, large_list). + std::shared_ptr suggested_cast; +}; + // Allocate a buffer and concatenate bitmaps into it. Status ConcatenateBitmaps(const std::vector& bitmaps, MemoryPool* pool, std::shared_ptr* out) { @@ -112,15 +137,16 @@ int64_t SumBufferSizesInBytes(const BufferVector& buffers) { // Write offsets in src into dst, adjusting them such that first_offset // will be the first offset written. template -Status PutOffsets(const Buffer& src, Offset first_offset, Offset* dst, - Range* values_range); +Result PutOffsets(const Buffer& src, Offset first_offset, + Offset* dst, Range* values_range); // Concatenate buffers holding offsets into a single buffer of offsets, // also computing the ranges of values spanned by each buffer of offsets. template -Status ConcatenateOffsets(const BufferVector& buffers, MemoryPool* pool, - std::shared_ptr* out, - std::vector* values_ranges) { +Result ConcatenateOffsets(const BufferVector& buffers, + MemoryPool* pool, + std::shared_ptr* out, + std::vector* values_ranges) { values_ranges->resize(buffers.size()); // allocate output buffer @@ -133,26 +159,30 @@ Status ConcatenateOffsets(const BufferVector& buffers, MemoryPool* pool, for (size_t i = 0; i < buffers.size(); ++i) { // the first offset from buffers[i] will be adjusted to values_length // (the cumulative length of values spanned by offsets in previous buffers) - RETURN_NOT_OK(PutOffsets(*buffers[i], values_length, - out_data + elements_length, &(*values_ranges)[i])); + ARROW_ASSIGN_OR_RAISE(auto outcome, PutOffsets(*buffers[i], values_length, + out_data + elements_length, + &(*values_ranges)[i])); + if (ARROW_PREDICT_FALSE(outcome != OffsetBufferOpOutcome::kOk)) { + return outcome; + } elements_length += buffers[i]->size() / sizeof(Offset); values_length += static_cast((*values_ranges)[i].length); } // the final element in out_data is the length of all values spanned by the offsets out_data[out_size_in_bytes / sizeof(Offset)] = values_length; - return Status::OK(); + return OffsetBufferOpOutcome::kOk; } template -Status PutOffsets(const Buffer& src, Offset first_offset, Offset* dst, - Range* values_range) { +Result PutOffsets(const Buffer& src, Offset first_offset, + Offset* dst, Range* values_range) { if (src.size() == 0) { // It's allowed to have an empty offsets buffer for a 0-length array // (see Array::Validate) values_range->offset = 0; values_range->length = 0; - return Status::OK(); + return OffsetBufferOpOutcome::kOk; } // Get the range of offsets to transfer from src @@ -162,8 +192,9 @@ Status PutOffsets(const Buffer& src, Offset first_offset, Offset* dst, // Compute the range of values which is spanned by this range of offsets values_range->offset = src_begin[0]; values_range->length = *src_end - values_range->offset; - if (first_offset > std::numeric_limits::max() - values_range->length) { - return Status::Invalid("offset overflow while concatenating arrays"); + if (ARROW_PREDICT_FALSE(first_offset > + std::numeric_limits::max() - values_range->length)) { + return OffsetBufferOpOutcome::kOffsetOverflow; } // Write offsets into dst, ensuring that the first offset written is @@ -175,12 +206,14 @@ Status PutOffsets(const Buffer& src, Offset first_offset, Offset* dst, std::transform(src_begin, src_end, dst, [displacement](Offset offset) { return SafeSignedAdd(offset, displacement); }); - return Status::OK(); + return OffsetBufferOpOutcome::kOk; } template -Status PutListViewOffsets(const ArrayData& input, offset_type* sizes, const Buffer& src, - offset_type displacement, offset_type* dst); +Result PutListViewOffsets(const ArrayData& input, + offset_type* sizes, const Buffer& src, + offset_type displacement, + offset_type* dst); // Concatenate buffers holding list-view offsets into a single buffer of offsets // @@ -198,10 +231,10 @@ Status PutListViewOffsets(const ArrayData& input, offset_type* sizes, const Buff // \param[in] in The child arrays // \param[in,out] sizes The concatenated sizes buffer template -Status ConcatenateListViewOffsets(const ArrayDataVector& in, offset_type* sizes, - const BufferVector& offset_buffers, - const std::vector& value_ranges, - MemoryPool* pool, std::shared_ptr* out) { +Result ConcatenateListViewOffsets( + const ArrayDataVector& in, offset_type* sizes, const BufferVector& offset_buffers, + const std::vector& value_ranges, MemoryPool* pool, + std::shared_ptr* out) { DCHECK_EQ(offset_buffers.size(), value_ranges.size()); // Allocate resulting offsets buffer and initialize it with zeros @@ -216,26 +249,32 @@ Status ConcatenateListViewOffsets(const ArrayDataVector& in, offset_type* sizes, for (size_t i = 0; i < offset_buffers.size(); ++i) { const auto displacement = static_cast(num_child_values - value_ranges[i].offset); - RETURN_NOT_OK(PutListViewOffsets(*in[i], /*sizes=*/sizes + elements_length, - /*src=*/*offset_buffers[i], displacement, - /*dst=*/out_offsets + elements_length)); + ARROW_ASSIGN_OR_RAISE(auto outcome, + PutListViewOffsets(*in[i], /*sizes=*/sizes + elements_length, + /*src=*/*offset_buffers[i], displacement, + /*dst=*/out_offsets + elements_length)); + if (ARROW_PREDICT_FALSE(outcome != OffsetBufferOpOutcome::kOk)) { + return outcome; + } elements_length += offset_buffers[i]->size() / sizeof(offset_type); num_child_values += value_ranges[i].length; if (num_child_values > std::numeric_limits::max()) { - return Status::Invalid("offset overflow while concatenating arrays"); + return OffsetBufferOpOutcome::kOffsetOverflow; } } DCHECK_EQ(elements_length, static_cast(out_size_in_bytes / sizeof(offset_type))); - return Status::OK(); + return OffsetBufferOpOutcome::kOk; } template -Status PutListViewOffsets(const ArrayData& input, offset_type* sizes, const Buffer& src, - offset_type displacement, offset_type* dst) { +Result PutListViewOffsets(const ArrayData& input, + offset_type* sizes, const Buffer& src, + offset_type displacement, + offset_type* dst) { if (src.size() == 0) { - return Status::OK(); + return OffsetBufferOpOutcome::kOk; } const auto& validity_buffer = input.buffers[0]; if (validity_buffer) { @@ -291,7 +330,7 @@ Status PutListViewOffsets(const ArrayData& input, offset_type* sizes, const Buff } } } - return Status::OK(); + return OffsetBufferOpOutcome::kOk; } class ConcatenateImpl { @@ -316,11 +355,17 @@ class ConcatenateImpl { } } - Status Concatenate(std::shared_ptr* out) && { + Status Concatenate(std::shared_ptr* out, ErrorHints* out_hints) && { if (out_->null_count != 0 && internal::may_have_validity_bitmap(out_->type->id())) { RETURN_NOT_OK(ConcatenateBitmaps(Bitmaps(0), pool_, &out_->buffers[0])); } - RETURN_NOT_OK(VisitTypeInline(*out_->type, this)); + auto status = VisitTypeInline(*out_->type, this); + if (!status.ok()) { + if (out_hints) { + out_hints->suggested_cast = std::move(suggested_cast_); + } + return status; + } *out = std::move(out_); return Status::OK(); } @@ -337,11 +382,29 @@ class ConcatenateImpl { return ConcatenateBuffers(buffers, pool_).Value(&out_->buffers[1]); } - Status Visit(const BinaryType&) { + Status Visit(const BinaryType& input_type) { std::vector value_ranges; ARROW_ASSIGN_OR_RAISE(auto index_buffers, Buffers(1, sizeof(int32_t))); - RETURN_NOT_OK(ConcatenateOffsets(index_buffers, pool_, &out_->buffers[1], - &value_ranges)); + ARROW_ASSIGN_OR_RAISE( + auto outcome, ConcatenateOffsets(index_buffers, pool_, &out_->buffers[1], + &value_ranges)); + switch (outcome) { + case OffsetBufferOpOutcome::kOk: + break; + case OffsetBufferOpOutcome::kOffsetOverflow: + switch (input_type.id()) { + case Type::BINARY: + suggested_cast_ = large_binary(); + break; + case Type::STRING: + suggested_cast_ = large_utf8(); + break; + default: + DCHECK(false) << "unexpected type id from BinaryType: " << input_type; + break; + } + return OffsetOverflowStatus(); + } ARROW_ASSIGN_OR_RAISE(auto value_buffers, Buffers(2, value_ranges)); return ConcatenateBuffers(value_buffers, pool_).Value(&out_->buffers[2]); } @@ -349,8 +412,10 @@ class ConcatenateImpl { Status Visit(const LargeBinaryType&) { std::vector value_ranges; ARROW_ASSIGN_OR_RAISE(auto index_buffers, Buffers(1, sizeof(int64_t))); - RETURN_NOT_OK(ConcatenateOffsets(index_buffers, pool_, &out_->buffers[1], - &value_ranges)); + ARROW_ASSIGN_OR_RAISE( + auto outcome, ConcatenateOffsets(index_buffers, pool_, &out_->buffers[1], + &value_ranges)); + RETURN_IF_NOT_OK_OUTCOME(outcome); ARROW_ASSIGN_OR_RAISE(auto value_buffers, Buffers(2, value_ranges)); return ConcatenateBuffers(value_buffers, pool_).Value(&out_->buffers[2]); } @@ -394,22 +459,44 @@ class ConcatenateImpl { return Status::OK(); } - Status Visit(const ListType&) { + Status Visit(const ListType& input_type) { std::vector value_ranges; ARROW_ASSIGN_OR_RAISE(auto index_buffers, Buffers(1, sizeof(int32_t))); - RETURN_NOT_OK(ConcatenateOffsets(index_buffers, pool_, &out_->buffers[1], - &value_ranges)); + ARROW_ASSIGN_OR_RAISE(auto offsets_outcome, + ConcatenateOffsets(index_buffers, pool_, + &out_->buffers[1], &value_ranges)); + switch (offsets_outcome) { + case OffsetBufferOpOutcome::kOk: + break; + case OffsetBufferOpOutcome::kOffsetOverflow: + suggested_cast_ = large_list(input_type.value_type()); + return OffsetOverflowStatus(); + } ARROW_ASSIGN_OR_RAISE(auto child_data, ChildData(0, value_ranges)); - return ConcatenateImpl(child_data, pool_).Concatenate(&out_->child_data[0]); + ErrorHints child_error_hints; + auto status = ConcatenateImpl(child_data, pool_) + .Concatenate(&out_->child_data[0], &child_error_hints); + if (!status.ok() && child_error_hints.suggested_cast) { + suggested_cast_ = list(std::move(child_error_hints.suggested_cast)); + } + return status; } Status Visit(const LargeListType&) { std::vector value_ranges; ARROW_ASSIGN_OR_RAISE(auto index_buffers, Buffers(1, sizeof(int64_t))); - RETURN_NOT_OK(ConcatenateOffsets(index_buffers, pool_, &out_->buffers[1], - &value_ranges)); + ARROW_ASSIGN_OR_RAISE( + auto outcome, ConcatenateOffsets(index_buffers, pool_, &out_->buffers[1], + &value_ranges)); + RETURN_IF_NOT_OK_OUTCOME(outcome); ARROW_ASSIGN_OR_RAISE(auto child_data, ChildData(0, value_ranges)); - return ConcatenateImpl(child_data, pool_).Concatenate(&out_->child_data[0]); + ErrorHints child_error_hints; + auto status = ConcatenateImpl(child_data, pool_) + .Concatenate(&out_->child_data[0], &child_error_hints); + if (!status.ok() && child_error_hints.suggested_cast) { + suggested_cast_ = large_list(std::move(child_error_hints.suggested_cast)); + } + return status; } template @@ -430,8 +517,17 @@ class ConcatenateImpl { } // Concatenate the values + ErrorHints child_error_hints; ARROW_ASSIGN_OR_RAISE(ArrayDataVector value_data, ChildData(0, value_ranges)); - RETURN_NOT_OK(ConcatenateImpl(value_data, pool_).Concatenate(&out_->child_data[0])); + auto values_status = ConcatenateImpl(value_data, pool_) + .Concatenate(&out_->child_data[0], &child_error_hints); + if (!values_status.ok()) { + if (child_error_hints.suggested_cast) { + suggested_cast_ = std::make_shared>( + std::move(child_error_hints.suggested_cast)); + } + return values_status; + } out_->child_data[0]->type = type.value_type(); // Concatenate the sizes first @@ -440,22 +536,39 @@ class ConcatenateImpl { // Concatenate the offsets ARROW_ASSIGN_OR_RAISE(auto offset_buffers, Buffers(1, sizeof(offset_type))); - RETURN_NOT_OK(ConcatenateListViewOffsets( - in_, /*sizes=*/out_->buffers[2]->mutable_data_as(), offset_buffers, - value_ranges, pool_, &out_->buffers[1])); - + ARROW_ASSIGN_OR_RAISE( + auto outcome, ConcatenateListViewOffsets( + in_, /*sizes=*/out_->buffers[2]->mutable_data_as(), + offset_buffers, value_ranges, pool_, &out_->buffers[1])); + switch (outcome) { + case OffsetBufferOpOutcome::kOk: + break; + case OffsetBufferOpOutcome::kOffsetOverflow: + if constexpr (T::type_id == Type::LIST_VIEW) { + suggested_cast_ = large_list_view(type.value_type()); + } + return OffsetOverflowStatus(); + } return Status::OK(); } - Status Visit(const FixedSizeListType& fixed_size_list) { - ARROW_ASSIGN_OR_RAISE(auto child_data, ChildData(0, fixed_size_list.list_size())); - return ConcatenateImpl(child_data, pool_).Concatenate(&out_->child_data[0]); + Status Visit(const FixedSizeListType& fsl_type) { + ARROW_ASSIGN_OR_RAISE(auto child_data, ChildData(0, fsl_type.list_size())); + ErrorHints hints; + auto status = + ConcatenateImpl(child_data, pool_).Concatenate(&out_->child_data[0], &hints); + if (!status.ok() && hints.suggested_cast) { + suggested_cast_ = + fixed_size_list(std::move(hints.suggested_cast), fsl_type.list_size()); + } + return status; } Status Visit(const StructType& s) { for (int i = 0; i < s.num_fields(); ++i) { ARROW_ASSIGN_OR_RAISE(auto child_data, ChildData(i)); - RETURN_NOT_OK(ConcatenateImpl(child_data, pool_).Concatenate(&out_->child_data[i])); + RETURN_NOT_OK(ConcatenateImpl(child_data, pool_) + .Concatenate(&out_->child_data[i], /*hints=*/nullptr)); } return Status::OK(); } @@ -570,8 +683,8 @@ class ConcatenateImpl { case UnionMode::SPARSE: { for (int i = 0; i < u.num_fields(); i++) { ARROW_ASSIGN_OR_RAISE(auto child_data, ChildData(i)); - RETURN_NOT_OK( - ConcatenateImpl(child_data, pool_).Concatenate(&out_->child_data[i])); + RETURN_NOT_OK(ConcatenateImpl(child_data, pool_) + .Concatenate(&out_->child_data[i], /*hints=*/nullptr)); } break; } @@ -581,8 +694,8 @@ class ConcatenateImpl { for (size_t j = 0; j < in_.size(); j++) { child_data[j] = in_[j]->child_data[i]; } - RETURN_NOT_OK( - ConcatenateImpl(child_data, pool_).Concatenate(&out_->child_data[i])); + RETURN_NOT_OK(ConcatenateImpl(child_data, pool_) + .Concatenate(&out_->child_data[i], /*hints=*/nullptr)); } break; } @@ -666,7 +779,8 @@ class ConcatenateImpl { storage_data[i]->type = e.storage_type(); } std::shared_ptr out_storage; - RETURN_NOT_OK(ConcatenateImpl(storage_data, pool_).Concatenate(&out_storage)); + RETURN_NOT_OK(ConcatenateImpl(storage_data, pool_) + .Concatenate(&out_storage, /*hints=*/nullptr)); out_storage->type = in_[0]->type; out_ = std::move(out_storage); return Status::OK(); @@ -797,11 +911,18 @@ class ConcatenateImpl { const ArrayDataVector& in_; MemoryPool* pool_; std::shared_ptr out_; + std::shared_ptr suggested_cast_; }; } // namespace -Result> Concatenate(const ArrayVector& arrays, MemoryPool* pool) { +namespace internal { + +Result> Concatenate( + const ArrayVector& arrays, MemoryPool* pool, + std::shared_ptr* out_suggested_cast) { + DCHECK(out_suggested_cast); + *out_suggested_cast = nullptr; if (arrays.size() == 0) { return Status::Invalid("Must pass at least one array"); } @@ -818,8 +939,31 @@ Result> Concatenate(const ArrayVector& arrays, MemoryPool } std::shared_ptr out_data; - RETURN_NOT_OK(ConcatenateImpl(data, pool).Concatenate(&out_data)); + ErrorHints hints; + auto status = ConcatenateImpl(data, pool).Concatenate(&out_data, &hints); + if (!status.ok()) { + if (hints.suggested_cast) { + DCHECK(status.IsInvalid()); + *out_suggested_cast = std::move(hints.suggested_cast); + } + return status; + } return MakeArray(std::move(out_data)); } +} // namespace internal + +Result> Concatenate(const ArrayVector& arrays, MemoryPool* pool) { + std::shared_ptr suggested_cast; + auto result = internal::Concatenate(arrays, pool, &suggested_cast); + if (!result.ok() && suggested_cast && arrays.size() > 0) { + DCHECK(result.status().IsInvalid()); + return Status::Invalid(result.status().message(), ", consider casting input from `", + *arrays[0]->type(), "` to `", *suggested_cast, "` first."); + } + return result; +} + +#undef RETURN_IF_NOT_OK_OUTCOME + } // namespace arrow diff --git a/cpp/src/arrow/array/concatenate.h b/cpp/src/arrow/array/concatenate.h index e7597aad812c4..aada5624d63a3 100644 --- a/cpp/src/arrow/array/concatenate.h +++ b/cpp/src/arrow/array/concatenate.h @@ -24,6 +24,22 @@ #include "arrow/util/visibility.h" namespace arrow { +namespace internal { + +/// \brief Concatenate arrays +/// +/// \param[in] arrays a vector of arrays to be concatenated +/// \param[in] pool memory to store the result will be allocated from this memory pool +/// \param[out] out_suggested_cast if a non-OK Result is returned, the function might set +/// out_suggested_cast to a cast suggestion that would allow concatenating the arrays +/// without overflow of offsets (e.g. string to large_string) +/// +/// \return the concatenated array +ARROW_EXPORT +Result> Concatenate(const ArrayVector& arrays, MemoryPool* pool, + std::shared_ptr* out_suggested_cast); + +} // namespace internal /// \brief Concatenate arrays /// diff --git a/cpp/src/arrow/array/concatenate_test.cc b/cpp/src/arrow/array/concatenate_test.cc index af595e897f9ee..aea5311575299 100644 --- a/cpp/src/arrow/array/concatenate_test.cc +++ b/cpp/src/arrow/array/concatenate_test.cc @@ -29,6 +29,7 @@ #include #include +#include #include #include "arrow/array.h" @@ -42,6 +43,7 @@ #include "arrow/testing/util.h" #include "arrow/type.h" #include "arrow/util/list_util.h" +#include "arrow/util/unreachable.h" namespace arrow { @@ -661,14 +663,103 @@ TEST_F(ConcatenateTest, ExtensionType) { }); } +std::shared_ptr LargeVersionOfType(const std::shared_ptr& type) { + switch (type->id()) { + case Type::BINARY: + return large_binary(); + case Type::STRING: + return large_utf8(); + case Type::LIST: + return large_list(static_cast(*type).value_type()); + case Type::LIST_VIEW: + return large_list_view(static_cast(*type).value_type()); + case Type::LARGE_BINARY: + case Type::LARGE_STRING: + case Type::LARGE_LIST: + case Type::LARGE_LIST_VIEW: + return type; + default: + Unreachable(); + } +} + +std::shared_ptr fixed_size_list_of_1(std::shared_ptr type) { + return fixed_size_list(std::move(type), 1); +} + TEST_F(ConcatenateTest, OffsetOverflow) { - auto fake_long = ArrayFromJSON(utf8(), "[\"\"]"); - fake_long->data()->GetMutableValues(1)[1] = + using TypeFactory = std::shared_ptr (*)(std::shared_ptr); + static const std::vector kNestedTypeFactories = { + list, large_list, list_view, large_list_view, fixed_size_list_of_1, + }; + + auto* pool = default_memory_pool(); + std::shared_ptr suggested_cast; + for (auto& ty : {binary(), utf8()}) { + auto large_ty = LargeVersionOfType(ty); + + auto fake_long = ArrayFromJSON(ty, "[\"\"]"); + fake_long->data()->GetMutableValues(1)[1] = + std::numeric_limits::max(); + // XXX: since the data fake_long claims to own isn't there, this would + // segfault if Concatenate didn't detect overflow and raise an error. + auto concatenate_status = Concatenate({fake_long, fake_long}); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::StrEq("Invalid: offset overflow while concatenating arrays, " + "consider casting input from `" + + ty->ToString() + "` to `large_" + ty->ToString() + "` first."), + concatenate_status); + + concatenate_status = + internal::Concatenate({fake_long, fake_long}, pool, &suggested_cast); + // Message is doesn't contain the suggested cast type when the caller + // asks for it by passing the output parameter. + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::StrEq("Invalid: offset overflow while concatenating arrays"), + concatenate_status); + ASSERT_TRUE(large_ty->Equals(*suggested_cast)); + + // Check that the suggested cast is correct when concatenation + // fails due to the child array being too large. + for (auto factory : kNestedTypeFactories) { + auto nested_ty = factory(ty); + auto expected_suggestion = factory(large_ty); + auto fake_long_list = ArrayFromJSON(nested_ty, "[[\"\"]]"); + fake_long_list->data()->child_data[0] = fake_long->data(); + + ASSERT_RAISES(Invalid, internal::Concatenate({fake_long_list, fake_long_list}, pool, + &suggested_cast) + .status()); + ASSERT_TRUE(suggested_cast->Equals(*expected_suggestion)); + } + } + + auto list_ty = list(utf8()); + auto fake_long_list = ArrayFromJSON(list_ty, "[[\"Hello\"]]"); + fake_long_list->data()->GetMutableValues(1)[1] = std::numeric_limits::max(); - std::shared_ptr concatenated; - // XX since the data fake_long claims to own isn't there, this will segfault if - // Concatenate doesn't detect overflow and raise an error. - ASSERT_RAISES(Invalid, Concatenate({fake_long, fake_long}).status()); + ASSERT_RAISES(Invalid, internal::Concatenate({fake_long_list, fake_long_list}, pool, + &suggested_cast) + .status()); + ASSERT_TRUE(suggested_cast->Equals(LargeVersionOfType(list_ty))); + + auto list_view_ty = list_view(null()); + auto fake_long_list_view = ArrayFromJSON(list_view_ty, "[[], []]"); + { + constexpr int kInt32Max = std::numeric_limits::max(); + auto* values = fake_long_list_view->data()->child_data[0].get(); + auto* mutable_offsets = fake_long_list_view->data()->GetMutableValues(1); + auto* mutable_sizes = fake_long_list_view->data()->GetMutableValues(2); + values->length = 2 * static_cast(kInt32Max); + mutable_offsets[1] = kInt32Max; + mutable_offsets[0] = kInt32Max; + mutable_sizes[0] = kInt32Max; + } + ASSERT_RAISES(Invalid, internal::Concatenate({fake_long_list_view, fake_long_list_view}, + pool, &suggested_cast) + .status()); + ASSERT_TRUE(suggested_cast->Equals(LargeVersionOfType(list_view_ty))); } TEST_F(ConcatenateTest, DictionaryConcatenateWithEmptyUint16) { diff --git a/cpp/src/arrow/compute/kernels/vector_selection_test.cc b/cpp/src/arrow/compute/kernels/vector_selection_test.cc index aba016d6b7e8d..b38f3fcbd8ccd 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_test.cc @@ -28,6 +28,7 @@ #include "arrow/chunked_array.h" #include "arrow/compute/api.h" #include "arrow/compute/kernels/test_util.h" +#include "arrow/scalar.h" #include "arrow/table.h" #include "arrow/testing/builder.h" #include "arrow/testing/fixed_width_test_util.h" @@ -1101,33 +1102,114 @@ TEST(TestFilterMetaFunction, ArityChecking) { // ---------------------------------------------------------------------- // Take tests +// +// Shorthand notation (as defined in `TakeMetaFunction`): +// +// A = Array +// C = ChunkedArray +// R = RecordBatch +// T = Table +// +// (e.g. TakeCAC = Take(ChunkedArray, Array) -> ChunkedArray) +// +// The interface implemented by `TakeMetaFunction` is: +// +// Take(A, A) -> A (TakeAAA) +// Take(A, C) -> C (TakeACC) +// Take(C, A) -> C (TakeCAC) +// Take(C, C) -> C (TakeCCC) +// Take(R, A) -> R (TakeRAR) +// Take(T, A) -> T (TakeTAT) +// Take(T, C) -> T (TakeTCT) +// +// The tests extend the notation with a few "union kinds": +// +// X = Array | ChunkedArray +// +// Examples: +// +// TakeXA = {TakeAAA, TakeCAC}, +// TakeXX = {TakeAAA, TakeACC, TakeCAC, TakeCCC} +namespace { -void AssertTakeArrays(const std::shared_ptr& values, - const std::shared_ptr& indices, - const std::shared_ptr& expected) { - ASSERT_OK_AND_ASSIGN(std::shared_ptr actual, Take(*values, *indices)); - ValidateOutput(actual); - AssertArraysEqual(*expected, *actual, /*verbose=*/true); +Result> TakeAAA(const Array& values, const Array& indices) { + ARROW_ASSIGN_OR_RAISE(Datum out, Take(Datum(values), Datum(indices))); + return out.make_array(); } -Status TakeJSON(const std::shared_ptr& type, const std::string& values, - const std::shared_ptr& index_type, const std::string& indices, - std::shared_ptr* out) { - return Take(*ArrayFromJSON(type, values), *ArrayFromJSON(index_type, indices)) - .Value(out); +Result> TakeAAA( + const std::shared_ptr& type, const std::string& values, + const std::string& indices, const std::shared_ptr& index_type = int32()) { + return TakeAAA(*ArrayFromJSON(type, values), *ArrayFromJSON(index_type, indices)); } -void DoCheckTake(const std::shared_ptr& values, - const std::shared_ptr& indices, - const std::shared_ptr& expected) { - AssertTakeArrays(values, indices, expected); +// TakeACC is never tested directly, so it is not defined here + +Result TakeCAC(std::shared_ptr values, + std::shared_ptr indices) { + return Take(Datum{std::move(values)}, Datum{std::move(indices)}); +} + +Result TakeCAC(const std::shared_ptr& type, + const std::vector& values, const std::string& indices, + const std::shared_ptr& index_type = int8()) { + return TakeCAC(ChunkedArrayFromJSON(type, values), ArrayFromJSON(index_type, indices)); +} + +Result TakeCCC(std::shared_ptr values, + std::shared_ptr indices) { + return Take(Datum{std::move(values)}, Datum{std::move(indices)}); +} + +Result TakeCCC(const std::shared_ptr& type, + const std::vector& values, + const std::vector& indices) { + return TakeCCC(ChunkedArrayFromJSON(type, values), + ChunkedArrayFromJSON(int8(), indices)); +} + +Result TakeRAR(const std::shared_ptr& schm, const std::string& batch_json, + const std::string& indices, + const std::shared_ptr& index_type = int8()) { + auto batch = RecordBatchFromJSON(schm, batch_json); + return Take(Datum{std::move(batch)}, Datum{ArrayFromJSON(index_type, indices)}); +} + +Result TakeTAT(const std::shared_ptr& schm, + const std::vector& values, const std::string& indices, + const std::shared_ptr& index_type = int8()) { + return Take(Datum{TableFromJSON(schm, values)}, + Datum{ArrayFromJSON(index_type, indices)}); +} + +Result TakeTCT(const std::shared_ptr& schm, + const std::vector& values, + const std::vector& indices) { + return Take(Datum{TableFromJSON(schm, values)}, + Datum{ChunkedArrayFromJSON(int8(), indices)}); +} + +// Assert helpers for Take tests + +void DoAssertTakeAAA(const std::shared_ptr& values, + const std::shared_ptr& indices, + const std::shared_ptr& expected) { + ASSERT_OK_AND_ASSIGN(std::shared_ptr actual, TakeAAA(*values, *indices)); + ValidateOutput(actual); + AssertArraysEqual(*expected, *actual, /*verbose=*/true); +} + +void DoCheckTakeAAA(const std::shared_ptr& values, + const std::shared_ptr& indices, + const std::shared_ptr& expected) { + DoAssertTakeAAA(values, indices, expected); // Check sliced values ASSERT_OK_AND_ASSIGN(auto values_filler, MakeArrayOfNull(values->type(), 2)); ASSERT_OK_AND_ASSIGN(auto values_sliced, Concatenate({values_filler, values, values_filler})); values_sliced = values_sliced->Slice(2, values->length()); - AssertTakeArrays(values_sliced, indices, expected); + DoAssertTakeAAA(values_sliced, indices, expected); // Check sliced indices ASSERT_OK_AND_ASSIGN(auto zero, MakeScalar(indices->type(), int8_t{0})); @@ -1135,33 +1217,171 @@ void DoCheckTake(const std::shared_ptr& values, ASSERT_OK_AND_ASSIGN(auto indices_sliced, Concatenate({indices_filler, indices, indices_filler})); indices_sliced = indices_sliced->Slice(3, indices->length()); - AssertTakeArrays(values, indices_sliced, expected); -} - -void CheckTake(const std::shared_ptr& type, const std::string& values_json, - const std::string& indices_json, const std::string& expected_json) { + DoAssertTakeAAA(values, indices_sliced, expected); +} + +void DoCheckTakeCACWithArrays(const std::shared_ptr& values, + const std::shared_ptr& indices, + const std::shared_ptr& expected) { + auto pool = default_memory_pool(); + const bool indices_null_count_is_known = indices->null_count() != kUnknownNullCount; + + // We check TakeCAC by checking this equality: + // + // TakeAAA(Concat(V, V, V), I') == Concat(TakeCAC([V, V, V], I')) + // where + // V = values + // I = indices + // I' = Concat(I + 2 * V.length, I, I + V.length) + auto values3 = ArrayVector{values, values, values}; + ASSERT_OK_AND_ASSIGN(auto concat_values3, Concatenate(values3, pool)); + auto chunked_values3 = std::make_shared(values3); + std::shared_ptr concat_indices3; + { + auto double_length = + MakeScalar(indices->type(), static_cast(2 * values->length())); + auto zero = MakeScalar(indices->type(), 0); + auto length = MakeScalar(indices->type(), static_cast(values->length())); + ASSERT_OK_AND_ASSIGN(auto indices_prefix, Add(indices, *double_length)); + ASSERT_OK_AND_ASSIGN(auto indices_middle, Add(indices, *zero)); + ASSERT_OK_AND_ASSIGN(auto indices_suffix, Add(indices, *length)); + auto indices3 = ArrayVector{ + indices_prefix.make_array(), + indices_middle.make_array(), + indices_suffix.make_array(), + }; + ASSERT_OK_AND_ASSIGN(concat_indices3, Concatenate(indices3, pool)); + // Preserve the fact that indices->null_count() is unknown if it is unknown. + if (!indices_null_count_is_known) { + concat_indices3->data()->null_count = kUnknownNullCount; + } + } + ASSERT_OK_AND_ASSIGN(auto concat_expected3, + Concatenate({expected, expected, expected})); + ASSERT_OK_AND_ASSIGN(Datum chunked_actual, TakeCAC(chunked_values3, concat_indices3)); + ValidateOutput(chunked_actual); + ASSERT_OK_AND_ASSIGN(auto concat_actual, + Concatenate(chunked_actual.chunked_array()->chunks())); + AssertArraysEqual(*concat_expected3, *concat_actual, /*verbose=*/true); + + // We check TakeCAC again by checking this equality: + // + // TakeAAA(V, I) == Concat(TakeCAC(C, I)) + // where + // K = V.length // 4 + // C = [V.slice(0, K), V.slice(K, 2*K), V.slice(3*K, N - 3*K)] + // V = values + // I = indices + const int64_t n = values->length(); + const int64_t k = n / 4; + if (k > 0) { + auto value_slices = ArrayVector{values->Slice(0, k), values->Slice(k, 2 * k), + values->Slice(3 * k, n - k)}; + auto chunked_values = std::make_shared(value_slices); + ASSERT_OK_AND_ASSIGN(chunked_actual, TakeCAC(chunked_values, indices)); + ValidateOutput(chunked_actual); + ASSERT_OK_AND_ASSIGN(concat_actual, + Concatenate(chunked_actual.chunked_array()->chunks())); + AssertArraysEqual(*concat_actual, *expected, /*verbose=*/true); + } +} + +// TakeXA = {TakeAAA, TakeCAC} +void DoCheckTakeXA(const std::shared_ptr& values, + const std::shared_ptr& indices, + const std::shared_ptr& expected) { + DoCheckTakeAAA(values, indices, expected); + DoCheckTakeCACWithArrays(values, indices, expected); +} + +// TakeXA = {TakeAAA, TakeCAC} +void CheckTakeXA(const std::shared_ptr& type, const std::string& values_json, + const std::string& indices_json, const std::string& expected_json) { auto values = ArrayFromJSON(type, values_json); auto expected = ArrayFromJSON(type, expected_json); for (auto index_type : {int8(), uint32()}) { auto indices = ArrayFromJSON(index_type, indices_json); - DoCheckTake(values, indices, expected); + DoCheckTakeXA(values, indices, expected); } } -void AssertTakeNull(const std::string& values, const std::string& indices, - const std::string& expected) { - CheckTake(null(), values, indices, expected); +void CheckTakeXADictionary(std::shared_ptr value_type, + const std::string& dictionary_values, + const std::string& dictionary_indices, + const std::string& indices, + const std::string& expected_indices) { + auto dict = ArrayFromJSON(value_type, dictionary_values); + auto type = dictionary(int8(), value_type); + ASSERT_OK_AND_ASSIGN( + auto values, + DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), dictionary_indices), dict)); + ASSERT_OK_AND_ASSIGN( + auto expected, + DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), expected_indices), dict)); + auto take_indices = ArrayFromJSON(int8(), indices); + DoCheckTakeXA(values, take_indices, expected); } -void AssertTakeBoolean(const std::string& values, const std::string& indices, - const std::string& expected) { - CheckTake(boolean(), values, indices, expected); +void AssertTakeCAC(const std::shared_ptr& type, + const std::vector& values, const std::string& indices, + const std::vector& expected) { + ASSERT_OK_AND_ASSIGN(auto actual, TakeCAC(type, values, indices)); + ValidateOutput(actual); + AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual.chunked_array()); } +void AssertTakeCCC(const std::shared_ptr& type, + const std::vector& values, + const std::vector& indices, + const std::vector& expected) { + ASSERT_OK_AND_ASSIGN(auto actual, TakeCCC(type, values, indices)); + ValidateOutput(actual); + AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual.chunked_array()); +} + +void CheckTakeXCC(const Datum& values, const std::vector& indices, + const std::vector& expected) { + EXPECT_TRUE(values.is_array() || values.is_chunked_array()); + auto idx = ChunkedArrayFromJSON(int32(), indices); + ASSERT_OK_AND_ASSIGN(auto actual, Take(values, Datum{idx})); + ValidateOutput(actual); + AssertChunkedEqual(*ChunkedArrayFromJSON(values.type(), expected), + *actual.chunked_array()); +} + +void AssertTakeRAR(const std::shared_ptr& schm, const std::string& batch_json, + const std::string& indices, const std::string& expected_batch) { + for (auto index_type : {int8(), uint32()}) { + ASSERT_OK_AND_ASSIGN(auto actual, TakeRAR(schm, batch_json, indices, index_type)); + ValidateOutput(actual); + ASSERT_BATCHES_EQUAL(*RecordBatchFromJSON(schm, expected_batch), + *actual.record_batch()); + } +} + +void AssertTakeTAT(const std::shared_ptr& schm, + const std::vector& table_json, const std::string& filter, + const std::vector& expected_table) { + ASSERT_OK_AND_ASSIGN(auto actual, TakeTAT(schm, table_json, filter)); + ValidateOutput(actual); + ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual.table()); +} + +void AssertTakeTCT(const std::shared_ptr& schm, + const std::vector& table_json, + const std::vector& filter, + const std::vector& expected_table) { + ASSERT_OK_AND_ASSIGN(auto actual, TakeTCT(schm, table_json, filter)); + ValidateOutput(actual); + ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual.table()); +} + +// Validators used by random data tests + template -void ValidateTakeImpl(const std::shared_ptr& values, - const std::shared_ptr& indices, - const std::shared_ptr& result) { +void ValidateTakeXAImpl(const std::shared_ptr& values, + const std::shared_ptr& indices, + const std::shared_ptr& result) { using ValuesArrayType = typename TypeTraits::ArrayType; using IndexArrayType = typename TypeTraits::ArrayType; auto typed_values = checked_pointer_cast(values); @@ -1185,39 +1405,45 @@ void ValidateTakeImpl(const std::shared_ptr& values, << i; } } + // DoCheckTakeCACWithArrays transforms the indices which has a risk of + // overflow, so we only call it if the index type is not too wide. + if (indices->type()->byte_width() <= 4) { + auto cast_options = CastOptions::Safe(TypeHolder{int64()}); + ASSERT_OK_AND_ASSIGN(auto indices64, Cast(indices, cast_options)); + DoCheckTakeCACWithArrays(values, indices64.make_array(), /*expected=*/result); + } } template -void ValidateTake(const std::shared_ptr& values, - const std::shared_ptr& indices) { - ASSERT_OK_AND_ASSIGN(Datum out, Take(values, indices)); - auto taken = out.make_array(); +void ValidateTakeXA(const std::shared_ptr& values, + const std::shared_ptr& indices) { + ASSERT_OK_AND_ASSIGN(auto taken, TakeAAA(*values, *indices)); ValidateOutput(taken); ASSERT_EQ(indices->length(), taken->length()); switch (indices->type_id()) { case Type::INT8: - ValidateTakeImpl(values, indices, taken); + ValidateTakeXAImpl(values, indices, taken); break; case Type::INT16: - ValidateTakeImpl(values, indices, taken); + ValidateTakeXAImpl(values, indices, taken); break; case Type::INT32: - ValidateTakeImpl(values, indices, taken); + ValidateTakeXAImpl(values, indices, taken); break; case Type::INT64: - ValidateTakeImpl(values, indices, taken); + ValidateTakeXAImpl(values, indices, taken); break; case Type::UINT8: - ValidateTakeImpl(values, indices, taken); + ValidateTakeXAImpl(values, indices, taken); break; case Type::UINT16: - ValidateTakeImpl(values, indices, taken); + ValidateTakeXAImpl(values, indices, taken); break; case Type::UINT32: - ValidateTakeImpl(values, indices, taken); + ValidateTakeXAImpl(values, indices, taken); break; case Type::UINT64: - ValidateTakeImpl(values, indices, taken); + ValidateTakeXAImpl(values, indices, taken); break; default: FAIL() << "Invalid index type"; @@ -1225,6 +1451,8 @@ void ValidateTake(const std::shared_ptr& values, } } +// ---- + template T GetMaxIndex(int64_t values_length) { int64_t max_index = values_length - 1; @@ -1239,13 +1467,15 @@ uint64_t GetMaxIndex(int64_t values_length) { return static_cast(values_length - 1); } +} // namespace + class TestTakeKernel : public ::testing::Test { - public: - void TestNoValidityBitmapButUnknownNullCount(const std::shared_ptr& values, - const std::shared_ptr& indices) { + private: + void DoTestNoValidityBitmapButUnknownNullCount(const std::shared_ptr& values, + const std::shared_ptr& indices) { ASSERT_EQ(values->null_count(), 0); ASSERT_EQ(indices->null_count(), 0); - auto expected = (*Take(values, indices)).make_array(); + ASSERT_OK_AND_ASSIGN(auto expected, TakeAAA(*values, *indices)); auto new_values = MakeArray(values->data()->Copy()); new_values->data()->buffers[0].reset(); @@ -1253,67 +1483,95 @@ class TestTakeKernel : public ::testing::Test { auto new_indices = MakeArray(indices->data()->Copy()); new_indices->data()->buffers[0].reset(); new_indices->data()->null_count = kUnknownNullCount; - auto result = (*Take(new_values, new_indices)).make_array(); - - AssertArraysEqual(*expected, *result); + DoCheckTakeXA(new_values, new_indices, expected); } - void TestNoValidityBitmapButUnknownNullCount(const std::shared_ptr& type, - const std::string& values, - const std::string& indices) { - TestNoValidityBitmapButUnknownNullCount(ArrayFromJSON(type, values), - ArrayFromJSON(int16(), indices)); + public: + void DoTestNoValidityBitmapButUnknownNullCount( + const std::shared_ptr& type, const std::string& values, + const std::string& indices, std::shared_ptr index_type = int8()) { + DoTestNoValidityBitmapButUnknownNullCount(ArrayFromJSON(type, values), + ArrayFromJSON(index_type, indices)); } void TestNumericBasics(const std::shared_ptr& type) { ARROW_SCOPED_TRACE("type = ", *type); - CheckTake(type, "[7, 8, 9]", "[]", "[]"); - CheckTake(type, "[7, 8, 9]", "[0, 1, 0]", "[7, 8, 7]"); - CheckTake(type, "[null, 8, 9]", "[0, 1, 0]", "[null, 8, null]"); - CheckTake(type, "[7, 8, 9]", "[null, 1, 0]", "[null, 8, 7]"); - CheckTake(type, "[null, 8, 9]", "[]", "[]"); - CheckTake(type, "[7, 8, 9]", "[0, 0, 0, 0, 0, 0, 2]", "[7, 7, 7, 7, 7, 7, 9]"); - + CheckTakeXA(type, "[7, 8, 9]", "[]", "[]"); + CheckTakeXA(type, "[7, 8, 9]", "[0, 1, 0]", "[7, 8, 7]"); + CheckTakeXA(type, "[null, 8, 9]", "[0, 1, 0]", "[null, 8, null]"); + CheckTakeXA(type, "[7, 8, 9]", "[null, 1, 0]", "[null, 8, 7]"); + CheckTakeXA(type, "[null, 8, 9]", "[]", "[]"); + CheckTakeXA(type, "[7, 8, 9]", "[0, 0, 0, 0, 0, 0, 2]", "[7, 7, 7, 7, 7, 7, 9]"); + + const std::string k789 = "[7, 8, 9]"; std::shared_ptr arr; - ASSERT_RAISES(IndexError, TakeJSON(type, "[7, 8, 9]", int8(), "[0, 9, 0]", &arr)); - ASSERT_RAISES(IndexError, TakeJSON(type, "[7, 8, 9]", int8(), "[0, -1, 0]", &arr)); + ASSERT_RAISES(IndexError, TakeAAA(type, k789, "[0, 9, 0]").Value(&arr)); + ASSERT_RAISES(IndexError, TakeAAA(type, k789, "[0, -1, 0]").Value(&arr)); + Datum chunked_arr; + ASSERT_RAISES(IndexError, + TakeCAC(type, {k789, k789}, "[0, 9, 0]").Value(&chunked_arr)); + ASSERT_RAISES(IndexError, + TakeCAC(type, {k789, k789}, "[0, -1, 0]").Value(&chunked_arr)); } }; template -class TestTakeKernelTyped : public TestTakeKernel {}; +class TestTakeKernelTyped : public TestTakeKernel { + protected: + virtual std::shared_ptr value_type() const { + if constexpr (is_parameter_free_type::value) { + return TypeTraits::type_singleton(); + } else { + EXPECT_TRUE(false) << "value_type() must be overridden for parameterized types"; + return nullptr; + } + } + + void TestNoValidityBitmapButUnknownNullCount( + const std::string& values, const std::string& indices, + const std::shared_ptr& index_type = int8()) { + return DoTestNoValidityBitmapButUnknownNullCount(this->value_type(), values, indices, + index_type); + } + + void CheckTakeXA(const std::string& values, const std::string& indices, + const std::string& expected) { + compute::CheckTakeXA(this->value_type(), values, indices, expected); + } +}; + +static const char kNull3[] = "[null, null, null]"; TEST_F(TestTakeKernel, TakeNull) { - AssertTakeNull("[null, null, null]", "[0, 1, 0]", "[null, null, null]"); - AssertTakeNull("[null, null, null]", "[0, 2]", "[null, null]"); + CheckTakeXA(null(), kNull3, "[0, 1, 0]", "[null, null, null]"); + CheckTakeXA(null(), kNull3, "[0, 2]", "[null, null]"); std::shared_ptr arr; + ASSERT_RAISES(IndexError, TakeAAA(null(), kNull3, "[0, 9, 0]").Value(&arr)); + ASSERT_RAISES(IndexError, TakeAAA(boolean(), kNull3, "[0, -1, 0]").Value(&arr)); + Datum chunked_arr; ASSERT_RAISES(IndexError, - TakeJSON(null(), "[null, null, null]", int8(), "[0, 9, 0]", &arr)); + TakeCAC(null(), {kNull3, kNull3}, "[0, 9, 0]").Value(&chunked_arr)); ASSERT_RAISES(IndexError, - TakeJSON(boolean(), "[null, null, null]", int8(), "[0, -1, 0]", &arr)); + TakeCAC(boolean(), {kNull3, kNull3}, "[0, -1, 0]").Value(&chunked_arr)); } TEST_F(TestTakeKernel, InvalidIndexType) { std::shared_ptr arr; - ASSERT_RAISES(NotImplemented, TakeJSON(null(), "[null, null, null]", float32(), - "[0.0, 1.0, 0.1]", &arr)); + ASSERT_RAISES(NotImplemented, + TakeAAA(null(), kNull3, "[0.0, 1.0, 0.1]", float32()).Value(&arr)); + Datum chunked_arr; + ASSERT_RAISES(NotImplemented, + TakeCAC(null(), {kNull3, kNull3}, "[0.0, 1.0, 0.1]", float32()) + .Value(&chunked_arr)); } -TEST_F(TestTakeKernel, TakeCCEmptyIndices) { - Datum dat = ChunkedArrayFromJSON(int8(), {"[]"}); - Datum idx = ChunkedArrayFromJSON(int32(), {}); - ASSERT_OK_AND_ASSIGN(auto out, Take(dat, idx)); - ValidateOutput(out); - AssertDatumsEqual(ChunkedArrayFromJSON(int8(), {"[]"}), out, true); -} - -TEST_F(TestTakeKernel, TakeACEmptyIndices) { - Datum dat = ArrayFromJSON(int8(), {"[]"}); - Datum idx = ChunkedArrayFromJSON(int32(), {}); - ASSERT_OK_AND_ASSIGN(auto out, Take(dat, idx)); - ValidateOutput(out); - AssertDatumsEqual(ChunkedArrayFromJSON(int8(), {"[]"}), out, true); +TEST_F(TestTakeKernel, TakeXCCEmptyIndices) { + auto expected = std::vector{"[]"}; + auto values = ArrayFromJSON(int8(), {"[1, 3, 3, 7]"}); + CheckTakeXCC(values, {"[]"}, expected); + auto chunked_values = std::make_shared(values); + CheckTakeXCC(chunked_values, {"[]"}, expected); } TEST_F(TestTakeKernel, DefaultOptions) { @@ -1329,18 +1587,25 @@ TEST_F(TestTakeKernel, DefaultOptions) { } TEST_F(TestTakeKernel, TakeBoolean) { - AssertTakeBoolean("[7, 8, 9]", "[]", "[]"); - AssertTakeBoolean("[true, false, true]", "[0, 1, 0]", "[true, false, true]"); - AssertTakeBoolean("[null, false, true]", "[0, 1, 0]", "[null, false, null]"); - AssertTakeBoolean("[true, false, true]", "[null, 1, 0]", "[null, false, true]"); + CheckTakeXA(boolean(), "[7, 8, 9]", "[]", "[]"); + CheckTakeXA(boolean(), "[true, false, true]", "[0, 1, 0]", "[true, false, true]"); + CheckTakeXA(boolean(), "[null, false, true]", "[0, 1, 0]", "[null, false, null]"); + CheckTakeXA(boolean(), "[true, false, true]", "[null, 1, 0]", "[null, false, true]"); - TestNoValidityBitmapButUnknownNullCount(boolean(), "[true, false, true]", "[1, 0, 0]"); + DoTestNoValidityBitmapButUnknownNullCount(boolean(), "[true, false, true]", + "[1, 0, 0]"); + const std::string kTrueFalseTrue = "[true, false, true]"; std::shared_ptr arr; + ASSERT_RAISES(IndexError, TakeAAA(boolean(), kTrueFalseTrue, "[0, 9, 0]").Value(&arr)); + ASSERT_RAISES(IndexError, TakeAAA(boolean(), kTrueFalseTrue, "[0, -1, 0]").Value(&arr)); + Datum chunked_arr; ASSERT_RAISES(IndexError, - TakeJSON(boolean(), "[true, false, true]", int8(), "[0, 9, 0]", &arr)); + TakeCAC(boolean(), {kTrueFalseTrue, kTrueFalseTrue}, "[0, 9, 0]") + .Value(&chunked_arr)); ASSERT_RAISES(IndexError, - TakeJSON(boolean(), "[true, false, true]", int8(), "[0, -1, 0]", &arr)); + TakeCAC(boolean(), {kTrueFalseTrue, kTrueFalseTrue}, "[0, -1, 0]") + .Value(&chunked_arr)); } TEST_F(TestTakeKernel, Temporal) { @@ -1349,8 +1614,8 @@ TEST_F(TestTakeKernel, Temporal) { this->TestNumericBasics(timestamp(TimeUnit::NANO, "Europe/Paris")); this->TestNumericBasics(duration(TimeUnit::SECOND)); this->TestNumericBasics(date32()); - CheckTake(date64(), "[0, 86400000, null]", "[null, 1, 1, 0]", - "[null, 86400000, 86400000, 0]"); + CheckTakeXA(date64(), "[0, 86400000, null]", "[null, 1, 1, 0]", + "[null, 86400000, 86400000, 0]"); } TEST_F(TestTakeKernel, Duration) { @@ -1363,177 +1628,184 @@ TEST_F(TestTakeKernel, Interval) { this->TestNumericBasics(month_interval()); auto type = day_time_interval(); - CheckTake(type, "[[1, -600], [2, 3000], null]", "[0, null, 2, 1]", - "[[1, -600], null, null, [2, 3000]]"); + CheckTakeXA(type, "[[1, -600], [2, 3000], null]", "[0, null, 2, 1]", + "[[1, -600], null, null, [2, 3000]]"); type = month_day_nano_interval(); - CheckTake(type, "[[1, -2, 34567890123456789], [2, 3, -34567890123456789], null]", - "[0, null, 2, 1]", - "[[1, -2, 34567890123456789], null, null, [2, 3, -34567890123456789]]"); + CheckTakeXA(type, "[[1, -2, 34567890123456789], [2, 3, -34567890123456789], null]", + "[0, null, 2, 1]", + "[[1, -2, 34567890123456789], null, null, [2, 3, -34567890123456789]]"); } template -class TestTakeKernelWithNumeric : public TestTakeKernelTyped { - protected: - void AssertTake(const std::string& values, const std::string& indices, - const std::string& expected) { - CheckTake(type_singleton(), values, indices, expected); - } - - std::shared_ptr type_singleton() { - return TypeTraits::type_singleton(); - } -}; +class TestTakeKernelWithNumeric : public TestTakeKernelTyped {}; TYPED_TEST_SUITE(TestTakeKernelWithNumeric, NumericArrowTypes); TYPED_TEST(TestTakeKernelWithNumeric, TakeNumeric) { - this->TestNumericBasics(this->type_singleton()); + this->TestNumericBasics(this->value_type()); } template class TestTakeKernelWithString : public TestTakeKernelTyped { public: - std::shared_ptr value_type() { - return TypeTraits::type_singleton(); - } - - void AssertTake(const std::string& values, const std::string& indices, - const std::string& expected) { - CheckTake(value_type(), values, indices, expected); - } - - void AssertTakeDictionary(const std::string& dictionary_values, - const std::string& dictionary_indices, - const std::string& indices, - const std::string& expected_indices) { - auto dict = ArrayFromJSON(value_type(), dictionary_values); - auto type = dictionary(int8(), value_type()); - ASSERT_OK_AND_ASSIGN(auto values, - DictionaryArray::FromArrays( - type, ArrayFromJSON(int8(), dictionary_indices), dict)); - ASSERT_OK_AND_ASSIGN( - auto expected, - DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), expected_indices), dict)); - auto take_indices = ArrayFromJSON(int8(), indices); - AssertTakeArrays(values, take_indices, expected); + void AssertTakeXADictionary(const std::string& dictionary_values, + const std::string& dictionary_indices, + const std::string& indices, + const std::string& expected_indices) { + return CheckTakeXADictionary(this->value_type(), dictionary_values, + dictionary_indices, indices, expected_indices); } }; TYPED_TEST_SUITE(TestTakeKernelWithString, BaseBinaryArrowTypes); TYPED_TEST(TestTakeKernelWithString, TakeString) { - this->AssertTake(R"(["a", "b", "c"])", "[0, 1, 0]", R"(["a", "b", "a"])"); - this->AssertTake(R"([null, "b", "c"])", "[0, 1, 0]", "[null, \"b\", null]"); - this->AssertTake(R"(["a", "b", "c"])", "[null, 1, 0]", R"([null, "b", "a"])"); + this->CheckTakeXA(R"(["a", "b", "c"])", "[0, 1, 0]", R"(["a", "b", "a"])"); + this->CheckTakeXA(R"([null, "b", "c"])", "[0, 1, 0]", "[null, \"b\", null]"); + this->CheckTakeXA(R"(["a", "b", "c"])", "[null, 1, 0]", R"([null, "b", "a"])"); - this->TestNoValidityBitmapButUnknownNullCount(this->value_type(), R"(["a", "b", "c"])", - "[0, 1, 0]"); + this->TestNoValidityBitmapButUnknownNullCount(R"(["a", "b", "c"])", "[0, 1, 0]"); std::shared_ptr type = this->value_type(); + const std::string kABC = R"(["a", "b", "c"])"; std::shared_ptr arr; - ASSERT_RAISES(IndexError, - TakeJSON(type, R"(["a", "b", "c"])", int8(), "[0, 9, 0]", &arr)); - ASSERT_RAISES(IndexError, TakeJSON(type, R"(["a", "b", null, "ddd", "ee"])", int64(), - "[2, 5]", &arr)); + ASSERT_RAISES(IndexError, TakeAAA(type, kABC, "[0, 9, 0]").Value(&arr)); + ASSERT_RAISES(IndexError, TakeAAA(type, kABC, "[2, 5]").Value(&arr)); + Datum chunked_arr; + ASSERT_RAISES(IndexError, TakeCAC(type, {kABC, kABC}, "[0, 9, 0]").Value(&chunked_arr)); + ASSERT_RAISES(IndexError, TakeCAC(type, {kABC, kABC}, "[4, 10]").Value(&chunked_arr)); } TYPED_TEST(TestTakeKernelWithString, TakeDictionary) { auto dict = R"(["a", "b", "c", "d", "e"])"; - this->AssertTakeDictionary(dict, "[3, 4, 2]", "[0, 1, 0]", "[3, 4, 3]"); - this->AssertTakeDictionary(dict, "[null, 4, 2]", "[0, 1, 0]", "[null, 4, null]"); - this->AssertTakeDictionary(dict, "[3, 4, 2]", "[null, 1, 0]", "[null, 4, 3]"); + this->AssertTakeXADictionary(dict, "[3, 4, 2]", "[0, 1, 0]", "[3, 4, 3]"); + this->AssertTakeXADictionary(dict, "[null, 4, 2]", "[0, 1, 0]", "[null, 4, null]"); + this->AssertTakeXADictionary(dict, "[3, 4, 2]", "[null, 1, 0]", "[null, 4, 3]"); } class TestTakeKernelFSB : public TestTakeKernelTyped { public: - std::shared_ptr value_type() { return fixed_size_binary(3); } - - void AssertTake(const std::string& values, const std::string& indices, - const std::string& expected) { - CheckTake(value_type(), values, indices, expected); - } + std::shared_ptr value_type() const override { return fixed_size_binary(3); } }; TEST_F(TestTakeKernelFSB, TakeFixedSizeBinary) { - this->AssertTake(R"(["aaa", "bbb", "ccc"])", "[0, 1, 0]", R"(["aaa", "bbb", "aaa"])"); - this->AssertTake(R"([null, "bbb", "ccc"])", "[0, 1, 0]", "[null, \"bbb\", null]"); - this->AssertTake(R"(["aaa", "bbb", "ccc"])", "[null, 1, 0]", R"([null, "bbb", "aaa"])"); + const std::string kABC = R"(["aaa", "bbb", "ccc"])"; + this->CheckTakeXA(kABC, "[0, 1, 0]", R"(["aaa", "bbb", "aaa"])"); + this->CheckTakeXA(R"([null, "bbb", "ccc"])", "[0, 1, 0]", "[null, \"bbb\", null]"); + this->CheckTakeXA(kABC, "[null, 1, 0]", R"([null, "bbb", "aaa"])"); - this->TestNoValidityBitmapButUnknownNullCount(this->value_type(), - R"(["aaa", "bbb", "ccc"])", "[0, 1, 0]"); + this->TestNoValidityBitmapButUnknownNullCount(kABC, "[0, 1, 0]"); std::shared_ptr type = this->value_type(); + const std::string kABNullDE = R"(["aaa", "bbb", null, "ddd", "eee"])"; std::shared_ptr arr; + ASSERT_RAISES(IndexError, TakeAAA(type, kABC, "[0, 9, 0]").Value(&arr)); + ASSERT_RAISES(IndexError, TakeAAA(type, kABNullDE, "[2, 5]").Value(&arr)); + Datum chunked_arr; + ASSERT_RAISES(IndexError, TakeCAC(type, {kABC, kABC}, "[0, 9, 0]").Value(&chunked_arr)); ASSERT_RAISES(IndexError, - TakeJSON(type, R"(["aaa", "bbb", "ccc"])", int8(), "[0, 9, 0]", &arr)); - ASSERT_RAISES(IndexError, TakeJSON(type, R"(["aaa", "bbb", null, "ddd", "eee"])", - int64(), "[2, 5]", &arr)); + TakeCAC(type, {kABNullDE, kABC}, "[4, 10]").Value(&chunked_arr)); } -class TestTakeKernelWithList : public TestTakeKernelTyped {}; +using ListAndListViewArrowTypes = + ::testing::Types; + +template +class TestTakeKernelWithList : public TestTakeKernelTyped { + protected: + std::shared_ptr inner_type_ = nullptr; + + std::shared_ptr value_type(std::shared_ptr inner_type) const { + return std::make_shared(std::move(inner_type)); + } + + std::shared_ptr value_type() const override { + EXPECT_TRUE(inner_type_); + return value_type(inner_type_); + } + + std::vector> InnerListTypes() const { + return std::vector>{ + list(int32()), + large_list(int32()), + list_view(int32()), + large_list_view(int32()), + }; + } +}; + +TYPED_TEST_SUITE(TestTakeKernelWithList, ListAndListViewArrowTypes); -TEST_F(TestTakeKernelWithList, TakeListInt32) { +TYPED_TEST(TestTakeKernelWithList, TakeListInt32) { + this->inner_type_ = int32(); std::string list_json = "[[], [1,2], null, [3]]"; - for (auto& type : kListAndListViewTypes) { - CheckTake(type, list_json, "[]", "[]"); - CheckTake(type, list_json, "[3, 2, 1]", "[[3], null, [1,2]]"); - CheckTake(type, list_json, "[null, 3, 0]", "[null, [3], []]"); - CheckTake(type, list_json, "[null, null]", "[null, null]"); - CheckTake(type, list_json, "[3, 0, 0, 3]", "[[3], [], [], [3]]"); - CheckTake(type, list_json, "[0, 1, 2, 3]", list_json); - CheckTake(type, list_json, "[0, 0, 0, 0, 0, 0, 1]", - "[[], [], [], [], [], [], [1, 2]]"); + { + this->CheckTakeXA(list_json, "[]", "[]"); + this->CheckTakeXA(list_json, "[3, 2, 1]", "[[3], null, [1,2]]"); + this->CheckTakeXA(list_json, "[null, 3, 0]", "[null, [3], []]"); + this->CheckTakeXA(list_json, "[null, null]", "[null, null]"); + this->CheckTakeXA(list_json, "[3, 0, 0, 3]", "[[3], [], [], [3]]"); + this->CheckTakeXA(list_json, "[0, 1, 2, 3]", list_json); + this->CheckTakeXA(list_json, "[0, 0, 0, 0, 0, 0, 1]", + "[[], [], [], [], [], [], [1, 2]]"); - this->TestNoValidityBitmapButUnknownNullCount(type, "[[], [1,2], [3]]", "[0, 1, 0]"); + this->TestNoValidityBitmapButUnknownNullCount("[[], [1,2], [3]]", "[0, 1, 0]"); } } -TEST_F(TestTakeKernelWithList, TakeListListInt32) { +TYPED_TEST(TestTakeKernelWithList, TakeListListInt32) { std::string list_json = R"([ [], [[1], [2, null, 2], []], null, [[3, null], null] ])"; - for (auto& type : kNestedListAndListViewTypes) { - ARROW_SCOPED_TRACE("type = ", *type); - CheckTake(type, list_json, "[]", "[]"); - CheckTake(type, list_json, "[3, 2, 1]", R"([ + for (auto& inner_type : this->InnerListTypes()) { + this->inner_type_ = inner_type; + ARROW_SCOPED_TRACE("type = ", *this->value_type()); + this->CheckTakeXA(list_json, "[]", "[]"); + this->CheckTakeXA(list_json, "[3, 2, 1]", R"([ [[3, null], null], null, [[1], [2, null, 2], []] ])"); - CheckTake(type, list_json, "[null, 3, 0]", R"([ + this->CheckTakeXA(list_json, "[null, 3, 0]", R"([ null, [[3, null], null], [] ])"); - CheckTake(type, list_json, "[null, null]", "[null, null]"); - CheckTake(type, list_json, "[3, 0, 0, 3]", - "[[[3, null], null], [], [], [[3, null], null]]"); - CheckTake(type, list_json, "[0, 1, 2, 3]", list_json); - CheckTake(type, list_json, "[0, 0, 0, 0, 0, 0, 1]", - "[[], [], [], [], [], [], [[1], [2, null, 2], []]]"); + this->CheckTakeXA(list_json, "[null, null]", "[null, null]"); + this->CheckTakeXA(list_json, "[3, 0, 0, 3]", + "[[[3, null], null], [], [], [[3, null], null]]"); + this->CheckTakeXA(list_json, "[0, 1, 2, 3]", list_json); + this->CheckTakeXA(list_json, "[0, 0, 0, 0, 0, 0, 1]", + "[[], [], [], [], [], [], [[1], [2, null, 2], []]]"); this->TestNoValidityBitmapButUnknownNullCount( - type, "[[[1], [2, null, 2], []], [[3, null]]]", "[0, 1, 0]"); + "[[[1], [2, null, 2], []], [[3, null]]]", "[0, 1, 0]"); } } -class TestTakeKernelWithLargeList : public TestTakeKernelTyped {}; - -TEST_F(TestTakeKernelWithLargeList, TakeLargeListInt32) { +TYPED_TEST(TestTakeKernelWithList, TakeLargeListInt32) { + this->inner_type_ = int32(); std::string list_json = "[[], [1,2], null, [3]]"; - for (auto& type : kLargeListAndListViewTypes) { - ARROW_SCOPED_TRACE("type = ", *type); - CheckTake(type, list_json, "[]", "[]"); - CheckTake(type, list_json, "[null, 1, 2, 0]", "[null, [1,2], null, []]"); + { + ARROW_SCOPED_TRACE("type = ", *this->value_type()); + this->CheckTakeXA(list_json, "[]", "[]"); + this->CheckTakeXA(list_json, "[null, 1, 2, 0]", "[null, [1,2], null, []]"); } } class TestTakeKernelWithFixedSizeList : public TestTakeKernelTyped { protected: - void CheckTakeOnNestedLists(const std::shared_ptr& inner_type, - const std::vector& list_sizes, int64_t length) { + std::shared_ptr inner_type_ = nullptr; + + std::shared_ptr value_type() const override { + EXPECT_TRUE(inner_type_); + return fixed_size_list(inner_type_, 3); + } + + void CheckTakeXAOnNestedLists(const std::shared_ptr& inner_type, + const std::vector& list_sizes, int64_t length) { using NLG = ::arrow::util::internal::NestedListGenerator; // Create two equivalent lists: one as a FixedSizeList and another as a List. ASSERT_OK_AND_ASSIGN(auto fsl_list, @@ -1544,51 +1816,50 @@ class TestTakeKernelWithFixedSizeList : public TestTakeKernelTypedtype())); - DoCheckTake(fsl_list, indices, expected_fsl); + DoCheckTakeXA(fsl_list, indices, expected_fsl); } }; TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListInt32) { + inner_type_ = int32(); std::string list_json = "[null, [1, null, 3], [4, 5, 6], [7, 8, null]]"; - CheckTake(fixed_size_list(int32(), 3), list_json, "[]", "[]"); - CheckTake(fixed_size_list(int32(), 3), list_json, "[3, 2, 1]", - "[[7, 8, null], [4, 5, 6], [1, null, 3]]"); - CheckTake(fixed_size_list(int32(), 3), list_json, "[null, 2, 0]", - "[null, [4, 5, 6], null]"); - CheckTake(fixed_size_list(int32(), 3), list_json, "[null, null]", "[null, null]"); - CheckTake(fixed_size_list(int32(), 3), list_json, "[3, 0, 0, 3]", - "[[7, 8, null], null, null, [7, 8, null]]"); - CheckTake(fixed_size_list(int32(), 3), list_json, "[0, 1, 2, 3]", list_json); + CheckTakeXA(list_json, "[]", "[]"); + CheckTakeXA(list_json, "[3, 2, 1]", "[[7, 8, null], [4, 5, 6], [1, null, 3]]"); + CheckTakeXA(list_json, "[null, 2, 0]", "[null, [4, 5, 6], null]"); + CheckTakeXA(list_json, "[null, null]", "[null, null]"); + CheckTakeXA(list_json, "[3, 0, 0, 3]", "[[7, 8, null], null, null, [7, 8, null]]"); + CheckTakeXA(list_json, "[0, 1, 2, 3]", list_json); // No nulls in inner list values trigger the use of FixedWidthTakeExec() in // FSLTakeExec() std::string no_nulls_list_json = "[[0, 0, 0], [1, 2, 3], [4, 5, 6], [7, 8, 9]]"; - CheckTake( - fixed_size_list(int32(), 3), no_nulls_list_json, "[2, 2, 2, 2, 2, 2, 1]", + CheckTakeXA( + no_nulls_list_json, "[2, 2, 2, 2, 2, 2, 1]", "[[4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [1, 2, 3]]"); - this->TestNoValidityBitmapButUnknownNullCount(fixed_size_list(int32(), 3), - "[[1, null, 3], [4, 5, 6], [7, 8, null]]", + this->TestNoValidityBitmapButUnknownNullCount("[[1, null, 3], [4, 5, 6], [7, 8, null]]", "[0, 1, 0]"); } TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListVarWidth) { + inner_type_ = utf8(); std::string list_json = R"([["zero", "one", ""], ["two", "", "three"], ["four", "five", "six"], ["seven", "eight", ""]])"; - CheckTake(fixed_size_list(utf8(), 3), list_json, "[]", "[]"); - CheckTake(fixed_size_list(utf8(), 3), list_json, "[3, 2, 1]", - R"([["seven", "eight", ""], ["four", "five", "six"], ["two", "", "three"]])"); - CheckTake(fixed_size_list(utf8(), 3), list_json, "[null, 2, 0]", - R"([null, ["four", "five", "six"], ["zero", "one", ""]])"); - CheckTake(fixed_size_list(utf8(), 3), list_json, R"([null, null])", "[null, null]"); - CheckTake( - fixed_size_list(utf8(), 3), list_json, "[3, 0, 0,3]", + CheckTakeXA(list_json, "[]", "[]"); + CheckTakeXA( + list_json, "[3, 2, 1]", + R"([["seven", "eight", ""], ["four", "five", "six"], ["two", "", "three"]])"); + CheckTakeXA(list_json, "[null, 2, 0]", + R"([null, ["four", "five", "six"], ["zero", "one", ""]])"); + CheckTakeXA(list_json, R"([null, null])", "[null, null]"); + CheckTakeXA( + list_json, "[3, 0, 0,3]", R"([["seven", "eight", ""], ["zero", "one", ""], ["zero", "one", ""], ["seven", "eight", ""]])"); - CheckTake(fixed_size_list(utf8(), 3), list_json, "[0, 1, 2, 3]", list_json); - CheckTake(fixed_size_list(utf8(), 3), list_json, "[2, 2, 2, 2, 2, 2, 1]", - R"([ + CheckTakeXA(list_json, "[0, 1, 2, 3]", list_json); + CheckTakeXA(list_json, "[2, 2, 2, 2, 2, 2, 1]", + R"([ ["four", "five", "six"], ["four", "five", "six"], ["four", "five", "six"], ["four", "five", "six"], ["four", "five", "six"], ["four", "five", "six"], @@ -1606,11 +1877,14 @@ TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListModuloNesting) { NLG::VisitAllNestedListConfigurations( value_types, [this](const std::shared_ptr& inner_type, const std::vector& list_sizes) { - this->CheckTakeOnNestedLists(inner_type, list_sizes, /*length=*/5); + this->CheckTakeXAOnNestedLists(inner_type, list_sizes, /*length=*/5); }); } -class TestTakeKernelWithMap : public TestTakeKernelTyped {}; +class TestTakeKernelWithMap : public TestTakeKernelTyped { + protected: + std::shared_ptr value_type() const override { return map(utf8(), int32()); } +}; TEST_F(TestTakeKernelWithMap, TakeMapStringToInt32) { std::string map_json = R"([ @@ -1619,21 +1893,20 @@ TEST_F(TestTakeKernelWithMap, TakeMapStringToInt32) { [["cap", 8]], [] ])"; - CheckTake(map(utf8(), int32()), map_json, "[]", "[]"); - CheckTake(map(utf8(), int32()), map_json, "[3, 1, 3, 1, 3]", - "[[], null, [], null, []]"); - CheckTake(map(utf8(), int32()), map_json, "[2, 1, null]", R"([ + CheckTakeXA(map_json, "[]", "[]"); + CheckTakeXA(map_json, "[3, 1, 3, 1, 3]", "[[], null, [], null, []]"); + CheckTakeXA(map_json, "[2, 1, null]", R"([ [["cap", 8]], null, null ])"); - CheckTake(map(utf8(), int32()), map_json, "[2, 1, 0]", R"([ + CheckTakeXA(map_json, "[2, 1, 0]", R"([ [["cap", 8]], null, [["joe", 0], ["mark", null]] ])"); - CheckTake(map(utf8(), int32()), map_json, "[0, 1, 2, 3]", map_json); - CheckTake(map(utf8(), int32()), map_json, "[0, 0, 0, 0, 0, 0, 3]", R"([ + CheckTakeXA(map_json, "[0, 1, 2, 3]", map_json); + CheckTakeXA(map_json, "[0, 0, 0, 0, 0, 0, 3]", R"([ [["joe", 0], ["mark", null]], [["joe", 0], ["mark", null]], [["joe", 0], ["mark", null]], @@ -1644,31 +1917,34 @@ TEST_F(TestTakeKernelWithMap, TakeMapStringToInt32) { ])"); } -class TestTakeKernelWithStruct : public TestTakeKernelTyped {}; +class TestTakeKernelWithStruct : public TestTakeKernelTyped { + std::shared_ptr value_type() const override { + return struct_({field("a", int32()), field("b", utf8())}); + } +}; TEST_F(TestTakeKernelWithStruct, TakeStruct) { - auto struct_type = struct_({field("a", int32()), field("b", utf8())}); auto struct_json = R"([ null, {"a": 1, "b": ""}, {"a": 2, "b": "hello"}, {"a": 4, "b": "eh"} ])"; - CheckTake(struct_type, struct_json, "[]", "[]"); - CheckTake(struct_type, struct_json, "[3, 1, 3, 1, 3]", R"([ + this->CheckTakeXA(struct_json, "[]", "[]"); + this->CheckTakeXA(struct_json, "[3, 1, 3, 1, 3]", R"([ {"a": 4, "b": "eh"}, {"a": 1, "b": ""}, {"a": 4, "b": "eh"}, {"a": 1, "b": ""}, {"a": 4, "b": "eh"} ])"); - CheckTake(struct_type, struct_json, "[3, 1, 0]", R"([ + this->CheckTakeXA(struct_json, "[3, 1, 0]", R"([ {"a": 4, "b": "eh"}, {"a": 1, "b": ""}, null ])"); - CheckTake(struct_type, struct_json, "[0, 1, 2, 3]", struct_json); - CheckTake(struct_type, struct_json, "[0, 2, 2, 2, 2, 2, 2]", R"([ + this->CheckTakeXA(struct_json, "[0, 1, 2, 3]", struct_json); + this->CheckTakeXA(struct_json, "[0, 2, 2, 2, 2, 2, 2]", R"([ null, {"a": 2, "b": "hello"}, {"a": 2, "b": "hello"}, @@ -1678,16 +1954,30 @@ TEST_F(TestTakeKernelWithStruct, TakeStruct) { {"a": 2, "b": "hello"} ])"); - this->TestNoValidityBitmapButUnknownNullCount( - struct_type, R"([{"a": 1}, {"a": 2, "b": "hello"}])", "[0, 1, 0]"); + this->TestNoValidityBitmapButUnknownNullCount(R"([{"a": 1}, {"a": 2, "b": "hello"}])", + "[0, 1, 0]"); } -class TestTakeKernelWithUnion : public TestTakeKernelTyped {}; +template +class TestTakeKernelWithUnion : public TestTakeKernelTyped { + protected: + std::shared_ptr value_type() const override { + return std::make_shared( + FieldVector{ + field("a", int32()), + field("b", utf8()), + }, + std::vector{ + 2, + 5, + }); + } +}; + +TYPED_TEST_SUITE(TestTakeKernelWithUnion, UnionArrowTypes); -TEST_F(TestTakeKernelWithUnion, TakeUnion) { - for (const auto& union_type : - {dense_union({field("a", int32()), field("b", utf8())}, {2, 5}), - sparse_union({field("a", int32()), field("b", utf8())}, {2, 5})}) { +TYPED_TEST(TestTakeKernelWithUnion, TakeUnion) { + { auto union_json = R"([ [2, 222], [2, null], @@ -1697,22 +1987,22 @@ TEST_F(TestTakeKernelWithUnion, TakeUnion) { [2, 111], [5, null] ])"; - CheckTake(union_type, union_json, "[]", "[]"); - CheckTake(union_type, union_json, "[3, 0, 3, 0, 3]", R"([ + this->CheckTakeXA(union_json, "[]", "[]"); + this->CheckTakeXA(union_json, "[3, 0, 3, 0, 3]", R"([ [5, "eh"], [2, 222], [5, "eh"], [2, 222], [5, "eh"] ])"); - CheckTake(union_type, union_json, "[4, 2, 0, 6]", R"([ + this->CheckTakeXA(union_json, "[4, 2, 0, 6]", R"([ [2, null], [5, "hello"], [2, 222], [5, null] ])"); - CheckTake(union_type, union_json, "[0, 1, 2, 3, 4, 5, 6]", union_json); - CheckTake(union_type, union_json, "[1, 2, 2, 2, 2, 2, 2]", R"([ + this->CheckTakeXA(union_json, "[0, 1, 2, 3, 4, 5, 6]", union_json); + this->CheckTakeXA(union_json, "[1, 2, 2, 2, 2, 2, 2]", R"([ [2, null], [5, "hello"], [5, "hello"], @@ -1721,7 +2011,7 @@ TEST_F(TestTakeKernelWithUnion, TakeUnion) { [5, "hello"], [5, "hello"] ])"); - CheckTake(union_type, union_json, "[0, null, 1, null, 2, 2, 2]", R"([ + this->CheckTakeXA(union_json, "[0, null, 1, null, 2, 2, 2]", R"([ [2, 222], [2, null], [2, null], @@ -1735,72 +2025,58 @@ TEST_F(TestTakeKernelWithUnion, TakeUnion) { class TestPermutationsWithTake : public ::testing::Test { protected: - void DoTake(const Int16Array& values, const Int16Array& indices, - std::shared_ptr* out) { - ASSERT_OK_AND_ASSIGN(std::shared_ptr boxed_out, Take(values, indices)); + Result> DoTakeAAA( + const std::shared_ptr& values, + const std::shared_ptr& indices) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr boxed_out, TakeAAA(*values, *indices)); ValidateOutput(boxed_out); - *out = checked_pointer_cast(std::move(boxed_out)); + return checked_pointer_cast(std::move(boxed_out)); } - std::shared_ptr DoTake(const Int16Array& values, - const Int16Array& indices) { - std::shared_ptr out; - DoTake(values, indices, &out); - return out; - } - - std::shared_ptr DoTakeN(uint64_t n, std::shared_ptr array) { + Result> DoTakeN(uint64_t n, + std::shared_ptr array) { auto power_of_2 = array; - array = Identity(array->length()); + ARROW_ASSIGN_OR_RAISE(array, Identity(array->length())); while (n != 0) { if (n & 1) { - array = DoTake(*array, *power_of_2); + ARROW_ASSIGN_OR_RAISE(array, DoTakeAAA(array, power_of_2)); } - power_of_2 = DoTake(*power_of_2, *power_of_2); + ARROW_ASSIGN_OR_RAISE(power_of_2, DoTakeAAA(power_of_2, power_of_2)); n >>= 1; } return array; } template - void Shuffle(const Int16Array& array, Rng& gen, std::shared_ptr* shuffled) { + Result> Shuffle(const Int16Array& array, Rng& gen) { auto byte_length = array.length() * sizeof(int16_t); - ASSERT_OK_AND_ASSIGN(auto data, array.values()->CopySlice(0, byte_length)); + ARROW_ASSIGN_OR_RAISE(auto data, array.values()->CopySlice(0, byte_length)); auto mutable_data = reinterpret_cast(data->mutable_data()); std::shuffle(mutable_data, mutable_data + array.length(), gen); - shuffled->reset(new Int16Array(array.length(), data)); - } - - template - std::shared_ptr Shuffle(const Int16Array& array, Rng& gen) { - std::shared_ptr out; - Shuffle(array, gen, &out); - return out; + return std::make_shared(array.length(), data); } - void Identity(int64_t length, std::shared_ptr* identity) { + Result> Identity(int64_t length) { + std::shared_ptr identity; Int16Builder identity_builder; - ASSERT_OK(identity_builder.Resize(length)); + RETURN_NOT_OK(identity_builder.Resize(length)); for (int16_t i = 0; i < length; ++i) { identity_builder.UnsafeAppend(i); } - ASSERT_OK(identity_builder.Finish(identity)); - } - - std::shared_ptr Identity(int64_t length) { - std::shared_ptr out; - Identity(length, &out); - return out; + RETURN_NOT_OK(identity_builder.Finish(&identity)); + return identity; } - std::shared_ptr Inverse(const std::shared_ptr& permutation) { + Result> Inverse( + const std::shared_ptr& permutation) { auto length = static_cast(permutation->length()); std::vector cycle_lengths(length + 1, false); auto permutation_to_the_i = permutation; for (int16_t cycle_length = 1; cycle_length <= length; ++cycle_length) { cycle_lengths[cycle_length] = HasTrivialCycle(*permutation_to_the_i); - permutation_to_the_i = DoTake(*permutation, *permutation_to_the_i); + ARROW_ASSIGN_OR_RAISE(permutation_to_the_i, + DoTakeAAA(permutation, permutation_to_the_i)); } uint64_t cycle_to_identity_length = 1; @@ -1836,42 +2112,18 @@ TEST_F(TestPermutationsWithTake, InvertPermutation) { for (auto seed : std::vector({0, kRandomSeed, kRandomSeed * 2 - 1})) { std::default_random_engine gen(seed); for (int16_t length = 0; length < 1 << 10; ++length) { - auto identity = Identity(length); - auto permutation = Shuffle(*identity, gen); - auto inverse = Inverse(permutation); + ASSERT_OK_AND_ASSIGN(auto identity, Identity(length)); + ASSERT_OK_AND_ASSIGN(auto permutation, Shuffle(*identity, gen)); + ASSERT_OK_AND_ASSIGN(auto inverse, Inverse(permutation)); if (inverse == nullptr) { break; } - ASSERT_TRUE(DoTake(*inverse, *permutation)->Equals(identity)); + DoCheckTakeXA(inverse, permutation, identity); } } } -class TestTakeKernelWithRecordBatch : public TestTakeKernelTyped { - public: - void AssertTake(const std::shared_ptr& schm, const std::string& batch_json, - const std::string& indices, const std::string& expected_batch) { - std::shared_ptr actual; - - for (auto index_type : {int8(), uint32()}) { - ASSERT_OK(TakeJSON(schm, batch_json, index_type, indices, &actual)); - ValidateOutput(actual); - ASSERT_BATCHES_EQUAL(*RecordBatchFromJSON(schm, expected_batch), *actual); - } - } - - Status TakeJSON(const std::shared_ptr& schm, const std::string& batch_json, - const std::shared_ptr& index_type, const std::string& indices, - std::shared_ptr* out) { - auto batch = RecordBatchFromJSON(schm, batch_json); - ARROW_ASSIGN_OR_RAISE(Datum result, - Take(Datum(batch), Datum(ArrayFromJSON(index_type, indices)))); - *out = result.record_batch(); - return Status::OK(); - } -}; - -TEST_F(TestTakeKernelWithRecordBatch, TakeRecordBatch) { +TEST(TestTakeKernelWithRecordBatch, TakeRecordBatch) { std::vector> fields = {field("a", int32()), field("b", utf8())}; auto schm = schema(fields); @@ -1881,21 +2133,21 @@ TEST_F(TestTakeKernelWithRecordBatch, TakeRecordBatch) { {"a": 2, "b": "hello"}, {"a": 4, "b": "eh"} ])"; - this->AssertTake(schm, struct_json, "[]", "[]"); - this->AssertTake(schm, struct_json, "[3, 1, 3, 1, 3]", R"([ + AssertTakeRAR(schm, struct_json, "[]", "[]"); + AssertTakeRAR(schm, struct_json, "[3, 1, 3, 1, 3]", R"([ {"a": 4, "b": "eh"}, {"a": 1, "b": ""}, {"a": 4, "b": "eh"}, {"a": 1, "b": ""}, {"a": 4, "b": "eh"} ])"); - this->AssertTake(schm, struct_json, "[3, 1, 0]", R"([ + AssertTakeRAR(schm, struct_json, "[3, 1, 0]", R"([ {"a": 4, "b": "eh"}, {"a": 1, "b": ""}, {"a": null, "b": "yo"} ])"); - this->AssertTake(schm, struct_json, "[0, 1, 2, 3]", struct_json); - this->AssertTake(schm, struct_json, "[0, 2, 2, 2, 2, 2, 2]", R"([ + AssertTakeRAR(schm, struct_json, "[0, 1, 2, 3]", struct_json); + AssertTakeRAR(schm, struct_json, "[0, 2, 2, 2, 2, 2, 2]", R"([ {"a": null, "b": "yo"}, {"a": 2, "b": "hello"}, {"a": 2, "b": "hello"}, @@ -1906,115 +2158,41 @@ TEST_F(TestTakeKernelWithRecordBatch, TakeRecordBatch) { ])"); } -class TestTakeKernelWithChunkedArray : public TestTakeKernelTyped { - public: - void AssertTake(const std::shared_ptr& type, - const std::vector& values, const std::string& indices, - const std::vector& expected) { - std::shared_ptr actual; - ASSERT_OK(this->TakeWithArray(type, values, indices, &actual)); - ValidateOutput(actual); - AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual); +TEST(TestTakeKernelWithChunkedIndices, TakeChunkedArray) { + for (auto& ty : {boolean(), int8(), uint64()}) { + AssertTakeCAC(ty, {"[]"}, "[]", {"[]"}); + AssertTakeCCC(ty, {}, {}, {}); + AssertTakeCCC(ty, {}, {"[]"}, {"[]"}); + AssertTakeCCC(ty, {}, {"[null]"}, {"[null]"}); + AssertTakeCCC(ty, {"[]"}, {}, {}); + AssertTakeCCC(ty, {"[]"}, {"[]"}, {"[]"}); + AssertTakeCCC(ty, {"[]"}, {"[null]"}, {"[null]"}); } - void AssertChunkedTake(const std::shared_ptr& type, - const std::vector& values, - const std::vector& indices, - const std::vector& expected) { - std::shared_ptr actual; - ASSERT_OK(this->TakeWithChunkedArray(type, values, indices, &actual)); - ValidateOutput(actual); - AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual); - } + AssertTakeCAC(boolean(), {"[true]", "[false, true]"}, "[0, 1, 0, 2]", + {"[true, false, true, true]"}); + AssertTakeCCC(boolean(), {"[false]", "[true, false]"}, {"[0, 1, 0]", "[]", "[2]"}, + {"[false, true, false]", "[]", "[false]"}); + AssertTakeCAC(boolean(), {"[true]", "[false, true]"}, "[2, 1]", {"[true, false]"}); - Status TakeWithArray(const std::shared_ptr& type, - const std::vector& values, const std::string& indices, - std::shared_ptr* out) { - ARROW_ASSIGN_OR_RAISE(Datum result, Take(ChunkedArrayFromJSON(type, values), - ArrayFromJSON(int8(), indices))); - *out = result.chunked_array(); - return Status::OK(); - } + Datum chunked_arr; + for (auto& int_ty : SignedIntTypes()) { + AssertTakeCAC(int_ty, {"[7]", "[8, 9]"}, "[0, 1, 0, 2]", {"[7, 8, 7, 9]"}); + AssertTakeCCC(int_ty, {"[7]", "[8, 9]"}, {"[0, 1, 0]", "[]", "[2]"}, + {"[7, 8, 7]", "[]", "[9]"}); + AssertTakeCAC(int_ty, {"[7]", "[8, 9]"}, "[2, 1]", {"[9, 8]"}); - Status TakeWithChunkedArray(const std::shared_ptr& type, - const std::vector& values, - const std::vector& indices, - std::shared_ptr* out) { - ARROW_ASSIGN_OR_RAISE(Datum result, Take(ChunkedArrayFromJSON(type, values), - ChunkedArrayFromJSON(int8(), indices))); - *out = result.chunked_array(); - return Status::OK(); + ASSERT_RAISES(IndexError, + TakeCAC(int_ty, {"[7]", "[8, 9]"}, "[0, 5]").Value(&chunked_arr)); + ASSERT_RAISES( + IndexError, + TakeCCC(int_ty, {"[7]", "[8, 9]"}, {"[0, 1, 0]", "[5, 1]"}).Value(&chunked_arr)); + ASSERT_RAISES(IndexError, TakeCCC(int_ty, {}, {"[0]"}).Value(&chunked_arr)); + ASSERT_RAISES(IndexError, TakeCCC(int_ty, {"[]"}, {"[0]"}).Value(&chunked_arr)); } -}; - -TEST_F(TestTakeKernelWithChunkedArray, TakeChunkedArray) { - this->AssertTake(int8(), {"[]"}, "[]", {"[]"}); - this->AssertChunkedTake(int8(), {}, {}, {}); - this->AssertChunkedTake(int8(), {}, {"[]"}, {"[]"}); - this->AssertChunkedTake(int8(), {}, {"[null]"}, {"[null]"}); - this->AssertChunkedTake(int8(), {"[]"}, {}, {}); - this->AssertChunkedTake(int8(), {"[]"}, {"[]"}, {"[]"}); - this->AssertChunkedTake(int8(), {"[]"}, {"[null]"}, {"[null]"}); - - this->AssertTake(int8(), {"[7]", "[8, 9]"}, "[0, 1, 0, 2]", {"[7, 8, 7, 9]"}); - this->AssertChunkedTake(int8(), {"[7]", "[8, 9]"}, {"[0, 1, 0]", "[]", "[2]"}, - {"[7, 8, 7]", "[]", "[9]"}); - this->AssertTake(int8(), {"[7]", "[8, 9]"}, "[2, 1]", {"[9, 8]"}); - - std::shared_ptr arr; - ASSERT_RAISES(IndexError, - this->TakeWithArray(int8(), {"[7]", "[8, 9]"}, "[0, 5]", &arr)); - ASSERT_RAISES(IndexError, this->TakeWithChunkedArray(int8(), {"[7]", "[8, 9]"}, - {"[0, 1, 0]", "[5, 1]"}, &arr)); - ASSERT_RAISES(IndexError, this->TakeWithChunkedArray(int8(), {}, {"[0]"}, &arr)); - ASSERT_RAISES(IndexError, this->TakeWithChunkedArray(int8(), {"[]"}, {"[0]"}, &arr)); } -class TestTakeKernelWithTable : public TestTakeKernelTyped { - public: - void AssertTake(const std::shared_ptr& schm, - const std::vector& table_json, const std::string& filter, - const std::vector& expected_table) { - std::shared_ptr
actual; - - ASSERT_OK(this->TakeWithArray(schm, table_json, filter, &actual)); - ValidateOutput(actual); - ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual); - } - - void AssertChunkedTake(const std::shared_ptr& schm, - const std::vector& table_json, - const std::vector& filter, - const std::vector& expected_table) { - std::shared_ptr
actual; - - ASSERT_OK(this->TakeWithChunkedArray(schm, table_json, filter, &actual)); - ValidateOutput(actual); - ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual); - } - - Status TakeWithArray(const std::shared_ptr& schm, - const std::vector& values, const std::string& indices, - std::shared_ptr
* out) { - ARROW_ASSIGN_OR_RAISE(Datum result, Take(Datum(TableFromJSON(schm, values)), - Datum(ArrayFromJSON(int8(), indices)))); - *out = result.table(); - return Status::OK(); - } - - Status TakeWithChunkedArray(const std::shared_ptr& schm, - const std::vector& values, - const std::vector& indices, - std::shared_ptr
* out) { - ARROW_ASSIGN_OR_RAISE(Datum result, - Take(Datum(TableFromJSON(schm, values)), - Datum(ChunkedArrayFromJSON(int8(), indices)))); - *out = result.table(); - return Status::OK(); - } -}; - -TEST_F(TestTakeKernelWithTable, TakeTable) { +TEST(TestTakeKernelWithTable, TakeTable) { std::vector> fields = {field("a", int32()), field("b", utf8())}; auto schm = schema(fields); @@ -2022,11 +2200,12 @@ TEST_F(TestTakeKernelWithTable, TakeTable) { "[{\"a\": null, \"b\": \"yo\"},{\"a\": 1, \"b\": \"\"}]", "[{\"a\": 2, \"b\": \"hello\"},{\"a\": 4, \"b\": \"eh\"}]"}; - this->AssertTake(schm, table_json, "[]", {"[]"}); + AssertTakeTAT(schm, table_json, "[]", {"[]"}); std::vector expected_310 = { - "[{\"a\": 4, \"b\": \"eh\"},{\"a\": 1, \"b\": \"\"},{\"a\": null, \"b\": \"yo\"}]"}; - this->AssertTake(schm, table_json, "[3, 1, 0]", expected_310); - this->AssertChunkedTake(schm, table_json, {"[0, 1]", "[2, 3]"}, table_json); + "[{\"a\": 4, \"b\": \"eh\"},{\"a\": 1, \"b\": \"\"},{\"a\": null, \"b\": " + "\"yo\"}]"}; + AssertTakeTAT(schm, table_json, "[3, 1, 0]", expected_310); + AssertTakeTCT(schm, table_json, {"[0, 1]", "[2, 3]"}, table_json); } TEST(TestTakeMetaFunction, ArityChecking) { @@ -2066,14 +2245,14 @@ void CheckTakeRandom(const std::shared_ptr& values, int64_t indices_lengt max_index, null_probability); auto indices_no_nulls = rand->Numeric( indices_length, static_cast(0), max_index, /*null_probability=*/0.0); - ValidateTake(values, indices); - ValidateTake(values, indices_no_nulls); + ValidateTakeXA(values, indices); + ValidateTakeXA(values, indices_no_nulls); // Sliced indices array if (indices_length >= 2) { indices = indices->Slice(1, indices_length - 2); indices_no_nulls = indices_no_nulls->Slice(1, indices_length - 2); - ValidateTake(values, indices); - ValidateTake(values, indices_no_nulls); + ValidateTakeXA(values, indices); + ValidateTakeXA(values, indices_no_nulls); } } diff --git a/cpp/src/arrow/filesystem/s3fs.cc b/cpp/src/arrow/filesystem/s3fs.cc index 99cee19ed1e78..fd5b2e5be2a3a 100644 --- a/cpp/src/arrow/filesystem/s3fs.cc +++ b/cpp/src/arrow/filesystem/s3fs.cc @@ -51,6 +51,7 @@ #include #include #include +#include #include #include #include @@ -74,6 +75,7 @@ #include #include #include +#include #include // AWS_SDK_VERSION_{MAJOR,MINOR,PATCH} are available since 1.9.7. @@ -1335,7 +1337,7 @@ struct ObjectMetadataSetter { static std::unordered_map GetSetters() { return {{"ACL", CannedACLSetter()}, {"Cache-Control", StringSetter(&ObjectRequest::SetCacheControl)}, - {"Content-Type", StringSetter(&ObjectRequest::SetContentType)}, + {"Content-Type", ContentTypeSetter()}, {"Content-Language", StringSetter(&ObjectRequest::SetContentLanguage)}, {"Expires", DateTimeSetter(&ObjectRequest::SetExpires)}}; } @@ -1365,6 +1367,16 @@ struct ObjectMetadataSetter { }; } + /** We need a special setter here and can not use `StringSetter` because for e.g. the + * `PutObjectRequest`, the setter is located in the base class (instead of the concrete + * class). */ + static Setter ContentTypeSetter() { + return [](const std::string& str, ObjectRequest* req) { + req->SetContentType(str); + return Status::OK(); + }; + } + static Result ParseACL(const std::string& v) { if (v.empty()) { return S3Model::ObjectCannedACL::NOT_SET; @@ -1583,6 +1595,15 @@ class ObjectInputFile final : public io::RandomAccessFile { // (for rational, see: https://github.com/apache/arrow/issues/34363) static constexpr int64_t kPartUploadSize = 10 * 1024 * 1024; +// Above this threshold, use a multi-part upload instead of a single request upload. Only +// relevant if early sanitization of writing to the bucket is disabled (see +// `allow_delayed_open`). +static constexpr int64_t kMultiPartUploadThresholdSize = kPartUploadSize - 1; + +static_assert(kMultiPartUploadThresholdSize < kPartUploadSize, + "Multi part upload threshold size must be stricly less than the actual " + "multi part upload part size."); + // An OutputStream that writes to a S3 object class ObjectOutputStream final : public io::OutputStream { protected: @@ -1598,7 +1619,8 @@ class ObjectOutputStream final : public io::OutputStream { path_(path), metadata_(metadata), default_metadata_(options.default_metadata), - background_writes_(options.background_writes) {} + background_writes_(options.background_writes), + allow_delayed_open_(options.allow_delayed_open) {} ~ObjectOutputStream() override { // For compliance with the rest of the IO stack, Close rather than Abort, @@ -1606,29 +1628,47 @@ class ObjectOutputStream final : public io::OutputStream { io::internal::CloseFromDestructor(this); } + template + Status SetMetadataInRequest(ObjectRequest* request) { + std::shared_ptr metadata; + + if (metadata_ && metadata_->size() != 0) { + metadata = metadata_; + } else if (default_metadata_ && default_metadata_->size() != 0) { + metadata = default_metadata_; + } + + bool is_content_type_set{false}; + if (metadata) { + RETURN_NOT_OK(SetObjectMetadata(metadata, request)); + + is_content_type_set = metadata->Contains("Content-Type"); + } + + if (!is_content_type_set) { + // If we do not set anything then the SDK will default to application/xml + // which confuses some tools (https://github.com/apache/arrow/issues/11934) + // So we instead default to application/octet-stream which is less misleading + request->SetContentType("application/octet-stream"); + } + + return Status::OK(); + } + std::shared_ptr Self() { return std::dynamic_pointer_cast(shared_from_this()); } - Status Init() { + Status CreateMultipartUpload() { + DCHECK(ShouldBeMultipartUpload()); + ARROW_ASSIGN_OR_RAISE(auto client_lock, holder_->Lock()); // Initiate the multi-part upload S3Model::CreateMultipartUploadRequest req; req.SetBucket(ToAwsString(path_.bucket)); req.SetKey(ToAwsString(path_.key)); - if (metadata_ && metadata_->size() != 0) { - RETURN_NOT_OK(SetObjectMetadata(metadata_, &req)); - } else if (default_metadata_ && default_metadata_->size() != 0) { - RETURN_NOT_OK(SetObjectMetadata(default_metadata_, &req)); - } - - // If we do not set anything then the SDK will default to application/xml - // which confuses some tools (https://github.com/apache/arrow/issues/11934) - // So we instead default to application/octet-stream which is less misleading - if (!req.ContentTypeHasBeenSet()) { - req.SetContentType("application/octet-stream"); - } + RETURN_NOT_OK(SetMetadataInRequest(&req)); auto outcome = client_lock.Move()->CreateMultipartUpload(req); if (!outcome.IsSuccess()) { @@ -1637,7 +1677,19 @@ class ObjectOutputStream final : public io::OutputStream { path_.key, "' in bucket '", path_.bucket, "': "), "CreateMultipartUpload", outcome.GetError()); } - upload_id_ = outcome.GetResult().GetUploadId(); + multipart_upload_id_ = outcome.GetResult().GetUploadId(); + + return Status::OK(); + } + + Status Init() { + // If we are allowed to do delayed I/O, we can use a single request to upload the + // data. If not, we use a multi-part upload and initiate it here to + // sanitize that writing to the bucket is possible. + if (!allow_delayed_open_) { + RETURN_NOT_OK(CreateMultipartUpload()); + } + upload_state_ = std::make_shared(); closed_ = false; return Status::OK(); @@ -1648,42 +1700,62 @@ class ObjectOutputStream final : public io::OutputStream { return Status::OK(); } - ARROW_ASSIGN_OR_RAISE(auto client_lock, holder_->Lock()); + if (IsMultipartCreated()) { + ARROW_ASSIGN_OR_RAISE(auto client_lock, holder_->Lock()); - S3Model::AbortMultipartUploadRequest req; - req.SetBucket(ToAwsString(path_.bucket)); - req.SetKey(ToAwsString(path_.key)); - req.SetUploadId(upload_id_); + S3Model::AbortMultipartUploadRequest req; + req.SetBucket(ToAwsString(path_.bucket)); + req.SetKey(ToAwsString(path_.key)); + req.SetUploadId(multipart_upload_id_); - auto outcome = client_lock.Move()->AbortMultipartUpload(req); - if (!outcome.IsSuccess()) { - return ErrorToStatus( - std::forward_as_tuple("When aborting multiple part upload for key '", path_.key, - "' in bucket '", path_.bucket, "': "), - "AbortMultipartUpload", outcome.GetError()); + auto outcome = client_lock.Move()->AbortMultipartUpload(req); + if (!outcome.IsSuccess()) { + return ErrorToStatus( + std::forward_as_tuple("When aborting multiple part upload for key '", + path_.key, "' in bucket '", path_.bucket, "': "), + "AbortMultipartUpload", outcome.GetError()); + } } + current_part_.reset(); holder_ = nullptr; closed_ = true; + return Status::OK(); } // OutputStream interface + bool ShouldBeMultipartUpload() const { + return pos_ > kMultiPartUploadThresholdSize || !allow_delayed_open_; + } + + bool IsMultipartCreated() const { return !multipart_upload_id_.empty(); } + Status EnsureReadyToFlushFromClose() { - if (current_part_) { - // Upload last part - RETURN_NOT_OK(CommitCurrentPart()); - } + if (ShouldBeMultipartUpload()) { + if (current_part_) { + // Upload last part + RETURN_NOT_OK(CommitCurrentPart()); + } - // S3 mandates at least one part, upload an empty one if necessary - if (part_number_ == 1) { - RETURN_NOT_OK(UploadPart("", 0)); + // S3 mandates at least one part, upload an empty one if necessary + if (part_number_ == 1) { + RETURN_NOT_OK(UploadPart("", 0)); + } + } else { + RETURN_NOT_OK(UploadUsingSingleRequest()); } return Status::OK(); } + Status CleanupAfterClose() { + holder_ = nullptr; + closed_ = true; + return Status::OK(); + } + Status FinishPartUploadAfterFlush() { ARROW_ASSIGN_OR_RAISE(auto client_lock, holder_->Lock()); @@ -1697,7 +1769,7 @@ class ObjectOutputStream final : public io::OutputStream { S3Model::CompleteMultipartUploadRequest req; req.SetBucket(ToAwsString(path_.bucket)); req.SetKey(ToAwsString(path_.key)); - req.SetUploadId(upload_id_); + req.SetUploadId(multipart_upload_id_); req.SetMultipartUpload(std::move(completed_upload)); auto outcome = @@ -1709,8 +1781,6 @@ class ObjectOutputStream final : public io::OutputStream { "CompleteMultipartUpload", outcome.GetError()); } - holder_ = nullptr; - closed_ = true; return Status::OK(); } @@ -1720,7 +1790,12 @@ class ObjectOutputStream final : public io::OutputStream { RETURN_NOT_OK(EnsureReadyToFlushFromClose()); RETURN_NOT_OK(Flush()); - return FinishPartUploadAfterFlush(); + + if (IsMultipartCreated()) { + RETURN_NOT_OK(FinishPartUploadAfterFlush()); + } + + return CleanupAfterClose(); } Future<> CloseAsync() override { @@ -1729,8 +1804,12 @@ class ObjectOutputStream final : public io::OutputStream { RETURN_NOT_OK(EnsureReadyToFlushFromClose()); // Wait for in-progress uploads to finish (if async writes are enabled) - return FlushAsync().Then( - [self = Self()]() { return self->FinishPartUploadAfterFlush(); }); + return FlushAsync().Then([self = Self()]() { + if (self->IsMultipartCreated()) { + RETURN_NOT_OK(self->FinishPartUploadAfterFlush()); + } + return self->CleanupAfterClose(); + }); } bool closed() const override { return closed_; } @@ -1776,7 +1855,8 @@ class ObjectOutputStream final : public io::OutputStream { return Status::OK(); } - // Upload current buffer + // Upload current buffer. We're only reaching this point if we have accumulated + // enough data to upload. RETURN_NOT_OK(CommitCurrentPart()); } @@ -1810,40 +1890,73 @@ class ObjectOutputStream final : public io::OutputStream { } // Wait for background writes to finish std::unique_lock lock(upload_state_->mutex); - return upload_state_->pending_parts_completed; + return upload_state_->pending_uploads_completed; } // Upload-related helpers Status CommitCurrentPart() { + if (!IsMultipartCreated()) { + RETURN_NOT_OK(CreateMultipartUpload()); + } + ARROW_ASSIGN_OR_RAISE(auto buf, current_part_->Finish()); current_part_.reset(); current_part_size_ = 0; return UploadPart(buf); } - Status UploadPart(std::shared_ptr buffer) { - return UploadPart(buffer->data(), buffer->size(), buffer); + Status UploadUsingSingleRequest() { + std::shared_ptr buf; + if (current_part_ == nullptr) { + // In case the stream is closed directly after it has been opened without writing + // anything, we'll have to create an empty buffer. + buf = std::make_shared(""); + } else { + ARROW_ASSIGN_OR_RAISE(buf, current_part_->Finish()); + } + + current_part_.reset(); + current_part_size_ = 0; + return UploadUsingSingleRequest(buf); } - Status UploadPart(const void* data, int64_t nbytes, - std::shared_ptr owned_buffer = nullptr) { - S3Model::UploadPartRequest req; + template + using UploadResultCallbackFunction = + std::function, + int32_t part_number, OutcomeType outcome)>; + + static Result TriggerUploadRequest( + const Aws::S3::Model::PutObjectRequest& request, + const std::shared_ptr& holder) { + ARROW_ASSIGN_OR_RAISE(auto client_lock, holder->Lock()); + return client_lock.Move()->PutObject(request); + } + + static Result TriggerUploadRequest( + const Aws::S3::Model::UploadPartRequest& request, + const std::shared_ptr& holder) { + ARROW_ASSIGN_OR_RAISE(auto client_lock, holder->Lock()); + return client_lock.Move()->UploadPart(request); + } + + template + Status Upload( + RequestType&& req, + UploadResultCallbackFunction sync_result_callback, + UploadResultCallbackFunction async_result_callback, + const void* data, int64_t nbytes, std::shared_ptr owned_buffer = nullptr) { req.SetBucket(ToAwsString(path_.bucket)); req.SetKey(ToAwsString(path_.key)); - req.SetUploadId(upload_id_); - req.SetPartNumber(part_number_); + req.SetBody(std::make_shared(data, nbytes)); req.SetContentLength(nbytes); if (!background_writes_) { req.SetBody(std::make_shared(data, nbytes)); - ARROW_ASSIGN_OR_RAISE(auto client_lock, holder_->Lock()); - auto outcome = client_lock.Move()->UploadPart(req); - if (!outcome.IsSuccess()) { - return UploadPartError(req, outcome); - } else { - AddCompletedPart(upload_state_, part_number_, outcome.GetResult()); - } + + ARROW_ASSIGN_OR_RAISE(auto outcome, TriggerUploadRequest(req, holder_)); + + RETURN_NOT_OK(sync_result_callback(req, upload_state_, part_number_, outcome)); } else { // If the data isn't owned, make an immutable copy for the lifetime of the closure if (owned_buffer == nullptr) { @@ -1858,19 +1971,18 @@ class ObjectOutputStream final : public io::OutputStream { { std::unique_lock lock(upload_state_->mutex); - if (upload_state_->parts_in_progress++ == 0) { - upload_state_->pending_parts_completed = Future<>::Make(); + if (upload_state_->uploads_in_progress++ == 0) { + upload_state_->pending_uploads_completed = Future<>::Make(); } } // The closure keeps the buffer and the upload state alive auto deferred = [owned_buffer, holder = holder_, req = std::move(req), - state = upload_state_, + state = upload_state_, async_result_callback, part_number = part_number_]() mutable -> Status { - ARROW_ASSIGN_OR_RAISE(auto client_lock, holder->Lock()); - auto outcome = client_lock.Move()->UploadPart(req); - HandleUploadOutcome(state, part_number, req, outcome); - return Status::OK(); + ARROW_ASSIGN_OR_RAISE(auto outcome, TriggerUploadRequest(req, holder)); + + return async_result_callback(req, state, part_number, outcome); }; RETURN_NOT_OK(SubmitIO(io_context_, std::move(deferred))); } @@ -1880,9 +1992,118 @@ class ObjectOutputStream final : public io::OutputStream { return Status::OK(); } - static void HandleUploadOutcome(const std::shared_ptr& state, - int part_number, const S3Model::UploadPartRequest& req, - const Result& result) { + static Status UploadUsingSingleRequestError( + const Aws::S3::Model::PutObjectRequest& request, + const Aws::S3::Model::PutObjectOutcome& outcome) { + return ErrorToStatus( + std::forward_as_tuple("When uploading object with key '", request.GetKey(), + "' in bucket '", request.GetBucket(), "': "), + "PutObject", outcome.GetError()); + } + + Status UploadUsingSingleRequest(std::shared_ptr buffer) { + return UploadUsingSingleRequest(buffer->data(), buffer->size(), buffer); + } + + Status UploadUsingSingleRequest(const void* data, int64_t nbytes, + std::shared_ptr owned_buffer = nullptr) { + auto sync_result_callback = [](const Aws::S3::Model::PutObjectRequest& request, + std::shared_ptr state, + int32_t part_number, + Aws::S3::Model::PutObjectOutcome outcome) { + if (!outcome.IsSuccess()) { + return UploadUsingSingleRequestError(request, outcome); + } + return Status::OK(); + }; + + auto async_result_callback = [](const Aws::S3::Model::PutObjectRequest& request, + std::shared_ptr state, + int32_t part_number, + Aws::S3::Model::PutObjectOutcome outcome) { + HandleUploadUsingSingleRequestOutcome(state, request, outcome.GetResult()); + return Status::OK(); + }; + + Aws::S3::Model::PutObjectRequest req{}; + RETURN_NOT_OK(SetMetadataInRequest(&req)); + + return Upload( + std::move(req), std::move(sync_result_callback), std::move(async_result_callback), + data, nbytes, std::move(owned_buffer)); + } + + Status UploadPart(std::shared_ptr buffer) { + return UploadPart(buffer->data(), buffer->size(), buffer); + } + + static Status UploadPartError(const Aws::S3::Model::UploadPartRequest& request, + const Aws::S3::Model::UploadPartOutcome& outcome) { + return ErrorToStatus( + std::forward_as_tuple("When uploading part for key '", request.GetKey(), + "' in bucket '", request.GetBucket(), "': "), + "UploadPart", outcome.GetError()); + } + + Status UploadPart(const void* data, int64_t nbytes, + std::shared_ptr owned_buffer = nullptr) { + if (!IsMultipartCreated()) { + RETURN_NOT_OK(CreateMultipartUpload()); + } + + Aws::S3::Model::UploadPartRequest req{}; + req.SetPartNumber(part_number_); + req.SetUploadId(multipart_upload_id_); + + auto sync_result_callback = [](const Aws::S3::Model::UploadPartRequest& request, + std::shared_ptr state, + int32_t part_number, + Aws::S3::Model::UploadPartOutcome outcome) { + if (!outcome.IsSuccess()) { + return UploadPartError(request, outcome); + } else { + AddCompletedPart(state, part_number, outcome.GetResult()); + } + + return Status::OK(); + }; + + auto async_result_callback = [](const Aws::S3::Model::UploadPartRequest& request, + std::shared_ptr state, + int32_t part_number, + Aws::S3::Model::UploadPartOutcome outcome) { + HandleUploadPartOutcome(state, part_number, request, outcome.GetResult()); + return Status::OK(); + }; + + return Upload( + std::move(req), std::move(sync_result_callback), std::move(async_result_callback), + data, nbytes, std::move(owned_buffer)); + } + + static void HandleUploadUsingSingleRequestOutcome( + const std::shared_ptr& state, const S3Model::PutObjectRequest& req, + const Result& result) { + std::unique_lock lock(state->mutex); + if (!result.ok()) { + state->status &= result.status(); + } else { + const auto& outcome = *result; + if (!outcome.IsSuccess()) { + state->status &= UploadUsingSingleRequestError(req, outcome); + } + } + // GH-41862: avoid potential deadlock if the Future's callback is called + // with the mutex taken. + auto fut = state->pending_uploads_completed; + lock.unlock(); + fut.MarkFinished(state->status); + } + + static void HandleUploadPartOutcome(const std::shared_ptr& state, + int part_number, + const S3Model::UploadPartRequest& req, + const Result& result) { std::unique_lock lock(state->mutex); if (!result.ok()) { state->status &= result.status(); @@ -1895,10 +2116,10 @@ class ObjectOutputStream final : public io::OutputStream { } } // Notify completion - if (--state->parts_in_progress == 0) { + if (--state->uploads_in_progress == 0) { // GH-41862: avoid potential deadlock if the Future's callback is called // with the mutex taken. - auto fut = state->pending_parts_completed; + auto fut = state->pending_uploads_completed; lock.unlock(); // State could be mutated concurrently if another thread writes to the // stream, but in this case the Flush() call is only advisory anyway. @@ -1923,14 +2144,6 @@ class ObjectOutputStream final : public io::OutputStream { state->completed_parts[slot] = std::move(part); } - static Status UploadPartError(const S3Model::UploadPartRequest& req, - const S3Model::UploadPartOutcome& outcome) { - return ErrorToStatus( - std::forward_as_tuple("When uploading part for key '", req.GetKey(), - "' in bucket '", req.GetBucket(), "': "), - "UploadPart", outcome.GetError()); - } - protected: std::shared_ptr holder_; const io::IOContext io_context_; @@ -1938,8 +2151,9 @@ class ObjectOutputStream final : public io::OutputStream { const std::shared_ptr metadata_; const std::shared_ptr default_metadata_; const bool background_writes_; + const bool allow_delayed_open_; - Aws::String upload_id_; + Aws::String multipart_upload_id_; bool closed_ = true; int64_t pos_ = 0; int32_t part_number_ = 1; @@ -1950,10 +2164,11 @@ class ObjectOutputStream final : public io::OutputStream { // in the completion handler. struct UploadState { std::mutex mutex; + // Only populated for multi-part uploads. Aws::Vector completed_parts; - int64_t parts_in_progress = 0; + int64_t uploads_in_progress = 0; Status status; - Future<> pending_parts_completed = Future<>::MakeFinished(Status::OK()); + Future<> pending_uploads_completed = Future<>::MakeFinished(Status::OK()); }; std::shared_ptr upload_state_; }; diff --git a/cpp/src/arrow/filesystem/s3fs.h b/cpp/src/arrow/filesystem/s3fs.h index fbbe9d0b3f42b..85d5ff8fed553 100644 --- a/cpp/src/arrow/filesystem/s3fs.h +++ b/cpp/src/arrow/filesystem/s3fs.h @@ -177,6 +177,16 @@ struct ARROW_EXPORT S3Options { /// to be true to address these scenarios. bool check_directory_existence_before_creation = false; + /// Whether to allow file-open methods to return before the actual open. + /// + /// Enabling this may reduce the latency of `OpenInputStream`, `OpenOutputStream`, + /// and similar methods, by reducing the number of roundtrips necessary. It may also + /// allow usage of more efficient S3 APIs for small files. + /// The downside is that failure conditions such as attempting to open a file in a + /// non-existing bucket will only be reported when actual I/O is done (at worse, + /// when attempting to close the file). + bool allow_delayed_open = false; + /// \brief Default metadata for OpenOutputStream. /// /// This will be ignored if non-empty metadata is passed to OpenOutputStream. diff --git a/cpp/src/arrow/filesystem/s3fs_test.cc b/cpp/src/arrow/filesystem/s3fs_test.cc index 5a160a78ceea0..c33fa4f5aac97 100644 --- a/cpp/src/arrow/filesystem/s3fs_test.cc +++ b/cpp/src/arrow/filesystem/s3fs_test.cc @@ -45,7 +45,9 @@ #include #include #include +#include #include +#include #include #include @@ -450,25 +452,8 @@ class TestS3FS : public S3TestMixin { req.SetBucket(ToAwsString("empty-bucket")); ASSERT_OK(OutcomeToStatus("CreateBucket", client_->CreateBucket(req))); } - { - Aws::S3::Model::PutObjectRequest req; - req.SetBucket(ToAwsString("bucket")); - req.SetKey(ToAwsString("emptydir/")); - req.SetBody(std::make_shared("")); - ASSERT_OK(OutcomeToStatus("PutObject", client_->PutObject(req))); - // NOTE: no need to create intermediate "directories" somedir/ and - // somedir/subdir/ - req.SetKey(ToAwsString("somedir/subdir/subfile")); - req.SetBody(std::make_shared("sub data")); - ASSERT_OK(OutcomeToStatus("PutObject", client_->PutObject(req))); - req.SetKey(ToAwsString("somefile")); - req.SetBody(std::make_shared("some data")); - req.SetContentType("x-arrow/test"); - ASSERT_OK(OutcomeToStatus("PutObject", client_->PutObject(req))); - req.SetKey(ToAwsString("otherdir/1/2/3/otherfile")); - req.SetBody(std::make_shared("other data")); - ASSERT_OK(OutcomeToStatus("PutObject", client_->PutObject(req))); - } + + ASSERT_OK(PopulateTestBucket()); } void TearDown() override { @@ -478,6 +463,72 @@ class TestS3FS : public S3TestMixin { S3TestMixin::TearDown(); } + Status PopulateTestBucket() { + Aws::S3::Model::PutObjectRequest req; + req.SetBucket(ToAwsString("bucket")); + req.SetKey(ToAwsString("emptydir/")); + req.SetBody(std::make_shared("")); + RETURN_NOT_OK(OutcomeToStatus("PutObject", client_->PutObject(req))); + // NOTE: no need to create intermediate "directories" somedir/ and + // somedir/subdir/ + req.SetKey(ToAwsString("somedir/subdir/subfile")); + req.SetBody(std::make_shared("sub data")); + RETURN_NOT_OK(OutcomeToStatus("PutObject", client_->PutObject(req))); + req.SetKey(ToAwsString("somefile")); + req.SetBody(std::make_shared("some data")); + req.SetContentType("x-arrow/test"); + RETURN_NOT_OK(OutcomeToStatus("PutObject", client_->PutObject(req))); + req.SetKey(ToAwsString("otherdir/1/2/3/otherfile")); + req.SetBody(std::make_shared("other data")); + RETURN_NOT_OK(OutcomeToStatus("PutObject", client_->PutObject(req))); + + return Status::OK(); + } + + Status RestoreTestBucket() { + // First empty the test bucket, and then re-upload initial test files. + + Aws::S3::Model::Delete delete_object; + { + // Mostly taken from + // https://github.com/awsdocs/aws-doc-sdk-examples/blob/main/cpp/example_code/s3/list_objects.cpp + Aws::S3::Model::ListObjectsV2Request req; + req.SetBucket(Aws::String{"bucket"}); + + Aws::String continuation_token; + do { + if (!continuation_token.empty()) { + req.SetContinuationToken(continuation_token); + } + + auto outcome = client_->ListObjectsV2(req); + + if (!outcome.IsSuccess()) { + return OutcomeToStatus("ListObjectsV2", outcome); + } else { + Aws::Vector objects = outcome.GetResult().GetContents(); + for (const auto& object : objects) { + delete_object.AddObjects( + Aws::S3::Model::ObjectIdentifier().WithKey(object.GetKey())); + } + + continuation_token = outcome.GetResult().GetNextContinuationToken(); + } + } while (!continuation_token.empty()); + } + + { + Aws::S3::Model::DeleteObjectsRequest req; + + req.SetDelete(std::move(delete_object)); + req.SetBucket(Aws::String{"bucket"}); + + RETURN_NOT_OK(OutcomeToStatus("DeleteObjects", client_->DeleteObjects(req))); + } + + return PopulateTestBucket(); + } + Result> MakeNewFileSystem( io::IOContext io_context = io::default_io_context()) { options_.ConfigureAccessKey(minio_->access_key(), minio_->secret_key()); @@ -518,11 +569,13 @@ class TestS3FS : public S3TestMixin { AssertFileInfo(infos[11], "empty-bucket", FileType::Directory); } - void TestOpenOutputStream() { + void TestOpenOutputStream(bool allow_delayed_open) { std::shared_ptr stream; - // Nonexistent - ASSERT_RAISES(IOError, fs_->OpenOutputStream("nonexistent-bucket/somefile")); + if (!allow_delayed_open) { + // Nonexistent + ASSERT_RAISES(IOError, fs_->OpenOutputStream("nonexistent-bucket/somefile")); + } // URI ASSERT_RAISES(Invalid, fs_->OpenOutputStream("s3:bucket/newfile1")); @@ -843,8 +896,8 @@ TEST_F(TestS3FS, GetFileInfoGenerator) { TEST_F(TestS3FS, GetFileInfoGeneratorStress) { // This test is slow because it needs to create a bunch of seed files. However, it is - // the only test that stresses listing and deleting when there are more than 1000 files - // and paging is required. + // the only test that stresses listing and deleting when there are more than 1000 + // files and paging is required. constexpr int32_t kNumDirs = 4; constexpr int32_t kNumFilesPerDir = 512; FileInfoVector expected_infos; @@ -1235,50 +1288,83 @@ TEST_F(TestS3FS, OpenInputFile) { ASSERT_RAISES(IOError, file->Seek(10)); } -TEST_F(TestS3FS, OpenOutputStreamBackgroundWrites) { TestOpenOutputStream(); } +struct S3OptionsTestParameters { + bool background_writes{false}; + bool allow_delayed_open{false}; -TEST_F(TestS3FS, OpenOutputStreamSyncWrites) { - options_.background_writes = false; - MakeFileSystem(); - TestOpenOutputStream(); -} + void ApplyToS3Options(S3Options* options) const { + options->background_writes = background_writes; + options->allow_delayed_open = allow_delayed_open; + } -TEST_F(TestS3FS, OpenOutputStreamAbortBackgroundWrites) { TestOpenOutputStreamAbort(); } + static std::vector GetCartesianProduct() { + return { + S3OptionsTestParameters{true, false}, + S3OptionsTestParameters{false, false}, + S3OptionsTestParameters{true, true}, + S3OptionsTestParameters{false, true}, + }; + } -TEST_F(TestS3FS, OpenOutputStreamAbortSyncWrites) { - options_.background_writes = false; - MakeFileSystem(); - TestOpenOutputStreamAbort(); -} + std::string ToString() const { + return std::string("background_writes = ") + (background_writes ? "true" : "false") + + ", allow_delayed_open = " + (allow_delayed_open ? "true" : "false"); + } +}; + +TEST_F(TestS3FS, OpenOutputStream) { + for (const auto& combination : S3OptionsTestParameters::GetCartesianProduct()) { + ARROW_SCOPED_TRACE(combination.ToString()); -TEST_F(TestS3FS, OpenOutputStreamDestructorBackgroundWrites) { - TestOpenOutputStreamDestructor(); + combination.ApplyToS3Options(&options_); + MakeFileSystem(); + TestOpenOutputStream(combination.allow_delayed_open); + ASSERT_OK(RestoreTestBucket()); + } } -TEST_F(TestS3FS, OpenOutputStreamDestructorSyncWrite) { - options_.background_writes = false; - MakeFileSystem(); - TestOpenOutputStreamDestructor(); +TEST_F(TestS3FS, OpenOutputStreamAbort) { + for (const auto& combination : S3OptionsTestParameters::GetCartesianProduct()) { + ARROW_SCOPED_TRACE(combination.ToString()); + + combination.ApplyToS3Options(&options_); + MakeFileSystem(); + TestOpenOutputStreamAbort(); + ASSERT_OK(RestoreTestBucket()); + } } -TEST_F(TestS3FS, OpenOutputStreamAsyncDestructorBackgroundWrites) { - TestOpenOutputStreamCloseAsyncDestructor(); +TEST_F(TestS3FS, OpenOutputStreamDestructor) { + for (const auto& combination : S3OptionsTestParameters::GetCartesianProduct()) { + ARROW_SCOPED_TRACE(combination.ToString()); + + combination.ApplyToS3Options(&options_); + MakeFileSystem(); + TestOpenOutputStreamDestructor(); + ASSERT_OK(RestoreTestBucket()); + } } -TEST_F(TestS3FS, OpenOutputStreamAsyncDestructorSyncWrite) { - options_.background_writes = false; - MakeFileSystem(); - TestOpenOutputStreamCloseAsyncDestructor(); +TEST_F(TestS3FS, OpenOutputStreamAsync) { + for (const auto& combination : S3OptionsTestParameters::GetCartesianProduct()) { + ARROW_SCOPED_TRACE(combination.ToString()); + + combination.ApplyToS3Options(&options_); + MakeFileSystem(); + TestOpenOutputStreamCloseAsyncDestructor(); + } } TEST_F(TestS3FS, OpenOutputStreamCloseAsyncFutureDeadlockBackgroundWrites) { TestOpenOutputStreamCloseAsyncFutureDeadlock(); + ASSERT_OK(RestoreTestBucket()); } TEST_F(TestS3FS, OpenOutputStreamCloseAsyncFutureDeadlockSyncWrite) { options_.background_writes = false; MakeFileSystem(); TestOpenOutputStreamCloseAsyncFutureDeadlock(); + ASSERT_OK(RestoreTestBucket()); } TEST_F(TestS3FS, OpenOutputStreamMetadata) { @@ -1396,8 +1482,8 @@ TEST_F(TestS3FS, CustomRetryStrategy) { auto retry_strategy = std::make_shared(); options_.retry_strategy = retry_strategy; MakeFileSystem(); - // Attempt to open file that doesn't exist. Should hit TestRetryStrategy::ShouldRetry() - // 3 times before bubbling back up here. + // Attempt to open file that doesn't exist. Should hit + // TestRetryStrategy::ShouldRetry() 3 times before bubbling back up here. ASSERT_RAISES(IOError, fs_->OpenInputStream("nonexistent-bucket/somefile")); ASSERT_EQ(retry_strategy->GetErrorsEncountered().size(), 3); for (const auto& error : retry_strategy->GetErrorsEncountered()) { diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 58a3ba4ab83e5..d0aee8ab9b3d2 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -584,8 +584,8 @@ arrow::Result> FlightClient::DoAction( arrow::Result FlightClient::CancelFlightInfo( const FlightCallOptions& options, const CancelFlightInfoRequest& request) { - ARROW_ASSIGN_OR_RAISE(auto body, request.SerializeToString()); - Action action{ActionType::kCancelFlightInfo.type, Buffer::FromString(body)}; + ARROW_ASSIGN_OR_RAISE(auto body, request.SerializeToBuffer()); + Action action{ActionType::kCancelFlightInfo.type, std::move(body)}; ARROW_ASSIGN_OR_RAISE(auto stream, DoAction(options, action)); ARROW_ASSIGN_OR_RAISE(auto result, stream->Next()); ARROW_ASSIGN_OR_RAISE(auto cancel_result, CancelFlightInfoResult::Deserialize( @@ -596,8 +596,8 @@ arrow::Result FlightClient::CancelFlightInfo( arrow::Result FlightClient::RenewFlightEndpoint( const FlightCallOptions& options, const RenewFlightEndpointRequest& request) { - ARROW_ASSIGN_OR_RAISE(auto body, request.SerializeToString()); - Action action{ActionType::kRenewFlightEndpoint.type, Buffer::FromString(body)}; + ARROW_ASSIGN_OR_RAISE(auto body, request.SerializeToBuffer()); + Action action{ActionType::kRenewFlightEndpoint.type, std::move(body)}; ARROW_ASSIGN_OR_RAISE(auto stream, DoAction(options, action)); ARROW_ASSIGN_OR_RAISE(auto result, stream->Next()); ARROW_ASSIGN_OR_RAISE(auto renewed_endpoint, @@ -716,8 +716,8 @@ arrow::Result FlightClient::DoExchange( ::arrow::Result FlightClient::SetSessionOptions( const FlightCallOptions& options, const SetSessionOptionsRequest& request) { RETURN_NOT_OK(CheckOpen()); - ARROW_ASSIGN_OR_RAISE(auto body, request.SerializeToString()); - Action action{ActionType::kSetSessionOptions.type, Buffer::FromString(body)}; + ARROW_ASSIGN_OR_RAISE(auto body, request.SerializeToBuffer()); + Action action{ActionType::kSetSessionOptions.type, std::move(body)}; ARROW_ASSIGN_OR_RAISE(auto stream, DoAction(options, action)); ARROW_ASSIGN_OR_RAISE(auto result, stream->Next()); ARROW_ASSIGN_OR_RAISE( @@ -730,8 +730,8 @@ ::arrow::Result FlightClient::SetSessionOptions( ::arrow::Result FlightClient::GetSessionOptions( const FlightCallOptions& options, const GetSessionOptionsRequest& request) { RETURN_NOT_OK(CheckOpen()); - ARROW_ASSIGN_OR_RAISE(auto body, request.SerializeToString()); - Action action{ActionType::kGetSessionOptions.type, Buffer::FromString(body)}; + ARROW_ASSIGN_OR_RAISE(auto body, request.SerializeToBuffer()); + Action action{ActionType::kGetSessionOptions.type, std::move(body)}; ARROW_ASSIGN_OR_RAISE(auto stream, DoAction(options, action)); ARROW_ASSIGN_OR_RAISE(auto result, stream->Next()); ARROW_ASSIGN_OR_RAISE( @@ -744,8 +744,8 @@ ::arrow::Result FlightClient::GetSessionOptions( ::arrow::Result FlightClient::CloseSession( const FlightCallOptions& options, const CloseSessionRequest& request) { RETURN_NOT_OK(CheckOpen()); - ARROW_ASSIGN_OR_RAISE(auto body, request.SerializeToString()); - Action action{ActionType::kCloseSession.type, Buffer::FromString(body)}; + ARROW_ASSIGN_OR_RAISE(auto body, request.SerializeToBuffer()); + Action action{ActionType::kCloseSession.type, std::move(body)}; ARROW_ASSIGN_OR_RAISE(auto stream, DoAction(options, action)); ARROW_ASSIGN_OR_RAISE(auto result, stream->Next()); ARROW_ASSIGN_OR_RAISE(auto close_session_result, diff --git a/cpp/src/arrow/flight/flight_internals_test.cc b/cpp/src/arrow/flight/flight_internals_test.cc index 57f4f3e030420..caab357ef8f4a 100644 --- a/cpp/src/arrow/flight/flight_internals_test.cc +++ b/cpp/src/arrow/flight/flight_internals_test.cc @@ -79,8 +79,9 @@ void TestRoundtrip(const std::vector& values, ASSERT_OK(internal::ToProto(values[i], &pb_value)); if constexpr (std::is_same_v) { - ASSERT_OK_AND_ASSIGN(FlightInfo value, internal::FromProto(pb_value)); - EXPECT_EQ(values[i], value); + FlightInfo::Data info_data; + ASSERT_OK(internal::FromProto(pb_value, &info_data)); + EXPECT_EQ(values[i], FlightInfo{std::move(info_data)}); } else if constexpr (std::is_same_v) { std::string data; ASSERT_OK(internal::FromProto(pb_value, &data)); @@ -152,9 +153,11 @@ TEST(FlightTypes, BasicAuth) { } TEST(FlightTypes, Criteria) { - std::vector values = {{""}, {"criteria"}}; - std::vector reprs = {"", - ""}; + std::vector values = {Criteria{""}, Criteria{"criteria"}}; + std::vector reprs = { + "", + "", + }; ASSERT_NO_FATAL_FAILURE(TestRoundtrip(values, reprs)); } @@ -191,14 +194,14 @@ TEST(FlightTypes, FlightEndpoint) { Timestamp expiration_time( std::chrono::duration_cast(expiration_time_duration)); std::vector values = { - {{""}, {}, std::nullopt, {}}, - {{"foo"}, {}, std::nullopt, {}}, - {{"bar"}, {}, std::nullopt, {"\xDE\xAD\xBE\xEF"}}, - {{"foo"}, {}, expiration_time, {}}, - {{"foo"}, {location1}, std::nullopt, {}}, - {{"bar"}, {location1}, std::nullopt, {}}, - {{"foo"}, {location2}, std::nullopt, {}}, - {{"foo"}, {location1, location2}, std::nullopt, {"\xba\xdd\xca\xfe"}}, + {Ticket{""}, {}, std::nullopt, {}}, + {Ticket{"foo"}, {}, std::nullopt, {}}, + {Ticket{"bar"}, {}, std::nullopt, {"\xDE\xAD\xBE\xEF"}}, + {Ticket{"foo"}, {}, expiration_time, {}}, + {Ticket{"foo"}, {location1}, std::nullopt, {}}, + {Ticket{"bar"}, {location1}, std::nullopt, {}}, + {Ticket{"foo"}, {location2}, std::nullopt, {}}, + {Ticket{"foo"}, {location1, location2}, std::nullopt, {"\xba\xdd\xca\xfe"}}, }; std::vector reprs = { " locations=[] " @@ -299,9 +302,9 @@ TEST(FlightTypes, PollInfo) { TEST(FlightTypes, Result) { std::vector values = { - {Buffer::FromString("")}, - {Buffer::FromString("foo")}, - {Buffer::FromString("bar")}, + Result{Buffer::FromString("")}, + Result{Buffer::FromString("foo")}, + Result{Buffer::FromString("bar")}, }; std::vector reprs = { "", @@ -333,9 +336,9 @@ TEST(FlightTypes, SchemaResult) { TEST(FlightTypes, Ticket) { std::vector values = { - {""}, - {"foo"}, - {"bar"}, + Ticket{""}, + Ticket{"foo"}, + Ticket{"bar"}, }; std::vector reprs = { "", diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index e179f3406d65e..101bb06b21288 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -998,7 +998,8 @@ TEST_F(TestFlightClient, ListFlights) { } TEST_F(TestFlightClient, ListFlightsWithCriteria) { - ASSERT_OK_AND_ASSIGN(auto listing, client_->ListFlights(FlightCallOptions(), {"foo"})); + ASSERT_OK_AND_ASSIGN(auto listing, + client_->ListFlights(FlightCallOptions{}, Criteria{"foo"})); std::unique_ptr info; ASSERT_OK_AND_ASSIGN(info, listing->Next()); ASSERT_TRUE(info == nullptr); diff --git a/cpp/src/arrow/flight/serialization_internal.cc b/cpp/src/arrow/flight/serialization_internal.cc index 10600d055b3a8..fedfc7d5cd590 100644 --- a/cpp/src/arrow/flight/serialization_internal.cc +++ b/cpp/src/arrow/flight/serialization_internal.cc @@ -251,22 +251,28 @@ Status ToProto(const FlightDescriptor& descriptor, pb::FlightDescriptor* pb_desc // FlightInfo -arrow::Result FromProto(const pb::FlightInfo& pb_info) { - FlightInfo::Data info; - RETURN_NOT_OK(FromProto(pb_info.flight_descriptor(), &info.descriptor)); +Status FromProto(const pb::FlightInfo& pb_info, FlightInfo::Data* info) { + RETURN_NOT_OK(FromProto(pb_info.flight_descriptor(), &info->descriptor)); - info.schema = pb_info.schema(); + info->schema = pb_info.schema(); - info.endpoints.resize(pb_info.endpoint_size()); + info->endpoints.resize(pb_info.endpoint_size()); for (int i = 0; i < pb_info.endpoint_size(); ++i) { - RETURN_NOT_OK(FromProto(pb_info.endpoint(i), &info.endpoints[i])); + RETURN_NOT_OK(FromProto(pb_info.endpoint(i), &info->endpoints[i])); } - info.total_records = pb_info.total_records(); - info.total_bytes = pb_info.total_bytes(); - info.ordered = pb_info.ordered(); - info.app_metadata = pb_info.app_metadata(); - return FlightInfo(std::move(info)); + info->total_records = pb_info.total_records(); + info->total_bytes = pb_info.total_bytes(); + info->ordered = pb_info.ordered(); + info->app_metadata = pb_info.app_metadata(); + return Status::OK(); +} + +Status FromProto(const pb::FlightInfo& pb_info, std::unique_ptr* info) { + FlightInfo::Data info_data; + RETURN_NOT_OK(FromProto(pb_info, &info_data)); + *info = std::make_unique(std::move(info_data)); + return Status::OK(); } Status FromProto(const pb::BasicAuth& pb_basic_auth, BasicAuth* basic_auth) { @@ -315,8 +321,9 @@ Status ToProto(const FlightInfo& info, pb::FlightInfo* pb_info) { Status FromProto(const pb::PollInfo& pb_info, PollInfo* info) { if (pb_info.has_info()) { - ARROW_ASSIGN_OR_RAISE(auto flight_info, FromProto(pb_info.info())); - info->info = std::make_unique(std::move(flight_info)); + FlightInfo::Data info_data; + RETURN_NOT_OK(FromProto(pb_info.info(), &info_data)); + info->info = std::make_unique(std::move(info_data)); } if (pb_info.has_flight_descriptor()) { FlightDescriptor descriptor; @@ -340,6 +347,13 @@ Status FromProto(const pb::PollInfo& pb_info, PollInfo* info) { return Status::OK(); } +Status FromProto(const pb::PollInfo& pb_info, std::unique_ptr* info) { + PollInfo poll_info; + RETURN_NOT_OK(FromProto(pb_info, &poll_info)); + *info = std::make_unique(std::move(poll_info)); + return Status::OK(); +} + Status ToProto(const PollInfo& info, pb::PollInfo* pb_info) { if (info.info) { RETURN_NOT_OK(ToProto(*info.info, pb_info->mutable_info())); @@ -360,8 +374,9 @@ Status ToProto(const PollInfo& info, pb::PollInfo* pb_info) { Status FromProto(const pb::CancelFlightInfoRequest& pb_request, CancelFlightInfoRequest* request) { - ARROW_ASSIGN_OR_RAISE(FlightInfo info, FromProto(pb_request.info())); - request->info = std::make_unique(std::move(info)); + FlightInfo::Data info_data; + RETURN_NOT_OK(FromProto(pb_request.info(), &info_data)); + request->info = std::make_unique(std::move(info_data)); return Status::OK(); } diff --git a/cpp/src/arrow/flight/serialization_internal.h b/cpp/src/arrow/flight/serialization_internal.h index 90dde87d3a5eb..9922cb61ac004 100644 --- a/cpp/src/arrow/flight/serialization_internal.h +++ b/cpp/src/arrow/flight/serialization_internal.h @@ -60,8 +60,10 @@ Status FromProto(const pb::FlightDescriptor& pb_descr, FlightDescriptor* descr); Status FromProto(const pb::FlightEndpoint& pb_endpoint, FlightEndpoint* endpoint); Status FromProto(const pb::RenewFlightEndpointRequest& pb_request, RenewFlightEndpointRequest* request); -arrow::Result FromProto(const pb::FlightInfo& pb_info); +Status FromProto(const pb::FlightInfo& pb_info, FlightInfo::Data* info); +Status FromProto(const pb::FlightInfo& pb_info, std::unique_ptr* info); Status FromProto(const pb::PollInfo& pb_info, PollInfo* info); +Status FromProto(const pb::PollInfo& pb_info, std::unique_ptr* info); Status FromProto(const pb::CancelFlightInfoRequest& pb_request, CancelFlightInfoRequest* request); Status FromProto(const pb::SchemaResult& pb_result, std::string* result); @@ -92,6 +94,7 @@ Status ToProto(const Result& result, pb::Result* pb_result); Status ToProto(const CancelFlightInfoResult& result, pb::CancelFlightInfoResult* pb_result); Status ToProto(const Criteria& criteria, pb::Criteria* pb_criteria); +Status ToProto(const Location& location, pb::Location* pb_location); Status ToProto(const SchemaResult& result, pb::SchemaResult* pb_result); Status ToProto(const Ticket& ticket, pb::Ticket* pb_ticket); Status ToProto(const BasicAuth& basic_auth, pb::BasicAuth* pb_basic_auth); diff --git a/cpp/src/arrow/flight/sql/example/sqlite_server.cc b/cpp/src/arrow/flight/sql/example/sqlite_server.cc index 20b234e90ad3b..0651e6111c25d 100644 --- a/cpp/src/arrow/flight/sql/example/sqlite_server.cc +++ b/cpp/src/arrow/flight/sql/example/sqlite_server.cc @@ -126,7 +126,7 @@ arrow::Result> DoGetSQLiteQuery( arrow::Result> GetFlightInfoForCommand( const FlightDescriptor& descriptor, const std::shared_ptr& schema) { std::vector endpoints{ - FlightEndpoint{{descriptor.cmd}, {}, std::nullopt, ""}}; + FlightEndpoint{Ticket{descriptor.cmd}, {}, std::nullopt, ""}}; ARROW_ASSIGN_OR_RAISE(auto result, FlightInfo::Make(*schema, descriptor, endpoints, -1, -1, false)) @@ -389,7 +389,7 @@ class SQLiteFlightSqlServer::Impl { const ServerCallContext& context, const GetTables& command, const FlightDescriptor& descriptor) { std::vector endpoints{ - FlightEndpoint{{descriptor.cmd}, {}, std::nullopt, ""}}; + FlightEndpoint{Ticket{descriptor.cmd}, {}, std::nullopt, ""}}; bool include_schema = command.include_schema; ARROW_LOG(INFO) << "GetTables include_schema=" << include_schema; diff --git a/cpp/src/arrow/flight/sql/server.cc b/cpp/src/arrow/flight/sql/server.cc index 63d1f5c5225fa..ac89976690877 100644 --- a/cpp/src/arrow/flight/sql/server.cc +++ b/cpp/src/arrow/flight/sql/server.cc @@ -477,13 +477,11 @@ arrow::Result PackActionResult(ActionBeginTransactionResult result) { } arrow::Result PackActionResult(CancelFlightInfoResult result) { - ARROW_ASSIGN_OR_RAISE(auto serialized, result.SerializeToString()); - return Result{Buffer::FromString(std::move(serialized))}; + return result.SerializeToBuffer(); } arrow::Result PackActionResult(const FlightEndpoint& endpoint) { - ARROW_ASSIGN_OR_RAISE(auto serialized, endpoint.SerializeToString()); - return Result{Buffer::FromString(std::move(serialized))}; + return endpoint.SerializeToBuffer(); } arrow::Result PackActionResult(CancelResult result) { @@ -525,21 +523,6 @@ arrow::Result PackActionResult(ActionCreatePreparedStatementResult resul return PackActionResult(pb_result); } -arrow::Result PackActionResult(SetSessionOptionsResult result) { - ARROW_ASSIGN_OR_RAISE(auto serialized, result.SerializeToString()); - return Result{Buffer::FromString(std::move(serialized))}; -} - -arrow::Result PackActionResult(GetSessionOptionsResult result) { - ARROW_ASSIGN_OR_RAISE(auto serialized, result.SerializeToString()); - return Result{Buffer::FromString(std::move(serialized))}; -} - -arrow::Result PackActionResult(CloseSessionResult result) { - ARROW_ASSIGN_OR_RAISE(auto serialized, result.SerializeToString()); - return Result{Buffer::FromString(std::move(serialized))}; -} - } // namespace arrow::Result StatementQueryTicket::Deserialize( @@ -908,23 +891,23 @@ Status FlightSqlServerBase::DoAction(const ServerCallContext& context, std::string_view body(*action.body); ARROW_ASSIGN_OR_RAISE(auto request, SetSessionOptionsRequest::Deserialize(body)); ARROW_ASSIGN_OR_RAISE(auto result, SetSessionOptions(context, request)); - ARROW_ASSIGN_OR_RAISE(auto packed_result, PackActionResult(std::move(result))); + ARROW_ASSIGN_OR_RAISE(auto packed_result, result.SerializeToBuffer()); - results.push_back(std::move(packed_result)); + results.emplace_back(std::move(packed_result)); } else if (action.type == ActionType::kGetSessionOptions.type) { std::string_view body(*action.body); ARROW_ASSIGN_OR_RAISE(auto request, GetSessionOptionsRequest::Deserialize(body)); ARROW_ASSIGN_OR_RAISE(auto result, GetSessionOptions(context, request)); - ARROW_ASSIGN_OR_RAISE(auto packed_result, PackActionResult(std::move(result))); + ARROW_ASSIGN_OR_RAISE(auto packed_result, result.SerializeToBuffer()); - results.push_back(std::move(packed_result)); + results.emplace_back(std::move(packed_result)); } else if (action.type == ActionType::kCloseSession.type) { std::string_view body(*action.body); ARROW_ASSIGN_OR_RAISE(auto request, CloseSessionRequest::Deserialize(body)); ARROW_ASSIGN_OR_RAISE(auto result, CloseSession(context, request)); - ARROW_ASSIGN_OR_RAISE(auto packed_result, PackActionResult(std::move(result))); + ARROW_ASSIGN_OR_RAISE(auto packed_result, result.SerializeToBuffer()); - results.push_back(std::move(packed_result)); + results.emplace_back(std::move(packed_result)); } else { google::protobuf::Any any; if (!any.ParseFromArray(action.body->data(), static_cast(action.body->size()))) { @@ -1063,7 +1046,7 @@ arrow::Result> FlightSqlServerBase::GetFlightInfoSql } std::vector endpoints{ - FlightEndpoint{{descriptor.cmd}, {}, std::nullopt, {}}}; + FlightEndpoint{Ticket{descriptor.cmd}, {}, std::nullopt, {}}}; ARROW_ASSIGN_OR_RAISE( auto result, FlightInfo::Make(*SqlSchema::GetSqlInfoSchema(), descriptor, endpoints, -1, -1, false)) diff --git a/cpp/src/arrow/flight/test_util.cc b/cpp/src/arrow/flight/test_util.cc index bf2f4c2b4effc..8b4245e74e843 100644 --- a/cpp/src/arrow/flight/test_util.cc +++ b/cpp/src/arrow/flight/test_util.cc @@ -604,11 +604,11 @@ std::vector ExampleFlightInfo() { Location location4 = *Location::ForGrpcTcp("foo4.bar.com", 12345); Location location5 = *Location::ForGrpcTcp("foo5.bar.com", 12345); - FlightEndpoint endpoint1({{"ticket-ints-1"}, {location1}, std::nullopt, {}}); - FlightEndpoint endpoint2({{"ticket-ints-2"}, {location2}, std::nullopt, {}}); - FlightEndpoint endpoint3({{"ticket-cmd"}, {location3}, std::nullopt, {}}); - FlightEndpoint endpoint4({{"ticket-dicts-1"}, {location4}, std::nullopt, {}}); - FlightEndpoint endpoint5({{"ticket-floats-1"}, {location5}, std::nullopt, {}}); + FlightEndpoint endpoint1({Ticket{"ticket-ints-1"}, {location1}, std::nullopt, {}}); + FlightEndpoint endpoint2({Ticket{"ticket-ints-2"}, {location2}, std::nullopt, {}}); + FlightEndpoint endpoint3({Ticket{"ticket-cmd"}, {location3}, std::nullopt, {}}); + FlightEndpoint endpoint4({Ticket{"ticket-dicts-1"}, {location4}, std::nullopt, {}}); + FlightEndpoint endpoint5({Ticket{"ticket-floats-1"}, {location5}, std::nullopt, {}}); FlightDescriptor descr1{FlightDescriptor::PATH, "", {"examples", "ints"}}; FlightDescriptor descr2{FlightDescriptor::CMD, "my_command", {}}; diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_client.cc b/cpp/src/arrow/flight/transport/grpc/grpc_client.cc index f799ba761c40d..6d8d40c2ebcf8 100644 --- a/cpp/src/arrow/flight/transport/grpc/grpc_client.cc +++ b/cpp/src/arrow/flight/transport/grpc/grpc_client.cc @@ -648,10 +648,10 @@ class UnaryUnaryAsyncCall : public ::grpc::ClientUnaryReactor, public internal:: void OnDone(const ::grpc::Status& status) override { if (status.ok()) { - auto result = internal::FromProto(pb_response); - client_status = result.status(); + FlightInfo::Data info_data; + client_status = internal::FromProto(pb_response, &info_data); if (client_status.ok()) { - listener->OnNext(std::move(result).MoveValueUnsafe()); + listener->OnNext(FlightInfo{std::move(info_data)}); } } Finish(status); @@ -889,7 +889,8 @@ class GrpcClientImpl : public internal::ClientTransport { pb::FlightInfo pb_info; while (!options.stop_token.IsStopRequested() && stream->Read(&pb_info)) { - ARROW_ASSIGN_OR_RAISE(FlightInfo info_data, internal::FromProto(pb_info)); + FlightInfo::Data info_data; + RETURN_NOT_OK(internal::FromProto(pb_info, &info_data)); flights.emplace_back(std::move(info_data)); } if (options.stop_token.IsStopRequested()) rpc.context.TryCancel(); @@ -939,7 +940,8 @@ class GrpcClientImpl : public internal::ClientTransport { stub_->GetFlightInfo(&rpc.context, pb_descriptor, &pb_response), &rpc.context); RETURN_NOT_OK(s); - ARROW_ASSIGN_OR_RAISE(auto info_data, internal::FromProto(pb_response)); + FlightInfo::Data info_data; + RETURN_NOT_OK(internal::FromProto(pb_response, &info_data)); *info = std::make_unique(std::move(info_data)); return Status::OK(); } diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc index 946ac2d176203..a78b6f825a0e9 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc @@ -118,7 +118,7 @@ class ClientConnection { params.flags = UCP_EP_PARAMS_FLAGS_CLIENT_SERVER; params.name = "UcxClientImpl"; params.sockaddr.addr = reinterpret_cast(&connect_addr); - params.sockaddr.addrlen = addrlen; + params.sockaddr.addrlen = static_cast(addrlen); auto status = ucp_ep_create(ucp_worker_->get(), ¶ms, &remote_endpoint_); RETURN_NOT_OK(FromUcsStatus("ucp_ep_create", status)); diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc index cb9c8948ccf1e..b1096ece77b1b 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc @@ -258,7 +258,7 @@ class UcxServerImpl : public arrow::flight::internal::ServerTransport { params.field_mask = UCP_LISTENER_PARAM_FIELD_SOCK_ADDR | UCP_LISTENER_PARAM_FIELD_CONN_HANDLER; params.sockaddr.addr = reinterpret_cast(&listen_addr); - params.sockaddr.addrlen = addrlen; + params.sockaddr.addrlen = static_cast(addrlen); params.conn_handler.cb = HandleIncomingConnection; params.conn_handler.arg = this; @@ -376,7 +376,7 @@ class UcxServerImpl : public arrow::flight::internal::ServerTransport { std::unique_ptr info; std::string response; SERVER_RETURN_NOT_OK(driver, base_->GetFlightInfo(context, descriptor, &info)); - SERVER_RETURN_NOT_OK(driver, info->SerializeToString().Value(&response)); + SERVER_RETURN_NOT_OK(driver, info->SerializeToString(&response)); RETURN_NOT_OK(driver->SendFrame(FrameType::kBuffer, reinterpret_cast(response.data()), static_cast(response.size()))); @@ -397,7 +397,7 @@ class UcxServerImpl : public arrow::flight::internal::ServerTransport { std::unique_ptr info; std::string response; SERVER_RETURN_NOT_OK(driver, base_->PollFlightInfo(context, descriptor, &info)); - SERVER_RETURN_NOT_OK(driver, info->SerializeToString().Value(&response)); + SERVER_RETURN_NOT_OK(driver, info->SerializeToString(&response)); RETURN_NOT_OK(driver->SendFrame(FrameType::kBuffer, reinterpret_cast(response.data()), static_cast(response.size()))); diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc index a04956a4ea3f7..bb5932a312567 100644 --- a/cpp/src/arrow/flight/types.cc +++ b/cpp/src/arrow/flight/types.cc @@ -81,20 +81,17 @@ Status SerializeToString(const char* name, const T& in, PBType* out_pb, // Result-returning ser/de functions (more convenient) template -arrow::Result DeserializeProtoString(const char* name, std::string_view serialized) { +arrow::Status DeserializeProtoString(const char* name, std::string_view serialized, + T* out) { PBType pb; RETURN_NOT_OK(ParseFromString(name, serialized, &pb)); - T out; - RETURN_NOT_OK(internal::FromProto(pb, &out)); - return out; + return internal::FromProto(pb, out); } template -arrow::Result SerializeToProtoString(const char* name, const T& in) { +Status SerializeToProtoString(const char* name, const T& in, std::string* out) { PBType pb; - std::string out; - RETURN_NOT_OK(SerializeToString(name, in, &pb, &out)); - return out; + return SerializeToString(name, in, &pb, out); } } // namespace @@ -154,18 +151,57 @@ Status MakeFlightError(FlightStatusCode code, std::string message, std::make_shared(code, std::move(extra_info))); } -bool FlightDescriptor::Equals(const FlightDescriptor& other) const { - if (type != other.type) { - return false; +static std::ostream& operator<<(std::ostream& os, std::vector values) { + os << '['; + std::string sep = ""; + for (const auto& v : values) { + os << sep << std::quoted(v); + sep = ", "; } - switch (type) { - case PATH: - return path == other.path; - case CMD: - return cmd == other.cmd; - default: - return false; + os << ']'; + + return os; +} + +template +static std::ostream& operator<<(std::ostream& os, std::map m) { + os << '{'; + std::string sep = ""; + if constexpr (std::is_convertible_v) { + // std::string, char*, std::string_view + for (const auto& [k, v] : m) { + os << sep << '[' << k << "]: " << std::quoted(v) << '"'; + sep = ", "; + } + } else { + for (const auto& [k, v] : m) { + os << sep << '[' << k << "]: " << v; + sep = ", "; + } } + os << '}'; + + return os; +} + +//------------------------------------------------------------ +// Wrapper types for Flight RPC protobuf messages + +std::string BasicAuth::ToString() const { + return arrow::util::StringBuilder(""); +} + +bool BasicAuth::Equals(const BasicAuth& other) const { + return (username == other.username) && (password == other.password); +} + +arrow::Status BasicAuth::Deserialize(std::string_view serialized, BasicAuth* out) { + return DeserializeProtoString("BasicAuth", serialized, out); +} + +arrow::Status BasicAuth::SerializeToString(std::string* out) const { + return SerializeToProtoString("BasicAuth", *this, out); } std::string FlightDescriptor::ToString() const { @@ -195,75 +231,28 @@ std::string FlightDescriptor::ToString() const { return ss.str(); } -Status FlightPayload::Validate() const { - static constexpr int64_t kInt32Max = std::numeric_limits::max(); - if (descriptor && descriptor->size() > kInt32Max) { - return Status::CapacityError("Descriptor size overflow (>= 2**31)"); - } - if (app_metadata && app_metadata->size() > kInt32Max) { - return Status::CapacityError("app_metadata size overflow (>= 2**31)"); +bool FlightDescriptor::Equals(const FlightDescriptor& other) const { + if (type != other.type) { + return false; } - if (ipc_message.body_length > kInt32Max) { - return Status::Invalid("Cannot send record batches exceeding 2GiB yet"); + switch (type) { + case PATH: + return path == other.path; + case CMD: + return cmd == other.cmd; + default: + return false; } - return Status::OK(); -} - -arrow::Result> SchemaResult::GetSchema( - ipc::DictionaryMemo* dictionary_memo) const { - // Create a non-owned Buffer to avoid copying - io::BufferReader schema_reader(std::make_shared(raw_schema_)); - return ipc::ReadSchema(&schema_reader, dictionary_memo); -} - -arrow::Result> SchemaResult::Make(const Schema& schema) { - std::string schema_in; - RETURN_NOT_OK(internal::SchemaToString(schema, &schema_in)); - return std::make_unique(std::move(schema_in)); -} - -std::string SchemaResult::ToString() const { - return ""; -} - -bool SchemaResult::Equals(const SchemaResult& other) const { - return raw_schema_ == other.raw_schema_; -} - -arrow::Result SchemaResult::SerializeToString() const { - return SerializeToProtoString("SchemaResult", *this); } -arrow::Result SchemaResult::Deserialize(std::string_view serialized) { - pb::SchemaResult pb_schema_result; - RETURN_NOT_OK(ParseFromString("SchemaResult", serialized, &pb_schema_result)); - return SchemaResult{pb_schema_result.schema()}; +arrow::Status FlightDescriptor::SerializeToString(std::string* out) const { + return SerializeToProtoString("FlightDescriptor", *this, out); } -arrow::Result FlightDescriptor::SerializeToString() const { - return SerializeToProtoString("FlightDescriptor", *this); -} - -arrow::Result FlightDescriptor::Deserialize( - std::string_view serialized) { +arrow::Status FlightDescriptor::Deserialize(std::string_view serialized, + FlightDescriptor* out) { return DeserializeProtoString( - "FlightDescriptor", serialized); -} - -std::string Ticket::ToString() const { - std::stringstream ss; - ss << ""; - return ss.str(); -} - -bool Ticket::Equals(const Ticket& other) const { return ticket == other.ticket; } - -arrow::Result Ticket::SerializeToString() const { - return SerializeToProtoString("Ticket", *this); -} - -arrow::Result Ticket::Deserialize(std::string_view serialized) { - return DeserializeProtoString("Ticket", serialized); + "FlightDescriptor", serialized, out); } arrow::Result FlightInfo::Make(const Schema& schema, @@ -279,7 +268,7 @@ arrow::Result FlightInfo::Make(const Schema& schema, data.ordered = ordered; data.app_metadata = std::move(app_metadata); RETURN_NOT_OK(internal::SchemaToString(schema, &data.schema)); - return FlightInfo(data); + return FlightInfo(std::move(data)); } arrow::Result> FlightInfo::GetSchema( @@ -294,16 +283,14 @@ arrow::Result> FlightInfo::GetSchema( return schema_; } -arrow::Result FlightInfo::SerializeToString() const { - return SerializeToProtoString("FlightInfo", *this); +arrow::Status FlightInfo::SerializeToString(std::string* out) const { + return SerializeToProtoString("FlightInfo", *this, out); } -arrow::Result> FlightInfo::Deserialize( - std::string_view serialized) { - pb::FlightInfo pb_info; - RETURN_NOT_OK(ParseFromString("FlightInfo", serialized, &pb_info)); - ARROW_ASSIGN_OR_RAISE(FlightInfo info, internal::FromProto(pb_info)); - return std::make_unique(std::move(info)); +arrow::Status FlightInfo::Deserialize(std::string_view serialized, + std::unique_ptr* out) { + return DeserializeProtoString>( + "FlightInfo", serialized, out); } std::string FlightInfo::ToString() const { @@ -340,17 +327,14 @@ bool FlightInfo::Equals(const FlightInfo& other) const { data_.app_metadata == other.data_.app_metadata; } -arrow::Result PollInfo::SerializeToString() const { - return SerializeToProtoString("PollInfo", *this); +arrow::Status PollInfo::SerializeToString(std::string* out) const { + return SerializeToProtoString("PollInfo", *this, out); } -arrow::Result> PollInfo::Deserialize( - std::string_view serialized) { - pb::PollInfo pb_info; - RETURN_NOT_OK(ParseFromString("PollInfo", serialized, &pb_info)); - PollInfo info; - RETURN_NOT_OK(internal::FromProto(pb_info, &info)); - return std::make_unique(std::move(info)); +arrow::Status PollInfo::Deserialize(std::string_view serialized, + std::unique_ptr* out) { + return DeserializeProtoString>("PollInfo", + serialized, out); } std::string PollInfo::ToString() const { @@ -427,54 +411,60 @@ bool CancelFlightInfoRequest::Equals(const CancelFlightInfoRequest& other) const return info == other.info; } -arrow::Result CancelFlightInfoRequest::SerializeToString() const { +arrow::Status CancelFlightInfoRequest::SerializeToString(std::string* out) const { return SerializeToProtoString("CancelFlightInfoRequest", - *this); + *this, out); } -arrow::Result CancelFlightInfoRequest::Deserialize( - std::string_view serialized) { +arrow::Status CancelFlightInfoRequest::Deserialize(std::string_view serialized, + CancelFlightInfoRequest* out) { return DeserializeProtoString( - "CancelFlightInfoRequest", serialized); + "CancelFlightInfoRequest", serialized, out); } -static const char* const SetSessionOptionStatusNames[] = {"Unspecified", "InvalidName", - "InvalidValue", "Error"}; -static const char* const CloseSessionStatusNames[] = {"Unspecified", "Closed", "Closing", - "NotClosable"}; - -// Helpers for stringifying maps containing various types -std::string ToString(const SetSessionOptionErrorValue& error_value) { - return SetSessionOptionStatusNames[static_cast(error_value)]; +std::string CancelFlightInfoResult::ToString() const { + std::stringstream ss; + ss << ""; + return ss.str(); } -std::ostream& operator<<(std::ostream& os, - const SetSessionOptionErrorValue& error_value) { - os << ToString(error_value); - return os; +bool CancelFlightInfoResult::Equals(const CancelFlightInfoResult& other) const { + return status == other.status; } -std::string ToString(const CloseSessionStatus& status) { - return CloseSessionStatusNames[static_cast(status)]; +arrow::Status CancelFlightInfoResult::SerializeToString(std::string* out) const { + return SerializeToProtoString("CancelFlightInfoResult", + *this, out); } -std::ostream& operator<<(std::ostream& os, const CloseSessionStatus& status) { - os << ToString(status); - return os; +arrow::Status CancelFlightInfoResult::Deserialize(std::string_view serialized, + CancelFlightInfoResult* out) { + return DeserializeProtoString( + "CancelFlightInfoResult", serialized, out); } -std::ostream& operator<<(std::ostream& os, std::vector values) { - os << '['; - std::string sep = ""; - for (const auto& v : values) { - os << sep << std::quoted(v); - sep = ", "; +std::ostream& operator<<(std::ostream& os, CancelStatus status) { + switch (status) { + case CancelStatus::kUnspecified: + os << "Unspecified"; + break; + case CancelStatus::kCancelled: + os << "Cancelled"; + break; + case CancelStatus::kCancelling: + os << "Cancelling"; + break; + case CancelStatus::kNotCancellable: + os << "NotCancellable"; + break; } - os << ']'; - return os; } +// Session management messages + +// SessionOptionValue + std::ostream& operator<<(std::ostream& os, const SessionOptionValue& v) { if (std::holds_alternative(v)) { os << ""; @@ -493,33 +483,6 @@ std::ostream& operator<<(std::ostream& os, const SessionOptionValue& v) { return os; } -std::ostream& operator<<(std::ostream& os, const SetSessionOptionsResult::Error& e) { - os << '{' << e.value << '}'; - return os; -} - -template -std::ostream& operator<<(std::ostream& os, std::map m) { - os << '{'; - std::string sep = ""; - if constexpr (std::is_convertible_v) { - // std::string, char*, std::string_view - for (const auto& [k, v] : m) { - os << sep << '[' << k << "]: " << std::quoted(v) << '"'; - sep = ", "; - } - } else { - for (const auto& [k, v] : m) { - os << sep << '[' << k << "]: " << v; - sep = ", "; - } - } - os << '}'; - - return os; -} - -namespace { static bool CompareSessionOptionMaps(const std::map& a, const std::map& b) { if (a.size() != b.size()) { @@ -540,15 +503,30 @@ static bool CompareSessionOptionMaps(const std::map(error_value)]; +} + +std::ostream& operator<<(std::ostream& os, + const SetSessionOptionErrorValue& error_value) { + os << ToString(error_value); + return os; +} // SetSessionOptionsRequest std::string SetSessionOptionsRequest::ToString() const { std::stringstream ss; - ss << " SetSessionOptionsRequest::Deserialize( - std::string_view serialized) { +arrow::Status SetSessionOptionsRequest::Deserialize(std::string_view serialized, + SetSessionOptionsRequest* out) { return DeserializeProtoString( - "SetSessionOptionsRequest", serialized); + "SetSessionOptionsRequest", serialized, out); } // SetSessionOptionsResult +std::ostream& operator<<(std::ostream& os, const SetSessionOptionsResult::Error& e) { + os << '{' << e.value << '}'; + return os; +} + std::string SetSessionOptionsResult::ToString() const { std::stringstream ss; - ss << " SetSessionOptionsResult::Deserialize( - std::string_view serialized) { +arrow::Status SetSessionOptionsResult::Deserialize(std::string_view serialized, + SetSessionOptionsResult* out) { return DeserializeProtoString( - "SetSessionOptionsResult", serialized); + "SetSessionOptionsResult", serialized, out); } // GetSessionOptionsRequest @@ -605,15 +586,15 @@ bool GetSessionOptionsRequest::Equals(const GetSessionOptionsRequest& other) con return true; } -arrow::Result GetSessionOptionsRequest::SerializeToString() const { +arrow::Status GetSessionOptionsRequest::SerializeToString(std::string* out) const { return SerializeToProtoString("GetSessionOptionsRequest", - *this); + *this, out); } -arrow::Result GetSessionOptionsRequest::Deserialize( - std::string_view serialized) { +arrow::Status GetSessionOptionsRequest::Deserialize(std::string_view serialized, + GetSessionOptionsRequest* out) { return DeserializeProtoString( - "GetSessionOptionsRequest", serialized); + "GetSessionOptionsRequest", serialized, out); } // GetSessionOptionsResult @@ -628,15 +609,15 @@ bool GetSessionOptionsResult::Equals(const GetSessionOptionsResult& other) const return CompareSessionOptionMaps(session_options, other.session_options); } -arrow::Result GetSessionOptionsResult::SerializeToString() const { +arrow::Status GetSessionOptionsResult::SerializeToString(std::string* out) const { return SerializeToProtoString("GetSessionOptionsResult", - *this); + *this, out); } -arrow::Result GetSessionOptionsResult::Deserialize( - std::string_view serialized) { +arrow::Status GetSessionOptionsResult::Deserialize(std::string_view serialized, + GetSessionOptionsResult* out) { return DeserializeProtoString( - "GetSessionOptionsResult", serialized); + "GetSessionOptionsResult", serialized, out); } // CloseSessionRequest @@ -645,23 +626,39 @@ std::string CloseSessionRequest::ToString() const { return " CloseSessionRequest::SerializeToString() const { - return SerializeToProtoString("CloseSessionRequest", *this); +arrow::Status CloseSessionRequest::SerializeToString(std::string* out) const { + return SerializeToProtoString("CloseSessionRequest", *this, + out); } -arrow::Result CloseSessionRequest::Deserialize( - std::string_view serialized) { +arrow::Status CloseSessionRequest::Deserialize(std::string_view serialized, + CloseSessionRequest* out) { return DeserializeProtoString( - "CloseSessionRequest", serialized); + "CloseSessionRequest", serialized, out); +} + +// CloseSessionStatus + +std::string ToString(const CloseSessionStatus& status) { + static constexpr const char* CloseSessionStatusNames[] = { + "Unspecified", + "Closed", + "Closing", + "NotClosable", + }; + return CloseSessionStatusNames[static_cast(status)]; +} + +std::ostream& operator<<(std::ostream& os, const CloseSessionStatus& status) { + os << ToString(status); + return os; } // CloseSessionResult std::string CloseSessionResult::ToString() const { std::stringstream ss; - ss << "("CloseSessionResult", *this, out); } -arrow::Result CloseSessionResult::Deserialize( - std::string_view serialized) { +arrow::Status CloseSessionResult::Deserialize(std::string_view serialized, + CloseSessionResult* out) { return DeserializeProtoString( - "CloseSessionResult", serialized); + "CloseSessionResult", serialized, out); +} + +// Ticket + +std::string Ticket::ToString() const { + std::stringstream ss; + ss << ""; + return ss.str(); +} + +bool Ticket::Equals(const Ticket& other) const { return ticket == other.ticket; } + +arrow::Status Ticket::SerializeToString(std::string* out) const { + return SerializeToProtoString("Ticket", *this, out); +} + +arrow::Status Ticket::Deserialize(std::string_view serialized, Ticket* out) { + return DeserializeProtoString("Ticket", serialized, out); } Location::Location() { uri_ = std::make_shared(); } @@ -718,7 +733,6 @@ arrow::Result Location::ForScheme(const std::string& scheme, return Location::Parse(uri_string.str()); } -std::string Location::ToString() const { return uri_->ToString(); } std::string Location::scheme() const { std::string scheme = uri_->scheme(); if (scheme.empty()) { @@ -728,6 +742,8 @@ std::string Location::scheme() const { return scheme; } +std::string Location::ToString() const { return uri_->ToString(); } + bool Location::Equals(const Location& other) const { return ToString() == other.ToString(); } @@ -781,13 +797,22 @@ bool FlightEndpoint::Equals(const FlightEndpoint& other) const { return true; } -arrow::Result FlightEndpoint::SerializeToString() const { - return SerializeToProtoString("FlightEndpoint", *this); +arrow::Status Location::SerializeToString(std::string* out) const { + return SerializeToProtoString("Location", *this, out); +} + +arrow::Status Location::Deserialize(std::string_view serialized, Location* out) { + return DeserializeProtoString("Location", serialized, out); +} + +arrow::Status FlightEndpoint::SerializeToString(std::string* out) const { + return SerializeToProtoString("FlightEndpoint", *this, out); } -arrow::Result FlightEndpoint::Deserialize(std::string_view serialized) { +arrow::Status FlightEndpoint::Deserialize(std::string_view serialized, + FlightEndpoint* out) { return DeserializeProtoString("FlightEndpoint", - serialized); + serialized, out); } std::string RenewFlightEndpointRequest::ToString() const { @@ -800,16 +825,30 @@ bool RenewFlightEndpointRequest::Equals(const RenewFlightEndpointRequest& other) return endpoint == other.endpoint; } -arrow::Result RenewFlightEndpointRequest::SerializeToString() const { +arrow::Status RenewFlightEndpointRequest::SerializeToString(std::string* out) const { return SerializeToProtoString( - "RenewFlightEndpointRequest", *this); + "RenewFlightEndpointRequest", *this, out); } -arrow::Result RenewFlightEndpointRequest::Deserialize( - std::string_view serialized) { +arrow::Status RenewFlightEndpointRequest::Deserialize(std::string_view serialized, + RenewFlightEndpointRequest* out) { return DeserializeProtoString("RenewFlightEndpointRequest", - serialized); + serialized, out); +} + +Status FlightPayload::Validate() const { + static constexpr int64_t kInt32Max = std::numeric_limits::max(); + if (descriptor && descriptor->size() > kInt32Max) { + return Status::CapacityError("Descriptor size overflow (>= 2**31)"); + } + if (app_metadata && app_metadata->size() > kInt32Max) { + return Status::CapacityError("app_metadata size overflow (>= 2**31)"); + } + if (ipc_message.body_length > kInt32Max) { + return Status::Invalid("Cannot send record batches exceeding 2GiB yet"); + } + return Status::OK(); } std::string ActionType::ToString() const { @@ -847,12 +886,13 @@ bool ActionType::Equals(const ActionType& other) const { return type == other.type && description == other.description; } -arrow::Result ActionType::SerializeToString() const { - return SerializeToProtoString("ActionType", *this); +arrow::Status ActionType::SerializeToString(std::string* out) const { + return SerializeToProtoString("ActionType", *this, out); } -arrow::Result ActionType::Deserialize(std::string_view serialized) { - return DeserializeProtoString("ActionType", serialized); +arrow::Status ActionType::Deserialize(std::string_view serialized, ActionType* out) { + return DeserializeProtoString("ActionType", serialized, + out); } std::string Criteria::ToString() const { @@ -863,12 +903,12 @@ bool Criteria::Equals(const Criteria& other) const { return expression == other.expression; } -arrow::Result Criteria::SerializeToString() const { - return SerializeToProtoString("Criteria", *this); +arrow::Status Criteria::SerializeToString(std::string* out) const { + return SerializeToProtoString("Criteria", *this, out); } -arrow::Result Criteria::Deserialize(std::string_view serialized) { - return DeserializeProtoString("Criteria", serialized); +arrow::Status Criteria::Deserialize(std::string_view serialized, Criteria* out) { + return DeserializeProtoString("Criteria", serialized, out); } std::string Action::ToString() const { @@ -889,12 +929,12 @@ bool Action::Equals(const Action& other) const { ((body == other.body) || (body && other.body && body->Equals(*other.body))); } -arrow::Result Action::SerializeToString() const { - return SerializeToProtoString("Action", *this); +arrow::Status Action::SerializeToString(std::string* out) const { + return SerializeToProtoString("Action", *this, out); } -arrow::Result Action::Deserialize(std::string_view serialized) { - return DeserializeProtoString("Action", serialized); +arrow::Status Action::Deserialize(std::string_view serialized, Action* out) { + return DeserializeProtoString("Action", serialized, out); } std::string Result::ToString() const { @@ -912,53 +952,48 @@ bool Result::Equals(const Result& other) const { return (body == other.body) || (body && other.body && body->Equals(*other.body)); } -arrow::Result Result::SerializeToString() const { - return SerializeToProtoString("Result", *this); +arrow::Status Result::SerializeToString(std::string* out) const { + return SerializeToProtoString("Result", *this, out); } -arrow::Result Result::Deserialize(std::string_view serialized) { - return DeserializeProtoString("Result", serialized); +arrow::Status Result::Deserialize(std::string_view serialized, Result* out) { + return DeserializeProtoString("Result", serialized, out); } -std::string CancelFlightInfoResult::ToString() const { - std::stringstream ss; - ss << ""; - return ss.str(); +arrow::Result> SchemaResult::GetSchema( + ipc::DictionaryMemo* dictionary_memo) const { + // Create a non-owned Buffer to avoid copying + io::BufferReader schema_reader(std::make_shared(raw_schema_)); + return ipc::ReadSchema(&schema_reader, dictionary_memo); } -bool CancelFlightInfoResult::Equals(const CancelFlightInfoResult& other) const { - return status == other.status; +arrow::Result> SchemaResult::Make(const Schema& schema) { + std::string schema_in; + RETURN_NOT_OK(internal::SchemaToString(schema, &schema_in)); + return std::make_unique(std::move(schema_in)); } -arrow::Result CancelFlightInfoResult::SerializeToString() const { - return SerializeToProtoString("CancelFlightInfoResult", - *this); +std::string SchemaResult::ToString() const { + return ""; } -arrow::Result CancelFlightInfoResult::Deserialize( - std::string_view serialized) { - return DeserializeProtoString( - "CancelFlightInfoResult", serialized); +bool SchemaResult::Equals(const SchemaResult& other) const { + return raw_schema_ == other.raw_schema_; } -std::ostream& operator<<(std::ostream& os, CancelStatus status) { - switch (status) { - case CancelStatus::kUnspecified: - os << "Unspecified"; - break; - case CancelStatus::kCancelled: - os << "Cancelled"; - break; - case CancelStatus::kCancelling: - os << "Cancelling"; - break; - case CancelStatus::kNotCancellable: - os << "NotCancellable"; - break; - } - return os; +arrow::Status SchemaResult::SerializeToString(std::string* out) const { + return SerializeToProtoString("SchemaResult", *this, out); +} + +arrow::Status SchemaResult::Deserialize(std::string_view serialized, SchemaResult* out) { + pb::SchemaResult pb_schema_result; + RETURN_NOT_OK(ParseFromString("SchemaResult", serialized, &pb_schema_result)); + *out = SchemaResult{pb_schema_result.schema()}; + return Status::OK(); } +//------------------------------------------------------------ + Status ResultStream::Drain() { while (true) { ARROW_ASSIGN_OR_RAISE(auto result, Next()); @@ -1046,23 +1081,6 @@ arrow::Result> SimpleResultStream::Next() { return std::make_unique(std::move(results_[position_++])); } -std::string BasicAuth::ToString() const { - return arrow::util::StringBuilder(""); -} - -bool BasicAuth::Equals(const BasicAuth& other) const { - return (username == other.username) && (password == other.password); -} - -arrow::Result BasicAuth::Deserialize(std::string_view serialized) { - return DeserializeProtoString("BasicAuth", serialized); -} - -arrow::Result BasicAuth::SerializeToString() const { - return SerializeToProtoString("BasicAuth", *this); -} - //------------------------------------------------------------ // Error propagation helpers diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index cdf03f21041ee..de93750f75b25 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -31,6 +31,7 @@ #include #include +#include "arrow/buffer.h" #include "arrow/flight/type_fwd.h" #include "arrow/flight/visibility.h" #include "arrow/ipc/options.h" @@ -60,6 +61,18 @@ class Uri; namespace flight { +ARROW_FLIGHT_EXPORT +extern const char* kSchemeGrpc; +ARROW_FLIGHT_EXPORT +extern const char* kSchemeGrpcTcp; +ARROW_FLIGHT_EXPORT +extern const char* kSchemeGrpcUnix; +ARROW_FLIGHT_EXPORT +extern const char* kSchemeGrpcTls; + +class FlightClient; +class FlightServerBase; + /// \brief A timestamp compatible with Protocol Buffer's /// google.protobuf.Timestamp: /// @@ -159,29 +172,122 @@ struct ARROW_FLIGHT_EXPORT CertKeyPair { std::string pem_key; }; +namespace internal { + +template +struct remove_unique_ptr { + using type = T; +}; + +template +struct remove_unique_ptr> { + using type = T; +}; + +// Base CRTP type +template +struct BaseType { + protected: + using SuperT = BaseType; + using SelfT = typename remove_unique_ptr::type; + + const SelfT& self() const { return static_cast(*this); } + SelfT& self() { return static_cast(*this); } + + public: + BaseType() = default; + + friend bool operator==(const SelfT& left, const SelfT& right) { + return left.Equals(right); + } + friend bool operator!=(const SelfT& left, const SelfT& right) { + return !left.Equals(right); + } + + /// \brief Serialize this message to its wire-format representation. + inline arrow::Result SerializeToString() const { + std::string out; + ARROW_RETURN_NOT_OK(self().SelfT::SerializeToString(&out)); + return out; + } + + inline static arrow::Result Deserialize(std::string_view serialized) { + T out; + ARROW_RETURN_NOT_OK(SelfT::Deserialize(serialized, &out)); + return out; + } + + inline arrow::Result> SerializeToBuffer() const { + std::string out; + ARROW_RETURN_NOT_OK(self().SelfT::SerializeToString(&out)); + return Buffer::FromString(std::move(out)); + } +}; + +} // namespace internal + +//------------------------------------------------------------ +// Wrapper types for Flight RPC protobuf messages + +// A wrapper around arrow.flight.protocol.HandshakeRequest is not defined +// A wrapper around arrow.flight.protocol.HandshakeResponse is not defined + +/// \brief message for simple auth +struct ARROW_FLIGHT_EXPORT BasicAuth : public internal::BaseType { + std::string username; + std::string password; + + BasicAuth() = default; + BasicAuth(std::string username, std::string password) + : username(std::move(username)), password(std::move(password)) {} + + std::string ToString() const; + bool Equals(const BasicAuth& other) const; + + using SuperT::Deserialize; + using SuperT::SerializeToString; + + /// \brief Serialize this message to its wire-format representation. + /// + /// Use `SerializeToString()` if you want a Result-returning version. + arrow::Status SerializeToString(std::string* out) const; + + /// \brief Deserialize this message from its wire-format representation. + /// + /// Use `Deserialize(serialized)` if you want a Result-returning version. + static arrow::Status Deserialize(std::string_view serialized, BasicAuth* out); +}; + +// A wrapper around arrow.flight.protocol.Empty is not defined + /// \brief A type of action that can be performed with the DoAction RPC. -struct ARROW_FLIGHT_EXPORT ActionType { +struct ARROW_FLIGHT_EXPORT ActionType : public internal::BaseType { /// \brief The name of the action. std::string type; /// \brief A human-readable description of the action. std::string description; + ActionType() = default; + + ActionType(std::string type, std::string description) + : type(std::move(type)), description(std::move(description)) {} + std::string ToString() const; bool Equals(const ActionType& other) const; - friend bool operator==(const ActionType& left, const ActionType& right) { - return left.Equals(right); - } - friend bool operator!=(const ActionType& left, const ActionType& right) { - return !(left == right); - } + using SuperT::Deserialize; + using SuperT::SerializeToString; /// \brief Serialize this message to its wire-format representation. - arrow::Result SerializeToString() const; + /// + /// Use `SerializeToString()` if you want a Result-returning version. + arrow::Status SerializeToString(std::string* out) const; /// \brief Deserialize this message from its wire-format representation. - static arrow::Result Deserialize(std::string_view serialized); + /// + /// Use `Deserialize(serialized)` if you want a Result-returning version. + static arrow::Status Deserialize(std::string_view serialized, ActionType* out); static const ActionType kCancelFlightInfo; static const ActionType kRenewFlightEndpoint; @@ -191,138 +297,126 @@ struct ARROW_FLIGHT_EXPORT ActionType { }; /// \brief Opaque selection criteria for ListFlights RPC -struct ARROW_FLIGHT_EXPORT Criteria { +struct ARROW_FLIGHT_EXPORT Criteria : public internal::BaseType { /// Opaque criteria expression, dependent on server implementation std::string expression; + Criteria() = default; + Criteria(std::string expression) // NOLINT runtime/explicit + : expression(std::move(expression)) {} + std::string ToString() const; bool Equals(const Criteria& other) const; - friend bool operator==(const Criteria& left, const Criteria& right) { - return left.Equals(right); - } - friend bool operator!=(const Criteria& left, const Criteria& right) { - return !(left == right); - } + using SuperT::Deserialize; + using SuperT::SerializeToString; /// \brief Serialize this message to its wire-format representation. - arrow::Result SerializeToString() const; + /// + /// Use `SerializeToString()` if you want a Result-returning version. + arrow::Status SerializeToString(std::string* out) const; /// \brief Deserialize this message from its wire-format representation. - static arrow::Result Deserialize(std::string_view serialized); + /// + /// Use `Deserialize(serialized)` if you want a Result-returning version. + static arrow::Status Deserialize(std::string_view serialized, Criteria* out); }; /// \brief An action to perform with the DoAction RPC -struct ARROW_FLIGHT_EXPORT Action { +struct ARROW_FLIGHT_EXPORT Action : public internal::BaseType { /// The action type std::string type; /// The action content as a Buffer std::shared_ptr body; + Action() = default; + Action(std::string type, std::shared_ptr body) + : type(std::move(type)), body(std::move(body)) {} + std::string ToString() const; bool Equals(const Action& other) const; - friend bool operator==(const Action& left, const Action& right) { - return left.Equals(right); - } - friend bool operator!=(const Action& left, const Action& right) { - return !(left == right); - } + using SuperT::Deserialize; + using SuperT::SerializeToString; /// \brief Serialize this message to its wire-format representation. - arrow::Result SerializeToString() const; + /// + /// Use `SerializeToString()` if you want a Result-returning version. + arrow::Status SerializeToString(std::string* out) const; /// \brief Deserialize this message from its wire-format representation. - static arrow::Result Deserialize(std::string_view serialized); + /// + /// Use `Deserialize(serialized)` if you want a Result-returning version. + static arrow::Status Deserialize(std::string_view serialized, Action* out); }; /// \brief Opaque result returned after executing an action -struct ARROW_FLIGHT_EXPORT Result { +struct ARROW_FLIGHT_EXPORT Result : public internal::BaseType { std::shared_ptr body; + Result() = default; + Result(std::shared_ptr body) // NOLINT runtime/explicit + : body(std::move(body)) {} + std::string ToString() const; bool Equals(const Result& other) const; - friend bool operator==(const Result& left, const Result& right) { - return left.Equals(right); - } - friend bool operator!=(const Result& left, const Result& right) { - return !(left == right); - } + using SuperT::Deserialize; + using SuperT::SerializeToString; /// \brief Serialize this message to its wire-format representation. - arrow::Result SerializeToString() const; + /// + /// Use `SerializeToString()` if you want a Result-returning version. + arrow::Status SerializeToString(std::string* out) const; /// \brief Deserialize this message from its wire-format representation. - static arrow::Result Deserialize(std::string_view serialized); + /// + /// Use `Deserialize(serialized)` if you want a Result-returning version. + static arrow::Status Deserialize(std::string_view serialized, Result* out); }; -enum class CancelStatus { - /// The cancellation status is unknown. Servers should avoid using - /// this value (send a kNotCancellable if the requested FlightInfo - /// is not known). Clients can retry the request. - kUnspecified = 0, - /// The cancellation request is complete. Subsequent requests with - /// the same payload may return kCancelled or a kNotCancellable error. - kCancelled = 1, - /// The cancellation request is in progress. The client may retry - /// the cancellation request. - kCancelling = 2, - // The FlightInfo is not cancellable. The client should not retry the - // cancellation request. - kNotCancellable = 3, -}; +/// \brief Schema result returned after a schema request RPC +struct ARROW_FLIGHT_EXPORT SchemaResult : public internal::BaseType { + public: + SchemaResult() = default; + explicit SchemaResult(std::string schema) : raw_schema_(std::move(schema)) {} -/// \brief The result of the CancelFlightInfo action. -struct ARROW_FLIGHT_EXPORT CancelFlightInfoResult { - CancelStatus status; + /// \brief Factory method to construct a SchemaResult. + static arrow::Result> Make(const Schema& schema); + + /// \brief return schema + /// \param[in,out] dictionary_memo for dictionary bookkeeping, will + /// be modified + /// \return Arrow result with the reconstructed Schema + arrow::Result> GetSchema( + ipc::DictionaryMemo* dictionary_memo) const; + + const std::string& serialized_schema() const { return raw_schema_; } std::string ToString() const; - bool Equals(const CancelFlightInfoResult& other) const; + bool Equals(const SchemaResult& other) const; - friend bool operator==(const CancelFlightInfoResult& left, - const CancelFlightInfoResult& right) { - return left.Equals(right); - } - friend bool operator!=(const CancelFlightInfoResult& left, - const CancelFlightInfoResult& right) { - return !(left == right); - } + using SuperT::Deserialize; + using SuperT::SerializeToString; /// \brief Serialize this message to its wire-format representation. - arrow::Result SerializeToString() const; + /// + /// Use `SerializeToString()` if you want a Result-returning version. + arrow::Status SerializeToString(std::string* out) const; /// \brief Deserialize this message from its wire-format representation. - static arrow::Result Deserialize(std::string_view serialized); -}; - -ARROW_FLIGHT_EXPORT -std::ostream& operator<<(std::ostream& os, CancelStatus status); - -/// \brief message for simple auth -struct ARROW_FLIGHT_EXPORT BasicAuth { - std::string username; - std::string password; - - std::string ToString() const; - bool Equals(const BasicAuth& other) const; - - friend bool operator==(const BasicAuth& left, const BasicAuth& right) { - return left.Equals(right); - } - friend bool operator!=(const BasicAuth& left, const BasicAuth& right) { - return !(left == right); - } + /// + /// Use `Deserialize(serialized)` if you want a Result-returning version. + static arrow::Status Deserialize(std::string_view serialized, SchemaResult* out); - /// \brief Deserialize this message from its wire-format representation. - static arrow::Result Deserialize(std::string_view serialized); - /// \brief Serialize this message to its wire-format representation. - arrow::Result SerializeToString() const; + private: + std::string raw_schema_; }; /// \brief A request to retrieve or generate a dataset -struct ARROW_FLIGHT_EXPORT FlightDescriptor { +struct ARROW_FLIGHT_EXPORT FlightDescriptor + : public internal::BaseType { enum DescriptorType { UNKNOWN = 0, /// Unused PATH = 1, /// Named path identifying a dataset @@ -330,7 +424,7 @@ struct ARROW_FLIGHT_EXPORT FlightDescriptor { }; /// The descriptor type - DescriptorType type; + DescriptorType type = UNKNOWN; /// Opaque value used to express a command. Should only be defined when type /// is CMD @@ -340,22 +434,33 @@ struct ARROW_FLIGHT_EXPORT FlightDescriptor { /// when type is PATH std::vector path; - bool Equals(const FlightDescriptor& other) const; + FlightDescriptor() = default; + + FlightDescriptor(DescriptorType type, std::string cmd, std::vector path) + : type(type), cmd(std::move(cmd)), path(std::move(path)) {} /// \brief Get a human-readable form of this descriptor. std::string ToString() const; + bool Equals(const FlightDescriptor& other) const; + + using SuperT::Deserialize; + using SuperT::SerializeToString; /// \brief Get the wire-format representation of this type. /// /// Useful when interoperating with non-Flight systems (e.g. REST /// services) that may want to return Flight types. - arrow::Result SerializeToString() const; + /// + /// Use `SerializeToString()` if you want a Result-returning version. + arrow::Status SerializeToString(std::string* out) const; /// \brief Parse the wire-format representation of this type. /// /// Useful when interoperating with non-Flight systems (e.g. REST /// services) that may want to return Flight types. - static arrow::Result Deserialize(std::string_view serialized); + /// + /// Use `Deserialize(serialized)` if you want a Result-returning version. + static arrow::Status Deserialize(std::string_view serialized, FlightDescriptor* out); // Convenience factory functions @@ -366,69 +471,289 @@ struct ARROW_FLIGHT_EXPORT FlightDescriptor { static FlightDescriptor Path(const std::vector& p) { return FlightDescriptor{PATH, "", p}; } - - friend bool operator==(const FlightDescriptor& left, const FlightDescriptor& right) { - return left.Equals(right); - } - friend bool operator!=(const FlightDescriptor& left, const FlightDescriptor& right) { - return !(left == right); - } }; -/// \brief Data structure providing an opaque identifier or credential to use -/// when requesting a data stream with the DoGet RPC -struct ARROW_FLIGHT_EXPORT Ticket { - std::string ticket; +/// \brief The access coordinates for retrieval of a dataset, returned by +/// GetFlightInfo +class ARROW_FLIGHT_EXPORT FlightInfo + : public internal::BaseType> { + public: + struct Data { + std::string schema; + FlightDescriptor descriptor; + std::vector endpoints; + int64_t total_records = -1; + int64_t total_bytes = -1; + bool ordered = false; + std::string app_metadata; + }; - std::string ToString() const; - bool Equals(const Ticket& other) const; + explicit FlightInfo(Data data) : data_(std::move(data)), reconstructed_schema_(false) {} - friend bool operator==(const Ticket& left, const Ticket& right) { - return left.Equals(right); - } - friend bool operator!=(const Ticket& left, const Ticket& right) { - return !(left == right); - } + /// \brief Factory method to construct a FlightInfo. + static arrow::Result Make(const Schema& schema, + const FlightDescriptor& descriptor, + const std::vector& endpoints, + int64_t total_records, int64_t total_bytes, + bool ordered = false, + std::string app_metadata = ""); + + /// \brief Deserialize the Arrow schema of the dataset. Populate any + /// dictionary encoded fields into a DictionaryMemo for + /// bookkeeping + /// \param[in,out] dictionary_memo for dictionary bookkeeping, will + /// be modified + /// \return Arrow result with the reconstructed Schema + arrow::Result> GetSchema( + ipc::DictionaryMemo* dictionary_memo) const; + + const std::string& serialized_schema() const { return data_.schema; } + + /// The descriptor associated with this flight, may not be set + const FlightDescriptor& descriptor() const { return data_.descriptor; } + + /// A list of endpoints associated with the flight (dataset). To consume the + /// whole flight, all endpoints must be consumed + const std::vector& endpoints() const { return data_.endpoints; } + + /// The total number of records (rows) in the dataset. If unknown, set to -1 + int64_t total_records() const { return data_.total_records; } + + /// The total number of bytes in the dataset. If unknown, set to -1 + int64_t total_bytes() const { return data_.total_bytes; } + + /// Whether endpoints are in the same order as the data. + bool ordered() const { return data_.ordered; } + + /// Application-defined opaque metadata + const std::string& app_metadata() const { return data_.app_metadata; } + + using SuperT::Deserialize; + using SuperT::SerializeToString; /// \brief Get the wire-format representation of this type. /// /// Useful when interoperating with non-Flight systems (e.g. REST /// services) that may want to return Flight types. - arrow::Result SerializeToString() const; + /// + /// Use `SerializeToString()` if you want a Result-returning version. + arrow::Status SerializeToString(std::string* out) const; /// \brief Parse the wire-format representation of this type. /// /// Useful when interoperating with non-Flight systems (e.g. REST /// services) that may want to return Flight types. - static arrow::Result Deserialize(std::string_view serialized); -}; + /// + /// Use `Deserialize(serialized)` if you want a Result-returning version. + static arrow::Status Deserialize(std::string_view serialized, + std::unique_ptr* out); -class FlightClient; -class FlightServerBase; + std::string ToString() const; -ARROW_FLIGHT_EXPORT -extern const char* kSchemeGrpc; -ARROW_FLIGHT_EXPORT -extern const char* kSchemeGrpcTcp; -ARROW_FLIGHT_EXPORT -extern const char* kSchemeGrpcUnix; -ARROW_FLIGHT_EXPORT -extern const char* kSchemeGrpcTls; + /// Compare two FlightInfo for equality. This will compare the + /// serialized schema representations, NOT the logical equality of + /// the schemas. + bool Equals(const FlightInfo& other) const; -/// \brief A host location (a URI) -struct ARROW_FLIGHT_EXPORT Location { + private: + Data data_; + mutable std::shared_ptr schema_; + mutable bool reconstructed_schema_; +}; + +/// \brief The information to process a long-running query. +class ARROW_FLIGHT_EXPORT PollInfo + : public internal::BaseType> { public: - /// \brief Initialize a blank location. - Location(); + /// The currently available results so far. + std::unique_ptr info = NULLPTR; + /// The descriptor the client should use on the next try. If unset, + /// the query is complete. + std::optional descriptor = std::nullopt; + /// Query progress. Must be in [0.0, 1.0] but need not be + /// monotonic or nondecreasing. If unknown, do not set. + std::optional progress = std::nullopt; + /// Expiration time for this request. After this passes, the server + /// might not accept the poll descriptor anymore (and the query may + /// be cancelled). This may be updated on a call to PollFlightInfo. + std::optional expiration_time = std::nullopt; - /// \brief Initialize a location by parsing a URI string - static arrow::Result Parse(const std::string& uri_string); + PollInfo() + : info(NULLPTR), + descriptor(std::nullopt), + progress(std::nullopt), + expiration_time(std::nullopt) {} - /// \brief Get the fallback URI. - /// - /// arrow-flight-reuse-connection://? means that a client may attempt to - /// reuse an existing connection to a Flight service to fetch data instead - /// of creating a new connection to one of the other locations listed in a + PollInfo(std::unique_ptr info, std::optional descriptor, + std::optional progress, std::optional expiration_time) + : info(std::move(info)), + descriptor(std::move(descriptor)), + progress(progress), + expiration_time(expiration_time) {} + + PollInfo(const PollInfo& other) + : info(other.info ? std::make_unique(*other.info) : NULLPTR), + descriptor(other.descriptor), + progress(other.progress), + expiration_time(other.expiration_time) {} + PollInfo(PollInfo&& other) noexcept = default; + ~PollInfo() = default; + PollInfo& operator=(const PollInfo& other) { + info = other.info ? std::make_unique(*other.info) : NULLPTR; + descriptor = other.descriptor; + progress = other.progress; + expiration_time = other.expiration_time; + return *this; + } + PollInfo& operator=(PollInfo&& other) = default; + + using SuperT::Deserialize; + using SuperT::SerializeToString; + + /// \brief Get the wire-format representation of this type. + /// + /// Useful when interoperating with non-Flight systems (e.g. REST + /// services) that may want to return Flight types. + /// + /// Use `SerializeToString()` if you want a Result-returning version. + arrow::Status SerializeToString(std::string* out) const; + + /// \brief Parse the wire-format representation of this type. + /// + /// Useful when interoperating with non-Flight systems (e.g. REST + /// services) that may want to return Flight types. + /// + /// Use `Deserialize(serialized)` if you want a Result-returning version. + static arrow::Status Deserialize(std::string_view serialized, + std::unique_ptr* out); + + std::string ToString() const; + + /// Compare two PollInfo for equality. This will compare the + /// serialized schema representations, NOT the logical equality of + /// the schemas. + bool Equals(const PollInfo& other) const; +}; + +/// \brief The request of the CancelFlightInfoRequest action. +struct ARROW_FLIGHT_EXPORT CancelFlightInfoRequest + : public internal::BaseType { + std::unique_ptr info; + + CancelFlightInfoRequest() = default; + CancelFlightInfoRequest(std::unique_ptr info) // NOLINT runtime/explicit + : info(std::move(info)) {} + + std::string ToString() const; + bool Equals(const CancelFlightInfoRequest& other) const; + + using SuperT::Deserialize; + using SuperT::SerializeToString; + + /// \brief Serialize this message to its wire-format representation. + /// + /// Use `SerializeToString()` if you want a Result-returning version. + arrow::Status SerializeToString(std::string* out) const; + + /// \brief Deserialize this message from its wire-format representation. + /// + /// Use `Deserialize(serialized)` if you want a Result-returning version. + static arrow::Status Deserialize(std::string_view serialized, + CancelFlightInfoRequest* out); +}; + +enum class CancelStatus { + /// The cancellation status is unknown. Servers should avoid using + /// this value (send a kNotCancellable if the requested FlightInfo + /// is not known). Clients can retry the request. + kUnspecified = 0, + /// The cancellation request is complete. Subsequent requests with + /// the same payload may return kCancelled or a kNotCancellable error. + kCancelled = 1, + /// The cancellation request is in progress. The client may retry + /// the cancellation request. + kCancelling = 2, + // The FlightInfo is not cancellable. The client should not retry the + // cancellation request. + kNotCancellable = 3, +}; + +/// \brief The result of the CancelFlightInfo action. +struct ARROW_FLIGHT_EXPORT CancelFlightInfoResult + : public internal::BaseType { + CancelStatus status = CancelStatus::kUnspecified; + + CancelFlightInfoResult() = default; + CancelFlightInfoResult(CancelStatus status) // NOLINT runtime/explicit + : status(status) {} + + std::string ToString() const; + bool Equals(const CancelFlightInfoResult& other) const; + + using SuperT::Deserialize; + using SuperT::SerializeToString; + + /// \brief Serialize this message to its wire-format representation. + /// + /// Use `SerializeToString()` if you want a Result-returning version. + arrow::Status SerializeToString(std::string* out) const; + + /// \brief Deserialize this message from its wire-format representation. + /// + /// Use `Deserialize(serialized)` if you want a Result-returning version. + static arrow::Status Deserialize(std::string_view serialized, + CancelFlightInfoResult* out); +}; + +ARROW_FLIGHT_EXPORT +std::ostream& operator<<(std::ostream& os, CancelStatus status); + +/// \brief Data structure providing an opaque identifier or credential to use +/// when requesting a data stream with the DoGet RPC +struct ARROW_FLIGHT_EXPORT Ticket : public internal::BaseType { + std::string ticket; + + Ticket() = default; + Ticket(std::string ticket) // NOLINT runtime/explicit + : ticket(std::move(ticket)) {} + + std::string ToString() const; + bool Equals(const Ticket& other) const; + + using SuperT::Deserialize; + using SuperT::SerializeToString; + + /// \brief Get the wire-format representation of this type. + /// + /// Useful when interoperating with non-Flight systems (e.g. REST + /// services) that may want to return Flight types. + /// + /// Use `SerializeToString()` if you want a Result-returning version. + arrow::Status SerializeToString(std::string* out) const; + + /// \brief Parse the wire-format representation of this type. + /// + /// Useful when interoperating with non-Flight systems (e.g. REST + /// services) that may want to return Flight types. + /// + /// Use `Deserialize(serialized)` if you want a Result-returning version. + static arrow::Status Deserialize(std::string_view serialized, Ticket* out); +}; + +/// \brief A host location (a URI) +struct ARROW_FLIGHT_EXPORT Location : public internal::BaseType { + public: + /// \brief Initialize a blank location. + Location(); + + /// \brief Initialize a location by parsing a URI string + static arrow::Result Parse(const std::string& uri_string); + + /// \brief Get the fallback URI. + /// + /// arrow-flight-reuse-connection://? means that a client may attempt to + /// reuse an existing connection to a Flight service to fetch data instead + /// of creating a new connection to one of the other locations listed in a /// FlightEndpoint response. static const Location& ReuseConnection(); @@ -456,20 +781,25 @@ struct ARROW_FLIGHT_EXPORT Location { static arrow::Result ForScheme(const std::string& scheme, const std::string& host, const int port); - /// \brief Get a representation of this URI as a string. - std::string ToString() const; - /// \brief Get the scheme of this URI. std::string scheme() const; + /// \brief Get a representation of this URI as a string. + std::string ToString() const; bool Equals(const Location& other) const; - friend bool operator==(const Location& left, const Location& right) { - return left.Equals(right); - } - friend bool operator!=(const Location& left, const Location& right) { - return !(left == right); - } + using SuperT::Deserialize; + using SuperT::SerializeToString; + + /// \brief Serialize this message to its wire-format representation. + /// + /// Use `SerializeToString()` if you want a Result-returning version. + arrow::Status SerializeToString(std::string* out) const; + + /// \brief Deserialize this message from its wire-format representation. + /// + /// Use `Deserialize(serialized)` if you want a Result-returning version. + static arrow::Status Deserialize(std::string_view serialized, Location* out); private: friend class FlightClient; @@ -479,7 +809,7 @@ struct ARROW_FLIGHT_EXPORT Location { /// \brief A flight ticket and list of locations where the ticket can be /// redeemed -struct ARROW_FLIGHT_EXPORT FlightEndpoint { +struct ARROW_FLIGHT_EXPORT FlightEndpoint : public internal::BaseType { /// Opaque ticket identify; use with DoGet RPC Ticket ticket; @@ -496,47 +826,60 @@ struct ARROW_FLIGHT_EXPORT FlightEndpoint { /// Opaque Application-defined metadata std::string app_metadata; + FlightEndpoint() = default; + FlightEndpoint(Ticket ticket, std::vector locations, + std::optional expiration_time, std::string app_metadata) + : ticket(std::move(ticket)), + locations(std::move(locations)), + expiration_time(expiration_time), + app_metadata(std::move(app_metadata)) {} + std::string ToString() const; bool Equals(const FlightEndpoint& other) const; - friend bool operator==(const FlightEndpoint& left, const FlightEndpoint& right) { - return left.Equals(right); - } - friend bool operator!=(const FlightEndpoint& left, const FlightEndpoint& right) { - return !(left == right); - } + using SuperT::Deserialize; + using SuperT::SerializeToString; /// \brief Serialize this message to its wire-format representation. - arrow::Result SerializeToString() const; + /// + /// Use `SerializeToString()` if you want a Result-returning version. + arrow::Status SerializeToString(std::string* out) const; /// \brief Deserialize this message from its wire-format representation. - static arrow::Result Deserialize(std::string_view serialized); + /// + /// Use `Deserialize(serialized)` if you want a Result-returning version. + static arrow::Status Deserialize(std::string_view serialized, FlightEndpoint* out); }; /// \brief The request of the RenewFlightEndpoint action. -struct ARROW_FLIGHT_EXPORT RenewFlightEndpointRequest { +struct ARROW_FLIGHT_EXPORT RenewFlightEndpointRequest + : public internal::BaseType { FlightEndpoint endpoint; + RenewFlightEndpointRequest() = default; + explicit RenewFlightEndpointRequest(FlightEndpoint endpoint) + : endpoint(std::move(endpoint)) {} + std::string ToString() const; bool Equals(const RenewFlightEndpointRequest& other) const; - friend bool operator==(const RenewFlightEndpointRequest& left, - const RenewFlightEndpointRequest& right) { - return left.Equals(right); - } - friend bool operator!=(const RenewFlightEndpointRequest& left, - const RenewFlightEndpointRequest& right) { - return !(left == right); - } + using SuperT::Deserialize; + using SuperT::SerializeToString; /// \brief Serialize this message to its wire-format representation. - arrow::Result SerializeToString() const; + /// + /// Use `SerializeToString()` if you want a Result-returning version. + arrow::Status SerializeToString(std::string* out) const; /// \brief Deserialize this message from its wire-format representation. - static arrow::Result Deserialize( - std::string_view serialized); + /// + /// Use `Deserialize(serialized)` if you want a Result-returning version. + static arrow::Status Deserialize(std::string_view serialized, + RenewFlightEndpointRequest* out); }; +// FlightData in Flight.proto maps to FlightPayload here. + /// \brief Staging data structure for messages about to be put on the wire /// /// This structure corresponds to FlightData in the protocol. @@ -545,241 +888,57 @@ struct ARROW_FLIGHT_EXPORT FlightPayload { std::shared_ptr app_metadata; ipc::IpcPayload ipc_message; + FlightPayload() = default; + FlightPayload(std::shared_ptr descriptor, std::shared_ptr app_metadata, + ipc::IpcPayload ipc_message) + : descriptor(std::move(descriptor)), + app_metadata(std::move(app_metadata)), + ipc_message(std::move(ipc_message)) {} + /// \brief Check that the payload can be written to the wire. Status Validate() const; }; -/// \brief Schema result returned after a schema request RPC -struct ARROW_FLIGHT_EXPORT SchemaResult { - public: - SchemaResult() = default; - explicit SchemaResult(std::string schema) : raw_schema_(std::move(schema)) {} +// A wrapper around arrow.flight.protocol.PutResult is not defined - /// \brief Factory method to construct a SchemaResult. - static arrow::Result> Make(const Schema& schema); +// Session management messages - /// \brief return schema - /// \param[in,out] dictionary_memo for dictionary bookkeeping, will - /// be modified - /// \return Arrow result with the reconstructed Schema - arrow::Result> GetSchema( - ipc::DictionaryMemo* dictionary_memo) const; - - const std::string& serialized_schema() const { return raw_schema_; } - - std::string ToString() const; - bool Equals(const SchemaResult& other) const; - - friend bool operator==(const SchemaResult& left, const SchemaResult& right) { - return left.Equals(right); - } - friend bool operator!=(const SchemaResult& left, const SchemaResult& right) { - return !(left == right); - } - - /// \brief Serialize this message to its wire-format representation. - arrow::Result SerializeToString() const; - - /// \brief Deserialize this message from its wire-format representation. - static arrow::Result Deserialize(std::string_view serialized); - - private: - std::string raw_schema_; -}; - -/// \brief The access coordinates for retrieval of a dataset, returned by -/// GetFlightInfo -class ARROW_FLIGHT_EXPORT FlightInfo { - public: - struct Data { - std::string schema; - FlightDescriptor descriptor; - std::vector endpoints; - int64_t total_records = -1; - int64_t total_bytes = -1; - bool ordered = false; - std::string app_metadata; - }; - - explicit FlightInfo(Data data) : data_(std::move(data)), reconstructed_schema_(false) {} - - /// \brief Factory method to construct a FlightInfo. - static arrow::Result Make(const Schema& schema, - const FlightDescriptor& descriptor, - const std::vector& endpoints, - int64_t total_records, int64_t total_bytes, - bool ordered = false, - std::string app_metadata = ""); - - /// \brief Deserialize the Arrow schema of the dataset. Populate any - /// dictionary encoded fields into a DictionaryMemo for - /// bookkeeping - /// \param[in,out] dictionary_memo for dictionary bookkeeping, will - /// be modified - /// \return Arrow result with the reconstructed Schema - arrow::Result> GetSchema( - ipc::DictionaryMemo* dictionary_memo) const; - - const std::string& serialized_schema() const { return data_.schema; } - - /// The descriptor associated with this flight, may not be set - const FlightDescriptor& descriptor() const { return data_.descriptor; } - - /// A list of endpoints associated with the flight (dataset). To consume the - /// whole flight, all endpoints must be consumed - const std::vector& endpoints() const { return data_.endpoints; } - - /// The total number of records (rows) in the dataset. If unknown, set to -1 - int64_t total_records() const { return data_.total_records; } - - /// The total number of bytes in the dataset. If unknown, set to -1 - int64_t total_bytes() const { return data_.total_bytes; } - - /// Whether endpoints are in the same order as the data. - bool ordered() const { return data_.ordered; } - - /// Application-defined opaque metadata - const std::string& app_metadata() const { return data_.app_metadata; } - - /// \brief Get the wire-format representation of this type. - /// - /// Useful when interoperating with non-Flight systems (e.g. REST - /// services) that may want to return Flight types. - arrow::Result SerializeToString() const; - - /// \brief Parse the wire-format representation of this type. - /// - /// Useful when interoperating with non-Flight systems (e.g. REST - /// services) that may want to return Flight types. - static arrow::Result> Deserialize( - std::string_view serialized); - - std::string ToString() const; - - /// Compare two FlightInfo for equality. This will compare the - /// serialized schema representations, NOT the logical equality of - /// the schemas. - bool Equals(const FlightInfo& other) const; - - friend bool operator==(const FlightInfo& left, const FlightInfo& right) { - return left.Equals(right); - } - friend bool operator!=(const FlightInfo& left, const FlightInfo& right) { - return !(left == right); - } - - private: - Data data_; - mutable std::shared_ptr schema_; - mutable bool reconstructed_schema_; -}; - -/// \brief The information to process a long-running query. -class ARROW_FLIGHT_EXPORT PollInfo { - public: - /// The currently available results so far. - std::unique_ptr info = NULLPTR; - /// The descriptor the client should use on the next try. If unset, - /// the query is complete. - std::optional descriptor = std::nullopt; - /// Query progress. Must be in [0.0, 1.0] but need not be - /// monotonic or nondecreasing. If unknown, do not set. - std::optional progress = std::nullopt; - /// Expiration time for this request. After this passes, the server - /// might not accept the poll descriptor anymore (and the query may - /// be cancelled). This may be updated on a call to PollFlightInfo. - std::optional expiration_time = std::nullopt; - - PollInfo() - : info(NULLPTR), - descriptor(std::nullopt), - progress(std::nullopt), - expiration_time(std::nullopt) {} - - explicit PollInfo(std::unique_ptr info, - std::optional descriptor, - std::optional progress, - std::optional expiration_time) - : info(std::move(info)), - descriptor(std::move(descriptor)), - progress(progress), - expiration_time(expiration_time) {} - - // Must not be explicit; to declare one we must declare all ("rule of five") - PollInfo(const PollInfo& other) // NOLINT(runtime/explicit) - : info(other.info ? std::make_unique(*other.info) : NULLPTR), - descriptor(other.descriptor), - progress(other.progress), - expiration_time(other.expiration_time) {} - PollInfo(PollInfo&& other) noexcept = default; // NOLINT(runtime/explicit) - ~PollInfo() = default; - PollInfo& operator=(const PollInfo& other) { - info = other.info ? std::make_unique(*other.info) : NULLPTR; - descriptor = other.descriptor; - progress = other.progress; - expiration_time = other.expiration_time; - return *this; - } - PollInfo& operator=(PollInfo&& other) = default; - - /// \brief Get the wire-format representation of this type. - /// - /// Useful when interoperating with non-Flight systems (e.g. REST - /// services) that may want to return Flight types. - arrow::Result SerializeToString() const; - - /// \brief Parse the wire-format representation of this type. - /// - /// Useful when interoperating with non-Flight systems (e.g. REST - /// services) that may want to return Flight types. - static arrow::Result> Deserialize( - std::string_view serialized); - - std::string ToString() const; - - /// Compare two PollInfo for equality. This will compare the - /// serialized schema representations, NOT the logical equality of - /// the schemas. - bool Equals(const PollInfo& other) const; +/// \brief Variant supporting all possible value types for {Set,Get}SessionOptions +/// +/// By convention, an attempt to set a valueless (std::monostate) SessionOptionValue +/// should attempt to unset or clear the named option value on the server. +using SessionOptionValue = std::variant>; +std::ostream& operator<<(std::ostream& os, const SessionOptionValue& v); - friend bool operator==(const PollInfo& left, const PollInfo& right) { - return left.Equals(right); - } - friend bool operator!=(const PollInfo& left, const PollInfo& right) { - return !(left == right); - } -}; +/// \brief A request to set a set of session options by name/value. +struct ARROW_FLIGHT_EXPORT SetSessionOptionsRequest + : public internal::BaseType { + std::map session_options; -/// \brief The request of the CancelFlightInfoRequest action. -struct ARROW_FLIGHT_EXPORT CancelFlightInfoRequest { - std::unique_ptr info; + SetSessionOptionsRequest() = default; + explicit SetSessionOptionsRequest( + std::map session_options) + : session_options(std::move(session_options)) {} std::string ToString() const; - bool Equals(const CancelFlightInfoRequest& other) const; + bool Equals(const SetSessionOptionsRequest& other) const; - friend bool operator==(const CancelFlightInfoRequest& left, - const CancelFlightInfoRequest& right) { - return left.Equals(right); - } - friend bool operator!=(const CancelFlightInfoRequest& left, - const CancelFlightInfoRequest& right) { - return !(left == right); - } + using SuperT::Deserialize; + using SuperT::SerializeToString; /// \brief Serialize this message to its wire-format representation. - arrow::Result SerializeToString() const; + /// + /// Use `SerializeToString()` if you want a Result-returning version. + arrow::Status SerializeToString(std::string* out) const; /// \brief Deserialize this message from its wire-format representation. - static arrow::Result Deserialize(std::string_view serialized); + /// + /// Use `Deserialize(serialized)` if you want a Result-returning version. + static arrow::Status Deserialize(std::string_view serialized, + SetSessionOptionsRequest* out); }; -/// \brief Variant supporting all possible value types for {Set,Get}SessionOptions -/// -/// By convention, an attempt to set a valueless (std::monostate) SessionOptionValue -/// should attempt to unset or clear the named option value on the server. -using SessionOptionValue = std::variant>; - /// \brief The result of setting a session option. enum class SetSessionOptionErrorValue : int8_t { /// \brief The status of setting the option is unknown. @@ -797,54 +956,9 @@ enum class SetSessionOptionErrorValue : int8_t { std::string ToString(const SetSessionOptionErrorValue& error_value); std::ostream& operator<<(std::ostream& os, const SetSessionOptionErrorValue& error_value); -/// \brief The result of closing a session. -enum class CloseSessionStatus : int8_t { - // \brief The session close status is unknown. - // - // Servers should avoid using this value (send a NOT_FOUND error if the requested - // session is not known). Clients can retry the request. - kUnspecified, - // \brief The session close request is complete. - // - // Subsequent requests with the same session produce a NOT_FOUND error. - kClosed, - // \brief The session close request is in progress. - // - // The client may retry the request. - kClosing, - // \brief The session is not closeable. - // - // The client should not retry the request. - kNotClosable -}; -std::string ToString(const CloseSessionStatus& status); -std::ostream& operator<<(std::ostream& os, const CloseSessionStatus& status); - -/// \brief A request to set a set of session options by name/value. -struct ARROW_FLIGHT_EXPORT SetSessionOptionsRequest { - std::map session_options; - - std::string ToString() const; - bool Equals(const SetSessionOptionsRequest& other) const; - - friend bool operator==(const SetSessionOptionsRequest& left, - const SetSessionOptionsRequest& right) { - return left.Equals(right); - } - friend bool operator!=(const SetSessionOptionsRequest& left, - const SetSessionOptionsRequest& right) { - return !(left == right); - } - - /// \brief Serialize this message to its wire-format representation. - arrow::Result SerializeToString() const; - - /// \brief Deserialize this message from its wire-format representation. - static arrow::Result Deserialize(std::string_view serialized); -}; - /// \brief The result(s) of setting session option(s). -struct ARROW_FLIGHT_EXPORT SetSessionOptionsResult { +struct ARROW_FLIGHT_EXPORT SetSessionOptionsResult + : public internal::BaseType { struct Error { SetSessionOptionErrorValue value; @@ -859,113 +973,152 @@ struct ARROW_FLIGHT_EXPORT SetSessionOptionsResult { std::map errors; + SetSessionOptionsResult() = default; + SetSessionOptionsResult(std::map errors) // NOLINT runtime/explicit + : errors(std::move(errors)) {} + std::string ToString() const; bool Equals(const SetSessionOptionsResult& other) const; - friend bool operator==(const SetSessionOptionsResult& left, - const SetSessionOptionsResult& right) { - return left.Equals(right); - } - friend bool operator!=(const SetSessionOptionsResult& left, - const SetSessionOptionsResult& right) { - return !(left == right); - } + using SuperT::Deserialize; + using SuperT::SerializeToString; /// \brief Serialize this message to its wire-format representation. - arrow::Result SerializeToString() const; + /// + /// Use `SerializeToString()` if you want a Result-returning version. + arrow::Status SerializeToString(std::string* out) const; /// \brief Deserialize this message from its wire-format representation. - static arrow::Result Deserialize(std::string_view serialized); + /// + /// Use `Deserialize(serialized)` if you want a Result-returning version. + static arrow::Status Deserialize(std::string_view serialized, + SetSessionOptionsResult* out); }; /// \brief A request to get current session options. -struct ARROW_FLIGHT_EXPORT GetSessionOptionsRequest { +struct ARROW_FLIGHT_EXPORT GetSessionOptionsRequest + : public internal::BaseType { + GetSessionOptionsRequest() = default; + std::string ToString() const; bool Equals(const GetSessionOptionsRequest& other) const; - friend bool operator==(const GetSessionOptionsRequest& left, - const GetSessionOptionsRequest& right) { - return left.Equals(right); - } - friend bool operator!=(const GetSessionOptionsRequest& left, - const GetSessionOptionsRequest& right) { - return !(left == right); - } + using SuperT::Deserialize; + using SuperT::SerializeToString; /// \brief Serialize this message to its wire-format representation. - arrow::Result SerializeToString() const; + /// + /// Use `SerializeToString()` if you want a Result-returning version. + arrow::Status SerializeToString(std::string* out) const; /// \brief Deserialize this message from its wire-format representation. - static arrow::Result Deserialize(std::string_view serialized); + /// + /// Use `Deserialize(serialized)` if you want a Result-returning version. + static arrow::Status Deserialize(std::string_view serialized, + GetSessionOptionsRequest* out); }; /// \brief The current session options. -struct ARROW_FLIGHT_EXPORT GetSessionOptionsResult { +struct ARROW_FLIGHT_EXPORT GetSessionOptionsResult + : public internal::BaseType { std::map session_options; + GetSessionOptionsResult() = default; + GetSessionOptionsResult( // NOLINT runtime/explicit + std::map session_options) + : session_options(std::move(session_options)) {} + std::string ToString() const; bool Equals(const GetSessionOptionsResult& other) const; - friend bool operator==(const GetSessionOptionsResult& left, - const GetSessionOptionsResult& right) { - return left.Equals(right); - } - friend bool operator!=(const GetSessionOptionsResult& left, - const GetSessionOptionsResult& right) { - return !(left == right); - } + using SuperT::Deserialize; + using SuperT::SerializeToString; /// \brief Serialize this message to its wire-format representation. - arrow::Result SerializeToString() const; + /// + /// Use `SerializeToString()` if you want a Result-returning version. + arrow::Status SerializeToString(std::string* out) const; /// \brief Deserialize this message from its wire-format representation. - static arrow::Result Deserialize(std::string_view serialized); + /// + /// Use `Deserialize(serialized)` if you want a Result-returning version. + static arrow::Status Deserialize(std::string_view serialized, + GetSessionOptionsResult* out); }; /// \brief A request to close the open client session. -struct ARROW_FLIGHT_EXPORT CloseSessionRequest { +struct ARROW_FLIGHT_EXPORT CloseSessionRequest + : public internal::BaseType { + CloseSessionRequest() = default; + std::string ToString() const; bool Equals(const CloseSessionRequest& other) const; - friend bool operator==(const CloseSessionRequest& left, - const CloseSessionRequest& right) { - return left.Equals(right); - } - friend bool operator!=(const CloseSessionRequest& left, - const CloseSessionRequest& right) { - return !(left == right); - } + using SuperT::Deserialize; + using SuperT::SerializeToString; /// \brief Serialize this message to its wire-format representation. - arrow::Result SerializeToString() const; + /// + /// Use `SerializeToString()` if you want a Result-returning version. + arrow::Status SerializeToString(std::string* out) const; /// \brief Deserialize this message from its wire-format representation. - static arrow::Result Deserialize(std::string_view serialized); + /// + /// Use `Deserialize(serialized)` if you want a Result-returning version. + static arrow::Status Deserialize(std::string_view serialized, CloseSessionRequest* out); +}; + +/// \brief The result of closing a session. +enum class CloseSessionStatus : int8_t { + // \brief The session close status is unknown. + // + // Servers should avoid using this value (send a NOT_FOUND error if the requested + // session is not known). Clients can retry the request. + kUnspecified, + // \brief The session close request is complete. + // + // Subsequent requests with the same session produce a NOT_FOUND error. + kClosed, + // \brief The session close request is in progress. + // + // The client may retry the request. + kClosing, + // \brief The session is not closeable. + // + // The client should not retry the request. + kNotClosable }; +std::string ToString(const CloseSessionStatus& status); +std::ostream& operator<<(std::ostream& os, const CloseSessionStatus& status); /// \brief The result of attempting to close the client session. -struct ARROW_FLIGHT_EXPORT CloseSessionResult { +struct ARROW_FLIGHT_EXPORT CloseSessionResult + : public internal::BaseType { CloseSessionStatus status; + CloseSessionResult() = default; + CloseSessionResult(CloseSessionStatus status) // NOLINT runtime/explicit + : status(status) {} + std::string ToString() const; bool Equals(const CloseSessionResult& other) const; - friend bool operator==(const CloseSessionResult& left, - const CloseSessionResult& right) { - return left.Equals(right); - } - friend bool operator!=(const CloseSessionResult& left, - const CloseSessionResult& right) { - return !(left == right); - } + using SuperT::Deserialize; + using SuperT::SerializeToString; /// \brief Serialize this message to its wire-format representation. - arrow::Result SerializeToString() const; + /// + /// Use `SerializeToString()` if you want a Result-returning version. + arrow::Status SerializeToString(std::string* out) const; /// \brief Deserialize this message from its wire-format representation. - static arrow::Result Deserialize(std::string_view serialized); + /// + /// Use `Deserialize(serialized)` if you want a Result-returning version. + static arrow::Status Deserialize(std::string_view serialized, CloseSessionResult* out); }; +//------------------------------------------------------------ + /// \brief An iterator to FlightInfo instances returned by ListFlights. class ARROW_FLIGHT_EXPORT FlightListing { public: diff --git a/cpp/src/parquet/encryption/encryption_internal.cc b/cpp/src/parquet/encryption/encryption_internal.cc index 6168dd2a9bd61..99d1707f4a8d4 100644 --- a/cpp/src/parquet/encryption/encryption_internal.cc +++ b/cpp/src/parquet/encryption/encryption_internal.cc @@ -469,23 +469,20 @@ AesDecryptor::AesDecryptorImpl::AesDecryptorImpl(ParquetCipher::type alg_id, int } } -AesEncryptor* AesEncryptor::Make(ParquetCipher::type alg_id, int key_len, bool metadata, - std::vector* all_encryptors) { - return Make(alg_id, key_len, metadata, true /*write_length*/, all_encryptors); +std::unique_ptr AesEncryptor::Make(ParquetCipher::type alg_id, int key_len, + bool metadata) { + return Make(alg_id, key_len, metadata, true /*write_length*/); } -AesEncryptor* AesEncryptor::Make(ParquetCipher::type alg_id, int key_len, bool metadata, - bool write_length, - std::vector* all_encryptors) { +std::unique_ptr AesEncryptor::Make(ParquetCipher::type alg_id, int key_len, + bool metadata, bool write_length) { if (ParquetCipher::AES_GCM_V1 != alg_id && ParquetCipher::AES_GCM_CTR_V1 != alg_id) { std::stringstream ss; ss << "Crypto algorithm " << alg_id << " is not supported"; throw ParquetException(ss.str()); } - AesEncryptor* encryptor = new AesEncryptor(alg_id, key_len, metadata, write_length); - if (all_encryptors != nullptr) all_encryptors->push_back(encryptor); - return encryptor; + return std::make_unique(alg_id, key_len, metadata, write_length); } AesDecryptor::AesDecryptor(ParquetCipher::type alg_id, int key_len, bool metadata, diff --git a/cpp/src/parquet/encryption/encryption_internal.h b/cpp/src/parquet/encryption/encryption_internal.h index a9a17f1ab98e3..c874b137ad1ad 100644 --- a/cpp/src/parquet/encryption/encryption_internal.h +++ b/cpp/src/parquet/encryption/encryption_internal.h @@ -52,12 +52,11 @@ class PARQUET_EXPORT AesEncryptor { explicit AesEncryptor(ParquetCipher::type alg_id, int key_len, bool metadata, bool write_length = true); - static AesEncryptor* Make(ParquetCipher::type alg_id, int key_len, bool metadata, - std::vector* all_encryptors); + static std::unique_ptr Make(ParquetCipher::type alg_id, int key_len, + bool metadata); - static AesEncryptor* Make(ParquetCipher::type alg_id, int key_len, bool metadata, - bool write_length, - std::vector* all_encryptors); + static std::unique_ptr Make(ParquetCipher::type alg_id, int key_len, + bool metadata, bool write_length); ~AesEncryptor(); diff --git a/cpp/src/parquet/encryption/encryption_internal_nossl.cc b/cpp/src/parquet/encryption/encryption_internal_nossl.cc index 2f6cdc8200016..2cce83915d7e5 100644 --- a/cpp/src/parquet/encryption/encryption_internal_nossl.cc +++ b/cpp/src/parquet/encryption/encryption_internal_nossl.cc @@ -72,14 +72,15 @@ void AesDecryptor::WipeOut() { ThrowOpenSSLRequiredException(); } AesDecryptor::~AesDecryptor() {} -AesEncryptor* AesEncryptor::Make(ParquetCipher::type alg_id, int key_len, bool metadata, - std::vector* all_encryptors) { +std::unique_ptr AesEncryptor::Make(ParquetCipher::type alg_id, int key_len, + bool metadata) { + ThrowOpenSSLRequiredException(); return NULLPTR; } -AesEncryptor* AesEncryptor::Make(ParquetCipher::type alg_id, int key_len, bool metadata, - bool write_length, - std::vector* all_encryptors) { +std::unique_ptr AesEncryptor::Make(ParquetCipher::type alg_id, int key_len, + bool metadata, bool write_length) { + ThrowOpenSSLRequiredException(); return NULLPTR; } @@ -91,6 +92,7 @@ AesDecryptor::AesDecryptor(ParquetCipher::type alg_id, int key_len, bool metadat std::shared_ptr AesDecryptor::Make( ParquetCipher::type alg_id, int key_len, bool metadata, std::vector>* all_decryptors) { + ThrowOpenSSLRequiredException(); return NULLPTR; } diff --git a/cpp/src/parquet/encryption/internal_file_decryptor.cc b/cpp/src/parquet/encryption/internal_file_decryptor.cc index a900a4d2eb094..fae5ce1f7a809 100644 --- a/cpp/src/parquet/encryption/internal_file_decryptor.cc +++ b/cpp/src/parquet/encryption/internal_file_decryptor.cc @@ -27,7 +27,7 @@ namespace parquet { Decryptor::Decryptor(std::shared_ptr aes_decryptor, const std::string& key, const std::string& file_aad, const std::string& aad, ::arrow::MemoryPool* pool) - : aes_decryptor_(aes_decryptor), + : aes_decryptor_(std::move(aes_decryptor)), key_(key), file_aad_(file_aad), aad_(aad), @@ -156,9 +156,9 @@ std::shared_ptr InternalFileDecryptor::GetFooterDecryptor( } footer_metadata_decryptor_ = std::make_shared( - aes_metadata_decryptor, footer_key, file_aad_, aad, pool_); - footer_data_decryptor_ = - std::make_shared(aes_data_decryptor, footer_key, file_aad_, aad, pool_); + std::move(aes_metadata_decryptor), footer_key, file_aad_, aad, pool_); + footer_data_decryptor_ = std::make_shared(std::move(aes_data_decryptor), + footer_key, file_aad_, aad, pool_); if (metadata) return footer_metadata_decryptor_; return footer_data_decryptor_; diff --git a/cpp/src/parquet/encryption/internal_file_encryptor.cc b/cpp/src/parquet/encryption/internal_file_encryptor.cc index a423cc678cccb..285c2100be813 100644 --- a/cpp/src/parquet/encryption/internal_file_encryptor.cc +++ b/cpp/src/parquet/encryption/internal_file_encryptor.cc @@ -53,8 +53,15 @@ InternalFileEncryptor::InternalFileEncryptor(FileEncryptionProperties* propertie void InternalFileEncryptor::WipeOutEncryptionKeys() { properties_->WipeOutEncryptionKeys(); - for (auto const& i : all_encryptors_) { - i->WipeOut(); + for (auto const& i : meta_encryptor_) { + if (i != nullptr) { + i->WipeOut(); + } + } + for (auto const& i : data_encryptor_) { + if (i != nullptr) { + i->WipeOut(); + } } } @@ -136,7 +143,7 @@ InternalFileEncryptor::InternalFileEncryptor::GetColumnEncryptor( return encryptor; } -int InternalFileEncryptor::MapKeyLenToEncryptorArrayIndex(int key_len) { +int InternalFileEncryptor::MapKeyLenToEncryptorArrayIndex(int key_len) const { if (key_len == 16) return 0; else if (key_len == 24) @@ -151,8 +158,7 @@ encryption::AesEncryptor* InternalFileEncryptor::GetMetaAesEncryptor( int key_len = static_cast(key_size); int index = MapKeyLenToEncryptorArrayIndex(key_len); if (meta_encryptor_[index] == nullptr) { - meta_encryptor_[index].reset( - encryption::AesEncryptor::Make(algorithm, key_len, true, &all_encryptors_)); + meta_encryptor_[index] = encryption::AesEncryptor::Make(algorithm, key_len, true); } return meta_encryptor_[index].get(); } @@ -162,8 +168,7 @@ encryption::AesEncryptor* InternalFileEncryptor::GetDataAesEncryptor( int key_len = static_cast(key_size); int index = MapKeyLenToEncryptorArrayIndex(key_len); if (data_encryptor_[index] == nullptr) { - data_encryptor_[index].reset( - encryption::AesEncryptor::Make(algorithm, key_len, false, &all_encryptors_)); + data_encryptor_[index] = encryption::AesEncryptor::Make(algorithm, key_len, false); } return data_encryptor_[index].get(); } diff --git a/cpp/src/parquet/encryption/internal_file_encryptor.h b/cpp/src/parquet/encryption/internal_file_encryptor.h index 41ffc6fd51943..91b6e9fe5aa2f 100644 --- a/cpp/src/parquet/encryption/internal_file_encryptor.h +++ b/cpp/src/parquet/encryption/internal_file_encryptor.h @@ -88,8 +88,6 @@ class InternalFileEncryptor { std::shared_ptr footer_signing_encryptor_; std::shared_ptr footer_encryptor_; - std::vector all_encryptors_; - // Key must be 16, 24 or 32 bytes in length. Thus there could be up to three // types of meta_encryptors and data_encryptors. std::unique_ptr meta_encryptor_[3]; @@ -105,7 +103,7 @@ class InternalFileEncryptor { encryption::AesEncryptor* GetDataAesEncryptor(ParquetCipher::type algorithm, size_t key_len); - int MapKeyLenToEncryptorArrayIndex(int key_len); + int MapKeyLenToEncryptorArrayIndex(int key_len) const; }; } // namespace parquet diff --git a/cpp/src/parquet/metadata.cc b/cpp/src/parquet/metadata.cc index 4ea3b05340d71..139793219df90 100644 --- a/cpp/src/parquet/metadata.cc +++ b/cpp/src/parquet/metadata.cc @@ -21,6 +21,8 @@ #include #include #include +#include +#include #include #include #include @@ -29,6 +31,7 @@ #include "arrow/io/memory.h" #include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" +#include "arrow/util/pcg_random.h" #include "parquet/encryption/encryption_internal.h" #include "parquet/encryption/internal_file_decryptor.h" #include "parquet/exception.h" @@ -599,6 +602,49 @@ std::vector RowGroupMetaData::sorting_columns() const { return impl_->sorting_columns(); } +// Replace string data with random-generated uppercase characters +static void Scrub(std::string* s) { + static ::arrow::random::pcg64 rng; + std::uniform_int_distribution<> caps(65, 90); + for (auto& c : *s) c = caps(rng); +} + +// Replace potentially sensitive metadata with random data +static void Scrub(format::FileMetaData* md) { + for (auto& s : md->schema) { + Scrub(&s.name); + } + for (auto& r : md->row_groups) { + for (auto& c : r.columns) { + Scrub(&c.file_path); + if (c.__isset.meta_data) { + auto& m = c.meta_data; + for (auto& p : m.path_in_schema) Scrub(&p); + for (auto& kv : m.key_value_metadata) { + Scrub(&kv.key); + Scrub(&kv.value); + } + Scrub(&m.statistics.max_value); + Scrub(&m.statistics.min_value); + Scrub(&m.statistics.min); + Scrub(&m.statistics.max); + } + + if (c.crypto_metadata.__isset.ENCRYPTION_WITH_COLUMN_KEY) { + auto& m = c.crypto_metadata.ENCRYPTION_WITH_COLUMN_KEY; + for (auto& p : m.path_in_schema) Scrub(&p); + Scrub(&m.key_metadata); + } + Scrub(&c.encrypted_column_metadata); + } + } + for (auto& kv : md->key_value_metadata) { + Scrub(&kv.key); + Scrub(&kv.value); + } + Scrub(&md->footer_signing_key_metadata); +} + // file metadata class FileMetaData::FileMetaDataImpl { public: @@ -651,9 +697,9 @@ class FileMetaData::FileMetaDataImpl { std::string key = file_decryptor_->GetFooterKey(); std::string aad = encryption::CreateFooterAad(file_decryptor_->file_aad()); - auto aes_encryptor = encryption::AesEncryptor::Make( - file_decryptor_->algorithm(), static_cast(key.size()), true, - false /*write_length*/, nullptr); + auto aes_encryptor = encryption::AesEncryptor::Make(file_decryptor_->algorithm(), + static_cast(key.size()), + true, false /*write_length*/); std::shared_ptr encrypted_buffer = AllocateBuffer( file_decryptor_->pool(), aes_encryptor->CiphertextLength(serialized_len)); @@ -662,7 +708,6 @@ class FileMetaData::FileMetaDataImpl { encrypted_buffer->mutable_span_as()); // Delete AES encryptor object. It was created only to verify the footer signature. aes_encryptor->WipeOut(); - delete aes_encryptor; return 0 == memcmp(encrypted_buffer->data() + encrypted_len - encryption::kGcmTagLength, tag, encryption::kGcmTagLength); @@ -822,6 +867,21 @@ class FileMetaData::FileMetaDataImpl { return out; } + std::string SerializeUnencrypted(bool scrub, bool debug) const { + auto md = *metadata_; + if (scrub) Scrub(&md); + if (debug) { + std::ostringstream ss; + md.printTo(ss); + return ss.str(); + } else { + ThriftSerializer serializer; + std::string out; + serializer.SerializeToString(&md, &out); + return out; + } + } + void set_file_decryptor(std::shared_ptr file_decryptor) { file_decryptor_ = std::move(file_decryptor); } @@ -993,6 +1053,10 @@ std::shared_ptr FileMetaData::Subset( return impl_->Subset(row_groups); } +std::string FileMetaData::SerializeUnencrypted(bool scrub, bool debug) const { + return impl_->SerializeUnencrypted(scrub, debug); +} + void FileMetaData::WriteTo(::arrow::io::OutputStream* dst, const std::shared_ptr& encryptor) const { return impl_->WriteTo(dst, encryptor); diff --git a/cpp/src/parquet/metadata.h b/cpp/src/parquet/metadata.h index 9fc30df58e0d3..e02d2e7c852f0 100644 --- a/cpp/src/parquet/metadata.h +++ b/cpp/src/parquet/metadata.h @@ -396,6 +396,13 @@ class PARQUET_EXPORT FileMetaData { /// FileMetaData. std::shared_ptr Subset(const std::vector& row_groups) const; + /// \brief Serialize metadata unencrypted as string + /// + /// \param[in] scrub whether to remove sensitive information from the metadata. + /// \param[in] debug whether to serialize the metadata as Thrift (if false) or + /// debug text (if true). + std::string SerializeUnencrypted(bool scrub, bool debug) const; + private: friend FileMetaDataBuilder; friend class SerializedFile; diff --git a/cpp/tools/parquet/CMakeLists.txt b/cpp/tools/parquet/CMakeLists.txt index 81ab49421d0f6..87c3254607589 100644 --- a/cpp/tools/parquet/CMakeLists.txt +++ b/cpp/tools/parquet/CMakeLists.txt @@ -16,7 +16,7 @@ # under the License. if(PARQUET_BUILD_EXECUTABLES) - set(PARQUET_TOOLS parquet-dump-schema parquet-reader parquet-scan) + set(PARQUET_TOOLS parquet-dump-footer parquet-dump-schema parquet-reader parquet-scan) foreach(TOOL ${PARQUET_TOOLS}) string(REGEX REPLACE "-" "_" TOOL_SOURCE ${TOOL}) diff --git a/cpp/tools/parquet/parquet_dump_footer.cc b/cpp/tools/parquet/parquet_dump_footer.cc new file mode 100644 index 0000000000000..4dd7476bc8ea3 --- /dev/null +++ b/cpp/tools/parquet/parquet_dump_footer.cc @@ -0,0 +1,135 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include + +#include "arrow/filesystem/filesystem.h" +#include "arrow/util/endian.h" +#include "arrow/util/ubsan.h" +#include "parquet/metadata.h" + +namespace parquet { +namespace { +uint32_t ReadLE32(const void* p) { + uint32_t x = ::arrow::util::SafeLoadAs(static_cast(p)); + return ::arrow::bit_util::FromLittleEndian(x); +} + +void AppendLE32(uint32_t v, std::string* out) { + v = ::arrow::bit_util::ToLittleEndian(v); + out->append(reinterpret_cast(&v), sizeof(v)); +} + +int DoIt(std::string in, bool scrub, bool debug, std::string out) { + std::string path; + auto fs = ::arrow::fs::FileSystemFromUriOrPath(in, &path).ValueOrDie(); + auto file = fs->OpenInputFile(path).ValueOrDie(); + int64_t file_len = file->GetSize().ValueOrDie(); + if (file_len < 8) { + std::cerr << "File too short: " << in << "\n"; + return 3; + } + // First do an opportunistic read of up to 1 MiB to try and get the entire footer. + int64_t tail_len = std::min(file_len, int64_t{1} << 20); + std::string tail; + tail.resize(tail_len); + char* data = tail.data(); + file->ReadAt(file_len - tail_len, tail_len, data).ValueOrDie(); + if (auto magic = ReadLE32(data + tail_len - 4); magic != ReadLE32("PAR1")) { + std::cerr << "Not a Parquet file: " << in << "\n"; + return 4; + } + uint32_t metadata_len = ReadLE32(data + tail_len - 8); + if (tail_len >= metadata_len + 8) { + // The footer is entirely in the initial read. Trim to size. + tail = tail.substr(tail_len - (metadata_len + 8)); + } else { + // The footer is larger than the initial read, read again the exact size. + if (metadata_len > file_len) { + std::cerr << "File too short: " << in << "\n"; + return 5; + } + tail_len = metadata_len + 8; + tail.resize(tail_len); + data = tail.data(); + file->ReadAt(file_len - tail_len, tail_len, data).ValueOrDie(); + } + auto md = FileMetaData::Make(tail.data(), &metadata_len); + std::string ser = md->SerializeUnencrypted(scrub, debug); + if (!debug) { + AppendLE32(static_cast(ser.size()), &ser); + ser.append("PAR1", 4); + } + std::optional fout; + if (!out.empty()) fout.emplace(out, std::ios::out); + std::ostream& os = fout ? *fout : std::cout; + if (!os.write(ser.data(), ser.size())) { + std::cerr << "Failed to write to output file: " << out << "\n"; + return 6; + } + + return 0; +} +} // namespace +} // namespace parquet + +static int PrintHelp() { + std::cerr << R"(Usage: parquet-dump-footer + -h|--help Print help and exit + --no-scrub Do not scrub potentially confidential metadata + --debug Output text represenation of footer for inspection + --in Input file (required): must be an URI or an absolute local path + --out Output file (optional, default stdout) + + Dump the footer of a Parquet file to stdout or to a file, optionally with + potentially confidential metadata scrubbed. +)"; + return 1; +} + +int main(int argc, char** argv) { + bool scrub = true; + bool debug = false; + std::string in; + std::string out; + for (int i = 1; i < argc; i++) { + char* arg = argv[i]; + if (!std::strcmp(arg, "-h") || !std::strcmp(arg, "--help")) { + return PrintHelp(); + } else if (!std::strcmp(arg, "--no-scrub")) { + scrub = false; + } else if (!std::strcmp(arg, "--debug")) { + debug = true; + } else if (!std::strcmp(arg, "--in")) { + if (i + 1 >= argc) return PrintHelp(); + in = argv[++i]; + } else if (!std::strcmp(arg, "--out")) { + if (i + 1 >= argc) return PrintHelp(); + out = argv[++i]; + } else { + // Unknown option. + return PrintHelp(); + } + } + if (in.empty()) return PrintHelp(); + + return parquet::DoIt(in, scrub, debug, out); +} diff --git a/csharp/README.md b/csharp/README.md index b36eb899db2d5..663aaf8ab243c 100644 --- a/csharp/README.md +++ b/csharp/README.md @@ -129,7 +129,8 @@ for currently available features. - Types - Tensor - Arrays - - Large Arrays + - Large Arrays. There are large array types provided to help with interoperability with other libraries, + but these do not support buffers larger than 2 GiB and an exception will be raised if trying to import an array that is too large. - Large Binary - Large List - Large String diff --git a/csharp/src/Apache.Arrow.Flight.AspNetCore/Apache.Arrow.Flight.AspNetCore.csproj b/csharp/src/Apache.Arrow.Flight.AspNetCore/Apache.Arrow.Flight.AspNetCore.csproj index 2dd1d9d8f98e2..ac1f8c9bae77a 100644 --- a/csharp/src/Apache.Arrow.Flight.AspNetCore/Apache.Arrow.Flight.AspNetCore.csproj +++ b/csharp/src/Apache.Arrow.Flight.AspNetCore/Apache.Arrow.Flight.AspNetCore.csproj @@ -5,7 +5,7 @@ - + diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Apache.Arrow.Flight.Sql.csproj b/csharp/src/Apache.Arrow.Flight.Sql/Apache.Arrow.Flight.Sql.csproj index ee6d42c8d17fc..1870888184906 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Apache.Arrow.Flight.Sql.csproj +++ b/csharp/src/Apache.Arrow.Flight.Sql/Apache.Arrow.Flight.Sql.csproj @@ -5,7 +5,7 @@ - + diff --git a/csharp/src/Apache.Arrow.Flight/Apache.Arrow.Flight.csproj b/csharp/src/Apache.Arrow.Flight/Apache.Arrow.Flight.csproj index 5030d37cdb16d..5334f877873e4 100644 --- a/csharp/src/Apache.Arrow.Flight/Apache.Arrow.Flight.csproj +++ b/csharp/src/Apache.Arrow.Flight/Apache.Arrow.Flight.csproj @@ -6,8 +6,8 @@ - - + + diff --git a/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs b/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs index 67c4b21a2e531..bd06c3a1b8b14 100644 --- a/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs +++ b/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs @@ -53,18 +53,24 @@ public static IArrowArray BuildArray(ArrayData data) return new StringArray(data); case ArrowTypeId.StringView: return new StringViewArray(data); + case ArrowTypeId.LargeString: + return new LargeStringArray(data); case ArrowTypeId.FixedSizedBinary: return new FixedSizeBinaryArray(data); case ArrowTypeId.Binary: return new BinaryArray(data); case ArrowTypeId.BinaryView: return new BinaryViewArray(data); + case ArrowTypeId.LargeBinary: + return new LargeBinaryArray(data); case ArrowTypeId.Timestamp: return new TimestampArray(data); case ArrowTypeId.List: return new ListArray(data); case ArrowTypeId.ListView: return new ListViewArray(data); + case ArrowTypeId.LargeList: + return new LargeListArray(data); case ArrowTypeId.Map: return new MapArray(data); case ArrowTypeId.Struct: diff --git a/csharp/src/Apache.Arrow/Arrays/LargeBinaryArray.cs b/csharp/src/Apache.Arrow/Arrays/LargeBinaryArray.cs new file mode 100644 index 0000000000000..9eddbedab54ed --- /dev/null +++ b/csharp/src/Apache.Arrow/Arrays/LargeBinaryArray.cs @@ -0,0 +1,154 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Apache.Arrow.Types; +using System; +using System.Collections; +using System.Collections.Generic; +using System.Runtime.CompilerServices; + +namespace Apache.Arrow; + +public class LargeBinaryArray : Array, IReadOnlyList, ICollection +{ + public LargeBinaryArray(ArrayData data) + : base(data) + { + data.EnsureDataType(ArrowTypeId.LargeBinary); + data.EnsureBufferCount(3); + } + + public LargeBinaryArray(ArrowTypeId typeId, ArrayData data) + : base(data) + { + data.EnsureDataType(typeId); + data.EnsureBufferCount(3); + } + + public LargeBinaryArray(IArrowType dataType, int length, + ArrowBuffer valueOffsetsBuffer, + ArrowBuffer dataBuffer, + ArrowBuffer nullBitmapBuffer, + int nullCount = 0, int offset = 0) + : this(new ArrayData(dataType, length, nullCount, offset, + new[] { nullBitmapBuffer, valueOffsetsBuffer, dataBuffer })) + { } + + public override void Accept(IArrowArrayVisitor visitor) => Accept(this, visitor); + + public ArrowBuffer ValueOffsetsBuffer => Data.Buffers[1]; + + public ArrowBuffer ValueBuffer => Data.Buffers[2]; + + public ReadOnlySpan ValueOffsets => ValueOffsetsBuffer.Span.CastTo().Slice(Offset, Length + 1); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int GetValueLength(int index) + { + if (index < 0 || index >= Length) + { + throw new ArgumentOutOfRangeException(nameof(index)); + } + if (!IsValid(index)) + { + return 0; + } + + ReadOnlySpan offsets = ValueOffsets; + return checked((int)(offsets[index + 1] - offsets[index])); + } + + /// + /// Get the collection of bytes, as a read-only span, at a given index in the array. + /// + /// + /// Note that this method cannot reliably identify null values, which are indistinguishable from empty byte + /// collection values when seen in the context of this method's return type of . + /// Use the method or the overload instead + /// to reliably determine null values. + /// + /// Index at which to get bytes. + /// Returns a object. + /// If the index is negative or beyond the length of the array. + /// + public ReadOnlySpan GetBytes(int index) => GetBytes(index, out _); + + /// + /// Get the collection of bytes, as a read-only span, at a given index in the array. + /// + /// Index at which to get bytes. + /// Set to if the value at the given index is null. + /// Returns a object. + /// If the index is negative or beyond the length of the array. + /// + public ReadOnlySpan GetBytes(int index, out bool isNull) + { + if (index < 0 || index >= Length) + { + throw new ArgumentOutOfRangeException(nameof(index)); + } + + isNull = IsNull(index); + + if (isNull) + { + // Note that `return null;` is valid syntax, but would be misleading as `null` in the context of a span + // is actually returned as an empty span. + return ReadOnlySpan.Empty; + } + + var offset = checked((int)ValueOffsets[index]); + return ValueBuffer.Span.Slice(offset, GetValueLength(index)); + } + + int IReadOnlyCollection.Count => Length; + + byte[] IReadOnlyList.this[int index] => GetBytes(index).ToArray(); + + IEnumerator IEnumerable.GetEnumerator() + { + for (int index = 0; index < Length; index++) + { + yield return GetBytes(index).ToArray(); + } + } + + IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)this).GetEnumerator(); + + int ICollection.Count => Length; + bool ICollection.IsReadOnly => true; + void ICollection.Add(byte[] item) => throw new NotSupportedException("Collection is read-only."); + bool ICollection.Remove(byte[] item) => throw new NotSupportedException("Collection is read-only."); + void ICollection.Clear() => throw new NotSupportedException("Collection is read-only."); + + bool ICollection.Contains(byte[] item) + { + for (int index = 0; index < Length; index++) + { + if (GetBytes(index).SequenceEqual(item)) + return true; + } + + return false; + } + + void ICollection.CopyTo(byte[][] array, int arrayIndex) + { + for (int srcIndex = 0, destIndex = arrayIndex; srcIndex < Length; srcIndex++, destIndex++) + { + array[destIndex] = GetBytes(srcIndex).ToArray(); + } + } +} diff --git a/csharp/src/Apache.Arrow/Arrays/LargeListArray.cs b/csharp/src/Apache.Arrow/Arrays/LargeListArray.cs new file mode 100644 index 0000000000000..6e37aa4c63536 --- /dev/null +++ b/csharp/src/Apache.Arrow/Arrays/LargeListArray.cs @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using Apache.Arrow.Types; + +namespace Apache.Arrow +{ + public class LargeListArray : Array + { + public IArrowArray Values { get; } + + public ArrowBuffer ValueOffsetsBuffer => Data.Buffers[1]; + + public ReadOnlySpan ValueOffsets => ValueOffsetsBuffer.Span.CastTo().Slice(Offset, Length + 1); + + public LargeListArray(IArrowType dataType, int length, + ArrowBuffer valueOffsetsBuffer, IArrowArray values, + ArrowBuffer nullBitmapBuffer, int nullCount = 0, int offset = 0) + : this(new ArrayData(dataType, length, nullCount, offset, + new[] { nullBitmapBuffer, valueOffsetsBuffer }, new[] { values.Data }), + values) + { + } + + public LargeListArray(ArrayData data) + : this(data, ArrowArrayFactory.BuildArray(data.Children[0])) + { + } + + private LargeListArray(ArrayData data, IArrowArray values) : base(data) + { + data.EnsureBufferCount(2); + data.EnsureDataType(ArrowTypeId.LargeList); + Values = values; + } + + public override void Accept(IArrowArrayVisitor visitor) => Accept(this, visitor); + + public int GetValueLength(int index) + { + if (index < 0 || index >= Length) + { + throw new ArgumentOutOfRangeException(nameof(index)); + } + + if (IsNull(index)) + { + return 0; + } + + ReadOnlySpan offsets = ValueOffsets; + return checked((int)(offsets[index + 1] - offsets[index])); + } + + public IArrowArray GetSlicedValues(int index) + { + if (index < 0 || index >= Length) + { + throw new ArgumentOutOfRangeException(nameof(index)); + } + + if (IsNull(index)) + { + return null; + } + + if (!(Values is Array array)) + { + return default; + } + + return array.Slice(checked((int)ValueOffsets[index]), GetValueLength(index)); + } + + protected override void Dispose(bool disposing) + { + if (disposing) + { + Values?.Dispose(); + } + base.Dispose(disposing); + } + } +} diff --git a/csharp/src/Apache.Arrow/Arrays/LargeStringArray.cs b/csharp/src/Apache.Arrow/Arrays/LargeStringArray.cs new file mode 100644 index 0000000000000..2a65b828acfa1 --- /dev/null +++ b/csharp/src/Apache.Arrow/Arrays/LargeStringArray.cs @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; +using Apache.Arrow.Types; + +namespace Apache.Arrow; + +public class LargeStringArray: LargeBinaryArray, IReadOnlyList, ICollection +{ + public static readonly Encoding DefaultEncoding = StringArray.DefaultEncoding; + + public LargeStringArray(ArrayData data) + : base(ArrowTypeId.LargeString, data) { } + + public LargeStringArray(int length, + ArrowBuffer valueOffsetsBuffer, + ArrowBuffer dataBuffer, + ArrowBuffer nullBitmapBuffer, + int nullCount = 0, int offset = 0) + : this(new ArrayData(LargeStringType.Default, length, nullCount, offset, + new[] { nullBitmapBuffer, valueOffsetsBuffer, dataBuffer })) + { } + + public override void Accept(IArrowArrayVisitor visitor) => Accept(this, visitor); + + /// + /// Get the string value at the given index + /// + /// Input index + /// Optional: the string encoding, default is UTF8 + /// The string object at the given index + public string GetString(int index, Encoding encoding = default) + { + encoding ??= DefaultEncoding; + + ReadOnlySpan bytes = GetBytes(index, out bool isNull); + + if (isNull) + { + return null; + } + + if (bytes.Length == 0) + { + return string.Empty; + } + + unsafe + { + fixed (byte* data = &MemoryMarshal.GetReference(bytes)) + { + return encoding.GetString(data, bytes.Length); + } + } + } + + + int IReadOnlyCollection.Count => Length; + + string IReadOnlyList.this[int index] => GetString(index); + + IEnumerator IEnumerable.GetEnumerator() + { + for (int index = 0; index < Length; index++) + { + yield return GetString(index); + }; + } + + IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)this).GetEnumerator(); + + int ICollection.Count => Length; + bool ICollection.IsReadOnly => true; + void ICollection.Add(string item) => throw new NotSupportedException("Collection is read-only."); + bool ICollection.Remove(string item) => throw new NotSupportedException("Collection is read-only."); + void ICollection.Clear() => throw new NotSupportedException("Collection is read-only."); + + bool ICollection.Contains(string item) + { + for (int index = 0; index < Length; index++) + { + if (GetString(index) == item) + return true; + } + + return false; + } + + void ICollection.CopyTo(string[] array, int arrayIndex) + { + for (int srcIndex = 0, destIndex = arrayIndex; srcIndex < Length; srcIndex++, destIndex++) + { + array[destIndex] = GetString(srcIndex); + } + } +} diff --git a/csharp/src/Apache.Arrow/C/CArrowArrayImporter.cs b/csharp/src/Apache.Arrow/C/CArrowArrayImporter.cs index abe02dcbb591f..68b67f3d7c620 100644 --- a/csharp/src/Apache.Arrow/C/CArrowArrayImporter.cs +++ b/csharp/src/Apache.Arrow/C/CArrowArrayImporter.cs @@ -162,6 +162,10 @@ private ArrayData GetAsArrayData(CArrowArray* cArray, IArrowType type) case ArrowTypeId.BinaryView: buffers = ImportByteArrayViewBuffers(cArray); break; + case ArrowTypeId.LargeString: + case ArrowTypeId.LargeBinary: + buffers = ImportLargeByteArrayBuffers(cArray); + break; case ArrowTypeId.List: children = ProcessListChildren(cArray, ((ListType)type).ValueDataType); buffers = ImportListBuffers(cArray); @@ -170,6 +174,10 @@ private ArrayData GetAsArrayData(CArrowArray* cArray, IArrowType type) children = ProcessListChildren(cArray, ((ListViewType)type).ValueDataType); buffers = ImportListViewBuffers(cArray); break; + case ArrowTypeId.LargeList: + children = ProcessListChildren(cArray, ((LargeListType)type).ValueDataType); + buffers = ImportLargeListBuffers(cArray); + break; case ArrowTypeId.FixedSizeList: children = ProcessListChildren(cArray, ((FixedSizeListType)type).ValueDataType); buffers = ImportFixedSizeListBuffers(cArray); @@ -313,6 +321,42 @@ private ArrowBuffer[] ImportByteArrayViewBuffers(CArrowArray* cArray) return buffers; } + private ArrowBuffer[] ImportLargeByteArrayBuffers(CArrowArray* cArray) + { + if (cArray->n_buffers != 3) + { + throw new InvalidOperationException("Large byte arrays are expected to have exactly three buffers"); + } + + const int maxLength = int.MaxValue / 8 - 1; + if (cArray->length > maxLength) + { + throw new OverflowException( + $"Cannot import large byte array. Array length {cArray->length} " + + $"is greater than the maximum supported large byte array length ({maxLength})"); + } + + int length = (int)cArray->length; + int offsetsLength = (length + 1) * 8; + long* offsets = (long*)cArray->buffers[1]; + Debug.Assert(offsets != null); + long valuesLength = offsets[length]; + + if (valuesLength > int.MaxValue) + { + throw new OverflowException( + $"Cannot import large byte array. Data length {valuesLength} " + + $"is greater than the maximum supported large byte array data length ({int.MaxValue})"); + } + + ArrowBuffer[] buffers = new ArrowBuffer[3]; + buffers[0] = ImportValidityBuffer(cArray); + buffers[1] = ImportCArrayBuffer(cArray, 1, offsetsLength); + buffers[2] = ImportCArrayBuffer(cArray, 2, (int)valuesLength); + + return buffers; + } + private ArrowBuffer[] ImportListBuffers(CArrowArray* cArray) { if (cArray->n_buffers != 2) @@ -348,6 +392,31 @@ private ArrowBuffer[] ImportListViewBuffers(CArrowArray* cArray) return buffers; } + private ArrowBuffer[] ImportLargeListBuffers(CArrowArray* cArray) + { + if (cArray->n_buffers != 2) + { + throw new InvalidOperationException("Large list arrays are expected to have exactly two buffers"); + } + + const int maxLength = int.MaxValue / 8 - 1; + if (cArray->length > maxLength) + { + throw new OverflowException( + $"Cannot import large list array. Array length {cArray->length} " + + $"is greater than the maximum supported large list array length ({maxLength})"); + } + + int length = (int)cArray->length; + int offsetsLength = (length + 1) * 8; + + ArrowBuffer[] buffers = new ArrowBuffer[2]; + buffers[0] = ImportValidityBuffer(cArray); + buffers[1] = ImportCArrayBuffer(cArray, 1, offsetsLength); + + return buffers; + } + private ArrowBuffer[] ImportFixedSizeListBuffers(CArrowArray* cArray) { if (cArray->n_buffers != 1) diff --git a/csharp/src/Apache.Arrow/C/CArrowSchemaExporter.cs b/csharp/src/Apache.Arrow/C/CArrowSchemaExporter.cs index 3bb7134af3ba9..92d48a2d70880 100644 --- a/csharp/src/Apache.Arrow/C/CArrowSchemaExporter.cs +++ b/csharp/src/Apache.Arrow/C/CArrowSchemaExporter.cs @@ -168,8 +168,10 @@ private static string GetFormat(IArrowType datatype) // Binary case BinaryType _: return "z"; case BinaryViewType _: return "vz"; + case LargeBinaryType _: return "Z"; case StringType _: return "u"; case StringViewType _: return "vu"; + case LargeStringType _: return "U"; case FixedSizeBinaryType binaryType: return $"w:{binaryType.ByteWidth}"; // Date @@ -199,6 +201,7 @@ private static string GetFormat(IArrowType datatype) // Nested case ListType _: return "+l"; case ListViewType _: return "+vl"; + case LargeListType _: return "+L"; case FixedSizeListType fixedListType: return $"+w:{fixedListType.ListSize}"; case StructType _: return "+s"; @@ -208,7 +211,7 @@ private static string GetFormat(IArrowType datatype) case DictionaryType dictionaryType: return GetFormat(dictionaryType.IndexType); default: throw new NotImplementedException($"Exporting {datatype.Name} not implemented"); - }; + } } private static long GetFlags(IArrowType datatype, bool nullable = true) diff --git a/csharp/src/Apache.Arrow/C/CArrowSchemaImporter.cs b/csharp/src/Apache.Arrow/C/CArrowSchemaImporter.cs index f1acc007bcef7..94177184dea00 100644 --- a/csharp/src/Apache.Arrow/C/CArrowSchemaImporter.cs +++ b/csharp/src/Apache.Arrow/C/CArrowSchemaImporter.cs @@ -165,7 +165,7 @@ public ArrowType GetAsType() } // Special handling for nested types - if (format == "+l" || format == "+vl") + if (format == "+l" || format == "+vl" || format == "+L") { if (_cSchema->n_children != 1) { @@ -180,7 +180,13 @@ public ArrowType GetAsType() Field childField = childSchema.GetAsField(); - return format[1] == 'v' ? new ListViewType(childField) : new ListType(childField); + return format[1] switch + { + 'l' => new ListType(childField), + 'v' => new ListViewType(childField), + 'L' => new LargeListType(childField), + _ => throw new InvalidDataException($"Invalid format for list: '{format}'"), + }; } else if (format == "+s") { @@ -304,10 +310,10 @@ public ArrowType GetAsType() // Binary data "z" => BinaryType.Default, "vz" => BinaryViewType.Default, - //"Z" => new LargeBinaryType() // Not yet implemented + "Z" => LargeBinaryType.Default, "u" => StringType.Default, "vu" => StringViewType.Default, - //"U" => new LargeStringType(), // Not yet implemented + "U" => LargeStringType.Default, // Date and time "tdD" => Date32Type.Default, "tdm" => Date64Type.Default, diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs index a37c501072f4b..7e766677f8b28 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs @@ -291,6 +291,8 @@ private ArrayData LoadField( break; case ArrowTypeId.String: case ArrowTypeId.Binary: + case ArrowTypeId.LargeString: + case ArrowTypeId.LargeBinary: case ArrowTypeId.ListView: buffers = 3; break; diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs index 5583a58487bf5..12a2a17cf04e2 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs @@ -132,7 +132,13 @@ protected ReadResult ReadMessage() Flatbuf.Message message = Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuff)); - int bodyLength = checked((int)message.BodyLength); + if (message.BodyLength > int.MaxValue) + { + throw new OverflowException( + $"Arrow IPC message body length ({message.BodyLength}) is larger than " + + $"the maximum supported message size ({int.MaxValue})"); + } + int bodyLength = (int)message.BodyLength; IMemoryOwner bodyBuffOwner = _allocator.Allocate(bodyLength); Memory bodyBuff = bodyBuffOwner.Memory.Slice(0, bodyLength); diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs index c66569afeba85..eaa8471fa7bd3 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs @@ -57,11 +57,14 @@ private class ArrowRecordBatchFlatBufferBuilder : IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, + IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, + IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, + IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, @@ -199,6 +202,28 @@ public void Visit(ListViewArray array) VisitArray(values); } + public void Visit(LargeListArray array) + { + _buffers.Add(CreateBitmapBuffer(array.NullBitmapBuffer, array.Offset, array.Length)); + _buffers.Add(CreateBuffer(GetZeroBasedLongValueOffsets(array.ValueOffsetsBuffer, array.Offset, array.Length))); + + int valuesOffset = 0; + int valuesLength = 0; + if (array.Length > 0) + { + valuesOffset = checked((int)array.ValueOffsets[0]); + valuesLength = checked((int)array.ValueOffsets[array.Length] - valuesOffset); + } + + var values = array.Values; + if (valuesOffset > 0 || valuesLength < values.Length) + { + values = ArrowArrayFactory.Slice(values, valuesOffset, valuesLength); + } + + VisitArray(values); + } + public void Visit(FixedSizeListArray array) { _buffers.Add(CreateBitmapBuffer(array.NullBitmapBuffer, array.Offset, array.Length)); @@ -214,6 +239,8 @@ public void Visit(FixedSizeListArray array) public void Visit(StringViewArray array) => Visit(array as BinaryViewArray); + public void Visit(LargeStringArray array) => Visit(array as LargeBinaryArray); + public void Visit(BinaryArray array) { _buffers.Add(CreateBitmapBuffer(array.NullBitmapBuffer, array.Offset, array.Length)); @@ -242,6 +269,22 @@ public void Visit(BinaryViewArray array) VariadicCounts.Add(array.DataBufferCount); } + public void Visit(LargeBinaryArray array) + { + _buffers.Add(CreateBitmapBuffer(array.NullBitmapBuffer, array.Offset, array.Length)); + _buffers.Add(CreateBuffer(GetZeroBasedLongValueOffsets(array.ValueOffsetsBuffer, array.Offset, array.Length))); + + int valuesOffset = 0; + int valuesLength = 0; + if (array.Length > 0) + { + valuesOffset = checked((int)array.ValueOffsets[0]); + valuesLength = checked((int)array.ValueOffsets[array.Length]) - valuesOffset; + } + + _buffers.Add(CreateSlicedBuffer(array.ValueBuffer, valuesOffset, valuesLength)); + } + public void Visit(FixedSizeBinaryArray array) { var itemSize = ((FixedSizeBinaryType)array.Data.DataType).ByteWidth; @@ -327,6 +370,39 @@ private ArrowBuffer GetZeroBasedValueOffsets(ArrowBuffer valueOffsetsBuffer, int } } + private ArrowBuffer GetZeroBasedLongValueOffsets(ArrowBuffer valueOffsetsBuffer, int arrayOffset, int arrayLength) + { + var requiredBytes = CalculatePaddedBufferLength(checked(sizeof(long) * (arrayLength + 1))); + + if (arrayOffset != 0) + { + // Array has been sliced, so we need to shift and adjust the offsets + var originalOffsets = valueOffsetsBuffer.Span.CastTo().Slice(arrayOffset, arrayLength + 1); + var firstOffset = arrayLength > 0 ? originalOffsets[0] : 0L; + + var newValueOffsetsBuffer = _allocator.Allocate(requiredBytes); + var newValueOffsets = newValueOffsetsBuffer.Memory.Span.CastTo(); + + for (int i = 0; i < arrayLength + 1; ++i) + { + newValueOffsets[i] = originalOffsets[i] - firstOffset; + } + + return new ArrowBuffer(newValueOffsetsBuffer); + } + else if (valueOffsetsBuffer.Length > requiredBytes) + { + // Array may have been sliced but the offset is zero, + // so we can truncate the existing offsets + return new ArrowBuffer(valueOffsetsBuffer.Memory.Slice(0, requiredBytes)); + } + else + { + // Use the full buffer + return valueOffsetsBuffer; + } + } + private (ArrowBuffer Buffer, int minOffset, int maxEnd) GetZeroBasedListViewOffsets(ListViewArray array) { if (array.Length == 0) diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs b/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs index 473e18968f8cb..adc229a051227 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs @@ -57,6 +57,7 @@ class TypeVisitor : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, + IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, @@ -65,9 +66,11 @@ class TypeVisitor : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, + IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, + IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, @@ -120,6 +123,14 @@ public void Visit(BinaryViewType type) Flatbuf.Type.BinaryView, offset); } + public void Visit(LargeBinaryType type) + { + Flatbuf.LargeBinary.StartLargeBinary(Builder); + Offset offset = Flatbuf.LargeBinary.EndLargeBinary(Builder); + Result = FieldType.Build( + Flatbuf.Type.LargeBinary, offset); + } + public void Visit(ListType type) { Flatbuf.List.StartList(Builder); @@ -136,6 +147,14 @@ public void Visit(ListViewType type) Flatbuf.ListView.EndListView(Builder)); } + public void Visit(LargeListType type) + { + Flatbuf.LargeList.StartLargeList(Builder); + Result = FieldType.Build( + Flatbuf.Type.LargeList, + Flatbuf.LargeList.EndLargeList(Builder)); + } + public void Visit(FixedSizeListType type) { Result = FieldType.Build( @@ -166,6 +185,14 @@ public void Visit(StringViewType type) Flatbuf.Type.Utf8View, offset); } + public void Visit(LargeStringType type) + { + Flatbuf.LargeUtf8.StartLargeUtf8(Builder); + Offset offset = Flatbuf.LargeUtf8.EndLargeUtf8(Builder); + Result = FieldType.Build( + Flatbuf.Type.LargeUtf8, offset); + } + public void Visit(TimestampType type) { StringOffset timezoneStringOffset = default; @@ -363,7 +390,7 @@ private static Flatbuf.IntervalUnit ToFlatBuffer(Types.IntervalUnit unit) Types.IntervalUnit.DayTime => Flatbuf.IntervalUnit.DAY_TIME, Types.IntervalUnit.MonthDayNanosecond => Flatbuf.IntervalUnit.MONTH_DAY_NANO, _ => throw new ArgumentException($"unsupported interval unit <{unit}>", nameof(unit)) - }; ; + }; } } } diff --git a/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs b/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs index 0e6f330aef091..8e15632c517e1 100644 --- a/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs +++ b/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs @@ -186,6 +186,8 @@ private static Types.IArrowType GetFieldArrowType(Flatbuf.Field field, Field[] c return Types.StringType.Default; case Flatbuf.Type.Utf8View: return Types.StringViewType.Default; + case Flatbuf.Type.LargeUtf8: + return Types.LargeStringType.Default; case Flatbuf.Type.FixedSizeBinary: Flatbuf.FixedSizeBinary fixedSizeBinaryMetadata = field.Type().Value; return new Types.FixedSizeBinaryType(fixedSizeBinaryMetadata.ByteWidth); @@ -193,6 +195,8 @@ private static Types.IArrowType GetFieldArrowType(Flatbuf.Field field, Field[] c return Types.BinaryType.Default; case Flatbuf.Type.BinaryView: return Types.BinaryViewType.Default; + case Flatbuf.Type.LargeBinary: + return Types.LargeBinaryType.Default; case Flatbuf.Type.List: if (childFields == null || childFields.Length != 1) { @@ -205,6 +209,12 @@ private static Types.IArrowType GetFieldArrowType(Flatbuf.Field field, Field[] c throw new InvalidDataException($"List view type must have exactly one child."); } return new Types.ListViewType(childFields[0]); + case Flatbuf.Type.LargeList: + if (childFields == null || childFields.Length != 1) + { + throw new InvalidDataException($"Large list type must have exactly one child."); + } + return new Types.LargeListType(childFields[0]); case Flatbuf.Type.FixedSizeList: if (childFields == null || childFields.Length != 1) { diff --git a/csharp/src/Apache.Arrow/Types/IArrowType.cs b/csharp/src/Apache.Arrow/Types/IArrowType.cs index cf520391fe1e6..7a3159a1bbccd 100644 --- a/csharp/src/Apache.Arrow/Types/IArrowType.cs +++ b/csharp/src/Apache.Arrow/Types/IArrowType.cs @@ -53,6 +53,9 @@ public enum ArrowTypeId BinaryView, StringView, ListView, + LargeList, + LargeBinary, + LargeString, } public interface IArrowType diff --git a/csharp/src/Apache.Arrow/Types/LargeBinaryType.cs b/csharp/src/Apache.Arrow/Types/LargeBinaryType.cs new file mode 100644 index 0000000000000..e22c333824480 --- /dev/null +++ b/csharp/src/Apache.Arrow/Types/LargeBinaryType.cs @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace Apache.Arrow.Types; + +public class LargeBinaryType: ArrowType +{ + public static readonly LargeBinaryType Default = new LargeBinaryType(); + + public override ArrowTypeId TypeId => ArrowTypeId.LargeBinary; + + public override string Name => "large_binary"; + + public override void Accept(IArrowTypeVisitor visitor) => Accept(this, visitor); +} diff --git a/csharp/src/Apache.Arrow/Types/LargeListType.cs b/csharp/src/Apache.Arrow/Types/LargeListType.cs new file mode 100644 index 0000000000000..2fe8166972931 --- /dev/null +++ b/csharp/src/Apache.Arrow/Types/LargeListType.cs @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace Apache.Arrow.Types +{ + public sealed class LargeListType : NestedType + { + public override ArrowTypeId TypeId => ArrowTypeId.LargeList; + + public override string Name => "large_list"; + + public Field ValueField => Fields[0]; + + public IArrowType ValueDataType => Fields[0].DataType; + + public LargeListType(Field valueField) + : base(valueField) { } + + public LargeListType(IArrowType valueDataType) + : this(new Field("item", valueDataType, true)) { } + + public override void Accept(IArrowTypeVisitor visitor) => Accept(this, visitor); + } +} diff --git a/csharp/src/Apache.Arrow/Types/LargeStringType.cs b/csharp/src/Apache.Arrow/Types/LargeStringType.cs new file mode 100644 index 0000000000000..8698ca4747a0e --- /dev/null +++ b/csharp/src/Apache.Arrow/Types/LargeStringType.cs @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace Apache.Arrow.Types; + +public sealed class LargeStringType : ArrowType +{ + public static readonly LargeStringType Default = new LargeStringType(); + + public override ArrowTypeId TypeId => ArrowTypeId.LargeString; + + public override string Name => "large_utf8"; + + public override void Accept(IArrowTypeVisitor visitor) => Accept(this, visitor); +} diff --git a/csharp/test/Apache.Arrow.IntegrationTest/JsonFile.cs b/csharp/test/Apache.Arrow.IntegrationTest/JsonFile.cs index 7232f74b8bec6..c9e44b8d2f491 100644 --- a/csharp/test/Apache.Arrow.IntegrationTest/JsonFile.cs +++ b/csharp/test/Apache.Arrow.IntegrationTest/JsonFile.cs @@ -177,8 +177,10 @@ private static IArrowType ToArrowType(JsonArrowType type, Field[] children) "decimal" => ToDecimalArrowType(type), "binary" => BinaryType.Default, "binaryview" => BinaryViewType.Default, + "largebinary" => LargeBinaryType.Default, "utf8" => StringType.Default, "utf8view" => StringViewType.Default, + "largeutf8" => LargeStringType.Default, "fixedsizebinary" => new FixedSizeBinaryType(type.ByteWidth), "date" => ToDateArrowType(type), "time" => ToTimeArrowType(type), @@ -188,6 +190,7 @@ private static IArrowType ToArrowType(JsonArrowType type, Field[] children) "timestamp" => ToTimestampArrowType(type), "list" => ToListArrowType(type, children), "listview" => ToListViewArrowType(type, children), + "largelist" => ToLargeListArrowType(type, children), "fixedsizelist" => ToFixedSizeListArrowType(type, children), "struct" => ToStructArrowType(type, children), "union" => ToUnionArrowType(type, children), @@ -303,6 +306,11 @@ private static IArrowType ToListViewArrowType(JsonArrowType type, Field[] childr return new ListViewType(children[0]); } + private static IArrowType ToLargeListArrowType(JsonArrowType type, Field[] children) + { + return new LargeListType(children[0]); + } + private static IArrowType ToFixedSizeListArrowType(JsonArrowType type, Field[] children) { return new FixedSizeListType(children[0], type.ListSize); @@ -461,11 +469,14 @@ private class ArrayCreator : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, + IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, + IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, + IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, @@ -696,6 +707,24 @@ public void Visit(StringViewType type) Array = new StringViewArray(arrayData); } + public void Visit(LargeStringType type) + { + ArrowBuffer validityBuffer = GetValidityBuffer(out int nullCount); + ArrowBuffer offsetBuffer = GetLargeOffsetBuffer(); + + var json = JsonFieldData.Data.GetRawText(); + string[] values = JsonSerializer.Deserialize(json, s_options); + + ArrowBuffer.Builder valueBuilder = new ArrowBuffer.Builder(); + foreach (string value in values) + { + valueBuilder.Append(Encoding.UTF8.GetBytes(value)); + } + ArrowBuffer valueBuffer = valueBuilder.Build(default); + + Array = new LargeStringArray(JsonFieldData.Count, offsetBuffer, valueBuffer, validityBuffer, nullCount); + } + public void Visit(BinaryType type) { ArrowBuffer validityBuffer = GetValidityBuffer(out int nullCount); @@ -747,6 +776,25 @@ public void Visit(BinaryViewType type) Array = new BinaryViewArray(arrayData); } + public void Visit(LargeBinaryType type) + { + ArrowBuffer validityBuffer = GetValidityBuffer(out int nullCount); + ArrowBuffer offsetBuffer = GetLargeOffsetBuffer(); + + var json = JsonFieldData.Data.GetRawText(); + string[] values = JsonSerializer.Deserialize(json, s_options); + + ArrowBuffer.Builder valueBuilder = new ArrowBuffer.Builder(); + foreach (string value in values) + { + valueBuilder.Append(ConvertHexStringToByteArray(value)); + } + ArrowBuffer valueBuffer = valueBuilder.Build(default); + + ArrayData arrayData = new ArrayData(type, JsonFieldData.Count, nullCount, 0, new[] { validityBuffer, offsetBuffer, valueBuffer }); + Array = new LargeBinaryArray(arrayData); + } + public void Visit(FixedSizeBinaryType type) { ArrowBuffer validityBuffer = GetValidityBuffer(out int nullCount); @@ -796,6 +844,21 @@ public void Visit(ListViewType type) Array = new ListViewArray(arrayData); } + public void Visit(LargeListType type) + { + ArrowBuffer validityBuffer = GetValidityBuffer(out int nullCount); + ArrowBuffer offsetBuffer = GetLargeOffsetBuffer(); + + var data = JsonFieldData; + JsonFieldData = data.Children[0]; + type.ValueDataType.Accept(this); + JsonFieldData = data; + + ArrayData arrayData = new ArrayData(type, JsonFieldData.Count, nullCount, 0, + new[] { validityBuffer, offsetBuffer }, new[] { Array.Data }); + Array = new LargeListArray(arrayData); + } + public void Visit(FixedSizeListType type) { ArrowBuffer validityBuffer = GetValidityBuffer(out int nullCount); @@ -975,6 +1038,13 @@ private ArrowBuffer GetOffsetBuffer() return valueOffsets.Build(default); } + private ArrowBuffer GetLargeOffsetBuffer() + { + ArrowBuffer.Builder valueOffsets = new ArrowBuffer.Builder(JsonFieldData.Offset.Count); + valueOffsets.AppendRange(JsonFieldData.LongOffset); + return valueOffsets.Build(default); + } + private ArrowBuffer GetSizeBuffer() { ArrowBuffer.Builder valueSizes = new ArrowBuffer.Builder(JsonFieldData.Size.Count); @@ -1039,6 +1109,12 @@ public IEnumerable IntOffset get { return Offset.Select(GetInt); } } + [JsonIgnore] + public IEnumerable LongOffset + { + get { return Offset.Select(GetLong); } + } + [JsonIgnore] public IEnumerable IntSize { @@ -1056,6 +1132,18 @@ static int GetInt(JsonNode node) return int.Parse(node.GetValue()); } } + + static long GetLong(JsonNode node) + { + try + { + return node.GetValue(); + } + catch + { + return long.Parse(node.GetValue()); + } + } } public class JsonView diff --git a/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs b/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs index 5c33d1fd43986..85f7b75f931ef 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs @@ -95,12 +95,15 @@ private class ArrayComparer : IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, + IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, + IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, + IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, @@ -144,14 +147,17 @@ public ArrayComparer(IArrowArray expectedArray, bool strictCompare) public void Visit(MonthDayNanosecondIntervalArray array) => CompareArrays(array); public void Visit(ListArray array) => CompareArrays(array); public void Visit(ListViewArray array) => CompareArrays(array); + public void Visit(LargeListArray array) => CompareArrays(array); public void Visit(FixedSizeListArray array) => CompareArrays(array); public void Visit(FixedSizeBinaryArray array) => CompareArrays(array); public void Visit(Decimal128Array array) => CompareArrays(array); public void Visit(Decimal256Array array) => CompareArrays(array); public void Visit(StringArray array) => CompareBinaryArrays(array); public void Visit(StringViewArray array) => CompareVariadicArrays(array); + public void Visit(LargeStringArray array) => CompareLargeBinaryArrays(array); public void Visit(BinaryArray array) => CompareBinaryArrays(array); public void Visit(BinaryViewArray array) => CompareVariadicArrays(array); + public void Visit(LargeBinaryArray array) => CompareLargeBinaryArrays(array); public void Visit(StructArray array) { @@ -276,6 +282,40 @@ private void CompareBinaryArrays(BinaryArray actualArray) } } + private void CompareLargeBinaryArrays(LargeBinaryArray actualArray) + where T : IArrowArray + { + Assert.IsAssignableFrom(_expectedArray); + Assert.IsAssignableFrom(actualArray); + + var expectedArray = (LargeBinaryArray)_expectedArray; + + actualArray.Data.DataType.Accept(_arrayTypeComparer); + + Assert.Equal(expectedArray.Length, actualArray.Length); + Assert.Equal(expectedArray.NullCount, actualArray.NullCount); + + CompareValidityBuffer( + expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, + expectedArray.Offset, actualArray.NullBitmapBuffer, actualArray.Offset); + + if (_strictCompare) + { + Assert.Equal(expectedArray.Offset, actualArray.Offset); + Assert.True(expectedArray.ValueOffsetsBuffer.Span.SequenceEqual(actualArray.ValueOffsetsBuffer.Span)); + Assert.True(expectedArray.ValueBuffer.Span.Slice(0, expectedArray.Length).SequenceEqual(actualArray.ValueBuffer.Span.Slice(0, actualArray.Length))); + } + else + { + for (int i = 0; i < expectedArray.Length; i++) + { + Assert.True( + expectedArray.GetBytes(i).SequenceEqual(actualArray.GetBytes(i)), + $"LargeBinaryArray values do not match at index {i}."); + } + } + } + private void CompareVariadicArrays(BinaryViewArray actualArray) where T : IArrowArray { @@ -469,6 +509,44 @@ private void CompareArrays(ListViewArray actualArray) } } + private void CompareArrays(LargeListArray actualArray) + { + Assert.IsAssignableFrom(_expectedArray); + LargeListArray expectedArray = (LargeListArray)_expectedArray; + + actualArray.Data.DataType.Accept(_arrayTypeComparer); + + Assert.Equal(expectedArray.Length, actualArray.Length); + Assert.Equal(expectedArray.NullCount, actualArray.NullCount); + + CompareValidityBuffer( + expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, + expectedArray.Offset, actualArray.NullBitmapBuffer, actualArray.Offset); + + if (_strictCompare) + { + Assert.Equal(expectedArray.Offset, actualArray.Offset); + Assert.True(expectedArray.ValueOffsetsBuffer.Span.SequenceEqual(actualArray.ValueOffsetsBuffer.Span)); + actualArray.Values.Accept(new ArrayComparer(expectedArray.Values, _strictCompare)); + } + else + { + for (int i = 0; i < actualArray.Length; ++i) + { + if (expectedArray.IsNull(i)) + { + Assert.True(actualArray.IsNull(i)); + } + else + { + var expectedList = expectedArray.GetSlicedValues(i); + var actualList = actualArray.GetSlicedValues(i); + actualList.Accept(new ArrayComparer(expectedList, _strictCompare)); + } + } + } + } + private void CompareArrays(FixedSizeListArray actualArray) { Assert.IsAssignableFrom(_expectedArray); diff --git a/csharp/test/Apache.Arrow.Tests/LargeBinaryArrayTests.cs b/csharp/test/Apache.Arrow.Tests/LargeBinaryArrayTests.cs new file mode 100644 index 0000000000000..4ee1f1d0e0ffa --- /dev/null +++ b/csharp/test/Apache.Arrow.Tests/LargeBinaryArrayTests.cs @@ -0,0 +1,95 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using Apache.Arrow.Types; +using Xunit; + +namespace Apache.Arrow.Tests; + +public class LargeBinaryArrayTests +{ + [Fact] + public void GetBytesReturnsCorrectValue() + { + var byteArrays = new byte[][] + { + new byte[] {0, 1, 2, 255}, + new byte[] {3, 4, 5}, + new byte[] {}, + null, + new byte[] {254, 253, 252}, + }; + var array = BuildArray(byteArrays); + + Assert.Equal(array.Length, byteArrays.Length); + for (var i = 0; i < byteArrays.Length; ++i) + { + var byteSpan = array.GetBytes(i, out var isNull); + var byteArray = isNull ? null : byteSpan.ToArray(); + Assert.Equal(byteArrays[i], byteArray); + } + } + + [Fact] + public void GetBytesChecksForOffsetOverflow() + { + var valueBuffer = new ArrowBuffer.Builder(); + var offsetBuffer = new ArrowBuffer.Builder(); + var validityBuffer = new ArrowBuffer.BitmapBuilder(); + + offsetBuffer.Append(0); + offsetBuffer.Append((long)int.MaxValue + 1); + validityBuffer.Append(true); + + var array = new LargeBinaryArray( + LargeBinaryType.Default, length: 1, + offsetBuffer.Build(), valueBuffer.Build(), validityBuffer.Build(), + validityBuffer.UnsetBitCount); + + Assert.Throws(() => array.GetBytes(0)); + } + + private static LargeBinaryArray BuildArray(IReadOnlyCollection byteArrays) + { + var valueBuffer = new ArrowBuffer.Builder(); + var offsetBuffer = new ArrowBuffer.Builder(); + var validityBuffer = new ArrowBuffer.BitmapBuilder(); + + long offset = 0; + offsetBuffer.Append(offset); + foreach (var bytes in byteArrays) + { + if (bytes == null) + { + validityBuffer.Append(false); + offsetBuffer.Append(offset); + } + else + { + valueBuffer.Append(bytes); + offset += bytes.Length; + offsetBuffer.Append(offset); + validityBuffer.Append(true); + } + } + + return new LargeBinaryArray( + LargeBinaryType.Default, byteArrays.Count, + offsetBuffer.Build(), valueBuffer.Build(), validityBuffer.Build(), + validityBuffer.UnsetBitCount); + } +} diff --git a/csharp/test/Apache.Arrow.Tests/LargeListArrayTests.cs b/csharp/test/Apache.Arrow.Tests/LargeListArrayTests.cs new file mode 100644 index 0000000000000..1d35a8ffd62c5 --- /dev/null +++ b/csharp/test/Apache.Arrow.Tests/LargeListArrayTests.cs @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Linq; +using Apache.Arrow.Types; +using Xunit; + +namespace Apache.Arrow.Tests; + +public class LargeListArrayTests +{ + [Fact] + public void GetSlicedValuesReturnsCorrectValues() + { + var values = new int?[][] + { + new int?[] {0, 1, 2}, + System.Array.Empty(), + null, + new int?[] {3, 4, null, 6}, + }; + + var array = BuildArray(values); + + Assert.Equal(values.Length, array.Length); + for (int i = 0; i < values.Length; ++i) + { + Assert.Equal(values[i] == null, array.IsNull(i)); + var arrayItem = (Int32Array) array.GetSlicedValues(i); + if (values[i] == null) + { + Assert.Null(arrayItem); + } + else + { + Assert.Equal(values[i], arrayItem.ToArray()); + } + } + } + + [Fact] + public void GetSlicedValuesChecksForOffsetOverflow() + { + var valuesArray = new Int32Array.Builder().Build(); + var offsetBuffer = new ArrowBuffer.Builder(); + var validityBuffer = new ArrowBuffer.BitmapBuilder(); + + offsetBuffer.Append(0); + offsetBuffer.Append((long)int.MaxValue + 1); + validityBuffer.Append(true); + + var array = new LargeListArray( + new LargeListType(new Int32Type()), length: 1, + offsetBuffer.Build(), valuesArray, validityBuffer.Build(), + validityBuffer.UnsetBitCount); + + Assert.Throws(() => array.GetSlicedValues(0)); + } + + private static LargeListArray BuildArray(int?[][] values) + { + var valuesBuilder = new Int32Array.Builder(); + var offsetBuffer = new ArrowBuffer.Builder(); + var validityBuffer = new ArrowBuffer.BitmapBuilder(); + + long offset = 0; + offsetBuffer.Append(offset); + foreach (var listValue in values) + { + if (listValue == null) + { + validityBuffer.Append(false); + offsetBuffer.Append(offset); + } + else + { + foreach (var value in listValue) + { + valuesBuilder.Append(value); + } + offset += listValue.Length; + offsetBuffer.Append(offset); + validityBuffer.Append(true); + } + } + + return new LargeListArray( + new LargeListType(new Int32Type()), values.Length, + offsetBuffer.Build(), valuesBuilder.Build(), validityBuffer.Build(), + validityBuffer.UnsetBitCount); + } +} diff --git a/csharp/test/Apache.Arrow.Tests/LargeStringArrayTests.cs b/csharp/test/Apache.Arrow.Tests/LargeStringArrayTests.cs new file mode 100644 index 0000000000000..aba97ba338c75 --- /dev/null +++ b/csharp/test/Apache.Arrow.Tests/LargeStringArrayTests.cs @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using Xunit; + +namespace Apache.Arrow.Tests; + +public class LargeStringArrayTests +{ + [Fact] + public void GetStringReturnsCorrectValue() + { + var strings = new string[] + { + "abc", + "defg", + "", + null, + "123", + }; + var array = BuildArray(strings); + + Assert.Equal(array.Length, strings.Length); + for (var i = 0; i < strings.Length; ++i) + { + Assert.Equal(strings[i], array.GetString(i)); + } + } + + [Fact] + public void GetStringChecksForOffsetOverflow() + { + var valueBuffer = new ArrowBuffer.Builder(); + var offsetBuffer = new ArrowBuffer.Builder(); + var validityBuffer = new ArrowBuffer.BitmapBuilder(); + + offsetBuffer.Append(0); + offsetBuffer.Append((long)int.MaxValue + 1); + validityBuffer.Append(true); + + var array = new LargeStringArray( + length: 1, offsetBuffer.Build(), valueBuffer.Build(), validityBuffer.Build(), + validityBuffer.UnsetBitCount); + + Assert.Throws(() => array.GetString(0)); + } + + private static LargeStringArray BuildArray(IReadOnlyCollection strings) + { + var valueBuffer = new ArrowBuffer.Builder(); + var offsetBuffer = new ArrowBuffer.Builder(); + var validityBuffer = new ArrowBuffer.BitmapBuilder(); + + long offset = 0; + offsetBuffer.Append(offset); + foreach (var value in strings) + { + if (value == null) + { + validityBuffer.Append(false); + offsetBuffer.Append(offset); + } + else + { + var bytes = LargeStringArray.DefaultEncoding.GetBytes(value); + valueBuffer.Append(bytes); + offset += value.Length; + offsetBuffer.Append(offset); + validityBuffer.Append(true); + } + } + + return new LargeStringArray( + strings.Count, offsetBuffer.Build(), valueBuffer.Build(), validityBuffer.Build(), + validityBuffer.UnsetBitCount); + } +} diff --git a/csharp/test/Apache.Arrow.Tests/TableTests.cs b/csharp/test/Apache.Arrow.Tests/TableTests.cs index 83c88265d172b..35fbe7cba68f1 100644 --- a/csharp/test/Apache.Arrow.Tests/TableTests.cs +++ b/csharp/test/Apache.Arrow.Tests/TableTests.cs @@ -63,9 +63,9 @@ public void TestTableFromRecordBatches() Table table1 = Table.TableFromRecordBatches(recordBatch1.Schema, recordBatches); Assert.Equal(20, table1.RowCount); #if NET5_0_OR_GREATER - Assert.Equal(35, table1.ColumnCount); + Assert.Equal(38, table1.ColumnCount); #else - Assert.Equal(34, table1.ColumnCount); + Assert.Equal(37, table1.ColumnCount); #endif Assert.Equal("ChunkedArray: Length=20, DataType=list", table1.Column(0).Data.ToString()); diff --git a/csharp/test/Apache.Arrow.Tests/TestData.cs b/csharp/test/Apache.Arrow.Tests/TestData.cs index 3ea42ee0fbcb7..36969766aeae0 100644 --- a/csharp/test/Apache.Arrow.Tests/TestData.cs +++ b/csharp/test/Apache.Arrow.Tests/TestData.cs @@ -49,6 +49,7 @@ void AddField(Field field) { AddField(CreateField(new ListType(Int64Type.Default), i)); AddField(CreateField(new ListViewType(Int64Type.Default), i)); + AddField(CreateField(new LargeListType(Int64Type.Default), i)); AddField(CreateField(BooleanType.Default, i)); AddField(CreateField(UInt8Type.Default, i)); AddField(CreateField(Int8Type.Default, i)); @@ -84,6 +85,8 @@ void AddField(Field field) AddField(CreateField(new UnionType(new[] { CreateField(StringType.Default, i), CreateField(Int32Type.Default, i) }, new[] { 0, 1 }, UnionMode.Sparse), i)); AddField(CreateField(new UnionType(new[] { CreateField(StringType.Default, i), CreateField(Int32Type.Default, i) }, new[] { 0, 1 }, UnionMode.Dense), -i)); AddField(CreateField(new DictionaryType(Int32Type.Default, StringType.Default, false), i)); + AddField(CreateField(new LargeBinaryType(), i)); + AddField(CreateField(new LargeStringType(), i)); } Schema schema = builder.Build(); @@ -144,8 +147,10 @@ private class ArrayCreator : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, + IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, + IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, @@ -154,6 +159,7 @@ private class ArrayCreator : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, + IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, @@ -335,6 +341,45 @@ public void Visit(StringViewType type) Array = builder.Build(); } + public void Visit(LargeStringType type) + { + var str = "hello"; + var valueBuffer = new ArrowBuffer.Builder(); + var offsetBuffer = new ArrowBuffer.Builder(); + var validityBuffer = new ArrowBuffer.BitmapBuilder(); + + long offset = 0; + offsetBuffer.Append(offset); + + for (var i = 0; i < Length; i++) + { + switch (i % 3) + { + case 0: + offsetBuffer.Append(offset); + validityBuffer.Append(false); + break; + case 1: + valueBuffer.Append(LargeStringArray.DefaultEncoding.GetBytes(str)); + offset += str.Length; + offsetBuffer.Append(offset); + validityBuffer.Append(true); + break; + case 2: + valueBuffer.Append(LargeStringArray.DefaultEncoding.GetBytes(str + str)); + offset += str.Length * 2; + offsetBuffer.Append(offset); + validityBuffer.Append(true); + break; + } + } + + var validity = validityBuffer.UnsetBitCount > 0 ? validityBuffer.Build() : ArrowBuffer.Empty; + Array = new LargeStringArray( + Length, offsetBuffer.Build(), valueBuffer.Build(), validity, + validityBuffer.UnsetBitCount); + } + public void Visit(ListType type) { var builder = new ListArray.Builder(type.ValueField).Reserve(Length); @@ -379,6 +424,37 @@ public void Visit(ListViewType type) Array = builder.Build(); } + public void Visit(LargeListType type) + { + var valueBuilder = new Int64Array.Builder().Reserve(Length * 3 / 2); + var offsetBuffer = new ArrowBuffer.Builder(); + var validityBuffer = new ArrowBuffer.BitmapBuilder(); + + offsetBuffer.Append(0); + + for (var i = 0; i < Length; i++) + { + if (i % 10 == 2) + { + offsetBuffer.Append(valueBuilder.Length); + validityBuffer.Append(false); + } + else + { + var listLength = i % 4; + valueBuilder.AppendRange(Enumerable.Range(i, listLength).Select(x => (long)x)); + offsetBuffer.Append(valueBuilder.Length); + validityBuffer.Append(true); + } + } + + var validity = validityBuffer.UnsetBitCount > 0 ? validityBuffer.Build() : ArrowBuffer.Empty; + Array = new LargeListArray( + new LargeListType(new Int64Type()), Length, + offsetBuffer.Build(), valueBuilder.Build(), validity, + validityBuffer.UnsetBitCount); + } + public void Visit(FixedSizeListType type) { var builder = new FixedSizeListArray.Builder(type.ValueField, type.ListSize).Reserve(Length); @@ -554,6 +630,48 @@ public void Visit(BinaryViewType type) Array = builder.Build(); } + public void Visit(LargeBinaryType type) + { + ReadOnlySpan shortData = new[] { (byte)0, (byte)1, (byte)2, (byte)3, (byte)4, (byte)5, (byte)6, (byte)7, (byte)8, (byte)9 }; + ReadOnlySpan longData = new[] + { + (byte)0, (byte)1, (byte)2, (byte)3, (byte)4, (byte)5, (byte)6, (byte)7, (byte)8, (byte)9, + (byte)10, (byte)11, (byte)12, (byte)13, (byte)14, (byte)15, (byte)16, (byte)17, (byte)18, (byte)19 + }; + var valueBuffer = new ArrowBuffer.Builder(); + var offsetBuffer = new ArrowBuffer.Builder(); + var validityBuffer = new ArrowBuffer.BitmapBuilder(); + + offsetBuffer.Append(0L); + + for (var i = 0; i < Length; i++) + { + switch (i % 3) + { + case 0: + offsetBuffer.Append(valueBuffer.Length); + validityBuffer.Append(false); + break; + case 1: + valueBuffer.Append(shortData); + offsetBuffer.Append(valueBuffer.Length); + validityBuffer.Append(true); + break; + case 2: + valueBuffer.Append(longData); + offsetBuffer.Append(valueBuffer.Length); + validityBuffer.Append(true); + break; + } + } + + var validity = validityBuffer.UnsetBitCount > 0 ? validityBuffer.Build() : ArrowBuffer.Empty; + Array = new LargeBinaryArray( + LargeBinaryType.Default, Length, + offsetBuffer.Build(), valueBuffer.Build(), validity, + validityBuffer.UnsetBitCount); + } + public void Visit(FixedSizeBinaryType type) { ArrowBuffer.Builder valueBuilder = new ArrowBuffer.Builder(); diff --git a/dev/archery/archery/integration/datagen.py b/dev/archery/archery/integration/datagen.py index b51f3d876f820..47310c905a9ff 100644 --- a/dev/archery/archery/integration/datagen.py +++ b/dev/archery/archery/integration/datagen.py @@ -1872,8 +1872,7 @@ def _temp_path(): generate_primitive_case([17, 20], name='primitive'), generate_primitive_case([0, 0, 0], name='primitive_zerolength'), - generate_primitive_large_offsets_case([17, 20]) - .skip_tester('C#'), + generate_primitive_large_offsets_case([17, 20]), generate_null_case([10, 0]), @@ -1906,7 +1905,6 @@ def _temp_path(): generate_recursive_nested_case(), generate_nested_large_offsets_case() - .skip_tester('C#') .skip_tester('JS'), generate_unions_case(), diff --git a/dev/archery/archery/integration/tester_java.py b/dev/archery/archery/integration/tester_java.py index 9b14c6939cde8..8d207d3393730 100644 --- a/dev/archery/archery/integration/tester_java.py +++ b/dev/archery/archery/integration/tester_java.py @@ -46,6 +46,7 @@ def load_version_from_pom(): _JAVA_OPTS = [ "-Dio.netty.tryReflectionSetAccessible=true", "-Darrow.struct.conflict.policy=CONFLICT_APPEND", + "--add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED", # GH-39113: avoid failures accessing files in `/tmp/hsperfdata_...` "-XX:-UsePerfData", ] @@ -88,24 +89,13 @@ def setup_jpype(): import jpype jar_path = f"{_ARROW_TOOLS_JAR}:{_ARROW_C_DATA_JAR}" # XXX Didn't manage to tone down the logging level here (DEBUG -> INFO) - java_opts = _JAVA_OPTS[:] - proc = subprocess.run( - ['java', '--add-opens'], - stderr=subprocess.PIPE, - stdout=subprocess.PIPE, - text=True) - if 'Unrecognized option: --add-opens' not in proc.stderr: - # Java 9+ - java_opts.append( - '--add-opens=java.base/java.nio=' - 'org.apache.arrow.memory.core,ALL-UNNAMED') jpype.startJVM(jpype.getDefaultJVMPath(), "-Djava.class.path=" + jar_path, # This flag is too heavy for IPC and Flight tests "-Darrow.memory.debug.allocator=true", # Reduce internal use of signals by the JVM "-Xrs", - *java_opts) + *_JAVA_OPTS) class _CDataBase: @@ -253,20 +243,9 @@ class JavaTester(Tester): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # Detect whether we're on Java 8 or Java 9+ self._java_opts = _JAVA_OPTS[:] - proc = subprocess.run( - ['java', '--add-opens'], - stderr=subprocess.PIPE, - stdout=subprocess.PIPE, - text=True) - if 'Unrecognized option: --add-opens' not in proc.stderr: - # Java 9+ - self._java_opts.append( - '--add-opens=java.base/java.nio=' - 'org.apache.arrow.memory.core,ALL-UNNAMED') - self._java_opts.append( - '--add-reads=org.apache.arrow.flight.core=ALL-UNNAMED') + self._java_opts.append( + '--add-reads=org.apache.arrow.flight.core=ALL-UNNAMED') def _run(self, arrow_path=None, json_path=None, command='VALIDATE'): cmd = ( diff --git a/dev/archery/archery/lang/java.py b/dev/archery/archery/lang/java.py index bc169adf647bc..f447b352e6a6c 100644 --- a/dev/archery/archery/lang/java.py +++ b/dev/archery/archery/lang/java.py @@ -34,8 +34,11 @@ def __init__(self, jar, *args, **kwargs): class JavaConfiguration: - def __init__(self, + REQUIRED_JAVA_OPTIONS = [ + "--add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED", + ] + def __init__(self, # toolchain java_home=None, java_options=None, # build & benchmark @@ -43,6 +46,13 @@ def __init__(self, self.java_home = java_home self.java_options = java_options + if self.java_options is None: + self.java_options = " ".join(self.REQUIRED_JAVA_OPTIONS) + else: + for option in self.REQUIRED_JAVA_OPTIONS: + if option not in self.java_options: + self.java_options += " " + option + self.build_extras = list(build_extras) if build_extras else [] self.benchmark_extras = list( benchmark_extras) if benchmark_extras else [] @@ -63,7 +73,7 @@ def environment(self): env["JAVA_HOME"] = self.java_home if self.java_options: - env["JAVA_OPTIONS"] = self.java_options + env["JDK_JAVA_OPTIONS"] = self.java_options return env diff --git a/dev/conbench_envs/README.md b/dev/conbench_envs/README.md index 509dc5c0c9537..7fab503974805 100644 --- a/dev/conbench_envs/README.md +++ b/dev/conbench_envs/README.md @@ -99,16 +99,16 @@ Here are steps how `@ursabot` benchmark builds use `benchmarks.env` and `hooks.s ### 2. Install Arrow dependencies for Java sudo su - apt-get install openjdk-8-jdk + apt-get install openjdk-11-jdk apt-get install maven Verify that you have at least these versions of `java`, `javac` and `maven`: # java -version - openjdk version "1.8.0_292" + openjdk version "11.0.22" 2024-01-16 .. # javac -version - javac 1.8.0_292 + javac 11.0.22 ... # mvn -version Apache Maven 3.6.3 diff --git a/dev/release/post-12-bump-versions-test.rb b/dev/release/post-12-bump-versions-test.rb index 2bd14587461cc..f31e1a3122814 100644 --- a/dev/release/post-12-bump-versions-test.rb +++ b/dev/release/post-12-bump-versions-test.rb @@ -358,8 +358,15 @@ def test_version_post_tag def test_deb_package_names omit_on_release_branch unless bump_type.nil? current_commit = git_current_commit - stdout = bump_versions("DEB_PACKAGE_NAMES") - changes = parse_patch(git("log", "-p", "#{current_commit}..")) + stdout = bump_versions("VERSION_POST_TAG", "DEB_PACKAGE_NAMES") + log = git("log", "-p", "#{current_commit}..") + # Remove a commit for VERSION_POST_TAG + if log.scan(/^commit/).size == 1 + log = "" + else + log.gsub!(/\A(commit.*?)^commit .*\z/um, "\\1") + end + changes = parse_patch(log) sampled_changes = changes.collect do |change| first_hunk = change[:hunks][0] first_removed_line = first_hunk.find { |line| line.start_with?("-") } diff --git a/dev/release/post-12-bump-versions.sh b/dev/release/post-12-bump-versions.sh index 422821a66bde5..bf40f4ce5c4ea 100755 --- a/dev/release/post-12-bump-versions.sh +++ b/dev/release/post-12-bump-versions.sh @@ -40,6 +40,7 @@ fi version=$1 next_version=$2 next_version_snapshot="${next_version}-SNAPSHOT" +current_version_before_bump="$(current_version)" case "${version}" in *.0.0) @@ -64,7 +65,7 @@ if [ ${BUMP_VERSION_POST_TAG} -gt 0 ]; then fi if [ ${BUMP_DEB_PACKAGE_NAMES} -gt 0 ] && \ - [ "${next_version}" != "$(current_version)" ]; then + [ "${next_version}" != "${current_version_before_bump}" ]; then update_deb_package_names "${version}" "${next_version}" fi diff --git a/dev/release/setup-rhel-rebuilds.sh b/dev/release/setup-rhel-rebuilds.sh index dc190d2d2426e..e8861a19f35b7 100755 --- a/dev/release/setup-rhel-rebuilds.sh +++ b/dev/release/setup-rhel-rebuilds.sh @@ -35,7 +35,7 @@ dnf -y install \ cmake \ git \ gobject-introspection-devel \ - java-1.8.0-openjdk-devel \ + java-11-openjdk-devel \ libcurl-devel \ llvm-devel \ llvm-toolset \ @@ -55,3 +55,5 @@ npm install -g yarn python3 -m ensurepip --upgrade alternatives --set python /usr/bin/python3 +alternatives --set java java-11-openjdk.$(uname -i) +alternatives --set javac java-11-openjdk.$(uname -i) diff --git a/dev/release/utils-prepare.sh b/dev/release/utils-prepare.sh index 760a7f404a74d..6ba8b22a06e89 100644 --- a/dev/release/utils-prepare.sh +++ b/dev/release/utils-prepare.sh @@ -88,7 +88,6 @@ update_versions() { # versions-maven-plugin:set-scm-tag does not update the whole reactor. Invoking separately mvn versions:set-scm-tag -DnewTag=apache-arrow-${version} -DgenerateBackupPoms=false -pl :arrow-java-root mvn versions:set-scm-tag -DnewTag=apache-arrow-${version} -DgenerateBackupPoms=false -pl :arrow-bom - mvn versions:set-scm-tag -DnewTag=apache-arrow-${version} -DgenerateBackupPoms=false -pl :arrow-maven-plugins fi git add "pom.xml" git add "**/pom.xml" diff --git a/dev/release/verify-release-candidate.sh b/dev/release/verify-release-candidate.sh index 2f4b203f217af..6a36109dc2fc1 100755 --- a/dev/release/verify-release-candidate.sh +++ b/dev/release/verify-release-candidate.sh @@ -21,7 +21,7 @@ # Requirements # - Ruby >= 2.3 # - Maven >= 3.8.7 -# - JDK >=8 +# - JDK >= 11 # - gcc >= 4.8 # - Node.js >= 18 # - Go >= 1.21 diff --git a/dev/tasks/java-jars/github.yml b/dev/tasks/java-jars/github.yml index 9493be05be6ee..ba988f893148f 100644 --- a/dev/tasks/java-jars/github.yml +++ b/dev/tasks/java-jars/github.yml @@ -83,7 +83,7 @@ jobs: - { runs_on: ["macos-13"], arch: "x86_64"} - { runs_on: ["macos-14"], arch: "aarch_64" } env: - MACOSX_DEPLOYMENT_TARGET: "10.15" + MACOSX_DEPLOYMENT_TARGET: "14.0" steps: {{ macros.github_checkout_arrow()|indent }} - name: Set up Python @@ -140,6 +140,12 @@ jobs: brew uninstall protobuf brew bundle --file=arrow/java/Brewfile + + # We want to use the bundled googletest for static linking. Since + # both BUNDLED and brew options are enabled, it could cause a conflict + # when there is a version mismatch. + # We uninstall googletest to ensure using the bundled googletest. + brew uninstall googletest - name: Build C++ libraries env: {{ macros.github_set_sccache_envvars()|indent(8) }} diff --git a/dev/tasks/tasks.yml b/dev/tasks/tasks.yml index 2eb361047fc62..5c8a7c4990d7a 100644 --- a/dev/tasks/tasks.yml +++ b/dev/tasks/tasks.yml @@ -98,6 +98,8 @@ groups: vcpkg: - test-*vcpkg* + - wheel-* + - java-jars integration: - test-*dask* @@ -745,9 +747,6 @@ tasks: - arrow-jdbc-{no_rc_snapshot_version}-tests.jar - arrow-jdbc-{no_rc_snapshot_version}.jar - arrow-jdbc-{no_rc_snapshot_version}.pom - - arrow-maven-plugins-{no_rc_snapshot_version}-cyclonedx.json - - arrow-maven-plugins-{no_rc_snapshot_version}-cyclonedx.xml - - arrow-maven-plugins-{no_rc_snapshot_version}.pom - arrow-memory-core-{no_rc_snapshot_version}-cyclonedx.json - arrow-memory-core-{no_rc_snapshot_version}-cyclonedx.xml - arrow-memory-core-{no_rc_snapshot_version}-javadoc.jar @@ -843,12 +842,6 @@ tasks: - flight-sql-jdbc-driver-{no_rc_snapshot_version}-tests.jar - flight-sql-jdbc-driver-{no_rc_snapshot_version}.jar - flight-sql-jdbc-driver-{no_rc_snapshot_version}.pom - - module-info-compiler-maven-plugin-{no_rc_snapshot_version}-cyclonedx.json - - module-info-compiler-maven-plugin-{no_rc_snapshot_version}-cyclonedx.xml - - module-info-compiler-maven-plugin-{no_rc_snapshot_version}-javadoc.jar - - module-info-compiler-maven-plugin-{no_rc_snapshot_version}-sources.jar - - module-info-compiler-maven-plugin-{no_rc_snapshot_version}.jar - - module-info-compiler-maven-plugin-{no_rc_snapshot_version}.pom ############################## NuGet packages ############################### @@ -1067,6 +1060,15 @@ tasks: UBUNTU: 20.04 image: ubuntu-cpp-bundled + test-ubuntu-24.04-cpp-gcc-13-bundled: + ci: github + template: docker-tests/github.linux.yml + params: + env: + UBUNTU: 24.04 + GCC_VERSION: 13 + image: ubuntu-cpp-bundled + test-ubuntu-24.04-cpp: ci: github template: docker-tests/github.linux.yml @@ -1549,9 +1551,7 @@ tasks: image: conda-python-hdfs {% endfor %} -{% for python_version, spark_version, test_pyarrow_only, numpy_version, jdk_version in [("3.8", "v3.5.0", "false", "latest", "8"), - ("3.10", "v3.5.0", "false", "1.23", "8"), - ("3.11", "master", "false", "latest", "17")] %} +{% for python_version, spark_version, test_pyarrow_only, numpy_version, jdk_version in [("3.11", "master", "false", "latest", "17")] %} test-conda-python-{{ python_version }}-spark-{{ spark_version }}: ci: github template: docker-tests/github.linux.yml diff --git a/docker-compose.yml b/docker-compose.yml index fa248d59037d3..cf22324f7cfb4 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1202,7 +1202,7 @@ services: build: args: base: ${REPO}:${ARCH}-python-${PYTHON}-wheel-manylinux-2014-vcpkg-${VCPKG} - java: 1.8.0 + java: 11 context: . dockerfile: ci/docker/java-jni-manylinux-201x.dockerfile cache_from: @@ -1747,7 +1747,7 @@ services: # docker-compose run java # Parameters: # MAVEN: 3.9.5 - # JDK: 8, 11, 17, 21 + # JDK: 11, 17, 21 image: ${ARCH}/maven:${MAVEN}-eclipse-temurin-${JDK} shm_size: *shm-size volumes: &java-volumes diff --git a/docs/source/developers/java/building.rst b/docs/source/developers/java/building.rst index 82053e901186c..3904841de9c5a 100644 --- a/docs/source/developers/java/building.rst +++ b/docs/source/developers/java/building.rst @@ -32,7 +32,7 @@ Arrow Java uses the `Maven `_ build system. Building requires: -* JDK 8+ +* JDK 11+ * Maven 3+ .. note:: @@ -321,6 +321,54 @@ Building Java JNI Modules -Darrow.c.jni.dist.dir=/java-dist/lib/ \ -Parrow-jni clean install +Testing +======= + +By default, Maven uses the same Java version to both build the code and run the tests. + +It is also possible to use a different JDK version for the tests. This requires Maven +toolchains to be configured beforehand, and then a specific test property needs to be set. + +Configuring Maven toolchains +---------------------------- + +To be able to use a JDK version for testing, it needs to be registered first in Maven ``toolchains.xml`` +configuration file usually located under ``${HOME}/.m2`` with the following snippet added to it: + + .. code-block:: + + + + + [...] + + + jdk + + 21 + temurin + + + path/to/jdk/home + + + + [...] + + + +Testing with a specific JDK +--------------------------- + +To run Arrow tests with a specific JDK version, use the ``arrow.test.jdk-version`` property. + +For example, to run Arrow tests with JDK 17, use the following snippet: + + .. code-block:: + + $ cd arrow/java + $ mvn -Darrow.test.jdk-version=17 clean verify + IDE Configuration ================= @@ -335,7 +383,6 @@ Arrow repository, and update the following settings: right click the directory, and select Mark Directory as > Generated Sources Root. There is no need to mark other generated sources directories, as only the ``vector`` module generates sources. -* For JDK 8, disable the ``error-prone`` profile to build the project successfully. * For JDK 11, due to an `IntelliJ bug `__, you must go into Settings > Build, Execution, Deployment > Compiler > Java Compiler and disable @@ -538,3 +585,40 @@ Installing Manually .. _builds@arrow.apache.org: https://lists.apache.org/list.html?builds@arrow.apache.org .. _GitHub Nightly: https://github.com/ursacomputing/crossbow/releases/tag/nightly-packaging-2022-07-30-0-github-java-jars + +Installing Staging Packages +=========================== + +.. warning:: + These packages are not official releases. Use them at your own risk. + +Arrow staging builds are created when a Release Candidate (RC) is being prepared. This allows users to test the RC in their applications before voting on the release. + + +Installing from Apache Staging +-------------------------------- +1. Look up the next version number for the Arrow libraries used. + +2. Add Apache Staging Repository to the Maven/Gradle project. + + .. code-block:: xml + + + 9.0.0 + + ... + + + arrow-apache-staging + https://repository.apache.org/content/repositories/staging + + + ... + + + org.apache.arrow + arrow-vector + ${arrow.version} + + + ... diff --git a/docs/source/format/CDataInterface/PyCapsuleInterface.rst b/docs/source/format/CDataInterface/PyCapsuleInterface.rst index d38ba2822da46..f4f6b54849e77 100644 --- a/docs/source/format/CDataInterface/PyCapsuleInterface.rst +++ b/docs/source/format/CDataInterface/PyCapsuleInterface.rst @@ -303,7 +303,6 @@ function accepts an object implementing one of these protocols. .. code-block:: python from typing import Tuple, Protocol - from typing_extensions import Self class ArrowSchemaExportable(Protocol): def __arrow_c_schema__(self) -> object: ... diff --git a/docs/source/format/Columnar.rst b/docs/source/format/Columnar.rst index 7ae0c2b4bdbd8..c5f822f41643f 100644 --- a/docs/source/format/Columnar.rst +++ b/docs/source/format/Columnar.rst @@ -1656,8 +1656,8 @@ the Arrow spec. .. _Message.fbs: https://github.com/apache/arrow/blob/main/format/Message.fbs .. _File.fbs: https://github.com/apache/arrow/blob/main/format/File.fbs .. _least-significant bit (LSB) numbering: https://en.wikipedia.org/wiki/Bit_numbering -.. _Intel performance guide: https://software.intel.com/en-us/articles/practical-intel-avx-optimization-on-2nd-generation-intel-core-processors +.. _Intel performance guide: https://web.archive.org/web/20151101074635/https://software.intel.com/en-us/articles/practical-intel-avx-optimization-on-2nd-generation-intel-core-processors .. _Endianness: https://en.wikipedia.org/wiki/Endianness -.. _SIMD: https://software.intel.com/en-us/cpp-compiler-developer-guide-and-reference-introduction-to-the-simd-data-layout-templates +.. _SIMD: https://www.intel.com/content/www/us/en/docs/cpp-compiler/developer-guide-reference/2021-8/simd-data-layout-templates.html .. _Parquet: https://parquet.apache.org/docs/ .. _UmbraDB: https://db.in.tum.de/~freitag/papers/p29-neumann-cidr20.pdf diff --git a/docs/source/java/flight_sql_jdbc_driver.rst b/docs/source/java/flight_sql_jdbc_driver.rst index f95c2ac755d97..0224cc3235652 100644 --- a/docs/source/java/flight_sql_jdbc_driver.rst +++ b/docs/source/java/flight_sql_jdbc_driver.rst @@ -28,7 +28,7 @@ This driver can be used with any database that implements Flight SQL. Installation and Requirements ============================= -The driver is compatible with JDK 8+. On JDK 9+, the following JVM +The driver is compatible with JDK 11+. Note that the following JVM parameter is required: .. code-block:: shell diff --git a/docs/source/java/install.rst b/docs/source/java/install.rst index dc6a55c87fcd6..c238690c6b930 100644 --- a/docs/source/java/install.rst +++ b/docs/source/java/install.rst @@ -29,10 +29,10 @@ Java modules are regularly built and tested on macOS and Linux distributions. Java Compatibility ================== -Java modules are compatible with JDK 8 and above. Currently, JDK versions -8, 11, 17, and 21 are tested in CI. The latest JDK is also tested in CI. +Java modules are compatible with JDK 11 and above. Currently, JDK versions +11, 17, 21, and latest are tested in CI. -When using Java 9 or later, some JDK internals must be exposed by +Note that some JDK internals must be exposed by adding ``--add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED`` to the ``java`` command: .. code-block:: shell @@ -40,7 +40,7 @@ adding ``--add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED # Directly on the command line $ java --add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED -jar ... # Indirectly via environment variables - $ env _JAVA_OPTIONS="--add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED" java -jar ... + $ env JDK_JAVA_OPTIONS="--add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED" java -jar ... Otherwise, you may see errors like ``module java.base does not "opens java.nio" to unnamed module`` or ``module java.base does not "opens @@ -58,7 +58,7 @@ Modifying the command above for Flight: # Directly on the command line $ java --add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED -jar ... # Indirectly via environment variables - $ env _JAVA_OPTIONS="--add-reads=org.apache.arrow.flight.core=ALL-UNNAMED --add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED" java -jar ... + $ env JDK_JAVA_OPTIONS="--add-reads=org.apache.arrow.flight.core=ALL-UNNAMED --add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED" java -jar ... Otherwise, you may see errors like ``java.lang.IllegalAccessError: superclass access check failed: class org.apache.arrow.flight.ArrowMessage$ArrowBufRetainingCompositeByteBuf (in module org.apache.arrow.flight.core) @@ -67,12 +67,13 @@ org.apache.arrow.flight.core does not read unnamed module ...`` Finally, if you are using arrow-dataset, you'll also need to report that JDK internals need to be exposed. Modifying the command above for arrow-memory: + .. code-block:: shell # Directly on the command line $ java --add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED -jar ... # Indirectly via environment variables - $ env _JAVA_OPTIONS="--add-opens=java.base/java.nio=org.apache.arrow.dataset,org.apache.arrow.memory.core,ALL-UNNAMED" java -jar ... + $ env JDK_JAVA_OPTIONS="--add-opens=java.base/java.nio=org.apache.arrow.dataset,org.apache.arrow.memory.core,ALL-UNNAMED" java -jar ... Otherwise you may see errors such as ``java.lang.RuntimeException: java.lang.reflect.InaccessibleObjectException: Unable to make static void java.nio.Bits.reserveMemory(long,long) accessible: module @@ -215,7 +216,7 @@ Or they can be added via environment variable, for example when executing your c .. code-block:: - _JAVA_OPTIONS="--add-opens=java.base/java.nio=ALL-UNNAMED" mvn exec:java -Dexec.mainClass="YourMainCode" + JDK_JAVA_OPTIONS="--add-opens=java.base/java.nio=ALL-UNNAMED" mvn exec:java -Dexec.mainClass="YourMainCode" Installing from Source ====================== diff --git a/docs/source/status.rst b/docs/source/status.rst index 266381175608a..c232aa280befb 100644 --- a/docs/source/status.rst +++ b/docs/source/status.rst @@ -62,11 +62,11 @@ Data Types +-------------------+-------+-------+-------+----+-------+-------+-------+-------+-----------+ | Binary | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | +-------------------+-------+-------+-------+----+-------+-------+-------+-------+-----------+ -| Large Binary | ✓ | ✓ | ✓ | ✓ | | ✓ | ✓ | | ✓ | +| Large Binary | ✓ | ✓ | ✓ | ✓ | \(4) | ✓ | ✓ | | ✓ | +-------------------+-------+-------+-------+----+-------+-------+-------+-------+-----------+ | Utf8 | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | +-------------------+-------+-------+-------+----+-------+-------+-------+-------+-----------+ -| Large Utf8 | ✓ | ✓ | ✓ | ✓ | | ✓ | ✓ | | ✓ | +| Large Utf8 | ✓ | ✓ | ✓ | ✓ | \(4) | ✓ | ✓ | | ✓ | +-------------------+-------+-------+-------+----+-------+-------+-------+-------+-----------+ | Binary View | ✓ | | ✓ | | ✓ | | | | | +-------------------+-------+-------+-------+----+-------+-------+-------+-------+-----------+ @@ -85,7 +85,7 @@ Data Types +-------------------+-------+-------+-------+----+-------+-------+-------+-------+-----------+ | List | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | ✓ | +-------------------+-------+-------+-------+----+-------+-------+-------+-------+-----------+ -| Large List | ✓ | ✓ | ✓ | | | ✓ | ✓ | | ✓ | +| Large List | ✓ | ✓ | ✓ | | \(4) | ✓ | ✓ | | ✓ | +-------------------+-------+-------+-------+----+-------+-------+-------+-------+-----------+ | List View | ✓ | | ✓ | | ✓ | | | | | +-------------------+-------+-------+-------+----+-------+-------+-------+-------+-----------+ @@ -125,6 +125,8 @@ Notes: * \(1) Casting to/from Float16 in Java is not supported. * \(2) Float16 support in C# is only available when targeting .NET 6+. * \(3) Nested dictionaries not supported +* \(4) C# large array types are provided to help with interoperability with other libraries, + but these do not support buffers larger than 2 GiB and an exception will be raised if trying to import an array that is too large. .. seealso:: The :ref:`format_columnar` and the diff --git a/format/Flight.proto b/format/Flight.proto index 4963e8c09ae47..2187a51ed48f4 100644 --- a/format/Flight.proto +++ b/format/Flight.proto @@ -208,24 +208,6 @@ message Action { bytes body = 2; } -/* - * The request of the CancelFlightInfo action. - * - * The request should be stored in Action.body. - */ -message CancelFlightInfoRequest { - FlightInfo info = 1; -} - -/* - * The request of the RenewFlightEndpoint action. - * - * The request should be stored in Action.body. - */ -message RenewFlightEndpointRequest { - FlightEndpoint endpoint = 1; -} - /* * An opaque result returned after executing an action. */ @@ -233,36 +215,6 @@ message Result { bytes body = 1; } -/* - * The result of a cancel operation. - * - * This is used by CancelFlightInfoResult.status. - */ -enum CancelStatus { - // The cancellation status is unknown. Servers should avoid using - // this value (send a NOT_FOUND error if the requested query is - // not known). Clients can retry the request. - CANCEL_STATUS_UNSPECIFIED = 0; - // The cancellation request is complete. Subsequent requests with - // the same payload may return CANCELLED or a NOT_FOUND error. - CANCEL_STATUS_CANCELLED = 1; - // The cancellation request is in progress. The client may retry - // the cancellation request. - CANCEL_STATUS_CANCELLING = 2; - // The query is not cancellable. The client should not retry the - // cancellation request. - CANCEL_STATUS_NOT_CANCELLABLE = 3; -} - -/* - * The result of the CancelFlightInfo action. - * - * The result should be stored in Result.body. - */ -message CancelFlightInfoResult { - CancelStatus status = 1; -} - /* * Wrap the result of a getSchema call */ @@ -423,6 +375,64 @@ message PollInfo { google.protobuf.Timestamp expiration_time = 4; } +/* + * The request of the CancelFlightInfo action. + * + * The request should be stored in Action.body. + */ +message CancelFlightInfoRequest { + FlightInfo info = 1; +} + +/* + * The result of a cancel operation. + * + * This is used by CancelFlightInfoResult.status. + */ +enum CancelStatus { + // The cancellation status is unknown. Servers should avoid using + // this value (send a NOT_FOUND error if the requested query is + // not known). Clients can retry the request. + CANCEL_STATUS_UNSPECIFIED = 0; + // The cancellation request is complete. Subsequent requests with + // the same payload may return CANCELLED or a NOT_FOUND error. + CANCEL_STATUS_CANCELLED = 1; + // The cancellation request is in progress. The client may retry + // the cancellation request. + CANCEL_STATUS_CANCELLING = 2; + // The query is not cancellable. The client should not retry the + // cancellation request. + CANCEL_STATUS_NOT_CANCELLABLE = 3; +} + +/* + * The result of the CancelFlightInfo action. + * + * The result should be stored in Result.body. + */ +message CancelFlightInfoResult { + CancelStatus status = 1; +} + +/* + * An opaque identifier that the service can use to retrieve a particular + * portion of a stream. + * + * Tickets are meant to be single use. It is an error/application-defined + * behavior to reuse a ticket. + */ +message Ticket { + bytes ticket = 1; +} + +/* + * A location where a Flight service will accept retrieval of a particular + * stream given a ticket. + */ +message Location { + string uri = 1; +} + /* * A particular stream or split associated with a flight. */ @@ -475,22 +485,12 @@ message FlightEndpoint { } /* - * A location where a Flight service will accept retrieval of a particular - * stream given a ticket. - */ -message Location { - string uri = 1; -} - -/* - * An opaque identifier that the service can use to retrieve a particular - * portion of a stream. + * The request of the RenewFlightEndpoint action. * - * Tickets are meant to be single use. It is an error/application-defined - * behavior to reuse a ticket. + * The request should be stored in Action.body. */ -message Ticket { - bytes ticket = 1; +message RenewFlightEndpointRequest { + FlightEndpoint endpoint = 1; } /* diff --git a/go/go.mod b/go/go.mod index 1c730cc87709b..43c2c41b69eca 100644 --- a/go/go.mod +++ b/go/go.mod @@ -47,9 +47,9 @@ require ( require ( github.com/google/uuid v1.6.0 - github.com/hamba/avro/v2 v2.22.1 + github.com/hamba/avro/v2 v2.23.0 github.com/huandu/xstrings v1.4.0 - github.com/substrait-io/substrait-go v0.4.2 + github.com/substrait-io/substrait-go v0.5.0 github.com/tidwall/sjson v1.2.5 ) diff --git a/go/go.sum b/go/go.sum index 6ce51c83350a0..a96f0a3797c74 100644 --- a/go/go.sum +++ b/go/go.sum @@ -43,8 +43,8 @@ github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbu github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/hamba/avro/v2 v2.22.1 h1:q1rAbfJsrbMaZPDLQvwUQMfQzp6H+hGXvckmU/lXemk= -github.com/hamba/avro/v2 v2.22.1/go.mod h1:HOeTrE3kvWnBAgsufqhAzDDV5gvS0QXs65Z6BHfGgbg= +github.com/hamba/avro/v2 v2.23.0 h1:DYWz6UqNCi21JflaZlcwNfW+rK+D/CwnrWWJtfmO4vw= +github.com/hamba/avro/v2 v2.23.0/go.mod h1:7vDfy/2+kYCE8WUHoj2et59GTv0ap7ptktMXu0QHePI= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= @@ -99,8 +99,8 @@ github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/substrait-io/substrait-go v0.4.2 h1:buDnjsb3qAqTaNbOR7VKmNgXf4lYQxWEcnSGUWBtmN8= -github.com/substrait-io/substrait-go v0.4.2/go.mod h1:qhpnLmrcvAnlZsUyPXZRqldiHapPTXC3t7xFgDi3aQg= +github.com/substrait-io/substrait-go v0.5.0 h1:8sYsoqcrzoNpThPyot1CQpwF6OokxvplLUQJTGlKws4= +github.com/substrait-io/substrait-go v0.5.0/go.mod h1:Co7ko6iIjdqCGcN3LfkKWPVlxONkNZem9omWAGIaOrQ= github.com/tidwall/gjson v1.14.2 h1:6BBkirS0rAHjumnjHF6qgy5d2YAJ1TLIaFE2lzfOLqo= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= diff --git a/go/parquet/file/file_reader_test.go b/go/parquet/file/file_reader_test.go index 7d20bbe1006f8..547ec475c2720 100644 --- a/go/parquet/file/file_reader_test.go +++ b/go/parquet/file/file_reader_test.go @@ -18,6 +18,7 @@ package file_test import ( "bytes" + "context" "crypto/rand" "encoding/binary" "fmt" @@ -26,6 +27,8 @@ import ( "path" "testing" + "github.com/apache/arrow/go/v18/arrow" + "github.com/apache/arrow/go/v18/arrow/array" "github.com/apache/arrow/go/v18/arrow/memory" "github.com/apache/arrow/go/v18/internal/utils" "github.com/apache/arrow/go/v18/parquet" @@ -35,6 +38,7 @@ import ( format "github.com/apache/arrow/go/v18/parquet/internal/gen-go/parquet" "github.com/apache/arrow/go/v18/parquet/internal/thrift" "github.com/apache/arrow/go/v18/parquet/metadata" + "github.com/apache/arrow/go/v18/parquet/pqarrow" "github.com/apache/arrow/go/v18/parquet/schema" libthrift "github.com/apache/thrift/lib/go/thrift" "github.com/stretchr/testify/assert" @@ -582,3 +586,61 @@ func TestByteStreamSplitEncodingFileRead(t *testing.T) { }) } } + +func TestDeltaBinaryPackedMultipleBatches(t *testing.T) { + size := 10 + batchSize := size / 2 // write 2 batches + + // Define the schema for the test data + fields := []arrow.Field{ + {Name: "int64", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, + } + schema := arrow.NewSchema(fields, nil) + + // Create a record batch with the test data + b := array.NewRecordBuilder(memory.DefaultAllocator, schema) + defer b.Release() + + for i := 0; i < size; i++ { + b.Field(0).(*array.Int64Builder).Append(int64(i)) + } + rec := b.NewRecord() + defer rec.Release() + + // Write the data to Parquet using the file writer + props := parquet.NewWriterProperties( + parquet.WithDictionaryDefault(false), + parquet.WithEncoding(parquet.Encodings.DeltaBinaryPacked)) + writerProps := pqarrow.DefaultWriterProps() + + var buf bytes.Buffer + pw, err := pqarrow.NewFileWriter(schema, &buf, props, writerProps) + require.NoError(t, err) + require.NoError(t, pw.Write(rec)) + require.NoError(t, pw.Close()) + + // Read the data back from the Parquet file + reader, err := file.NewParquetReader(bytes.NewReader(buf.Bytes())) + require.NoError(t, err) + defer reader.Close() + + pr, err := pqarrow.NewFileReader(reader, pqarrow.ArrowReadProperties{BatchSize: int64(batchSize)}, memory.DefaultAllocator) + require.NoError(t, err) + + rr, err := pr.GetRecordReader(context.Background(), nil, nil) + require.NoError(t, err) + + totalRows := 0 + for rr.Next() { + rec := rr.Record() + for i := 0; i < int(rec.NumRows()); i++ { + col := rec.Column(0).(*array.Int64) + + val := col.Value(i) + require.Equal(t, val, int64(totalRows+i)) + } + totalRows += int(rec.NumRows()) + } + + require.Equalf(t, size, totalRows, "Expected %d rows, but got %d rows", size, totalRows) +} diff --git a/go/parquet/internal/encoding/delta_bit_packing.go b/go/parquet/internal/encoding/delta_bit_packing.go index ca1ed14511f43..ac91953a7f903 100644 --- a/go/parquet/internal/encoding/delta_bit_packing.go +++ b/go/parquet/internal/encoding/delta_bit_packing.go @@ -19,9 +19,9 @@ package encoding import ( "bytes" "errors" + "fmt" "math" "math/bits" - "reflect" "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/memory" @@ -32,7 +32,7 @@ import ( // see the deltaBitPack encoder for a description of the encoding format that is // used for delta-bitpacking. -type deltaBitPackDecoder struct { +type deltaBitPackDecoder[T int32 | int64] struct { decoder mem memory.Allocator @@ -52,18 +52,20 @@ type deltaBitPackDecoder struct { totalValues uint64 lastVal int64 + + miniBlockValues []T } // returns the number of bytes read so far -func (d *deltaBitPackDecoder) bytesRead() int64 { +func (d *deltaBitPackDecoder[T]) bytesRead() int64 { return d.bitdecoder.CurOffset() } -func (d *deltaBitPackDecoder) Allocator() memory.Allocator { return d.mem } +func (d *deltaBitPackDecoder[T]) Allocator() memory.Allocator { return d.mem } // SetData sets the bytes and the expected number of values to decode // into the decoder, updating the decoder and allowing it to be reused. -func (d *deltaBitPackDecoder) SetData(nvalues int, data []byte) error { +func (d *deltaBitPackDecoder[T]) SetData(nvalues int, data []byte) error { // set our data into the underlying decoder for the type if err := d.decoder.SetData(nvalues, data); err != nil { return err @@ -103,7 +105,7 @@ func (d *deltaBitPackDecoder) SetData(nvalues int, data []byte) error { } // initialize a block to decode -func (d *deltaBitPackDecoder) initBlock() error { +func (d *deltaBitPackDecoder[T]) initBlock() error { // first we grab the min delta value that we'll start from var ok bool if d.minDelta, ok = d.bitdecoder.GetZigZagVlqInt(); !ok { @@ -126,16 +128,9 @@ func (d *deltaBitPackDecoder) initBlock() error { return nil } -// DeltaBitPackInt32Decoder decodes Int32 values which are packed using the Delta BitPacking algorithm. -type DeltaBitPackInt32Decoder struct { - *deltaBitPackDecoder - - miniBlockValues []int32 -} - -func (d *DeltaBitPackInt32Decoder) unpackNextMini() error { +func (d *deltaBitPackDecoder[T]) unpackNextMini() error { if d.miniBlockValues == nil { - d.miniBlockValues = make([]int32, 0, int(d.valsPerMini)) + d.miniBlockValues = make([]T, 0, int(d.valsPerMini)) } else { d.miniBlockValues = d.miniBlockValues[:0] } @@ -149,7 +144,7 @@ func (d *DeltaBitPackInt32Decoder) unpackNextMini() error { } d.lastVal += int64(delta) + int64(d.minDelta) - d.miniBlockValues = append(d.miniBlockValues, int32(d.lastVal)) + d.miniBlockValues = append(d.miniBlockValues, T(d.lastVal)) } d.miniBlockIdx++ return nil @@ -157,15 +152,15 @@ func (d *DeltaBitPackInt32Decoder) unpackNextMini() error { // Decode retrieves min(remaining values, len(out)) values from the data and returns the number // of values actually decoded and any errors encountered. -func (d *DeltaBitPackInt32Decoder) Decode(out []int32) (int, error) { - max := shared_utils.Min(len(out), int(d.totalValues)) +func (d *deltaBitPackDecoder[T]) Decode(out []T) (int, error) { + max := shared_utils.Min(len(out), int(d.nvals)) if max == 0 { return 0, nil } out = out[:max] if !d.usedFirst { // starting value to calculate deltas against - out[0] = int32(d.lastVal) + out[0] = T(d.lastVal) out = out[1:] d.usedFirst = true } @@ -198,7 +193,7 @@ func (d *DeltaBitPackInt32Decoder) Decode(out []int32) (int, error) { } // DecodeSpaced is like Decode, but the result is spaced out appropriately based on the passed in bitmap -func (d *DeltaBitPackInt32Decoder) DecodeSpaced(out []int32, nullCount int, validBits []byte, validBitsOffset int64) (int, error) { +func (d *deltaBitPackDecoder[T]) DecodeSpaced(out []T, nullCount int, validBits []byte, validBitsOffset int64) (int, error) { toread := len(out) - nullCount values, err := d.Decode(out[:toread]) if err != nil { @@ -211,101 +206,23 @@ func (d *DeltaBitPackInt32Decoder) DecodeSpaced(out []int32, nullCount int, vali return spacedExpand(out, nullCount, validBits, validBitsOffset), nil } -// Type returns the physical parquet type that this decoder decodes, in this case Int32 -func (DeltaBitPackInt32Decoder) Type() parquet.Type { - return parquet.Types.Int32 -} - -// DeltaBitPackInt64Decoder decodes a delta bit packed int64 column of data. -type DeltaBitPackInt64Decoder struct { - *deltaBitPackDecoder - - miniBlockValues []int64 -} - -func (d *DeltaBitPackInt64Decoder) unpackNextMini() error { - if d.miniBlockValues == nil { - d.miniBlockValues = make([]int64, 0, int(d.valsPerMini)) - } else { - d.miniBlockValues = d.miniBlockValues[:0] - } - - d.deltaBitWidth = d.deltaBitWidths.Bytes()[int(d.miniBlockIdx)] - d.currentMiniBlockVals = d.valsPerMini - - for j := 0; j < int(d.valsPerMini); j++ { - delta, ok := d.bitdecoder.GetValue(int(d.deltaBitWidth)) - if !ok { - return errors.New("parquet: eof exception") - } - - d.lastVal += int64(delta) + d.minDelta - d.miniBlockValues = append(d.miniBlockValues, d.lastVal) - } - d.miniBlockIdx++ - return nil -} - -// Decode retrieves min(remaining values, len(out)) values from the data and returns the number -// of values actually decoded and any errors encountered. -func (d *DeltaBitPackInt64Decoder) Decode(out []int64) (int, error) { - max := shared_utils.Min(len(out), d.nvals) - if max == 0 { - return 0, nil - } - - out = out[:max] - if !d.usedFirst { - out[0] = d.lastVal - out = out[1:] - d.usedFirst = true - } - - var err error - for len(out) > 0 { - if d.currentBlockVals == 0 { - err = d.initBlock() - if err != nil { - return 0, err - } - } - if d.currentMiniBlockVals == 0 { - err = d.unpackNextMini() - } - - if err != nil { - return 0, err - } - - start := int(d.valsPerMini - d.currentMiniBlockVals) - numCopied := copy(out, d.miniBlockValues[start:]) - - out = out[numCopied:] - d.currentBlockVals -= uint32(numCopied) - d.currentMiniBlockVals -= uint32(numCopied) +// Type returns the underlying physical type this decoder works with +func (dec *deltaBitPackDecoder[T]) Type() parquet.Type { + switch v := any(dec).(type) { + case *deltaBitPackDecoder[int32]: + return parquet.Types.Int32 + case *deltaBitPackDecoder[int64]: + return parquet.Types.Int64 + default: + panic(fmt.Sprintf("deltaBitPackDecoder is not supported for type: %T", v)) } - d.nvals -= max - return max, nil -} - -// Type returns the physical parquet type that this decoder decodes, in this case Int64 -func (DeltaBitPackInt64Decoder) Type() parquet.Type { - return parquet.Types.Int64 } -// DecodeSpaced is like Decode, but the result is spaced out appropriately based on the passed in bitmap -func (d DeltaBitPackInt64Decoder) DecodeSpaced(out []int64, nullCount int, validBits []byte, validBitsOffset int64) (int, error) { - toread := len(out) - nullCount - values, err := d.Decode(out[:toread]) - if err != nil { - return values, err - } - if values != toread { - return values, errors.New("parquet: number of values / definition levels read did not match") - } +// DeltaBitPackInt32Decoder decodes Int32 values which are packed using the Delta BitPacking algorithm. +type DeltaBitPackInt32Decoder = deltaBitPackDecoder[int32] - return spacedExpand(out, nullCount, validBits, validBitsOffset), nil -} +// DeltaBitPackInt64Decoder decodes Int64 values which are packed using the Delta BitPacking algorithm. +type DeltaBitPackInt64Decoder = deltaBitPackDecoder[int64] const ( // block size must be a multiple of 128 @@ -333,7 +250,7 @@ const ( // // Sets aside bytes at the start of the internal buffer where the header will be written, // and only writes the header when FlushValues is called before returning it. -type deltaBitPackEncoder struct { +type deltaBitPackEncoder[T int32 | int64] struct { encoder bitWriter *utils.BitWriter @@ -348,7 +265,7 @@ type deltaBitPackEncoder struct { } // flushBlock flushes out a finished block for writing to the underlying encoder -func (enc *deltaBitPackEncoder) flushBlock() { +func (enc *deltaBitPackEncoder[T]) flushBlock() { if len(enc.deltas) == 0 { return } @@ -400,9 +317,8 @@ func (enc *deltaBitPackEncoder) flushBlock() { // putInternal is the implementation for actually writing data which must be // integral data as int, int8, int32, or int64. -func (enc *deltaBitPackEncoder) putInternal(data interface{}) { - v := reflect.ValueOf(data) - if v.Len() == 0 { +func (enc *deltaBitPackEncoder[T]) Put(in []T) { + if len(in) == 0 { return } @@ -412,16 +328,16 @@ func (enc *deltaBitPackEncoder) putInternal(data interface{}) { enc.numMiniBlocks = defaultNumMiniBlocks enc.miniBlockSize = defaultNumValuesPerMini - enc.firstVal = v.Index(0).Int() + enc.firstVal = int64(in[0]) enc.currentVal = enc.firstVal idx = 1 enc.bitWriter = utils.NewBitWriter(enc.sink) } - enc.totalVals += uint64(v.Len()) - for ; idx < v.Len(); idx++ { - val := v.Index(idx).Int() + enc.totalVals += uint64(len(in)) + for ; idx < len(in); idx++ { + val := int64(in[idx]) enc.deltas = append(enc.deltas, val-enc.currentVal) enc.currentVal = val if len(enc.deltas) == int(enc.blockSize) { @@ -432,7 +348,7 @@ func (enc *deltaBitPackEncoder) putInternal(data interface{}) { // FlushValues flushes any remaining data and returns the finished encoded buffer // or returns nil and any error encountered during flushing. -func (enc *deltaBitPackEncoder) FlushValues() (Buffer, error) { +func (enc *deltaBitPackEncoder[T]) FlushValues() (Buffer, error) { if enc.bitWriter != nil { // write any remaining values enc.flushBlock() @@ -465,7 +381,7 @@ func (enc *deltaBitPackEncoder) FlushValues() (Buffer, error) { } // EstimatedDataEncodedSize returns the current amount of data actually flushed out and written -func (enc *deltaBitPackEncoder) EstimatedDataEncodedSize() int64 { +func (enc *deltaBitPackEncoder[T]) EstimatedDataEncodedSize() int64 { if enc.bitWriter == nil { return 0 } @@ -473,56 +389,33 @@ func (enc *deltaBitPackEncoder) EstimatedDataEncodedSize() int64 { return int64(enc.bitWriter.Written()) } -// DeltaBitPackInt32Encoder is an encoder for the delta bitpacking encoding for int32 data. -type DeltaBitPackInt32Encoder struct { - *deltaBitPackEncoder -} - -// Put writes the values from the provided slice of int32 to the encoder -func (enc DeltaBitPackInt32Encoder) Put(in []int32) { - enc.putInternal(in) -} - -// PutSpaced takes a slice of int32 along with a bitmap that describes the nulls and an offset into the bitmap +// PutSpaced takes a slice of values along with a bitmap that describes the nulls and an offset into the bitmap // in order to write spaced data to the encoder. -func (enc DeltaBitPackInt32Encoder) PutSpaced(in []int32, validBits []byte, validBitsOffset int64) { +func (enc *deltaBitPackEncoder[T]) PutSpaced(in []T, validBits []byte, validBitsOffset int64) { buffer := memory.NewResizableBuffer(enc.mem) - buffer.Reserve(arrow.Int32Traits.BytesRequired(len(in))) + dt := arrow.GetDataType[T]().(arrow.FixedWidthDataType) + buffer.Reserve(dt.Bytes() * len(in)) defer buffer.Release() - data := arrow.Int32Traits.CastFromBytes(buffer.Buf()) + data := arrow.GetData[T](buffer.Buf()) nvalid := spacedCompress(in, data, validBits, validBitsOffset) enc.Put(data[:nvalid]) } -// Type returns the underlying physical type this encoder works with, in this case Int32 -func (DeltaBitPackInt32Encoder) Type() parquet.Type { - return parquet.Types.Int32 -} - -// DeltaBitPackInt32Encoder is an encoder for the delta bitpacking encoding for int32 data. -type DeltaBitPackInt64Encoder struct { - *deltaBitPackEncoder -} - -// Put writes the values from the provided slice of int64 to the encoder -func (enc DeltaBitPackInt64Encoder) Put(in []int64) { - enc.putInternal(in) +// Type returns the underlying physical type this encoder works with +func (dec *deltaBitPackEncoder[T]) Type() parquet.Type { + switch v := any(dec).(type) { + case *deltaBitPackEncoder[int32]: + return parquet.Types.Int32 + case *deltaBitPackEncoder[int64]: + return parquet.Types.Int64 + default: + panic(fmt.Sprintf("deltaBitPackEncoder is not supported for type: %T", v)) + } } -// PutSpaced takes a slice of int64 along with a bitmap that describes the nulls and an offset into the bitmap -// in order to write spaced data to the encoder. -func (enc DeltaBitPackInt64Encoder) PutSpaced(in []int64, validBits []byte, validBitsOffset int64) { - buffer := memory.NewResizableBuffer(enc.mem) - buffer.Reserve(arrow.Int64Traits.BytesRequired(len(in))) - defer buffer.Release() +// DeltaBitPackInt32Encoder is an encoder for the delta bitpacking encoding for Int32 data. +type DeltaBitPackInt32Encoder = deltaBitPackEncoder[int32] - data := arrow.Int64Traits.CastFromBytes(buffer.Buf()) - nvalid := spacedCompress(in, data, validBits, validBitsOffset) - enc.Put(data[:nvalid]) -} - -// Type returns the underlying physical type this encoder works with, in this case Int64 -func (DeltaBitPackInt64Encoder) Type() parquet.Type { - return parquet.Types.Int64 -} +// DeltaBitPackInt64Encoder is an encoder for the delta bitpacking encoding for Int64 data. +type DeltaBitPackInt64Encoder = deltaBitPackEncoder[int64] diff --git a/go/parquet/internal/encoding/delta_byte_array.go b/go/parquet/internal/encoding/delta_byte_array.go index e7990f0dacbe8..62c8d08999972 100644 --- a/go/parquet/internal/encoding/delta_byte_array.go +++ b/go/parquet/internal/encoding/delta_byte_array.go @@ -53,11 +53,14 @@ func (enc *DeltaByteArrayEncoder) EstimatedDataEncodedSize() int64 { func (enc *DeltaByteArrayEncoder) initEncoders() { enc.prefixEncoder = &DeltaBitPackInt32Encoder{ - deltaBitPackEncoder: &deltaBitPackEncoder{encoder: newEncoderBase(enc.encoding, nil, enc.mem)}} + encoder: newEncoderBase(enc.encoding, nil, enc.mem), + } enc.suffixEncoder = &DeltaLengthByteArrayEncoder{ newEncoderBase(enc.encoding, nil, enc.mem), &DeltaBitPackInt32Encoder{ - deltaBitPackEncoder: &deltaBitPackEncoder{encoder: newEncoderBase(enc.encoding, nil, enc.mem)}}} + encoder: newEncoderBase(enc.encoding, nil, enc.mem), + }, + } } // Type returns the underlying physical type this operates on, in this case ByteArrays only @@ -160,9 +163,9 @@ func (d *DeltaByteArrayDecoder) Allocator() memory.Allocator { return d.mem } // blocks of suffix data in order to initialize the decoder. func (d *DeltaByteArrayDecoder) SetData(nvalues int, data []byte) error { prefixLenDec := DeltaBitPackInt32Decoder{ - deltaBitPackDecoder: &deltaBitPackDecoder{ - decoder: newDecoderBase(d.encoding, d.descr), - mem: d.mem}} + decoder: newDecoderBase(d.encoding, d.descr), + mem: d.mem, + } if err := prefixLenDec.SetData(nvalues, data); err != nil { return err diff --git a/go/parquet/internal/encoding/delta_length_byte_array.go b/go/parquet/internal/encoding/delta_length_byte_array.go index b72960fe438ad..87c48d574ed68 100644 --- a/go/parquet/internal/encoding/delta_length_byte_array.go +++ b/go/parquet/internal/encoding/delta_length_byte_array.go @@ -110,9 +110,9 @@ func (d *DeltaLengthByteArrayDecoder) Allocator() memory.Allocator { return d.me // followed by the rest of the byte array data immediately after. func (d *DeltaLengthByteArrayDecoder) SetData(nvalues int, data []byte) error { dec := DeltaBitPackInt32Decoder{ - deltaBitPackDecoder: &deltaBitPackDecoder{ - decoder: newDecoderBase(d.encoding, d.descr), - mem: d.mem}} + decoder: newDecoderBase(d.encoding, d.descr), + mem: d.mem, + } if err := dec.SetData(nvalues, data); err != nil { return err diff --git a/go/parquet/internal/encoding/encoding_benchmarks_test.go b/go/parquet/internal/encoding/encoding_benchmarks_test.go index 95c0b3861bc05..2ca414eec6b90 100644 --- a/go/parquet/internal/encoding/encoding_benchmarks_test.go +++ b/go/parquet/internal/encoding/encoding_benchmarks_test.go @@ -634,3 +634,48 @@ func BenchmarkByteStreamSplitDecodingFixedLenByteArray(b *testing.B) { }) } } + +func BenchmarkDeltaBinaryPackedEncodingInt32(b *testing.B) { + for sz := MINSIZE; sz < MAXSIZE+1; sz *= 2 { + b.Run(fmt.Sprintf("len %d", sz), func(b *testing.B) { + values := make([]int32, sz) + for idx := range values { + values[idx] = 64 + } + encoder := encoding.NewEncoder(parquet.Types.Int32, parquet.Encodings.DeltaBinaryPacked, + false, nil, memory.DefaultAllocator).(encoding.Int32Encoder) + b.ResetTimer() + b.SetBytes(int64(len(values) * arrow.Int32SizeBytes)) + for n := 0; n < b.N; n++ { + encoder.Put(values) + buf, _ := encoder.FlushValues() + buf.Release() + } + }) + } +} + +func BenchmarkDeltaBinaryPackedDecodingInt32(b *testing.B) { + for sz := MINSIZE; sz < MAXSIZE+1; sz *= 2 { + b.Run(fmt.Sprintf("len %d", sz), func(b *testing.B) { + output := make([]int32, sz) + values := make([]int32, sz) + for idx := range values { + values[idx] = 64 + } + encoder := encoding.NewEncoder(parquet.Types.Int32, parquet.Encodings.DeltaBinaryPacked, + false, nil, memory.DefaultAllocator).(encoding.Int32Encoder) + encoder.Put(values) + buf, _ := encoder.FlushValues() + defer buf.Release() + + decoder := encoding.NewDecoder(parquet.Types.Int32, parquet.Encodings.DeltaBinaryPacked, nil, memory.DefaultAllocator) + b.ResetTimer() + b.SetBytes(int64(len(values) * arrow.Int32SizeBytes)) + for n := 0; n < b.N; n++ { + decoder.SetData(sz, buf.Bytes()) + decoder.(encoding.Int32Decoder).Decode(output) + } + }) + } +} diff --git a/go/parquet/internal/encoding/typed_encoder.gen.go b/go/parquet/internal/encoding/typed_encoder.gen.go index 3a960e2c62332..e67c976adc042 100644 --- a/go/parquet/internal/encoding/typed_encoder.gen.go +++ b/go/parquet/internal/encoding/typed_encoder.gen.go @@ -86,8 +86,9 @@ func (int32EncoderTraits) Encoder(e format.Encoding, useDict bool, descr *schema case format.Encoding_PLAIN: return &PlainInt32Encoder{encoder: newEncoderBase(e, descr, mem)} case format.Encoding_DELTA_BINARY_PACKED: - return DeltaBitPackInt32Encoder{&deltaBitPackEncoder{ - encoder: newEncoderBase(e, descr, mem)}} + return &DeltaBitPackInt32Encoder{ + encoder: newEncoderBase(e, descr, mem), + } case format.Encoding_BYTE_STREAM_SPLIT: return &ByteStreamSplitInt32Encoder{PlainInt32Encoder: PlainInt32Encoder{encoder: newEncoderBase(e, descr, mem)}} default: @@ -118,10 +119,9 @@ func (int32DecoderTraits) Decoder(e parquet.Encoding, descr *schema.Column, useD mem = memory.DefaultAllocator } return &DeltaBitPackInt32Decoder{ - deltaBitPackDecoder: &deltaBitPackDecoder{ - decoder: newDecoderBase(format.Encoding(e), descr), - mem: mem, - }} + decoder: newDecoderBase(format.Encoding(e), descr), + mem: mem, + } case parquet.Encodings.ByteStreamSplit: return &ByteStreamSplitInt32Decoder{decoder: newDecoderBase(format.Encoding(e), descr)} default: @@ -327,8 +327,9 @@ func (int64EncoderTraits) Encoder(e format.Encoding, useDict bool, descr *schema case format.Encoding_PLAIN: return &PlainInt64Encoder{encoder: newEncoderBase(e, descr, mem)} case format.Encoding_DELTA_BINARY_PACKED: - return DeltaBitPackInt64Encoder{&deltaBitPackEncoder{ - encoder: newEncoderBase(e, descr, mem)}} + return &DeltaBitPackInt64Encoder{ + encoder: newEncoderBase(e, descr, mem), + } case format.Encoding_BYTE_STREAM_SPLIT: return &ByteStreamSplitInt64Encoder{PlainInt64Encoder: PlainInt64Encoder{encoder: newEncoderBase(e, descr, mem)}} default: @@ -359,10 +360,9 @@ func (int64DecoderTraits) Decoder(e parquet.Encoding, descr *schema.Column, useD mem = memory.DefaultAllocator } return &DeltaBitPackInt64Decoder{ - deltaBitPackDecoder: &deltaBitPackDecoder{ - decoder: newDecoderBase(format.Encoding(e), descr), - mem: mem, - }} + decoder: newDecoderBase(format.Encoding(e), descr), + mem: mem, + } case parquet.Encodings.ByteStreamSplit: return &ByteStreamSplitInt64Decoder{decoder: newDecoderBase(format.Encoding(e), descr)} default: @@ -1306,7 +1306,8 @@ func (byteArrayEncoderTraits) Encoder(e format.Encoding, useDict bool, descr *sc return &DeltaLengthByteArrayEncoder{ encoder: newEncoderBase(e, descr, mem), lengthEncoder: &DeltaBitPackInt32Encoder{ - &deltaBitPackEncoder{encoder: newEncoderBase(e, descr, mem)}}, + encoder: newEncoderBase(e, descr, mem), + }, } case format.Encoding_DELTA_BYTE_ARRAY: return &DeltaByteArrayEncoder{ diff --git a/go/parquet/internal/encoding/typed_encoder.gen.go.tmpl b/go/parquet/internal/encoding/typed_encoder.gen.go.tmpl index 079c1aad6bd3f..601d90712baa6 100644 --- a/go/parquet/internal/encoding/typed_encoder.gen.go.tmpl +++ b/go/parquet/internal/encoding/typed_encoder.gen.go.tmpl @@ -79,15 +79,17 @@ func ({{.lower}}EncoderTraits) Encoder(e format.Encoding, useDict bool, descr *s {{- end}} {{- if or (eq .Name "Int32") (eq .Name "Int64")}} case format.Encoding_DELTA_BINARY_PACKED: - return DeltaBitPack{{.Name}}Encoder{&deltaBitPackEncoder{ - encoder: newEncoderBase(e, descr, mem)}} + return &DeltaBitPack{{.Name}}Encoder{ + encoder: newEncoderBase(e, descr, mem), + } {{- end}} {{- if eq .Name "ByteArray"}} case format.Encoding_DELTA_LENGTH_BYTE_ARRAY: return &DeltaLengthByteArrayEncoder{ encoder: newEncoderBase(e, descr, mem), lengthEncoder: &DeltaBitPackInt32Encoder{ - &deltaBitPackEncoder{encoder: newEncoderBase(e, descr, mem)}}, + encoder: newEncoderBase(e, descr, mem), + }, } case format.Encoding_DELTA_BYTE_ARRAY: return &DeltaByteArrayEncoder{ @@ -135,10 +137,9 @@ func ({{.lower}}DecoderTraits) Decoder(e parquet.Encoding, descr *schema.Column, mem = memory.DefaultAllocator } return &DeltaBitPack{{.Name}}Decoder{ - deltaBitPackDecoder: &deltaBitPackDecoder{ - decoder: newDecoderBase(format.Encoding(e), descr), - mem: mem, - }} + decoder: newDecoderBase(format.Encoding(e), descr), + mem: mem, + } {{- end}} {{- if eq .Name "ByteArray"}} case parquet.Encodings.DeltaLengthByteArray: diff --git a/go/parquet/pqarrow/file_reader.go b/go/parquet/pqarrow/file_reader.go index 208ac9ceebadf..a2e84d9ce2795 100755 --- a/go/parquet/pqarrow/file_reader.go +++ b/go/parquet/pqarrow/file_reader.go @@ -18,6 +18,7 @@ package pqarrow import ( "context" + "errors" "fmt" "io" "sync" @@ -375,6 +376,10 @@ func (fr *FileReader) ReadRowGroups(ctx context.Context, indices, rowGroups []in data.data.Release() } + // if the context is in error, but we haven't set an error yet, then it means that the parent context + // was cancelled. In this case, we should exit early as some columns may not have been read yet. + err = errors.Join(err, ctx.Err()) + if err != nil { // if we encountered an error, consume any waiting data on the channel // so the goroutines don't leak and so memory can get cleaned up. we already diff --git a/go/parquet/pqarrow/file_reader_test.go b/go/parquet/pqarrow/file_reader_test.go index b7d178f8644de..fe5a4547a775c 100644 --- a/go/parquet/pqarrow/file_reader_test.go +++ b/go/parquet/pqarrow/file_reader_test.go @@ -167,6 +167,29 @@ func TestArrowReaderAdHocReadFloat16s(t *testing.T) { } } +func TestArrowReaderCanceledContext(t *testing.T) { + dataDir := getDataDir() + + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + filename := filepath.Join(dataDir, "int32_decimal.parquet") + require.FileExists(t, filename) + + rdr, err := file.OpenParquetFile(filename, false, file.WithReadProps(parquet.NewReaderProperties(mem))) + require.NoError(t, err) + defer rdr.Close() + arrowRdr, err := pqarrow.NewFileReader(rdr, pqarrow.ArrowReadProperties{}, mem) + require.NoError(t, err) + + // create a canceled context + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err = arrowRdr.ReadTable(ctx) + require.ErrorIs(t, err, context.Canceled) +} + func TestRecordReaderParallel(t *testing.T) { mem := memory.NewCheckedAllocator(memory.DefaultAllocator) defer mem.AssertSize(t, 0) diff --git a/go/parquet/pqarrow/file_writer.go b/go/parquet/pqarrow/file_writer.go index 891b757f5eb51..539c544829e3b 100644 --- a/go/parquet/pqarrow/file_writer.go +++ b/go/parquet/pqarrow/file_writer.go @@ -246,7 +246,7 @@ func (fw *FileWriter) Write(rec arrow.Record) error { } } fw.colIdx = 0 - return nil + return fw.rgw.Close() } // WriteTable writes an arrow table to the underlying file using chunkSize to determine diff --git a/go/parquet/pqarrow/file_writer_test.go b/go/parquet/pqarrow/file_writer_test.go index 25ef3879e7811..5b807389a3eb1 100644 --- a/go/parquet/pqarrow/file_writer_test.go +++ b/go/parquet/pqarrow/file_writer_test.go @@ -55,7 +55,11 @@ func TestFileWriterRowGroupNumRows(t *testing.T) { numRows, err := writer.RowGroupNumRows() require.NoError(t, err) assert.Equal(t, 4, numRows) + + // Make sure that row group stats are up-to-date immediately after writing + bytesWritten := writer.RowGroupTotalBytesWritten() require.NoError(t, writer.Close()) + require.Equal(t, bytesWritten, writer.RowGroupTotalBytesWritten()) } func TestFileWriterNumRows(t *testing.T) { diff --git a/go/parquet/pqarrow/helpers.go b/go/parquet/pqarrow/helpers.go index 800cd84192005..237de4366c03e 100644 --- a/go/parquet/pqarrow/helpers.go +++ b/go/parquet/pqarrow/helpers.go @@ -38,6 +38,8 @@ func releaseArrayData(data []arrow.ArrayData) { func releaseColumns(columns []arrow.Column) { for _, col := range columns { - col.Release() + if col.Data() != nil { // data can be nil due to the way columns are constructed in ReadRowGroups + col.Release() + } } } diff --git a/java/.mvn/extensions.xml b/java/.mvn/extensions.xml index d6e80695e22d0..716e2f9e81c35 100644 --- a/java/.mvn/extensions.xml +++ b/java/.mvn/extensions.xml @@ -23,7 +23,7 @@ com.gradle develocity-maven-extension - 1.21.5 + 1.21.6 com.gradle diff --git a/java/README.md b/java/README.md index 25e35c10973e9..9f1b1c63c8f41 100644 --- a/java/README.md +++ b/java/README.md @@ -85,7 +85,7 @@ variable are set, the system property takes precedence. ## Java Properties - * For Java 9 or later, should set `-Dio.netty.tryReflectionSetAccessible=true`. + * `-Dio.netty.tryReflectionSetAccessible=true` should be set. This fixes `java.lang.UnsupportedOperationException: sun.misc.Unsafe or java.nio.DirectByteBuffer.(long, int) not available`. thrown by Netty. * To support duplicate fields in a `StructVector` enable `-Darrow.struct.conflict.policy=CONFLICT_APPEND`. Duplicate fields are ignored (`CONFLICT_REPLACE`) by default and overwritten. To support different policies for diff --git a/java/adapter/jdbc/pom.xml b/java/adapter/jdbc/pom.xml index 875334af4526d..124cc535c25bf 100644 --- a/java/adapter/jdbc/pom.xml +++ b/java/adapter/jdbc/pom.xml @@ -59,7 +59,7 @@ under the License. com.h2database h2 - 2.2.224 + 2.3.230 test @@ -82,7 +82,6 @@ under the License. com.fasterxml.jackson.core jackson-annotations - test @@ -93,24 +92,30 @@ under the License. - - - jdk11+ - - [11,] - - - - - org.apache.maven.plugins - maven-surefire-plugin - - --add-reads=org.apache.arrow.adapter.jdbc=com.fasterxml.jackson.dataformat.yaml --add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED -Duser.timezone=UTC + + + + org.apache.maven.plugins + maven-dependency-plugin + + + analyze + verify + + + com.fasterxml.jackson.core:jackson-annotations + - - - - - - + + + + + org.apache.maven.plugins + maven-surefire-plugin + + --add-reads=org.apache.arrow.adapter.jdbc=com.fasterxml.jackson.dataformat.yaml --add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED -Duser.timezone=UTC + + + + diff --git a/java/adapter/jdbc/src/main/java/module-info.java b/java/adapter/jdbc/src/main/java/module-info.java index 5b59ce768472a..04977222c1530 100644 --- a/java/adapter/jdbc/src/main/java/module-info.java +++ b/java/adapter/jdbc/src/main/java/module-info.java @@ -20,6 +20,7 @@ exports org.apache.arrow.adapter.jdbc; exports org.apache.arrow.adapter.jdbc.binder; + requires com.fasterxml.jackson.annotation; requires com.fasterxml.jackson.databind; requires java.sql; requires jdk.unsupported; diff --git a/java/bom/pom.xml b/java/bom/pom.xml index e51906cd77e35..fe3264102144b 100644 --- a/java/bom/pom.xml +++ b/java/bom/pom.xml @@ -23,7 +23,7 @@ under the License. org.apache apache - 31 + 33 org.apache.arrow @@ -79,18 +79,10 @@ under the License. - 1.8 - 1.8 - 3.12.0 - 3.2.5 - 0.16.1 - 3.7.1 - 3.12.1 - 3.6.1 - 3.2.4 - 3.2.2 - 3.6.3 - 3.5.0 + 11 + 11 + 11 + 11 diff --git a/java/compression/pom.xml b/java/compression/pom.xml index 79105dbfccda5..8774f7cabde94 100644 --- a/java/compression/pom.xml +++ b/java/compression/pom.xml @@ -55,7 +55,7 @@ under the License. com.github.luben zstd-jni - 1.5.6-3 + 1.5.6-4 diff --git a/java/dataset/pom.xml b/java/dataset/pom.xml index c5c7468ccee84..74071a6c305ad 100644 --- a/java/dataset/pom.xml +++ b/java/dataset/pom.xml @@ -165,6 +165,7 @@ under the License. test + @@ -179,6 +180,7 @@ under the License. maven-surefire-plugin + --add-reads=org.apache.arrow.dataset=com.fasterxml.jackson.databind --add-opens=java.base/java.nio=org.apache.arrow.dataset,org.apache.arrow.memory.core,ALL-UNNAMED false ${project.basedir}/../../testing/data @@ -202,24 +204,4 @@ under the License. - - - - jdk11+ - - [11,] - - - - - org.apache.maven.plugins - maven-surefire-plugin - - --add-reads=org.apache.arrow.dataset=com.fasterxml.jackson.databind --add-opens=java.base/java.nio=org.apache.arrow.dataset,org.apache.arrow.memory.core,ALL-UNNAMED - - - - - - diff --git a/java/flight/flight-core/pom.xml b/java/flight/flight-core/pom.xml index c00bba5e6c763..be3c191654a58 100644 --- a/java/flight/flight-core/pom.xml +++ b/java/flight/flight-core/pom.xml @@ -32,6 +32,8 @@ under the License. 1 + + --add-opens=org.apache.arrow.flight.core/org.apache.arrow.flight.perf.impl=protobuf.java --add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED @@ -144,11 +146,13 @@ under the License. test + maven-surefire-plugin + --add-opens=org.apache.arrow.flight.core/org.apache.arrow.flight.perf.impl=protobuf.java --add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED false ${project.basedir}/../../../testing/data @@ -198,27 +202,4 @@ under the License. - - - - jdk11+ - - [11,] - - - - - org.apache.maven.plugins - maven-surefire-plugin - - --add-opens=org.apache.arrow.flight.core/org.apache.arrow.flight.perf.impl=protobuf.java --add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED - - ${project.basedir}/../../../testing/data - - - - - - - diff --git a/java/flight/flight-core/src/main/java/module-info.java b/java/flight/flight-core/src/main/java/module-info.java index f6bf5b73b0972..e668fe6149fb9 100644 --- a/java/flight/flight-core/src/main/java/module-info.java +++ b/java/flight/flight-core/src/main/java/module-info.java @@ -31,6 +31,7 @@ requires io.grpc.netty; requires io.grpc.protobuf; requires io.grpc.stub; + requires io.netty.buffer; requires io.netty.common; requires io.netty.handler; requires io.netty.transport; @@ -38,5 +39,6 @@ requires org.apache.arrow.memory.core; requires org.apache.arrow.vector; requires protobuf.java; + requires protobuf.java.util; requires org.slf4j; } diff --git a/java/flight/flight-integration-tests/pom.xml b/java/flight/flight-integration-tests/pom.xml index 97bce0c6ed5e3..a154062ba814d 100644 --- a/java/flight/flight-integration-tests/pom.xml +++ b/java/flight/flight-integration-tests/pom.xml @@ -69,19 +69,29 @@ under the License. - maven-assembly-plugin - - - jar-with-dependencies - - + maven-shade-plugin make-assembly - single + shade package + + false + true + jar-with-dependencies + + + + **/module-info.class + + + + + + + diff --git a/java/flight/flight-sql-jdbc-core/pom.xml b/java/flight/flight-sql-jdbc-core/pom.xml index 4833d30dbc33f..502d866fcc0bd 100644 --- a/java/flight/flight-sql-jdbc-core/pom.xml +++ b/java/flight/flight-sql-jdbc-core/pom.xml @@ -132,10 +132,8 @@ under the License. - com.google.code.findbugs - jsr305 - 3.0.2 - compile + org.checkerframework + checker-qual diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java index 845f5372d3f74..0e9c79a0907a5 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java @@ -29,7 +29,6 @@ import java.util.Map; import java.util.Optional; import java.util.Set; -import javax.annotation.Nullable; import org.apache.arrow.driver.jdbc.client.utils.ClientAuthenticationUtils; import org.apache.arrow.flight.CallOption; import org.apache.arrow.flight.CallStatus; @@ -61,6 +60,7 @@ import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.calcite.avatica.Meta.StatementType; +import org.checkerframework.checker.nullness.qual.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/java/flight/flight-sql-jdbc-driver/pom.xml b/java/flight/flight-sql-jdbc-driver/pom.xml index 524b9cd4f8aae..148319e5d9d64 100644 --- a/java/flight/flight-sql-jdbc-driver/pom.xml +++ b/java/flight/flight-sql-jdbc-driver/pom.xml @@ -59,6 +59,7 @@ under the License. maven-failsafe-plugin + default-it integration-test verify @@ -161,6 +162,7 @@ under the License. META-INF/native/libio_grpc_netty* META-INF/native/io_grpc_netty_shaded* **/*.proto + **/module-info.class diff --git a/java/flight/flight-sql/pom.xml b/java/flight/flight-sql/pom.xml index 9c8c5df07fb78..c9c589d202ac6 100644 --- a/java/flight/flight-sql/pom.xml +++ b/java/flight/flight-sql/pom.xml @@ -32,6 +32,8 @@ under the License. 1 + + --add-reads=org.apache.arrow.flight.sql=org.slf4j --add-reads=org.apache.arrow.flight.core=ALL-UNNAMED --add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED @@ -87,7 +89,7 @@ under the License. org.apache.derby derby - 10.14.2.0 + 10.15.2.0 test @@ -120,24 +122,4 @@ under the License. true - - - - jdk11+ - - [11,] - - - - - org.apache.maven.plugins - maven-surefire-plugin - - --add-reads=org.apache.arrow.flight.sql=org.slf4j --add-reads=org.apache.arrow.flight.core=ALL-UNNAMED --add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED - - - - - - diff --git a/java/flight/flight-sql/src/main/java/module-info.java b/java/flight/flight-sql/src/main/java/module-info.java index 5514d5b870afd..cb3835117daf6 100644 --- a/java/flight/flight-sql/src/main/java/module-info.java +++ b/java/flight/flight-sql/src/main/java/module-info.java @@ -25,5 +25,6 @@ requires org.apache.arrow.flight.core; requires org.apache.arrow.memory.core; requires org.apache.arrow.vector; + requires org.apache.commons.cli; requires protobuf.java; } diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSql.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSql.java index 8387834947283..2eb74adc5bc0e 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSql.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSql.java @@ -123,10 +123,10 @@ protected static void setUpExpectedResultsMap() { Integer.toString(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME_VALUE), "Apache Derby"); GET_SQL_INFO_EXPECTED_RESULTS_MAP.put( Integer.toString(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_VERSION_VALUE), - "10.14.2.0 - (1828579)"); + "10.15.2.0 - (1873585)"); GET_SQL_INFO_EXPECTED_RESULTS_MAP.put( Integer.toString(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_ARROW_VERSION_VALUE), - "10.14.2.0 - (1828579)"); + "10.15.2.0 - (1873585)"); GET_SQL_INFO_EXPECTED_RESULTS_MAP.put( Integer.toString(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY_VALUE), "false"); GET_SQL_INFO_EXPECTED_RESULTS_MAP.put( diff --git a/java/flight/pom.xml b/java/flight/pom.xml index 851f44d7bf19e..55511eba82b3a 100644 --- a/java/flight/pom.xml +++ b/java/flight/pom.xml @@ -37,17 +37,4 @@ under the License. flight-sql-jdbc-driver flight-integration-tests - - - - pin-mockito-jdk8 - - 1.8 - - - 4.11.0 - 5.2.0 - - - diff --git a/java/maven/module-info-compiler-maven-plugin/pom.xml b/java/maven/module-info-compiler-maven-plugin/pom.xml deleted file mode 100644 index 77184d35b5ac7..0000000000000 --- a/java/maven/module-info-compiler-maven-plugin/pom.xml +++ /dev/null @@ -1,124 +0,0 @@ - - - - 4.0.0 - - org.apache.arrow.maven.plugins - arrow-maven-plugins - 18.0.0-SNAPSHOT - - module-info-compiler-maven-plugin - maven-plugin - - Module Info Compiler Maven Plugin - - https://arrow.apache.org - - - ${maven.version} - - - - 3.8.7 - - - - - org.glavo - module-info-compiler - 2.0 - - - org.apache.maven - maven-plugin-api - ${maven.version} - provided - - - org.apache.maven - maven-core - ${maven.version} - provided - - - org.apache.maven - maven-artifact - ${maven.version} - provided - - - org.apache.maven - maven-model - ${maven.version} - provided - - - org.apache.maven.plugin-tools - maven-plugin-annotations - ${maven.plugin.tools.version} - provided - - - - - - - - com.gradle - develocity-maven-extension - - - - - - arrow-git.properties - - - - - - - - - - - org.apache.maven.plugins - maven-plugin-plugin - - true - - - - mojo-descriptor - - descriptor - - - - help-goal - - helpmojo - - - - - - - diff --git a/java/maven/module-info-compiler-maven-plugin/src/main/java/org/apache/arrow/maven/plugins/BaseModuleInfoCompilerPlugin.java b/java/maven/module-info-compiler-maven-plugin/src/main/java/org/apache/arrow/maven/plugins/BaseModuleInfoCompilerPlugin.java deleted file mode 100644 index 4fc8fc46e6bcc..0000000000000 --- a/java/maven/module-info-compiler-maven-plugin/src/main/java/org/apache/arrow/maven/plugins/BaseModuleInfoCompilerPlugin.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.arrow.maven.plugins; - -import java.io.File; -import java.io.IOException; -import java.io.InputStreamReader; -import java.io.OutputStream; -import java.io.Reader; -import java.nio.charset.StandardCharsets; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.util.List; -import java.util.Optional; -import org.apache.maven.plugin.AbstractMojo; -import org.apache.maven.plugin.MojoExecutionException; -import org.glavo.mic.ModuleInfoCompiler; - -/** Compiles the first module-info.java file in the project purely syntactically. */ -public abstract class BaseModuleInfoCompilerPlugin extends AbstractMojo { - protected abstract List getSourceRoots(); - - protected abstract boolean skip(); - - protected abstract String getOutputDirectory(); - - @Override - public void execute() throws MojoExecutionException { - if (skip()) { - getLog().info("Skipping module-info-compiler-maven-plugin"); - return; - } - - Optional moduleInfoFile = findFirstModuleInfo(getSourceRoots()); - if (moduleInfoFile.isPresent()) { - // The compiled module-info.class file goes into target/classes/module-info/main - Path outputDir = Paths.get(getOutputDirectory()); - - outputDir.toFile().mkdirs(); - Path targetPath = outputDir.resolve("module-info.class"); - - // Invoke the compiler, - ModuleInfoCompiler compiler = new ModuleInfoCompiler(); - try (Reader reader = - new InputStreamReader( - Files.newInputStream(moduleInfoFile.get().toPath()), StandardCharsets.UTF_8); - OutputStream output = Files.newOutputStream(targetPath)) { - compiler.compile(reader, output); - getLog().info("Successfully wrote module-info.class file."); - } catch (IOException ex) { - throw new MojoExecutionException("Error compiling module-info.java", ex); - } - } else { - getLog().info("No module-info.java file found. module-info.class file was not generated."); - } - } - - /** Finds the first module-info.java file in the set of source directories. */ - private Optional findFirstModuleInfo(List sourceDirectories) { - if (sourceDirectories == null) { - return Optional.empty(); - } - - return sourceDirectories.stream() - .map(Paths::get) - .map( - sourcePath -> - sourcePath.toFile().listFiles(file -> file.getName().equals("module-info.java"))) - .filter(matchingFiles -> matchingFiles != null && matchingFiles.length != 0) - .map(matchingFiles -> matchingFiles[0]) - .findAny(); - } -} diff --git a/java/maven/module-info-compiler-maven-plugin/src/main/java/org/apache/arrow/maven/plugins/ModuleInfoCompilerPlugin.java b/java/maven/module-info-compiler-maven-plugin/src/main/java/org/apache/arrow/maven/plugins/ModuleInfoCompilerPlugin.java deleted file mode 100644 index e66a475dbf8be..0000000000000 --- a/java/maven/module-info-compiler-maven-plugin/src/main/java/org/apache/arrow/maven/plugins/ModuleInfoCompilerPlugin.java +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.arrow.maven.plugins; - -import java.util.ArrayList; -import java.util.List; -import org.apache.maven.plugins.annotations.LifecyclePhase; -import org.apache.maven.plugins.annotations.Mojo; -import org.apache.maven.plugins.annotations.Parameter; -import org.apache.maven.project.MavenProject; - -/** A maven plugin for compiler module-info files in main code with JDK8. */ -@Mojo(name = "compile", defaultPhase = LifecyclePhase.COMPILE) -public class ModuleInfoCompilerPlugin extends BaseModuleInfoCompilerPlugin { - - @Parameter( - defaultValue = "${project.compileSourceRoots}", - property = "compileSourceRoots", - required = true) - private final List compileSourceRoots = new ArrayList<>(); - - @Parameter(defaultValue = "false", property = "skip", required = false) - private boolean skip = false; - - @Parameter(defaultValue = "${project}", readonly = true, required = true) - private MavenProject project; - - @Override - protected List getSourceRoots() { - return compileSourceRoots; - } - - @Override - protected boolean skip() { - return skip; - } - - @Override - protected String getOutputDirectory() { - return project.getBuild().getOutputDirectory(); - } -} diff --git a/java/maven/module-info-compiler-maven-plugin/src/main/java/org/apache/arrow/maven/plugins/ModuleInfoTestCompilerPlugin.java b/java/maven/module-info-compiler-maven-plugin/src/main/java/org/apache/arrow/maven/plugins/ModuleInfoTestCompilerPlugin.java deleted file mode 100644 index f18ac9faac735..0000000000000 --- a/java/maven/module-info-compiler-maven-plugin/src/main/java/org/apache/arrow/maven/plugins/ModuleInfoTestCompilerPlugin.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.arrow.maven.plugins; - -import java.util.List; -import org.apache.maven.plugins.annotations.LifecyclePhase; -import org.apache.maven.plugins.annotations.Mojo; -import org.apache.maven.plugins.annotations.Parameter; -import org.apache.maven.project.MavenProject; - -/** A maven plugin for compiler module-info files in unit tests with JDK8. */ -@Mojo(name = "testCompile", defaultPhase = LifecyclePhase.TEST_COMPILE) -public class ModuleInfoTestCompilerPlugin extends BaseModuleInfoCompilerPlugin { - - @Parameter(defaultValue = "false", property = "skip", required = false) - private boolean skip = false; - - @Parameter(defaultValue = "${project}", readonly = true, required = true) - private MavenProject project; - - @Override - protected List getSourceRoots() { - return project.getTestCompileSourceRoots(); - } - - @Override - protected boolean skip() { - return skip; - } - - @Override - protected String getOutputDirectory() { - return project.getBuild().getTestOutputDirectory(); - } -} diff --git a/java/maven/pom.xml b/java/maven/pom.xml deleted file mode 100644 index d342b629358dd..0000000000000 --- a/java/maven/pom.xml +++ /dev/null @@ -1,419 +0,0 @@ - - - - 4.0.0 - - - org.apache - apache - 31 - - - - org.apache.arrow.maven.plugins - arrow-maven-plugins - 18.0.0-SNAPSHOT - pom - - Arrow Maven Plugins - https://arrow.apache.org/ - - - - Developer List - dev-subscribe@arrow.apache.org - dev-unsubscribe@arrow.apache.org - dev@arrow.apache.org - https://lists.apache.org/list.html?dev@arrow.apache.org - - - Commits List - commits-subscribe@arrow.apache.org - commits-unsubscribe@arrow.apache.org - commits@arrow.apache.org - https://lists.apache.org/list.html?commits@arrow.apache.org - - - Issues List - issues-subscribe@arrow.apache.org - issues-unsubscribe@arrow.apache.org - https://lists.apache.org/list.html?issues@arrow.apache.org - - - GitHub List - github-subscribe@arrow.apache.org - github-unsubscribe@arrow.apache.org - https://lists.apache.org/list.html?github@arrow.apache.org - - - - - module-info-compiler-maven-plugin - - - - scm:git:https://github.com/apache/arrow.git - scm:git:https://github.com/apache/arrow.git - main - https://github.com/apache/arrow/tree/${project.scm.tag} - - - - GitHub - https://github.com/apache/arrow/issues - - - - true - - 1.8 - 1.8 - 3.13.1 - 3.2.5 - 0.16.1 - 3.7.1 - 3.12.1 - 3.6.1 - 3.2.4 - 3.2.2 - 3.6.3 - 3.5.0 - - - - - - - com.diffplug.spotless - spotless-maven-plugin - 2.30.0 - - - pl.project13.maven - git-commit-id-plugin - 4.9.10 - - - org.cyclonedx - cyclonedx-maven-plugin - 2.8.0 - - - org.codehaus.mojo - versions-maven-plugin - 2.17.0 - - - - - - org.apache.rat - apache-rat-plugin - - false - - **/dependency-reduced-pom.xml - **/*.log - **/*.css - **/*.js - **/*.md - **/*.eps - **/*.json - **/*.seq - **/*.parquet - **/*.sql - **/arrow-git.properties - **/*.csv - **/*.csvh - **/*.csvh-test - **/*.tsv - **/*.txt - **/*.ssv - **/arrow-*.conf - **/.buildpath - **/*.proto - **/*.fmpp - **/target/** - **/*.tdd - **/*.project - **/TAGS - **/*.checkstyle - **/.classpath - **/.factorypath - **/.settings/** - .*/** - **/*.patch - **/*.pb.cc - **/*.pb.h - **/*.linux - **/client/build/** - **/*.tbl - **/*.iml - **/flight.properties - **/*.idea/** - - - - - rat-checks - - check - - validate - - - - - - org.apache.maven.plugins - maven-jar-plugin - - - **/logging.properties - **/logback-test.xml - **/logback.out.xml - **/logback.xml - - - - org.apache.arrow - ${username} - https://arrow.apache.org/ - - - - - - - test-jar - - - true - - - - - - - org.apache.maven.plugins - maven-compiler-plugin - - 2048m - true - - - - maven-enforcer-plugin - - - avoid_bad_dependencies - - enforce - - verify - - - - - commons-logging - javax.servlet:servlet-api - org.mortbay.jetty:servlet-api - org.mortbay.jetty:servlet-api-2.5 - log4j:log4j - - - - - - - - - pl.project13.maven - git-commit-id-plugin - - dd.MM.yyyy '@' HH:mm:ss z - false - false - true - false - - false - false - 7 - -dirty - true - - - - - for-jars - - revision - - true - - target/classes/arrow-git.properties - - - - for-source-tarball - - revision - - false - - ./arrow-git.properties - - - - - - - org.apache.maven.plugins - maven-checkstyle-plugin - - ../dev/checkstyle/checkstyle.xml - ../dev/license/asf-java.license - ../dev/checkstyle/suppressions.xml - true - UTF-8 - true - ${checkstyle.failOnViolation} - ${checkstyle.failOnViolation} - warning - xml - ${project.build.directory}/test/checkstyle-errors.xml - false - - - - com.puppycrawl.tools - checkstyle - 8.29 - - - org.slf4j - jcl-over-slf4j - 2.0.13 - - - - - validate - - check - - validate - - - - - org.cyclonedx - cyclonedx-maven-plugin - - - - makeBom - - package - - - - - org.apache.maven.plugins - maven-project-info-reports-plugin - - - org.apache.maven.plugins - maven-site-plugin - - - com.diffplug.spotless - spotless-maven-plugin - - - - ${maven.multiModuleProjectDirectory}/dev/license/asf-xml.license - (<configuration|<project) - - - - - - 1.7 - - - - ${maven.multiModuleProjectDirectory}/dev/license/asf-java.license - package - - - - - - spotless-check - - check - - - - - - - - - - - org.apache.maven.plugins - maven-project-info-reports-plugin - - - org.apache.maven.plugins - maven-site-plugin - - - - - - - apache-release - - - - org.apache.maven.plugins - maven-assembly-plugin - - - source-release-assembly - - - true - - - - - - - - - diff --git a/java/memory/memory-core/pom.xml b/java/memory/memory-core/pom.xml index 95ef16aaa1cfe..9b24cee032023 100644 --- a/java/memory/memory-core/pom.xml +++ b/java/memory/memory-core/pom.xml @@ -30,11 +30,12 @@ under the License. Arrow Memory - Core Core off-heap memory management libraries for Arrow ValueVectors. + + + --add-reads=org.apache.arrow.memory.core=ch.qos.logback.classic --add-opens=java.base/java.lang.reflect=org.apache.arrow.memory.core --add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED + + - - com.google.code.findbugs - jsr305 - org.slf4j slf4j-api @@ -47,111 +48,67 @@ under the License. org.checkerframework checker-qual + + com.google.errorprone + error_prone_annotations + + + org.apache.maven.plugins + maven-compiler-plugin + + + -Xmaxerrs + + 10000 + -Xmaxwarns + 10000 + -AskipDefs=.*Test + + -AatfDoNotCache + + + + + org.checkerframework + checker + ${checker.framework.version} + + + + org.apache.maven.plugins maven-surefire-plugin + --add-reads=org.apache.arrow.memory.core=ch.qos.logback.classic --add-opens=java.base/java.lang.reflect=org.apache.arrow.memory.core --add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED **/TestOpens.java + + + + opens-tests + + test + + test + + + + + + **/TestOpens.java + + + + - - - - jdk11+ - - [11,] - - - - - org.apache.maven.plugins - maven-surefire-plugin - - --add-reads=org.apache.arrow.memory.core=ch.qos.logback.classic --add-opens=java.base/java.lang.reflect=org.apache.arrow.memory.core --add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED - - - **/TestOpens.java - - - - - - - - opens-tests - - - [16,] - - - - - org.apache.maven.plugins - maven-surefire-plugin - - - opens-tests - - test - - test - - - - - - **/TestOpens.java - - - - - - - - - - - checkerframework-jdk11+ - - [11,] - - - - - org.apache.maven.plugins - maven-compiler-plugin - - - -Xmaxerrs - - 10000 - -Xmaxwarns - 10000 - -AskipDefs=.*Test - - -AatfDoNotCache - - - - - org.checkerframework - checker - ${checker.framework.version} - - - - - - - - diff --git a/java/memory/memory-core/src/main/java/module-info.java b/java/memory/memory-core/src/main/java/module-info.java index 52fcb52d014a5..0a607bdf2f43a 100644 --- a/java/memory/memory-core/src/main/java/module-info.java +++ b/java/memory/memory-core/src/main/java/module-info.java @@ -22,7 +22,10 @@ exports org.apache.arrow.memory.util.hash; exports org.apache.arrow.util; + requires java.compiler; requires transitive jdk.unsupported; - requires jsr305; + requires static org.checkerframework.checker.qual; + requires static org.immutables.value.annotations; + requires static com.google.errorprone.annotations; requires org.slf4j; } diff --git a/java/memory/memory-core/src/main/java/org/apache/arrow/memory/Accountant.java b/java/memory/memory-core/src/main/java/org/apache/arrow/memory/Accountant.java index 5a31f4cd1914a..5d052c2cdeeec 100644 --- a/java/memory/memory-core/src/main/java/org/apache/arrow/memory/Accountant.java +++ b/java/memory/memory-core/src/main/java/org/apache/arrow/memory/Accountant.java @@ -17,7 +17,6 @@ package org.apache.arrow.memory; import java.util.concurrent.atomic.AtomicLong; -import javax.annotation.concurrent.ThreadSafe; import org.apache.arrow.util.Preconditions; import org.checkerframework.checker.nullness.qual.Nullable; @@ -25,7 +24,6 @@ * Provides a concurrent way to manage account for memory usage without locking. Used as basis for * Allocators. All operations are threadsafe (except for close). */ -@ThreadSafe class Accountant implements AutoCloseable { /** The parent allocator. */ diff --git a/java/memory/memory-core/src/main/java/org/apache/arrow/memory/BaseAllocator.java b/java/memory/memory-core/src/main/java/org/apache/arrow/memory/BaseAllocator.java index 3f4426d2c36e5..dd6375e910b92 100644 --- a/java/memory/memory-core/src/main/java/org/apache/arrow/memory/BaseAllocator.java +++ b/java/memory/memory-core/src/main/java/org/apache/arrow/memory/BaseAllocator.java @@ -16,6 +16,8 @@ */ package org.apache.arrow.memory; +import com.google.errorprone.annotations.FormatMethod; +import com.google.errorprone.annotations.FormatString; import java.util.Collection; import java.util.Collections; import java.util.HashSet; @@ -539,9 +541,8 @@ public String toVerboseString() { return sb.toString(); } - /* Remove @SuppressWarnings after fixing https://github.com/apache/arrow/issues/41951 */ - @SuppressWarnings("FormatStringAnnotation") - private void hist(String noteFormat, Object... args) { + @FormatMethod + private void hist(@FormatString String noteFormat, Object... args) { if (historicalLog != null) { historicalLog.recordEvent(noteFormat, args); } diff --git a/java/memory/memory-core/src/main/java/org/apache/arrow/memory/util/HistoricalLog.java b/java/memory/memory-core/src/main/java/org/apache/arrow/memory/util/HistoricalLog.java index 659ddde28df9b..5b1bdd8b7244c 100644 --- a/java/memory/memory-core/src/main/java/org/apache/arrow/memory/util/HistoricalLog.java +++ b/java/memory/memory-core/src/main/java/org/apache/arrow/memory/util/HistoricalLog.java @@ -16,6 +16,8 @@ */ package org.apache.arrow.memory.util; +import com.google.errorprone.annotations.FormatMethod; +import com.google.errorprone.annotations.FormatString; import java.util.ArrayDeque; import java.util.Arrays; import java.util.Deque; @@ -42,9 +44,8 @@ public class HistoricalLog { * object instance is best. * @param args for the format string, or nothing if none are required */ - @SuppressWarnings("FormatStringAnnotation") - /* Remove @SuppressWarnings after fixing https://github.com/apache/arrow/issues/41951 */ - public HistoricalLog(final String idStringFormat, Object... args) { + @FormatMethod + public HistoricalLog(@FormatString final String idStringFormat, Object... args) { this(Integer.MAX_VALUE, idStringFormat, args); } @@ -65,9 +66,8 @@ public HistoricalLog(final String idStringFormat, Object... args) { * object instance is best. * @param args for the format string, or nothing if none are required */ - @SuppressWarnings("AnnotateFormatMethod") - public HistoricalLog(final int limit, final String idStringFormat, Object... args) { - // Remove @SuppressWarnings after fixing https://github.com/apache/arrow/issues/41951 + @FormatMethod + public HistoricalLog(final int limit, @FormatString final String idStringFormat, Object... args) { this.limit = limit; this.idString = String.format(idStringFormat, args); this.firstEvent = null; @@ -80,9 +80,8 @@ public HistoricalLog(final int limit, final String idStringFormat, Object... arg * @param noteFormat {@link String#format} format string that describes the event * @param args for the format string, or nothing if none are required */ - @SuppressWarnings("AnnotateFormatMethod") - public synchronized void recordEvent(final String noteFormat, Object... args) { - // Remove @SuppressWarnings after fixing https://github.com/apache/arrow/issues/41951 + @FormatMethod + public synchronized void recordEvent(@FormatString final String noteFormat, Object... args) { final String note = String.format(noteFormat, args); final Event event = new Event(note); if (firstEvent == null) { diff --git a/java/memory/memory-core/src/test/java/org/apache/arrow/memory/TestOpens.java b/java/memory/memory-core/src/test/java/org/apache/arrow/memory/TestOpens.java index 756aa2919789b..b5e0a71e7ee0e 100644 --- a/java/memory/memory-core/src/test/java/org/apache/arrow/memory/TestOpens.java +++ b/java/memory/memory-core/src/test/java/org/apache/arrow/memory/TestOpens.java @@ -18,12 +18,15 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.condition.JRE.JAVA_16; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledForJreRange; public class TestOpens { /** Instantiating the RootAllocator should poke MemoryUtil and fail. */ @Test + @EnabledForJreRange(min = JAVA_16) public void testMemoryUtilFailsLoudly() { // This test is configured by Maven to run WITHOUT add-opens. So this should fail on JDK16+ // (where JEP396 means that add-opens is required to access JDK internals). @@ -44,6 +47,6 @@ public void testMemoryUtilFailsLoudly() { break; } } - assertTrue(found, "Expected exception as not thrown"); + assertTrue(found, "Expected exception was not thrown"); } } diff --git a/java/memory/memory-netty/pom.xml b/java/memory/memory-netty/pom.xml index e29ca3a4d053c..f2d4d2d0fe3bc 100644 --- a/java/memory/memory-netty/pom.xml +++ b/java/memory/memory-netty/pom.xml @@ -78,6 +78,7 @@ under the License. maven-failsafe-plugin + default-it integration-test verify diff --git a/java/performance/pom.xml b/java/performance/pom.xml index 0dfc26b469ce2..f6d3a26b4f352 100644 --- a/java/performance/pom.xml +++ b/java/performance/pom.xml @@ -75,7 +75,7 @@ under the License. com.h2database h2 - 2.2.224 + 2.3.230 runtime diff --git a/java/pom.xml b/java/pom.xml index 4ce0c1981d295..997257c71b6e9 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -23,7 +23,7 @@ under the License. org.apache apache - 31 + 33 org.apache.arrow @@ -65,7 +65,6 @@ under the License. - maven bom format memory @@ -97,8 +96,8 @@ under the License. 5.10.3 2.0.13 33.2.1-jre - 4.1.110.Final - 1.63.0 + 4.1.112.Final + 1.65.0 3.25.1 2.17.2 3.4.0 @@ -107,30 +106,24 @@ under the License. 2 true - 9+181-r4173-1 - 2.28.0 + 2.29.2 5.11.0 5.2.0 3.45.0 none -Xdoclint:none + + --add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED - 1.8 - 1.8 - 3.12.0 - 3.2.5 - 0.16.1 - 3.7.1 - 3.12.1 - 3.6.1 - 3.2.4 + 11 + 11 + 11 + 11 3.2.2 - 3.6.3 - 3.5.0 @@ -155,9 +148,10 @@ under the License. ${dep.fbs.version} - com.google.code.findbugs - jsr305 - 3.0.2 + com.google.errorprone + error_prone_annotations + ${error_prone_core.version} + provided org.slf4j @@ -177,7 +171,7 @@ under the License. org.assertj assertj-core - 3.26.0 + 3.26.3 test @@ -273,13 +267,13 @@ under the License. org.mockito mockito-junit-jupiter - 2.25.1 + 5.12.0 test ch.qos.logback logback-classic - 1.3.14 + 1.4.14 test @@ -298,8 +292,6 @@ under the License. maven-compiler-plugin true - **/module-info.java - **/module-info.java false @@ -313,6 +305,7 @@ under the License. maven-surefire-plugin + ${surefire.add-opens.argLine} true true ${forkCount} @@ -325,11 +318,13 @@ under the License. which in turn can cause OOM. --> 1048576 + false maven-failsafe-plugin + ${surefire.add-opens.argLine} ${project.build.directory} true @@ -444,13 +439,9 @@ under the License. **/module-info.java + arrow-memory-netty-buffer-patch,arrow-memory-netty,flight-sql-jdbc-core,flight-integration-tests,arrow-performance - - org.apache.arrow.maven.plugins - module-info-compiler-maven-plugin - ${project.version} - com.gradle develocity-maven-extension @@ -491,6 +482,7 @@ under the License. com.google.protobuf:protoc:${dep.protobuf-bom.version}:exe:${os.detected.classifier} grpc-java io.grpc:protoc-gen-grpc-java:${dep.grpc-bom.version}:exe:${os.detected.classifier} + @generated=omit @@ -779,24 +771,6 @@ under the License. - - org.apache.arrow.maven.plugins - module-info-compiler-maven-plugin - - - default-compile - - compile - - - - default-testCompile - - testCompile - - - - org.apache.maven.plugins maven-project-info-reports-plugin @@ -856,6 +830,7 @@ under the License. **/module-info.java + arrow-memory-netty-buffer-patch,arrow-memory-netty,flight-sql-jdbc-core,flight-integration-tests,arrow-performance @@ -917,56 +892,13 @@ under the License. - error-prone-jdk8 + error-prone - 1.8 - - !m2e.version - - - - - - org.apache.maven.plugins - maven-compiler-plugin - - - -XDcompilePolicy=simple - -Xplugin:ErrorProne - -J-Xbootclasspath/p:${settings.localRepository}/com/google/errorprone/javac/${errorprone.javac.version}/javac-${errorprone.javac.version}.jar - - - - com.google.errorprone - error_prone_core - - 2.10.0 - - - - - - - - - - error-prone-jdk11+ - - [11,] !m2e.version @@ -1003,30 +935,6 @@ under the License. - - jdk11+ - - [11,] - - - - - org.apache.maven.plugins - maven-surefire-plugin - - --add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED - - - - org.apache.maven.plugins - maven-failsafe-plugin - - --add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED - - - - - code-coverage @@ -1359,5 +1267,59 @@ under the License. + + + + cross-jdk-testing + + + arrow.test.jdk-version + + + + + + maven-enforcer-plugin + + + check-jdk-version-property + + enforce + + validate + + + + arrow.test.jdk-version + "JDK version used for test must be specified." + ^\d{2,} + "JDK version used for test must be 11, 17, 21, ..." + + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + ${arrow.test.jdk-version} + + + + + org.apache.maven.plugins + maven-failsafe-plugin + + + ${arrow.test.jdk-version} + + + + + + diff --git a/java/tools/pom.xml b/java/tools/pom.xml index b69d24786cb14..94566495dff19 100644 --- a/java/tools/pom.xml +++ b/java/tools/pom.xml @@ -59,7 +59,7 @@ under the License. ch.qos.logback logback-classic - 1.3.14 + 1.4.14 test