From 60902401d0176f33416c9dfcbfd8b192f32fa1a8 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Thu, 11 Aug 2022 16:24:02 -0400 Subject: [PATCH] ARROW-17390: [Go] Add union scalar types --- go/arrow/scalar/nested.go | 192 +++++++++++++++++++++++ go/arrow/scalar/scalar.go | 81 ++++++++-- go/arrow/scalar/scalar_test.go | 273 +++++++++++++++++++++++++++++++++ 3 files changed, 531 insertions(+), 15 deletions(-) diff --git a/go/arrow/scalar/nested.go b/go/arrow/scalar/nested.go index 756e383f5a7b6..2d106e50711d3 100644 --- a/go/arrow/scalar/nested.go +++ b/go/arrow/scalar/nested.go @@ -520,3 +520,195 @@ func (s *Dictionary) GetEncodedValue() (Scalar, error) { func (s *Dictionary) value() interface{} { return s.Value.Index.value() } + +type Union interface { + Scalar + ChildValue() Scalar + Release() +} + +type SparseUnion struct { + scalar + + TypeCode arrow.UnionTypeCode + Value []Scalar + ChildID int +} + +func (s *SparseUnion) equals(rhs Scalar) bool { + right := rhs.(*SparseUnion) + return Equals(s.ChildValue(), right.ChildValue()) +} + +func (s *SparseUnion) value() interface{} { return s.ChildValue() } + +func (s *SparseUnion) String() string { + dt := s.Type.(*arrow.SparseUnionType) + val := s.ChildValue() + return "union{" + dt.Fields()[dt.ChildIDs()[s.TypeCode]].String() + " = " + val.String() + "}" +} + +func (s *SparseUnion) Release() { + for _, v := range s.Value { + if v, ok := v.(Releasable); ok { + v.Release() + } + } +} + +func (s *SparseUnion) Validate() (err error) { + dt := s.Type.(*arrow.SparseUnionType) + if len(dt.Fields()) != len(s.Value) { + return fmt.Errorf("sparse union scalar value had %d fields but type has %d fields", len(dt.Fields()), len(s.Value)) + } + + if s.TypeCode < 0 || int(s.TypeCode) >= len(dt.ChildIDs()) || dt.ChildIDs()[s.TypeCode] == arrow.InvalidUnionChildID { + return fmt.Errorf("%s scalar has invalid type code %d", dt, s.TypeCode) + } + + for i, f := range dt.Fields() { + v := s.Value[i] + if !arrow.TypeEqual(f.Type, v.DataType()) { + return fmt.Errorf("%s value for field %s had incorrect type of %s", dt, f, v.DataType()) + } + if err = v.Validate(); err != nil { + return err + } + } + return +} + +func (s *SparseUnion) ValidateFull() (err error) { + dt := s.Type.(*arrow.SparseUnionType) + if len(dt.Fields()) != len(s.Value) { + return fmt.Errorf("sparse union scalar value had %d fields but type has %d fields", len(dt.Fields()), len(s.Value)) + } + + if s.TypeCode < 0 || int(s.TypeCode) >= len(dt.ChildIDs()) || dt.ChildIDs()[s.TypeCode] == arrow.InvalidUnionChildID { + return fmt.Errorf("%s scalar has invalid type code %d", dt, s.TypeCode) + } + + for i, f := range dt.Fields() { + v := s.Value[i] + if !arrow.TypeEqual(f.Type, v.DataType()) { + return fmt.Errorf("%s value for field %s had incorrect type of %s", dt, f, v.DataType()) + } + if err = v.ValidateFull(); err != nil { + return err + } + } + return +} + +func (s *SparseUnion) CastTo(to arrow.DataType) (Scalar, error) { + if !s.Valid { + return MakeNullScalar(to), nil + } + + switch to.ID() { + case arrow.STRING: + return NewStringScalar(s.String()), nil + case arrow.LARGE_STRING: + return NewLargeStringScalar(s.String()), nil + } + + return nil, fmt.Errorf("cannot cast non-nil union to type other than string") +} + +func (s *SparseUnion) ChildValue() Scalar { return s.Value[s.ChildID] } + +func NewSparseUnionScalar(val []Scalar, code arrow.UnionTypeCode, dt *arrow.SparseUnionType) *SparseUnion { + ret := &SparseUnion{ + scalar: scalar{dt, true}, + TypeCode: code, + Value: val, + ChildID: dt.ChildIDs()[code], + } + ret.Valid = ret.Value[ret.ChildID].IsValid() + return ret +} + +func NewSparseUnionScalarFromValue(val Scalar, idx int, dt *arrow.SparseUnionType) *SparseUnion { + code := dt.TypeCodes()[idx] + values := make([]Scalar, len(dt.Fields())) + for i, f := range dt.Fields() { + if i == idx { + values[i] = val + } else { + values[i] = MakeNullScalar(f.Type) + } + } + return NewSparseUnionScalar(values, code, dt) +} + +type DenseUnion struct { + scalar + + TypeCode arrow.UnionTypeCode + Value Scalar +} + +func (s *DenseUnion) equals(rhs Scalar) bool { + right := rhs.(*DenseUnion) + return Equals(s.Value, right.Value) +} + +func (s *DenseUnion) value() interface{} { return s.ChildValue() } + +func (s *DenseUnion) String() string { + dt := s.Type.(*arrow.DenseUnionType) + return "union{" + dt.Fields()[dt.ChildIDs()[s.TypeCode]].String() + " = " + s.Value.String() + "}" +} + +func (s *DenseUnion) Release() { + if v, ok := s.Value.(Releasable); ok { + v.Release() + } +} + +func (s *DenseUnion) Validate() (err error) { + dt := s.Type.(*arrow.DenseUnionType) + if s.TypeCode < 0 || int(s.TypeCode) >= len(dt.ChildIDs()) || dt.ChildIDs()[s.TypeCode] == arrow.InvalidUnionChildID { + return fmt.Errorf("%s scalar has invalid type code %d", dt, s.TypeCode) + } + fieldType := dt.Fields()[dt.ChildIDs()[s.TypeCode]].Type + if !arrow.TypeEqual(fieldType, s.Value.DataType()) { + return fmt.Errorf("%s scalar with type code %d should have an underlying value of type %s, got %s", + s.Type, s.TypeCode, fieldType, s.Value.DataType()) + } + return s.Value.Validate() +} + +func (s *DenseUnion) ValidateFull() error { + dt := s.Type.(*arrow.DenseUnionType) + if s.TypeCode < 0 || int(s.TypeCode) >= len(dt.ChildIDs()) || dt.ChildIDs()[s.TypeCode] == arrow.InvalidUnionChildID { + return fmt.Errorf("%s scalar has invalid type code %d", dt, s.TypeCode) + } + fieldType := dt.Fields()[dt.ChildIDs()[s.TypeCode]].Type + if !arrow.TypeEqual(fieldType, s.Value.DataType()) { + return fmt.Errorf("%s scalar with type code %d should have an underlying value of type %s, got %s", + s.Type, s.TypeCode, fieldType, s.Value.DataType()) + } + return s.Value.ValidateFull() +} + +func (s *DenseUnion) CastTo(to arrow.DataType) (Scalar, error) { + if !s.Valid { + return MakeNullScalar(to), nil + } + + switch to.ID() { + case arrow.STRING: + return NewStringScalar(s.String()), nil + case arrow.LARGE_STRING: + return NewLargeStringScalar(s.String()), nil + } + + return nil, fmt.Errorf("cannot cast non-nil union to type other than string") +} + +func (s *DenseUnion) ChildValue() Scalar { return s.Value } + +func NewDenseUnionScalar(v Scalar, code arrow.UnionTypeCode, dt *arrow.DenseUnionType) *DenseUnion { + return &DenseUnion{scalar: scalar{dt, v.IsValid()}, TypeCode: code, Value: v} +} diff --git a/go/arrow/scalar/scalar.go b/go/arrow/scalar/scalar.go index 5edc98584b5b3..7ae8b03473480 100644 --- a/go/arrow/scalar/scalar.go +++ b/go/arrow/scalar/scalar.go @@ -466,10 +466,6 @@ func MakeNullScalar(dt arrow.DataType) Scalar { return makeNullFn[byte(dt.ID()&0x3f)](dt) } -func unsupportedScalarType(dt arrow.DataType) Scalar { - panic("unsupported scalar data type: " + dt.ID().String()) -} - func invalidScalarType(dt arrow.DataType) Scalar { panic("invalid scalar type: " + dt.ID().String()) } @@ -516,17 +512,33 @@ func init() { arrow.DECIMAL128: func(dt arrow.DataType) Scalar { return &Decimal128{scalar: scalar{dt, false}} }, arrow.LIST: func(dt arrow.DataType) Scalar { return &List{scalar: scalar{dt, false}} }, arrow.STRUCT: func(dt arrow.DataType) Scalar { return &Struct{scalar: scalar{dt, false}} }, - arrow.SPARSE_UNION: unsupportedScalarType, - arrow.DENSE_UNION: unsupportedScalarType, - arrow.DICTIONARY: func(dt arrow.DataType) Scalar { return NewNullDictScalar(dt) }, - arrow.LARGE_STRING: func(dt arrow.DataType) Scalar { return &LargeString{&String{&Binary{scalar: scalar{dt, false}}}} }, - arrow.LARGE_BINARY: func(dt arrow.DataType) Scalar { return &LargeBinary{&Binary{scalar: scalar{dt, false}}} }, - arrow.LARGE_LIST: func(dt arrow.DataType) Scalar { return &LargeList{&List{scalar: scalar{dt, false}}} }, - arrow.DECIMAL256: func(dt arrow.DataType) Scalar { return &Decimal256{scalar: scalar{dt, false}} }, - arrow.MAP: func(dt arrow.DataType) Scalar { return &Map{&List{scalar: scalar{dt, false}}} }, - arrow.EXTENSION: func(dt arrow.DataType) Scalar { return &Extension{scalar: scalar{dt, false}} }, - arrow.FIXED_SIZE_LIST: func(dt arrow.DataType) Scalar { return &FixedSizeList{&List{scalar: scalar{dt, false}}} }, - arrow.DURATION: func(dt arrow.DataType) Scalar { return &Duration{scalar: scalar{dt, false}} }, + arrow.SPARSE_UNION: func(dt arrow.DataType) Scalar { + typ := dt.(*arrow.SparseUnionType) + if len(typ.Fields()) == 0 { + panic("cannot make scalar of empty union type") + } + values := make([]Scalar, len(typ.Fields())) + for i, f := range typ.Fields() { + values[i] = MakeNullScalar(f.Type) + } + return NewSparseUnionScalar(values, typ.TypeCodes()[0], typ) + }, + arrow.DENSE_UNION: func(dt arrow.DataType) Scalar { + typ := dt.(*arrow.DenseUnionType) + if len(typ.Fields()) == 0 { + panic("cannot make scalar of empty union type") + } + return NewDenseUnionScalar(MakeNullScalar(typ.Fields()[0].Type), typ.TypeCodes()[0], typ) + }, + arrow.DICTIONARY: func(dt arrow.DataType) Scalar { return NewNullDictScalar(dt) }, + arrow.LARGE_STRING: func(dt arrow.DataType) Scalar { return &LargeString{&String{&Binary{scalar: scalar{dt, false}}}} }, + arrow.LARGE_BINARY: func(dt arrow.DataType) Scalar { return &LargeBinary{&Binary{scalar: scalar{dt, false}}} }, + arrow.LARGE_LIST: func(dt arrow.DataType) Scalar { return &LargeList{&List{scalar: scalar{dt, false}}} }, + arrow.DECIMAL256: func(dt arrow.DataType) Scalar { return &Decimal256{scalar: scalar{dt, false}} }, + arrow.MAP: func(dt arrow.DataType) Scalar { return &Map{&List{scalar: scalar{dt, false}}} }, + arrow.EXTENSION: func(dt arrow.DataType) Scalar { return &Extension{scalar: scalar{dt, false}} }, + arrow.FIXED_SIZE_LIST: func(dt arrow.DataType) Scalar { return &FixedSizeList{&List{scalar: scalar{dt, false}}} }, + arrow.DURATION: func(dt arrow.DataType) Scalar { return &Duration{scalar: scalar{dt, false}} }, // invalid data types to fill out array size 2^6 - 1 63: invalidScalarType, } @@ -646,6 +658,39 @@ func GetScalar(arr arrow.Array, idx int) (Scalar, error) { scalar.Value.Dict = arr.Dictionary() scalar.Value.Dict.Retain() return scalar, nil + case *array.SparseUnion: + var err error + typeCode := arr.TypeCode(idx) + children := make([]Scalar, arr.NumFields()) + defer func() { + if err != nil { + for _, c := range children { + if c == nil { + break + } + + if v, ok := c.(Releasable); ok { + v.Release() + } + } + } + }() + + for i := range arr.UnionType().Fields() { + if children[i], err = GetScalar(arr.Field(i), idx); err != nil { + return nil, err + } + } + return NewSparseUnionScalar(children, typeCode, arr.UnionType().(*arrow.SparseUnionType)), nil + case *array.DenseUnion: + typeCode := arr.TypeCode(idx) + child := arr.Field(arr.ChildID(idx)) + offset := arr.ValueOffset(idx) + value, err := GetScalar(child, int(offset)) + if err != nil { + return nil, err + } + return NewDenseUnionScalar(value, typeCode, arr.UnionType().(*arrow.DenseUnionType)), nil } return nil, fmt.Errorf("cannot create scalar from array of type %s", arr.DataType()) @@ -902,6 +947,12 @@ func Hash(seed maphash.Seed, s Scalar) uint64 { return valueHash(s.Value.Days) & valueHash(s.Value.Milliseconds) case *MonthDayNanoInterval: return valueHash(s.Value.Months) & valueHash(s.Value.Days) & valueHash(s.Value.Nanoseconds) + case *SparseUnion: + // typecode is ignored when comparing for equality, so don't hash it either + out ^= Hash(seed, s.Value[s.ChildID]) + case *DenseUnion: + // typecode is ignored when comparing equality, so don't hash it either + out ^= Hash(seed, s.Value) case PrimitiveScalar: h.Write(s.Data()) hash() diff --git a/go/arrow/scalar/scalar_test.go b/go/arrow/scalar/scalar_test.go index 22f3bee20cb91..7b05cf456887a 100644 --- a/go/arrow/scalar/scalar_test.go +++ b/go/arrow/scalar/scalar_test.go @@ -1143,3 +1143,276 @@ func TestDictionaryScalarValidateErrors(t *testing.T) { assert.Error(t, invalid.ValidateFull()) } } + +func checkGetValidUnionScalar(t *testing.T, arr arrow.Array, idx int, expected, expectedValue scalar.Scalar) { + s, err := scalar.GetScalar(arr, idx) + assert.NoError(t, err) + assert.NoError(t, s.ValidateFull()) + assert.True(t, scalar.Equals(expected, s)) + + assert.True(t, s.IsValid()) + assert.True(t, scalar.Equals(s.(scalar.Union).ChildValue(), expectedValue), s, expectedValue) +} + +func checkGetNullUnionScalar(t *testing.T, arr arrow.Array, idx int) { + s, err := scalar.GetScalar(arr, idx) + assert.NoError(t, err) + assert.True(t, scalar.Equals(scalar.MakeNullScalar(arr.DataType()), s)) + assert.False(t, s.IsValid()) + assert.False(t, s.(scalar.Union).ChildValue().IsValid()) +} + +func makeSparseUnionScalar(ty *arrow.SparseUnionType, val scalar.Scalar, idx int) scalar.Scalar { + return scalar.NewSparseUnionScalarFromValue(val, idx, ty) +} + +func makeDenseUnionScalar(ty *arrow.DenseUnionType, val scalar.Scalar, idx int) scalar.Scalar { + return scalar.NewDenseUnionScalar(val, ty.TypeCodes()[idx], ty) +} + +func makeSpecificNullScalar(dt arrow.UnionType, idx int) scalar.Scalar { + switch dt.Mode() { + case arrow.SparseMode: + values := make([]scalar.Scalar, len(dt.Fields())) + for i, f := range dt.Fields() { + values[i] = scalar.MakeNullScalar(f.Type) + } + return scalar.NewSparseUnionScalar(values, dt.TypeCodes()[idx], dt.(*arrow.SparseUnionType)) + case arrow.DenseMode: + code := dt.TypeCodes()[idx] + value := scalar.MakeNullScalar(dt.Fields()[idx].Type) + return scalar.NewDenseUnionScalar(value, code, dt.(*arrow.DenseUnionType)) + } + return nil +} + +type UnionScalarSuite struct { + suite.Suite + + mode arrow.UnionMode + dt arrow.DataType + unionType arrow.UnionType + alpha, beta, two, three scalar.Scalar + unionAlpha, unionBeta, unionTwo, unionThree scalar.Scalar + unionOtherTwo, unionStringNull, unionNumberNull scalar.Scalar +} + +func (s *UnionScalarSuite) scalarFromValue(idx int, val scalar.Scalar) scalar.Scalar { + switch s.mode { + case arrow.SparseMode: + return makeSparseUnionScalar(s.dt.(*arrow.SparseUnionType), val, idx) + case arrow.DenseMode: + return makeDenseUnionScalar(s.dt.(*arrow.DenseUnionType), val, idx) + } + return nil +} + +func (s *UnionScalarSuite) specificNull(idx int) scalar.Scalar { + return makeSpecificNullScalar(s.unionType, idx) +} + +func (s *UnionScalarSuite) SetupTest() { + s.dt = arrow.UnionOf(s.mode, []arrow.Field{ + {Name: "string", Type: arrow.BinaryTypes.String, Nullable: true}, + {Name: "number", Type: arrow.PrimitiveTypes.Uint64, Nullable: true}, + {Name: "other_number", Type: arrow.PrimitiveTypes.Uint64, Nullable: true}, + }, []arrow.UnionTypeCode{3, 42, 43}) + + s.unionType = s.dt.(arrow.UnionType) + + s.alpha = scalar.MakeScalar("alpha") + s.beta = scalar.MakeScalar("beta") + s.two = scalar.MakeScalar(uint64(2)) + s.three = scalar.MakeScalar(uint64(3)) + + s.unionAlpha = s.scalarFromValue(0, s.alpha) + s.unionBeta = s.scalarFromValue(0, s.beta) + s.unionTwo = s.scalarFromValue(1, s.two) + s.unionOtherTwo = s.scalarFromValue(2, s.two) + s.unionThree = s.scalarFromValue(1, s.three) + s.unionStringNull = s.specificNull(0) + s.unionNumberNull = s.specificNull(1) +} + +func (s *UnionScalarSuite) TestValidate() { + s.NoError(s.unionAlpha.ValidateFull()) + s.NoError(s.unionAlpha.Validate()) + s.NoError(s.unionBeta.ValidateFull()) + s.NoError(s.unionBeta.Validate()) + s.NoError(s.unionTwo.ValidateFull()) + s.NoError(s.unionTwo.Validate()) + s.NoError(s.unionOtherTwo.ValidateFull()) + s.NoError(s.unionOtherTwo.Validate()) + s.NoError(s.unionThree.ValidateFull()) + s.NoError(s.unionThree.Validate()) + s.NoError(s.unionStringNull.ValidateFull()) + s.NoError(s.unionStringNull.Validate()) + s.NoError(s.unionNumberNull.ValidateFull()) + s.NoError(s.unionNumberNull.Validate()) +} + +func (s *UnionScalarSuite) setTypeCode(sc scalar.Scalar, c arrow.UnionTypeCode) { + switch sc := sc.(type) { + case *scalar.SparseUnion: + sc.TypeCode = c + case *scalar.DenseUnion: + sc.TypeCode = c + } +} + +func (s *UnionScalarSuite) setIsValid(sc scalar.Scalar, v bool) { + switch sc := sc.(type) { + case *scalar.SparseUnion: + sc.Valid = v + case *scalar.DenseUnion: + sc.Valid = v + } +} + +func (s *UnionScalarSuite) TestValidateErrors() { + // type code doesn't exist + sc := s.scalarFromValue(0, s.alpha) + + // invalid type code + s.setTypeCode(sc, 0) + s.Error(sc.Validate()) + s.Error(sc.ValidateFull()) + + s.setIsValid(sc, false) + s.Error(sc.Validate()) + s.Error(sc.ValidateFull()) + + s.setTypeCode(sc, -42) + s.setIsValid(sc, true) + s.Error(sc.Validate()) + s.Error(sc.ValidateFull()) + + s.setIsValid(sc, false) + s.Error(sc.Validate()) + s.Error(sc.ValidateFull()) + + // type code doesn't correspond to child type + if sc, ok := sc.(*scalar.DenseUnion); ok { + sc.TypeCode = 42 + sc.Valid = true + s.Error(sc.Validate()) + s.Error(sc.ValidateFull()) + + sc = s.scalarFromValue(2, s.two).(*scalar.DenseUnion) + sc.TypeCode = 3 + s.Error(sc.Validate()) + s.Error(sc.ValidateFull()) + } + + // underlying value has invalid utf8 + sc = s.scalarFromValue(0, scalar.NewStringScalar("\xff")) + s.NoError(sc.Validate()) + s.Error(sc.ValidateFull()) +} + +func (s *UnionScalarSuite) TestEquals() { + // differing values + s.False(scalar.Equals(s.unionAlpha, s.unionBeta)) + s.False(scalar.Equals(s.unionTwo, s.unionThree)) + // differing validities + s.False(scalar.Equals(s.unionAlpha, s.unionStringNull)) + // differing types + s.False(scalar.Equals(s.unionAlpha, s.unionTwo)) + s.False(scalar.Equals(s.unionAlpha, s.unionOtherTwo)) + // type codes don't count when comparing union scalars: the underlying + // values are identical even though their provenance is different + s.True(scalar.Equals(s.unionTwo, s.unionOtherTwo)) + s.True(scalar.Equals(s.unionStringNull, s.unionNumberNull)) +} + +func (s *UnionScalarSuite) TestMakeNullScalar() { + sc := scalar.MakeNullScalar(s.dt) + s.True(arrow.TypeEqual(s.dt, sc.DataType())) + s.False(sc.IsValid()) + + // the first child field is chosen arbitrarily for the purposes of + // making a null scalar + switch s.mode { + case arrow.DenseMode: + asDense := sc.(*scalar.DenseUnion) + s.EqualValues(3, asDense.TypeCode) + s.False(asDense.Value.IsValid()) + case arrow.SparseMode: + asSparse := sc.(*scalar.SparseUnion) + s.EqualValues(3, asSparse.TypeCode) + s.False(asSparse.Value[asSparse.ChildID].IsValid()) + } +} + +type SparseUnionSuite struct { + UnionScalarSuite +} + +func (s *SparseUnionSuite) SetupSuite() { + s.mode = arrow.SparseMode +} + +func (s *SparseUnionSuite) TestGetScalar() { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(s.T(), 0) + + children := make([]arrow.Array, 3) + children[0], _, _ = array.FromJSON(mem, arrow.BinaryTypes.String, strings.NewReader(`["alpha", "", "beta", null, "gamma"]`)) + defer children[0].Release() + children[1], _, _ = array.FromJSON(mem, arrow.PrimitiveTypes.Uint64, strings.NewReader(`[1, 2, 11, 22, null]`)) + defer children[1].Release() + children[2], _, _ = array.FromJSON(mem, arrow.PrimitiveTypes.Uint64, strings.NewReader(`[100, 101, 102, 103, 104]`)) + defer children[2].Release() + + typeIDs, _, _ := array.FromJSON(mem, arrow.PrimitiveTypes.Int8, strings.NewReader(`[3, 42, 3, 3, 42]`)) + defer typeIDs.Release() + + arr := array.NewSparseUnion(s.dt.(*arrow.SparseUnionType), 5, children, typeIDs.Data().Buffers()[1], 0) + defer arr.Release() + + checkGetValidUnionScalar(s.T(), arr, 0, s.unionAlpha, s.alpha) + checkGetValidUnionScalar(s.T(), arr, 1, s.unionTwo, s.two) + checkGetValidUnionScalar(s.T(), arr, 2, s.unionBeta, s.beta) + checkGetNullUnionScalar(s.T(), arr, 3) + checkGetNullUnionScalar(s.T(), arr, 4) +} + +type DenseUnionSuite struct { + UnionScalarSuite +} + +func (s *DenseUnionSuite) SetupSuite() { + s.mode = arrow.DenseMode +} + +func (s *DenseUnionSuite) TestGetScalar() { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(s.T(), 0) + + children := make([]arrow.Array, 3) + children[0], _, _ = array.FromJSON(mem, arrow.BinaryTypes.String, strings.NewReader(`["alpha", "beta", null]`)) + defer children[0].Release() + children[1], _, _ = array.FromJSON(mem, arrow.PrimitiveTypes.Uint64, strings.NewReader(`[2, 3]`)) + defer children[1].Release() + children[2], _, _ = array.FromJSON(mem, arrow.PrimitiveTypes.Uint64, strings.NewReader(`[]`)) + defer children[2].Release() + + typeIDs, _, _ := array.FromJSON(mem, arrow.PrimitiveTypes.Int8, strings.NewReader(`[3, 42, 3, 3, 42]`)) + defer typeIDs.Release() + offsets, _, _ := array.FromJSON(mem, arrow.PrimitiveTypes.Int32, strings.NewReader(`[0, 0, 1, 2, 1]`)) + defer offsets.Release() + + arr := array.NewDenseUnion(s.dt.(*arrow.DenseUnionType), 5, children, typeIDs.Data().Buffers()[1], offsets.Data().Buffers()[1], 0) + defer arr.Release() + + checkGetValidUnionScalar(s.T(), arr, 0, s.unionAlpha, s.alpha) + checkGetValidUnionScalar(s.T(), arr, 1, s.unionTwo, s.two) + checkGetValidUnionScalar(s.T(), arr, 2, s.unionBeta, s.beta) + checkGetNullUnionScalar(s.T(), arr, 3) + checkGetValidUnionScalar(s.T(), arr, 4, s.unionThree, s.three) +} + +func TestUnionScalars(t *testing.T) { + suite.Run(t, new(SparseUnionSuite)) + suite.Run(t, new(DenseUnionSuite)) +}