Skip to content

Commit

Permalink
ARROW-17390: [Go] Add union scalar types (apache#13860)
Browse files Browse the repository at this point in the history
Authored-by: Matt Topol <[email protected]>
Signed-off-by: Matt Topol <[email protected]>
  • Loading branch information
zeroshade authored and ksuarez1423 committed Aug 15, 2022
1 parent a5ea77a commit 15738c6
Show file tree
Hide file tree
Showing 3 changed files with 531 additions and 15 deletions.
192 changes: 192 additions & 0 deletions go/arrow/scalar/nested.go
Original file line number Diff line number Diff line change
Expand Up @@ -520,3 +520,195 @@ func (s *Dictionary) GetEncodedValue() (Scalar, error) {
func (s *Dictionary) value() interface{} {
return s.Value.Index.value()
}

type Union interface {
Scalar
ChildValue() Scalar
Release()
}

type SparseUnion struct {
scalar

TypeCode arrow.UnionTypeCode
Value []Scalar
ChildID int
}

func (s *SparseUnion) equals(rhs Scalar) bool {
right := rhs.(*SparseUnion)
return Equals(s.ChildValue(), right.ChildValue())
}

func (s *SparseUnion) value() interface{} { return s.ChildValue() }

func (s *SparseUnion) String() string {
dt := s.Type.(*arrow.SparseUnionType)
val := s.ChildValue()
return "union{" + dt.Fields()[dt.ChildIDs()[s.TypeCode]].String() + " = " + val.String() + "}"
}

func (s *SparseUnion) Release() {
for _, v := range s.Value {
if v, ok := v.(Releasable); ok {
v.Release()
}
}
}

func (s *SparseUnion) Validate() (err error) {
dt := s.Type.(*arrow.SparseUnionType)
if len(dt.Fields()) != len(s.Value) {
return fmt.Errorf("sparse union scalar value had %d fields but type has %d fields", len(dt.Fields()), len(s.Value))
}

if s.TypeCode < 0 || int(s.TypeCode) >= len(dt.ChildIDs()) || dt.ChildIDs()[s.TypeCode] == arrow.InvalidUnionChildID {
return fmt.Errorf("%s scalar has invalid type code %d", dt, s.TypeCode)
}

for i, f := range dt.Fields() {
v := s.Value[i]
if !arrow.TypeEqual(f.Type, v.DataType()) {
return fmt.Errorf("%s value for field %s had incorrect type of %s", dt, f, v.DataType())
}
if err = v.Validate(); err != nil {
return err
}
}
return
}

func (s *SparseUnion) ValidateFull() (err error) {
dt := s.Type.(*arrow.SparseUnionType)
if len(dt.Fields()) != len(s.Value) {
return fmt.Errorf("sparse union scalar value had %d fields but type has %d fields", len(dt.Fields()), len(s.Value))
}

if s.TypeCode < 0 || int(s.TypeCode) >= len(dt.ChildIDs()) || dt.ChildIDs()[s.TypeCode] == arrow.InvalidUnionChildID {
return fmt.Errorf("%s scalar has invalid type code %d", dt, s.TypeCode)
}

for i, f := range dt.Fields() {
v := s.Value[i]
if !arrow.TypeEqual(f.Type, v.DataType()) {
return fmt.Errorf("%s value for field %s had incorrect type of %s", dt, f, v.DataType())
}
if err = v.ValidateFull(); err != nil {
return err
}
}
return
}

func (s *SparseUnion) CastTo(to arrow.DataType) (Scalar, error) {
if !s.Valid {
return MakeNullScalar(to), nil
}

switch to.ID() {
case arrow.STRING:
return NewStringScalar(s.String()), nil
case arrow.LARGE_STRING:
return NewLargeStringScalar(s.String()), nil
}

return nil, fmt.Errorf("cannot cast non-nil union to type other than string")
}

func (s *SparseUnion) ChildValue() Scalar { return s.Value[s.ChildID] }

func NewSparseUnionScalar(val []Scalar, code arrow.UnionTypeCode, dt *arrow.SparseUnionType) *SparseUnion {
ret := &SparseUnion{
scalar: scalar{dt, true},
TypeCode: code,
Value: val,
ChildID: dt.ChildIDs()[code],
}
ret.Valid = ret.Value[ret.ChildID].IsValid()
return ret
}

func NewSparseUnionScalarFromValue(val Scalar, idx int, dt *arrow.SparseUnionType) *SparseUnion {
code := dt.TypeCodes()[idx]
values := make([]Scalar, len(dt.Fields()))
for i, f := range dt.Fields() {
if i == idx {
values[i] = val
} else {
values[i] = MakeNullScalar(f.Type)
}
}
return NewSparseUnionScalar(values, code, dt)
}

type DenseUnion struct {
scalar

TypeCode arrow.UnionTypeCode
Value Scalar
}

func (s *DenseUnion) equals(rhs Scalar) bool {
right := rhs.(*DenseUnion)
return Equals(s.Value, right.Value)
}

func (s *DenseUnion) value() interface{} { return s.ChildValue() }

func (s *DenseUnion) String() string {
dt := s.Type.(*arrow.DenseUnionType)
return "union{" + dt.Fields()[dt.ChildIDs()[s.TypeCode]].String() + " = " + s.Value.String() + "}"
}

func (s *DenseUnion) Release() {
if v, ok := s.Value.(Releasable); ok {
v.Release()
}
}

func (s *DenseUnion) Validate() (err error) {
dt := s.Type.(*arrow.DenseUnionType)
if s.TypeCode < 0 || int(s.TypeCode) >= len(dt.ChildIDs()) || dt.ChildIDs()[s.TypeCode] == arrow.InvalidUnionChildID {
return fmt.Errorf("%s scalar has invalid type code %d", dt, s.TypeCode)
}
fieldType := dt.Fields()[dt.ChildIDs()[s.TypeCode]].Type
if !arrow.TypeEqual(fieldType, s.Value.DataType()) {
return fmt.Errorf("%s scalar with type code %d should have an underlying value of type %s, got %s",
s.Type, s.TypeCode, fieldType, s.Value.DataType())
}
return s.Value.Validate()
}

func (s *DenseUnion) ValidateFull() error {
dt := s.Type.(*arrow.DenseUnionType)
if s.TypeCode < 0 || int(s.TypeCode) >= len(dt.ChildIDs()) || dt.ChildIDs()[s.TypeCode] == arrow.InvalidUnionChildID {
return fmt.Errorf("%s scalar has invalid type code %d", dt, s.TypeCode)
}
fieldType := dt.Fields()[dt.ChildIDs()[s.TypeCode]].Type
if !arrow.TypeEqual(fieldType, s.Value.DataType()) {
return fmt.Errorf("%s scalar with type code %d should have an underlying value of type %s, got %s",
s.Type, s.TypeCode, fieldType, s.Value.DataType())
}
return s.Value.ValidateFull()
}

func (s *DenseUnion) CastTo(to arrow.DataType) (Scalar, error) {
if !s.Valid {
return MakeNullScalar(to), nil
}

switch to.ID() {
case arrow.STRING:
return NewStringScalar(s.String()), nil
case arrow.LARGE_STRING:
return NewLargeStringScalar(s.String()), nil
}

return nil, fmt.Errorf("cannot cast non-nil union to type other than string")
}

func (s *DenseUnion) ChildValue() Scalar { return s.Value }

func NewDenseUnionScalar(v Scalar, code arrow.UnionTypeCode, dt *arrow.DenseUnionType) *DenseUnion {
return &DenseUnion{scalar: scalar{dt, v.IsValid()}, TypeCode: code, Value: v}
}
81 changes: 66 additions & 15 deletions go/arrow/scalar/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -466,10 +466,6 @@ func MakeNullScalar(dt arrow.DataType) Scalar {
return makeNullFn[byte(dt.ID()&0x3f)](dt)
}

func unsupportedScalarType(dt arrow.DataType) Scalar {
panic("unsupported scalar data type: " + dt.ID().String())
}

func invalidScalarType(dt arrow.DataType) Scalar {
panic("invalid scalar type: " + dt.ID().String())
}
Expand Down Expand Up @@ -516,17 +512,33 @@ func init() {
arrow.DECIMAL128: func(dt arrow.DataType) Scalar { return &Decimal128{scalar: scalar{dt, false}} },
arrow.LIST: func(dt arrow.DataType) Scalar { return &List{scalar: scalar{dt, false}} },
arrow.STRUCT: func(dt arrow.DataType) Scalar { return &Struct{scalar: scalar{dt, false}} },
arrow.SPARSE_UNION: unsupportedScalarType,
arrow.DENSE_UNION: unsupportedScalarType,
arrow.DICTIONARY: func(dt arrow.DataType) Scalar { return NewNullDictScalar(dt) },
arrow.LARGE_STRING: func(dt arrow.DataType) Scalar { return &LargeString{&String{&Binary{scalar: scalar{dt, false}}}} },
arrow.LARGE_BINARY: func(dt arrow.DataType) Scalar { return &LargeBinary{&Binary{scalar: scalar{dt, false}}} },
arrow.LARGE_LIST: func(dt arrow.DataType) Scalar { return &LargeList{&List{scalar: scalar{dt, false}}} },
arrow.DECIMAL256: func(dt arrow.DataType) Scalar { return &Decimal256{scalar: scalar{dt, false}} },
arrow.MAP: func(dt arrow.DataType) Scalar { return &Map{&List{scalar: scalar{dt, false}}} },
arrow.EXTENSION: func(dt arrow.DataType) Scalar { return &Extension{scalar: scalar{dt, false}} },
arrow.FIXED_SIZE_LIST: func(dt arrow.DataType) Scalar { return &FixedSizeList{&List{scalar: scalar{dt, false}}} },
arrow.DURATION: func(dt arrow.DataType) Scalar { return &Duration{scalar: scalar{dt, false}} },
arrow.SPARSE_UNION: func(dt arrow.DataType) Scalar {
typ := dt.(*arrow.SparseUnionType)
if len(typ.Fields()) == 0 {
panic("cannot make scalar of empty union type")
}
values := make([]Scalar, len(typ.Fields()))
for i, f := range typ.Fields() {
values[i] = MakeNullScalar(f.Type)
}
return NewSparseUnionScalar(values, typ.TypeCodes()[0], typ)
},
arrow.DENSE_UNION: func(dt arrow.DataType) Scalar {
typ := dt.(*arrow.DenseUnionType)
if len(typ.Fields()) == 0 {
panic("cannot make scalar of empty union type")
}
return NewDenseUnionScalar(MakeNullScalar(typ.Fields()[0].Type), typ.TypeCodes()[0], typ)
},
arrow.DICTIONARY: func(dt arrow.DataType) Scalar { return NewNullDictScalar(dt) },
arrow.LARGE_STRING: func(dt arrow.DataType) Scalar { return &LargeString{&String{&Binary{scalar: scalar{dt, false}}}} },
arrow.LARGE_BINARY: func(dt arrow.DataType) Scalar { return &LargeBinary{&Binary{scalar: scalar{dt, false}}} },
arrow.LARGE_LIST: func(dt arrow.DataType) Scalar { return &LargeList{&List{scalar: scalar{dt, false}}} },
arrow.DECIMAL256: func(dt arrow.DataType) Scalar { return &Decimal256{scalar: scalar{dt, false}} },
arrow.MAP: func(dt arrow.DataType) Scalar { return &Map{&List{scalar: scalar{dt, false}}} },
arrow.EXTENSION: func(dt arrow.DataType) Scalar { return &Extension{scalar: scalar{dt, false}} },
arrow.FIXED_SIZE_LIST: func(dt arrow.DataType) Scalar { return &FixedSizeList{&List{scalar: scalar{dt, false}}} },
arrow.DURATION: func(dt arrow.DataType) Scalar { return &Duration{scalar: scalar{dt, false}} },
// invalid data types to fill out array size 2^6 - 1
63: invalidScalarType,
}
Expand Down Expand Up @@ -646,6 +658,39 @@ func GetScalar(arr arrow.Array, idx int) (Scalar, error) {
scalar.Value.Dict = arr.Dictionary()
scalar.Value.Dict.Retain()
return scalar, nil
case *array.SparseUnion:
var err error
typeCode := arr.TypeCode(idx)
children := make([]Scalar, arr.NumFields())
defer func() {
if err != nil {
for _, c := range children {
if c == nil {
break
}

if v, ok := c.(Releasable); ok {
v.Release()
}
}
}
}()

for i := range arr.UnionType().Fields() {
if children[i], err = GetScalar(arr.Field(i), idx); err != nil {
return nil, err
}
}
return NewSparseUnionScalar(children, typeCode, arr.UnionType().(*arrow.SparseUnionType)), nil
case *array.DenseUnion:
typeCode := arr.TypeCode(idx)
child := arr.Field(arr.ChildID(idx))
offset := arr.ValueOffset(idx)
value, err := GetScalar(child, int(offset))
if err != nil {
return nil, err
}
return NewDenseUnionScalar(value, typeCode, arr.UnionType().(*arrow.DenseUnionType)), nil
}

return nil, fmt.Errorf("cannot create scalar from array of type %s", arr.DataType())
Expand Down Expand Up @@ -902,6 +947,12 @@ func Hash(seed maphash.Seed, s Scalar) uint64 {
return valueHash(s.Value.Days) & valueHash(s.Value.Milliseconds)
case *MonthDayNanoInterval:
return valueHash(s.Value.Months) & valueHash(s.Value.Days) & valueHash(s.Value.Nanoseconds)
case *SparseUnion:
// typecode is ignored when comparing for equality, so don't hash it either
out ^= Hash(seed, s.Value[s.ChildID])
case *DenseUnion:
// typecode is ignored when comparing equality, so don't hash it either
out ^= Hash(seed, s.Value)
case PrimitiveScalar:
h.Write(s.Data())
hash()
Expand Down
Loading

0 comments on commit 15738c6

Please sign in to comment.