From 39fe6fc45d14accf63b7aefed5a8f1225f6b552a Mon Sep 17 00:00:00 2001 From: Joel Lubinitsky <33523178+joellubi@users.noreply.github.com> Date: Mon, 12 Aug 2024 16:49:57 -0400 Subject: [PATCH] GH-17682: [Go] Bool8 Extension Type Implementation (#43323) ### Rationale for this change Go implementation of #43234 ### What changes are included in this PR? - Go implementation of the `Bool8` extension type - Minor refactor of existing extension builder interfaces ### Are these changes tested? Yes, unit tests and basic read/write benchmarks are included. ### Are there any user-facing changes? - A new extension type is added - Custom extension builders no longer need another builder created and released separately. * GitHub Issue: #17682 Authored-by: Joel Lubinitsky Signed-off-by: Joel Lubinitsky --- go/arrow/array/builder.go | 11 +- go/arrow/array/extension_builder.go | 10 +- go/arrow/extensions/bool8.go | 216 +++++++++++++++ go/arrow/extensions/bool8_test.go | 319 ++++++++++++++++++++++ go/arrow/extensions/extensions_test.go | 105 +++++++ go/internal/types/extension_types.go | 9 +- go/internal/types/extension_types_test.go | 16 +- go/parquet/pqarrow/encode_arrow_test.go | 4 +- 8 files changed, 663 insertions(+), 27 deletions(-) create mode 100644 go/arrow/extensions/bool8.go create mode 100644 go/arrow/extensions/bool8_test.go create mode 100644 go/arrow/extensions/extensions_test.go diff --git a/go/arrow/array/builder.go b/go/arrow/array/builder.go index 6c8ea877a2fb0..1f4d0ea963509 100644 --- a/go/arrow/array/builder.go +++ b/go/arrow/array/builder.go @@ -349,12 +349,13 @@ func NewBuilder(mem memory.Allocator, dtype arrow.DataType) Builder { typ := dtype.(*arrow.LargeListViewType) return NewLargeListViewBuilderWithField(mem, typ.ElemField()) case arrow.EXTENSION: - typ := dtype.(arrow.ExtensionType) - bldr := NewExtensionBuilder(mem, typ) - if custom, ok := typ.(ExtensionBuilderWrapper); ok { - return custom.NewBuilder(bldr) + if custom, ok := dtype.(CustomExtensionBuilder); ok { + return custom.NewBuilder(mem) } - return bldr + if typ, ok := dtype.(arrow.ExtensionType); ok { + return NewExtensionBuilder(mem, typ) + } + panic(fmt.Errorf("arrow/array: invalid extension type: %T", dtype)) case arrow.FIXED_SIZE_LIST: typ := dtype.(*arrow.FixedSizeListType) return NewFixedSizeListBuilderWithField(mem, typ.Len(), typ.ElemField()) diff --git a/go/arrow/array/extension_builder.go b/go/arrow/array/extension_builder.go index a71287faf0e36..9c2ee88056438 100644 --- a/go/arrow/array/extension_builder.go +++ b/go/arrow/array/extension_builder.go @@ -16,8 +16,10 @@ package array -// ExtensionBuilderWrapper is an interface that you need to implement in your custom extension type if you want to provide a customer builder as well. -// See example in ./arrow/internal/testing/types/extension_types.go -type ExtensionBuilderWrapper interface { - NewBuilder(bldr *ExtensionBuilder) Builder +import "github.com/apache/arrow/go/v18/arrow/memory" + +// CustomExtensionBuilder is an interface that custom extension types may implement to provide a custom builder +// instead of the underlying storage type's builder when array.NewBuilder is called with that type. +type CustomExtensionBuilder interface { + NewBuilder(memory.Allocator) Builder } diff --git a/go/arrow/extensions/bool8.go b/go/arrow/extensions/bool8.go new file mode 100644 index 0000000000000..20ab024a2a2fb --- /dev/null +++ b/go/arrow/extensions/bool8.go @@ -0,0 +1,216 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extensions + +import ( + "fmt" + "reflect" + "strconv" + "strings" + "unsafe" + + "github.com/apache/arrow/go/v18/arrow" + "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/memory" + "github.com/apache/arrow/go/v18/internal/json" +) + +// Bool8Type represents a logical boolean that is stored using 8 bits. +type Bool8Type struct { + arrow.ExtensionBase +} + +// NewBool8Type creates a new Bool8Type with the underlying storage type set correctly to Int8. +func NewBool8Type() *Bool8Type { + return &Bool8Type{ExtensionBase: arrow.ExtensionBase{Storage: arrow.PrimitiveTypes.Int8}} +} + +func (b *Bool8Type) ArrayType() reflect.Type { return reflect.TypeOf(Bool8Array{}) } + +func (b *Bool8Type) Deserialize(storageType arrow.DataType, data string) (arrow.ExtensionType, error) { + if !arrow.TypeEqual(storageType, arrow.PrimitiveTypes.Int8) { + return nil, fmt.Errorf("invalid storage type for Bool8Type: %s", storageType.Name()) + } + return NewBool8Type(), nil +} + +func (b *Bool8Type) ExtensionEquals(other arrow.ExtensionType) bool { + return b.ExtensionName() == other.ExtensionName() +} + +func (b *Bool8Type) ExtensionName() string { return "arrow.bool8" } + +func (b *Bool8Type) Serialize() string { return "" } + +func (b *Bool8Type) String() string { return fmt.Sprintf("extension<%s>", b.ExtensionName()) } + +func (*Bool8Type) NewBuilder(mem memory.Allocator) array.Builder { + return NewBool8Builder(mem) +} + +// Bool8Array is logically an array of boolean values but uses +// 8 bits to store values instead of 1 bit as in the native BooleanArray. +type Bool8Array struct { + array.ExtensionArrayBase +} + +func (a *Bool8Array) String() string { + var o strings.Builder + o.WriteString("[") + for i := 0; i < a.Len(); i++ { + if i > 0 { + o.WriteString(" ") + } + switch { + case a.IsNull(i): + o.WriteString(array.NullValueStr) + default: + fmt.Fprintf(&o, "%v", a.Value(i)) + } + } + o.WriteString("]") + return o.String() +} + +func (a *Bool8Array) Value(i int) bool { + return a.Storage().(*array.Int8).Value(i) != 0 +} + +func (a *Bool8Array) BoolValues() []bool { + int8s := a.Storage().(*array.Int8).Int8Values() + return unsafe.Slice((*bool)(unsafe.Pointer(unsafe.SliceData(int8s))), len(int8s)) +} + +func (a *Bool8Array) ValueStr(i int) string { + switch { + case a.IsNull(i): + return array.NullValueStr + default: + return fmt.Sprint(a.Value(i)) + } +} + +func (a *Bool8Array) MarshalJSON() ([]byte, error) { + values := make([]interface{}, a.Len()) + for i := 0; i < a.Len(); i++ { + if a.IsValid(i) { + values[i] = a.Value(i) + } + } + return json.Marshal(values) +} + +func (a *Bool8Array) GetOneForMarshal(i int) interface{} { + if a.IsNull(i) { + return nil + } + return a.Value(i) +} + +// boolToInt8 performs the simple scalar conversion of bool to the canonical int8 +// value for the Bool8Type. +func boolToInt8(v bool) int8 { + var res int8 + if v { + res = 1 + } + return res +} + +// Bool8Builder is a convenience builder for the Bool8 extension type, +// allowing arrays to be built with boolean values rather than the underlying storage type. +type Bool8Builder struct { + *array.ExtensionBuilder +} + +// NewBool8Builder creates a new Bool8Builder, exposing a convenient and efficient interface +// for writing boolean values to the underlying int8 storage array. +func NewBool8Builder(mem memory.Allocator) *Bool8Builder { + return &Bool8Builder{ExtensionBuilder: array.NewExtensionBuilder(mem, NewBool8Type())} +} + +func (b *Bool8Builder) Append(v bool) { + b.ExtensionBuilder.Builder.(*array.Int8Builder).Append(boolToInt8(v)) +} + +func (b *Bool8Builder) UnsafeAppend(v bool) { + b.ExtensionBuilder.Builder.(*array.Int8Builder).UnsafeAppend(boolToInt8(v)) +} + +func (b *Bool8Builder) AppendValueFromString(s string) error { + if s == array.NullValueStr { + b.AppendNull() + return nil + } + + val, err := strconv.ParseBool(s) + if err != nil { + return err + } + + b.Append(val) + return nil +} + +func (b *Bool8Builder) AppendValues(v []bool, valid []bool) { + boolsAsInt8s := unsafe.Slice((*int8)(unsafe.Pointer(unsafe.SliceData(v))), len(v)) + b.ExtensionBuilder.Builder.(*array.Int8Builder).AppendValues(boolsAsInt8s, valid) +} + +func (b *Bool8Builder) UnmarshalOne(dec *json.Decoder) error { + t, err := dec.Token() + if err != nil { + return err + } + + switch v := t.(type) { + case bool: + b.Append(v) + return nil + case string: + return b.AppendValueFromString(v) + case int8: + b.ExtensionBuilder.Builder.(*array.Int8Builder).Append(v) + return nil + case nil: + b.AppendNull() + return nil + default: + return &json.UnmarshalTypeError{ + Value: fmt.Sprint(t), + Type: reflect.TypeOf([]byte{}), + Offset: dec.InputOffset(), + Struct: "Bool8Builder", + } + } +} + +func (b *Bool8Builder) Unmarshal(dec *json.Decoder) error { + for dec.More() { + if err := b.UnmarshalOne(dec); err != nil { + return err + } + } + return nil +} + +var ( + _ arrow.ExtensionType = (*Bool8Type)(nil) + _ array.CustomExtensionBuilder = (*Bool8Type)(nil) + _ array.ExtensionArray = (*Bool8Array)(nil) + _ array.Builder = (*Bool8Builder)(nil) +) diff --git a/go/arrow/extensions/bool8_test.go b/go/arrow/extensions/bool8_test.go new file mode 100644 index 0000000000000..9f7365d1555fb --- /dev/null +++ b/go/arrow/extensions/bool8_test.go @@ -0,0 +1,319 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extensions_test + +import ( + "bytes" + "fmt" + "strings" + "testing" + + "github.com/apache/arrow/go/v18/arrow" + "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/extensions" + "github.com/apache/arrow/go/v18/arrow/ipc" + "github.com/apache/arrow/go/v18/arrow/memory" + "github.com/apache/arrow/go/v18/internal/json" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + MINSIZE = 1024 + MAXSIZE = 65536 +) + +func TestBool8ExtensionBuilder(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + builder := extensions.NewBool8Builder(mem) + defer builder.Release() + + builder.Append(true) + builder.AppendNull() + builder.Append(false) + arr := builder.NewArray() + defer arr.Release() + + arrStr := arr.String() + require.Equal(t, "[true (null) false]", arrStr) + + jsonStr, err := json.Marshal(arr) + require.NoError(t, err) + + arr1, _, err := array.FromJSON(mem, extensions.NewBool8Type(), bytes.NewReader(jsonStr)) + require.NoError(t, err) + defer arr1.Release() + + require.Equal(t, arr, arr1) +} + +func TestBool8ExtensionRecordBuilder(t *testing.T) { + schema := arrow.NewSchema([]arrow.Field{ + {Name: "bool8", Type: extensions.NewBool8Type()}, + }, nil) + + builder := array.NewRecordBuilder(memory.DefaultAllocator, schema) + defer builder.Release() + + builder.Field(0).(*extensions.Bool8Builder).Append(true) + record := builder.NewRecord() + defer record.Release() + + b, err := record.MarshalJSON() + require.NoError(t, err) + require.Equal(t, "[{\"bool8\":true}\n]", string(b)) + + record1, _, err := array.RecordFromJSON(memory.DefaultAllocator, schema, bytes.NewReader(b)) + require.NoError(t, err) + defer record1.Release() + + require.Equal(t, record, record1) + + require.NoError(t, builder.UnmarshalJSON([]byte(`{"bool8":true}`))) + record = builder.NewRecord() + defer record.Release() + + require.Equal(t, schema, record.Schema()) + require.Equal(t, true, record.Column(0).(*extensions.Bool8Array).Value(0)) +} + +func TestBool8StringRoundTrip(t *testing.T) { + // 1. create array + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + b := extensions.NewBool8Builder(mem) + b.Append(true) + b.AppendNull() + b.Append(false) + b.AppendNull() + b.Append(true) + + arr := b.NewArray() + defer arr.Release() + + // 2. create array via AppendValueFromString + b1 := extensions.NewBool8Builder(mem) + defer b1.Release() + + for i := 0; i < arr.Len(); i++ { + assert.NoError(t, b1.AppendValueFromString(arr.ValueStr(i))) + } + + arr1 := b1.NewArray() + defer arr1.Release() + + assert.True(t, array.Equal(arr, arr1)) +} + +func TestCompareBool8AndBoolean(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + bool8bldr := extensions.NewBool8Builder(mem) + defer bool8bldr.Release() + + boolbldr := array.NewBooleanBuilder(mem) + defer boolbldr.Release() + + inputVals := []bool{true, false, false, false, true} + inputValidity := []bool{true, false, true, false, true} + + bool8bldr.AppendValues(inputVals, inputValidity) + bool8Arr := bool8bldr.NewExtensionArray().(*extensions.Bool8Array) + defer bool8Arr.Release() + + boolbldr.AppendValues(inputVals, inputValidity) + boolArr := boolbldr.NewBooleanArray() + defer boolArr.Release() + + require.Equal(t, boolArr.Len(), bool8Arr.Len()) + for i := 0; i < boolArr.Len(); i++ { + require.Equal(t, boolArr.Value(i), bool8Arr.Value(i)) + } +} + +func TestReinterpretStorageEqualToValues(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + bool8bldr := extensions.NewBool8Builder(mem) + defer bool8bldr.Release() + + inputVals := []bool{true, false, false, false, true} + inputValidity := []bool{true, false, true, false, true} + + bool8bldr.AppendValues(inputVals, inputValidity) + bool8Arr := bool8bldr.NewExtensionArray().(*extensions.Bool8Array) + defer bool8Arr.Release() + + boolValsCopy := make([]bool, bool8Arr.Len()) + for i := 0; i < bool8Arr.Len(); i++ { + boolValsCopy[i] = bool8Arr.Value(i) + } + + boolValsZeroCopy := bool8Arr.BoolValues() + + require.Equal(t, len(boolValsZeroCopy), len(boolValsCopy)) + for i := range boolValsCopy { + require.Equal(t, boolValsZeroCopy[i], boolValsCopy[i]) + } +} + +func TestBool8TypeBatchIPCRoundTrip(t *testing.T) { + typ := extensions.NewBool8Type() + arrow.RegisterExtensionType(typ) + defer arrow.UnregisterExtensionType(typ.ExtensionName()) + + storage, _, err := array.FromJSON(memory.DefaultAllocator, arrow.PrimitiveTypes.Int8, + strings.NewReader(`[-1, 0, 1, 2, null]`)) + require.NoError(t, err) + defer storage.Release() + + arr := array.NewExtensionArrayWithStorage(typ, storage) + defer arr.Release() + + batch := array.NewRecord(arrow.NewSchema([]arrow.Field{{Name: "field", Type: typ, Nullable: true}}, nil), + []arrow.Array{arr}, -1) + defer batch.Release() + + var written arrow.Record + { + var buf bytes.Buffer + wr := ipc.NewWriter(&buf, ipc.WithSchema(batch.Schema())) + require.NoError(t, wr.Write(batch)) + require.NoError(t, wr.Close()) + + rdr, err := ipc.NewReader(&buf) + require.NoError(t, err) + written, err = rdr.Read() + require.NoError(t, err) + written.Retain() + defer written.Release() + rdr.Release() + } + + assert.Truef(t, batch.Schema().Equal(written.Schema()), "expected: %s, got: %s", + batch.Schema(), written.Schema()) + + assert.Truef(t, array.RecordEqual(batch, written), "expected: %s, got: %s", + batch, written) +} + +func BenchmarkWriteBool8Array(b *testing.B) { + bool8bldr := extensions.NewBool8Builder(memory.DefaultAllocator) + defer bool8bldr.Release() + + for sz := MINSIZE; sz < MAXSIZE+1; sz *= 2 { + b.Run(fmt.Sprintf("len %d", sz), func(b *testing.B) { + + values := make([]bool, sz) + for idx := range values { + values[idx] = true + } + + b.ResetTimer() + b.SetBytes(int64(sz)) + for n := 0; n < b.N; n++ { + bool8bldr.AppendValues(values, nil) + bool8bldr.NewArray() + } + }) + } +} + +func BenchmarkWriteBooleanArray(b *testing.B) { + boolbldr := array.NewBooleanBuilder(memory.DefaultAllocator) + defer boolbldr.Release() + + for sz := MINSIZE; sz < MAXSIZE+1; sz *= 2 { + b.Run(fmt.Sprintf("len %d", sz), func(b *testing.B) { + + values := make([]bool, sz) + for idx := range values { + values[idx] = true + } + + b.ResetTimer() + b.SetBytes(int64(len(values))) + for n := 0; n < b.N; n++ { + boolbldr.AppendValues(values, nil) + boolbldr.NewArray() + } + }) + } +} + +// storage benchmark result at package level to prevent compiler from eliminating the function call +var result []bool + +func BenchmarkReadBool8Array(b *testing.B) { + bool8bldr := extensions.NewBool8Builder(memory.DefaultAllocator) + defer bool8bldr.Release() + + for sz := MINSIZE; sz < MAXSIZE+1; sz *= 2 { + b.Run(fmt.Sprintf("len %d", sz), func(b *testing.B) { + + values := make([]bool, sz) + for idx := range values { + values[idx] = true + } + + bool8bldr.AppendValues(values, nil) + bool8Arr := bool8bldr.NewArray().(*extensions.Bool8Array) + defer bool8Arr.Release() + + var r []bool + b.ResetTimer() + b.SetBytes(int64(len(values))) + for n := 0; n < b.N; n++ { + r = bool8Arr.BoolValues() + } + result = r + }) + } +} + +func BenchmarkReadBooleanArray(b *testing.B) { + boolbldr := array.NewBooleanBuilder(memory.DefaultAllocator) + defer boolbldr.Release() + + for sz := MINSIZE; sz < MAXSIZE+1; sz *= 2 { + b.Run(fmt.Sprintf("len %d", sz), func(b *testing.B) { + + values := make([]bool, sz) + output := make([]bool, sz) + for idx := range values { + values[idx] = true + } + + boolbldr.AppendValues(values, nil) + boolArr := boolbldr.NewArray().(*array.Boolean) + defer boolArr.Release() + + b.ResetTimer() + b.SetBytes(int64(len(values))) + for n := 0; n < b.N; n++ { + for i := 0; i < boolArr.Len(); i++ { + output[i] = boolArr.Value(i) + } + } + }) + } +} diff --git a/go/arrow/extensions/extensions_test.go b/go/arrow/extensions/extensions_test.go new file mode 100644 index 0000000000000..f56fed5e132f9 --- /dev/null +++ b/go/arrow/extensions/extensions_test.go @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extensions_test + +import ( + "bytes" + "fmt" + "reflect" + "testing" + + "github.com/apache/arrow/go/v18/arrow" + "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/extensions" + "github.com/apache/arrow/go/v18/arrow/memory" + "github.com/stretchr/testify/require" +) + +// testBool8Type minimally implements arrow.ExtensionType, but importantly does not implement array.CustomExtensionBuilder +// so it will fall back to the storage type's default builder. +type testBool8Type struct { + arrow.ExtensionBase +} + +func newTestBool8Type() *testBool8Type { + return &testBool8Type{ExtensionBase: arrow.ExtensionBase{Storage: arrow.PrimitiveTypes.Int8}} +} + +func (t *testBool8Type) ArrayType() reflect.Type { return reflect.TypeOf(testBool8Array{}) } +func (t *testBool8Type) ExtensionEquals(arrow.ExtensionType) bool { panic("unimplemented") } +func (t *testBool8Type) ExtensionName() string { panic("unimplemented") } +func (t *testBool8Type) Serialize() string { panic("unimplemented") } +func (t *testBool8Type) Deserialize(arrow.DataType, string) (arrow.ExtensionType, error) { + panic("unimplemented") +} + +type testBool8Array struct { + array.ExtensionArrayBase +} + +func TestUnmarshalExtensionTypes(t *testing.T) { + logicalJSON := `[true,null,false,null,true]` + storageJSON := `[1,null,0,null,1]` + + // extensions.Bool8Type implements array.CustomExtensionBuilder so we expect the array to be built with the custom builder + arrCustomBuilder, _, err := array.FromJSON(memory.DefaultAllocator, extensions.NewBool8Type(), bytes.NewBufferString(logicalJSON)) + require.NoError(t, err) + defer arrCustomBuilder.Release() + require.Equal(t, 5, arrCustomBuilder.Len()) + + // testBoolType falls back to the default builder for the storage type, so it cannot deserialize native booleans + _, _, err = array.FromJSON(memory.DefaultAllocator, newTestBool8Type(), bytes.NewBufferString(logicalJSON)) + require.ErrorContains(t, err, "cannot unmarshal true into Go value of type int8") + + // testBoolType must build the array with the native storage type: Int8 + arrDefaultBuilder, _, err := array.FromJSON(memory.DefaultAllocator, newTestBool8Type(), bytes.NewBufferString(storageJSON)) + require.NoError(t, err) + defer arrDefaultBuilder.Release() + require.Equal(t, 5, arrDefaultBuilder.Len()) + + arrBool8, ok := arrCustomBuilder.(*extensions.Bool8Array) + require.True(t, ok) + + arrExt, ok := arrDefaultBuilder.(array.ExtensionArray) + require.True(t, ok) + + // The physical layout of both arrays is identical + require.True(t, array.Equal(arrBool8.Storage(), arrExt.Storage())) +} + +// invalidExtensionType does not fully implement the arrow.ExtensionType interface, even though it embeds arrow.ExtensionBase +type invalidExtensionType struct { + arrow.ExtensionBase +} + +func newInvalidExtensionType() *invalidExtensionType { + return &invalidExtensionType{ExtensionBase: arrow.ExtensionBase{Storage: arrow.BinaryTypes.String}} +} + +func TestInvalidExtensionType(t *testing.T) { + jsonStr := `["one","two","three"]` + typ := newInvalidExtensionType() + + require.PanicsWithError(t, fmt.Sprintf("arrow/array: invalid extension type: %T", typ), func() { + array.FromJSON(memory.DefaultAllocator, typ, bytes.NewBufferString(jsonStr)) + }) +} + +var ( + _ arrow.ExtensionType = (*testBool8Type)(nil) + _ array.ExtensionArray = (*testBool8Array)(nil) +) diff --git a/go/internal/types/extension_types.go b/go/internal/types/extension_types.go index 3c63b36874600..85c64d86bffcb 100644 --- a/go/internal/types/extension_types.go +++ b/go/internal/types/extension_types.go @@ -26,6 +26,7 @@ import ( "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/memory" "github.com/apache/arrow/go/v18/internal/json" "github.com/google/uuid" "golang.org/x/xerrors" @@ -37,8 +38,8 @@ type UUIDBuilder struct { *array.ExtensionBuilder } -func NewUUIDBuilder(builder *array.ExtensionBuilder) *UUIDBuilder { - return &UUIDBuilder{ExtensionBuilder: builder} +func NewUUIDBuilder(mem memory.Allocator) *UUIDBuilder { + return &UUIDBuilder{ExtensionBuilder: array.NewExtensionBuilder(mem, NewUUIDType())} } func (b *UUIDBuilder) Append(v uuid.UUID) { @@ -245,8 +246,8 @@ func (e *UUIDType) ExtensionEquals(other arrow.ExtensionType) bool { return e.ExtensionName() == other.ExtensionName() } -func (*UUIDType) NewBuilder(bldr *array.ExtensionBuilder) array.Builder { - return NewUUIDBuilder(bldr) +func (*UUIDType) NewBuilder(mem memory.Allocator) array.Builder { + return NewUUIDBuilder(mem) } // Parametric1Array is a simple int32 array for use with the Parametric1Type diff --git a/go/internal/types/extension_types_test.go b/go/internal/types/extension_types_test.go index 50abaae3a9e06..65f6353d01be1 100644 --- a/go/internal/types/extension_types_test.go +++ b/go/internal/types/extension_types_test.go @@ -32,12 +32,10 @@ import ( var testUUID = uuid.New() -func TestExtensionBuilder(t *testing.T) { +func TestUUIDExtensionBuilder(t *testing.T) { mem := memory.NewCheckedAllocator(memory.DefaultAllocator) defer mem.AssertSize(t, 0) - extBuilder := array.NewExtensionBuilder(mem, types.NewUUIDType()) - defer extBuilder.Release() - builder := types.NewUUIDBuilder(extBuilder) + builder := types.NewUUIDBuilder(mem) builder.Append(testUUID) arr := builder.NewArray() defer arr.Release() @@ -52,7 +50,7 @@ func TestExtensionBuilder(t *testing.T) { assert.Equal(t, arr, arr1) } -func TestExtensionRecordBuilder(t *testing.T) { +func TestUUIDExtensionRecordBuilder(t *testing.T) { schema := arrow.NewSchema([]arrow.Field{ {Name: "uuid", Type: types.NewUUIDType()}, }, nil) @@ -72,9 +70,7 @@ func TestUUIDStringRoundTrip(t *testing.T) { mem := memory.NewCheckedAllocator(memory.DefaultAllocator) defer mem.AssertSize(t, 0) - extBuilder := array.NewExtensionBuilder(mem, types.NewUUIDType()) - defer extBuilder.Release() - b := types.NewUUIDBuilder(extBuilder) + b := types.NewUUIDBuilder(mem) b.Append(uuid.Nil) b.AppendNull() b.Append(uuid.NameSpaceURL) @@ -85,9 +81,7 @@ func TestUUIDStringRoundTrip(t *testing.T) { defer arr.Release() // 2. create array via AppendValueFromString - extBuilder1 := array.NewExtensionBuilder(mem, types.NewUUIDType()) - defer extBuilder1.Release() - b1 := types.NewUUIDBuilder(extBuilder1) + b1 := types.NewUUIDBuilder(mem) defer b1.Release() for i := 0; i < arr.Len(); i++ { diff --git a/go/parquet/pqarrow/encode_arrow_test.go b/go/parquet/pqarrow/encode_arrow_test.go index 9b3419988d6df..16282173a685c 100644 --- a/go/parquet/pqarrow/encode_arrow_test.go +++ b/go/parquet/pqarrow/encode_arrow_test.go @@ -2053,9 +2053,7 @@ func (ps *ParquetIOTestSuite) TestArrowExtensionTypeRoundTrip() { mem := memory.NewCheckedAllocator(memory.DefaultAllocator) defer mem.AssertSize(ps.T(), 0) - extBuilder := array.NewExtensionBuilder(mem, types.NewUUIDType()) - defer extBuilder.Release() - builder := types.NewUUIDBuilder(extBuilder) + builder := types.NewUUIDBuilder(mem) builder.Append(uuid.New()) arr := builder.NewArray() defer arr.Release()