Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add UUID canonical extension type
Browse files Browse the repository at this point in the history
rok committed Oct 27, 2023
1 parent 547b240 commit d8bfaf8
Showing 20 changed files with 244 additions and 123 deletions.
1 change: 1 addition & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -548,6 +548,7 @@ if(ARROW_JSON)
list(APPEND
ARROW_SRCS
extension/fixed_shape_tensor.cc
extension/uuid_array.cc
json/options.cc
json/chunked_builder.cc
json/chunker.cc
2 changes: 2 additions & 0 deletions cpp/src/arrow/acero/hash_join_node_test.cc
Original file line number Diff line number Diff line change
@@ -28,6 +28,7 @@
#include "arrow/api.h"
#include "arrow/compute/kernels/row_encoder_internal.h"
#include "arrow/compute/kernels/test_util.h"
#include "arrow/extension/uuid_array.h"
#include "arrow/testing/extension_type.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/matchers.h"
@@ -47,6 +48,7 @@ using compute::SortIndices;
using compute::SortKey;
using compute::Take;
using compute::internal::RowEncoder;
using extension::uuid;

namespace acero {

3 changes: 3 additions & 0 deletions cpp/src/arrow/acero/util_test.cc
Original file line number Diff line number Diff line change
@@ -21,9 +21,12 @@
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/matchers.h"

#include "arrow/extension/uuid_array.h"

using testing::Eq;

namespace arrow {
using extension::uuid;
namespace acero {

const char* kLeftSuffix = ".left";
23 changes: 13 additions & 10 deletions cpp/src/arrow/c/bridge_test.cc
Original file line number Diff line number Diff line change
@@ -31,6 +31,7 @@
#include "arrow/c/bridge.h"
#include "arrow/c/helpers.h"
#include "arrow/c/util_internal.h"
#include "arrow/extension/uuid_array.h"
#include "arrow/ipc/json_simple.h"
#include "arrow/memory_pool.h"
#include "arrow/testing/extension_type.h"
@@ -50,6 +51,7 @@

namespace arrow {

using extension::uuid;
using internal::ArrayExportGuard;
using internal::ArrayExportTraits;
using internal::ArrayStreamExportGuard;
@@ -2026,15 +2028,16 @@ TEST_F(TestSchemaImport, Dictionary) {
}

TEST_F(TestSchemaImport, UnregisteredExtension) {
FillPrimitive("w:16");
c_struct_.metadata = kEncodedUuidMetadata.c_str();
auto expected = fixed_size_binary(16);
FillPrimitive(AddChild(), "u");
FillPrimitive("c");
FillDictionary();
c_struct_.metadata = kEncodedDictExtensionMetadata.c_str();
auto expected = dictionary(int8(), utf8());
CheckImport(expected);
}

TEST_F(TestSchemaImport, RegisteredExtension) {
{
ExtensionTypeGuard guard(uuid());
FillPrimitive("w:16");
c_struct_.metadata = kEncodedUuidMetadata.c_str();
auto expected = uuid();
@@ -2160,8 +2163,6 @@ TEST_F(TestSchemaImport, DictionaryError) {
}

TEST_F(TestSchemaImport, ExtensionError) {
ExtensionTypeGuard guard(uuid());

// Storage type doesn't match
FillPrimitive("w:15");
c_struct_.metadata = kEncodedUuidMetadata.c_str();
@@ -3411,7 +3412,10 @@ TEST_F(TestSchemaRoundtrip, Dictionary) {
}

TEST_F(TestSchemaRoundtrip, UnregisteredExtension) {
TestWithTypeFactory(uuid, []() { return fixed_size_binary(16); });
TestWithTypeFactory(complex128, []() {
return struct_({::arrow::field("real", float64(), /*nullable=*/false),
::arrow::field("imag", float64(), /*nullable=*/false)});
});
TestWithTypeFactory(dict_extension_type, []() { return dictionary(int8(), utf8()); });

// Inside nested type
@@ -3420,7 +3424,7 @@ TEST_F(TestSchemaRoundtrip, UnregisteredExtension) {
}

TEST_F(TestSchemaRoundtrip, RegisteredExtension) {
ExtensionTypeGuard guard({uuid(), dict_extension_type(), complex128()});
ExtensionTypeGuard guard({dict_extension_type(), complex128()});
TestWithTypeFactory(uuid);
TestWithTypeFactory(dict_extension_type);
TestWithTypeFactory(complex128);
@@ -3744,7 +3748,7 @@ TEST_F(TestArrayRoundtrip, Dictionary) {
}

TEST_F(TestArrayRoundtrip, RegisteredExtension) {
ExtensionTypeGuard guard({smallint(), complex128(), dict_extension_type(), uuid()});
ExtensionTypeGuard guard({smallint(), complex128(), dict_extension_type()});

TestWithArrayFactory(ExampleSmallint);
TestWithArrayFactory(ExampleUuid);
@@ -3773,7 +3777,6 @@ TEST_F(TestArrayRoundtrip, UnregisteredExtension) {
};

TestWithArrayFactory(ExampleSmallint, StorageExtractor(ExampleSmallint));
TestWithArrayFactory(ExampleUuid, StorageExtractor(ExampleUuid));
TestWithArrayFactory(ExampleComplex128, StorageExtractor(ExampleComplex128));
TestWithArrayFactory(ExampleDictExtension, StorageExtractor(ExampleDictExtension));
}
3 changes: 2 additions & 1 deletion cpp/src/arrow/extension/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -17,8 +17,9 @@

add_arrow_test(test
SOURCES
uuid_array_test.cc
fixed_shape_tensor_test.cc
PREFIX
"arrow-fixed-shape-tensor")
"arrow-canonical-extensions")

arrow_install_all_headers("arrow/extension")
15 changes: 0 additions & 15 deletions cpp/src/arrow/extension/fixed_shape_tensor_test.cc
Original file line number Diff line number Diff line change
@@ -23,7 +23,6 @@
#include "arrow/array/array_primitive.h"
#include "arrow/io/memory.h"
#include "arrow/ipc/reader.h"
#include "arrow/ipc/writer.h"
#include "arrow/record_batch.h"
#include "arrow/tensor.h"
#include "arrow/testing/gtest_util.h"
@@ -70,20 +69,6 @@ class TestExtensionType : public ::testing::Test {
std::string serialized_;
};

auto RoundtripBatch = [](const std::shared_ptr<RecordBatch>& batch,
std::shared_ptr<RecordBatch>* out) {
ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create());
ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(),
out_stream.get()));

ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish());

io::BufferReader reader(complete_ipc_stream);
std::shared_ptr<RecordBatchReader> batch_reader;
ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader));
ASSERT_OK(batch_reader->ReadNext(out));
};

TEST_F(TestExtensionType, CheckDummyRegistration) {
// We need a registered dummy type at runtime to allow for IPC deserialization
auto registered_type = GetExtensionType("arrow.fixed_shape_tensor");
51 changes: 51 additions & 0 deletions cpp/src/arrow/extension/uuid_array.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "arrow/extension_type.h"
#include "arrow/util/logging.h"

#include "arrow/extension/uuid_array.h"

namespace arrow {
namespace extension {

bool UuidType::ExtensionEquals(const ExtensionType& other) const {
return (other.extension_name() == this->extension_name());
}

std::shared_ptr<Array> UuidType::MakeArray(std::shared_ptr<ArrayData> data) const {
DCHECK_EQ(data->type->id(), Type::EXTENSION);
DCHECK_EQ("uuid", static_cast<const ExtensionType&>(*data->type).extension_name());
return std::make_shared<UuidArray>(data);
}

Result<std::shared_ptr<DataType>> UuidType::Deserialize(
std::shared_ptr<DataType> storage_type, const std::string& serialized) const {
if (serialized != "uuid-serialized") {
return Status::Invalid("Type identifier did not match: '", serialized, "'");
}
if (!storage_type->Equals(*fixed_size_binary(16))) {
return Status::Invalid("Invalid storage type for UuidType: ",
storage_type->ToString());
}
return std::make_shared<UuidType>();
}

std::shared_ptr<DataType> uuid() { return std::make_shared<UuidType>(); }

} // namespace extension
} // namespace arrow
51 changes: 51 additions & 0 deletions cpp/src/arrow/extension/uuid_array.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#pragma once

#include "arrow/extension_type.h"

namespace arrow {
namespace extension {

class ARROW_EXPORT UuidArray : public ExtensionArray {
public:
using ExtensionArray::ExtensionArray;
};

class ARROW_EXPORT UuidType : public ExtensionType {
public:
UuidType() : ExtensionType(fixed_size_binary(16)) {}

std::string extension_name() const override { return "uuid"; }

bool ExtensionEquals(const ExtensionType& other) const override;

std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override;

Result<std::shared_ptr<DataType>> Deserialize(
std::shared_ptr<DataType> storage_type,
const std::string& serialized) const override;

std::string Serialize() const override { return "uuid-serialized"; }
};

ARROW_EXPORT
std::shared_ptr<DataType> uuid();

} // namespace extension
} // namespace arrow
53 changes: 53 additions & 0 deletions cpp/src/arrow/extension/uuid_array_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "arrow/extension/uuid_array.h"

#include "arrow/testing/matchers.h"

#include "arrow/array/array_nested.h"
#include "arrow/array/array_primitive.h"
#include "arrow/io/memory.h"
#include "arrow/ipc/reader.h"
#include "arrow/ipc/writer.h"
#include "arrow/record_batch.h"
#include "arrow/tensor.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/util/key_value_metadata.h"

#include "arrow/testing/extension_type.h"

namespace arrow {

using extension::uuid;

class TestUuuidExtensionType : public ::testing::Test {};

TEST_F(TestUuuidExtensionType, ExtensionTypeTest) {
auto type = uuid();
ASSERT_EQ(type->id(), Type::EXTENSION);

const auto& ext_type = static_cast<const ExtensionType&>(*type);
std::string serialized = ext_type.Serialize();

ASSERT_OK_AND_ASSIGN(auto deserialized,
ext_type.Deserialize(fixed_size_binary(16), serialized));
ASSERT_TRUE(deserialized->Equals(*type));
ASSERT_FALSE(deserialized->Equals(*fixed_size_binary(16)));
}

} // namespace arrow
10 changes: 6 additions & 4 deletions cpp/src/arrow/extension_type.cc
Original file line number Diff line number Diff line change
@@ -30,6 +30,7 @@
#ifdef ARROW_JSON
#include "arrow/extension/fixed_shape_tensor.h"
#endif
#include "arrow/extension/uuid_array.h"
#include "arrow/status.h"
#include "arrow/type.h"
#include "arrow/util/checked_cast.h"
@@ -146,10 +147,11 @@ static void CreateGlobalRegistry() {

#ifdef ARROW_JSON
// Register canonical extension types
auto ext_type =
checked_pointer_cast<ExtensionType>(extension::fixed_shape_tensor(int64(), {}));

ARROW_CHECK_OK(g_registry->RegisterType(ext_type));
auto ext_types = {extension::fixed_shape_tensor(int64(), {}), extension::uuid()};
for (const auto& ext_type : ext_types) {
ARROW_CHECK_OK(
g_registry->RegisterType(checked_pointer_cast<ExtensionType>(ext_type)));
}
#endif
}

28 changes: 4 additions & 24 deletions cpp/src/arrow/extension_type_test.cc
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@

#include "arrow/array/array_nested.h"
#include "arrow/array/util.h"
#include "arrow/extension/uuid_array.h"
#include "arrow/extension_type.h"
#include "arrow/io/memory.h"
#include "arrow/ipc/options.h"
@@ -41,6 +42,8 @@

namespace arrow {

using extension::uuid;

class Parametric1Array : public ExtensionArray {
public:
using ExtensionArray::ExtensionArray;
@@ -176,16 +179,7 @@ class ExtStructType : public ExtensionType {
std::string Serialize() const override { return "ext-struct-type-unique-code"; }
};

class TestExtensionType : public ::testing::Test {
public:
void SetUp() { ASSERT_OK(RegisterExtensionType(std::make_shared<UuidType>())); }

void TearDown() {
if (GetExtensionType("uuid")) {
ASSERT_OK(UnregisterExtensionType("uuid"));
}
}
};
class TestExtensionType : public ::testing::Test {};

TEST_F(TestExtensionType, ExtensionTypeTest) {
auto type_not_exist = GetExtensionType("uuid-unknown");
@@ -206,20 +200,6 @@ TEST_F(TestExtensionType, ExtensionTypeTest) {
ASSERT_FALSE(deserialized->Equals(*fixed_size_binary(16)));
}

auto RoundtripBatch = [](const std::shared_ptr<RecordBatch>& batch,
std::shared_ptr<RecordBatch>* out) {
ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create());
ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(),
out_stream.get()));

ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish());

io::BufferReader reader(complete_ipc_stream);
std::shared_ptr<RecordBatchReader> batch_reader;
ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader));
ASSERT_OK(batch_reader->ReadNext(out));
};

TEST_F(TestExtensionType, IpcRoundtrip) {
auto ext_arr = ExampleUuid();
auto batch = RecordBatch::Make(schema({field("f0", uuid())}), 4, {ext_arr});
Loading

0 comments on commit d8bfaf8

Please sign in to comment.