Skip to content

Commit

Permalink
fix: Handles null values in data during GO Feature retrieval (feast-d…
Browse files Browse the repository at this point in the history
…ev#4274)

* fix: Handles null values in data during GO Feature retrieval

Signed-off-by: Bhargav Dodla <[email protected]>

* fix: Fixed formatting issues

Signed-off-by: Bhargav Dodla <[email protected]>

* fix: Fixed linting issues

Signed-off-by: Bhargav Dodla <[email protected]>

---------

Signed-off-by: Bhargav Dodla <[email protected]>
Co-authored-by: Bhargav Dodla <[email protected]>
  • Loading branch information
EXPEbdodla and Bhargav Dodla authored Jun 13, 2024
1 parent 2cdaa4a commit c491e57
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 89 deletions.
146 changes: 84 additions & 62 deletions go/types/typeconversion.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ import (
)

func ProtoTypeToArrowType(sample *types.Value) (arrow.DataType, error) {
if sample.Val == nil {
return nil, nil
}
switch sample.Val.(type) {
case *types.Value_BytesVal:
return arrow.BinaryTypes.Binary, nil
Expand Down Expand Up @@ -91,81 +94,71 @@ func ValueTypeEnumToArrowType(t types.ValueType_Enum) (arrow.DataType, error) {
}

func CopyProtoValuesToArrowArray(builder array.Builder, values []*types.Value) error {
switch fieldBuilder := builder.(type) {
case *array.BooleanBuilder:
for _, v := range values {
fieldBuilder.Append(v.GetBoolVal())
}
case *array.BinaryBuilder:
for _, v := range values {
fieldBuilder.Append(v.GetBytesVal())
}
case *array.StringBuilder:
for _, v := range values {
fieldBuilder.Append(v.GetStringVal())
}
case *array.Int32Builder:
for _, v := range values {
fieldBuilder.Append(v.GetInt32Val())
}
case *array.Int64Builder:
for _, v := range values {
fieldBuilder.Append(v.GetInt64Val())
}
case *array.Float32Builder:
for _, v := range values {
fieldBuilder.Append(v.GetFloatVal())
for _, value := range values {
if value == nil || value.Val == nil {
builder.AppendNull()
continue
}
case *array.Float64Builder:
for _, v := range values {
fieldBuilder.Append(v.GetDoubleVal())
}
case *array.TimestampBuilder:
for _, v := range values {
fieldBuilder.Append(arrow.Timestamp(v.GetUnixTimestampVal()))
}
case *array.ListBuilder:
for _, list := range values {

switch fieldBuilder := builder.(type) {

case *array.BooleanBuilder:
fieldBuilder.Append(value.GetBoolVal())
case *array.BinaryBuilder:
fieldBuilder.Append(value.GetBytesVal())
case *array.StringBuilder:
fieldBuilder.Append(value.GetStringVal())
case *array.Int32Builder:
fieldBuilder.Append(value.GetInt32Val())
case *array.Int64Builder:
fieldBuilder.Append(value.GetInt64Val())
case *array.Float32Builder:
fieldBuilder.Append(value.GetFloatVal())
case *array.Float64Builder:
fieldBuilder.Append(value.GetDoubleVal())
case *array.TimestampBuilder:
fieldBuilder.Append(arrow.Timestamp(value.GetUnixTimestampVal()))
case *array.ListBuilder:
fieldBuilder.Append(true)

switch valueBuilder := fieldBuilder.ValueBuilder().(type) {

case *array.BooleanBuilder:
for _, v := range list.GetBoolListVal().GetVal() {
for _, v := range value.GetBoolListVal().GetVal() {
valueBuilder.Append(v)
}
case *array.BinaryBuilder:
for _, v := range list.GetBytesListVal().GetVal() {
for _, v := range value.GetBytesListVal().GetVal() {
valueBuilder.Append(v)
}
case *array.StringBuilder:
for _, v := range list.GetStringListVal().GetVal() {
for _, v := range value.GetStringListVal().GetVal() {
valueBuilder.Append(v)
}
case *array.Int32Builder:
for _, v := range list.GetInt32ListVal().GetVal() {
for _, v := range value.GetInt32ListVal().GetVal() {
valueBuilder.Append(v)
}
case *array.Int64Builder:
for _, v := range list.GetInt64ListVal().GetVal() {
for _, v := range value.GetInt64ListVal().GetVal() {
valueBuilder.Append(v)
}
case *array.Float32Builder:
for _, v := range list.GetFloatListVal().GetVal() {
for _, v := range value.GetFloatListVal().GetVal() {
valueBuilder.Append(v)
}
case *array.Float64Builder:
for _, v := range list.GetDoubleListVal().GetVal() {
for _, v := range value.GetDoubleListVal().GetVal() {
valueBuilder.Append(v)
}
case *array.TimestampBuilder:
for _, v := range list.GetUnixTimestampListVal().GetVal() {
for _, v := range value.GetUnixTimestampListVal().GetVal() {
valueBuilder.Append(arrow.Timestamp(v))
}
}
default:
return fmt.Errorf("unsupported array builder: %s", builder)
}
default:
return fmt.Errorf("unsupported array builder: %s", builder)
}
return nil
}
Expand Down Expand Up @@ -249,41 +242,68 @@ func ArrowValuesToProtoValues(arr arrow.Array) ([]*types.Value, error) {

switch arr.DataType() {
case arrow.PrimitiveTypes.Int32:
for _, v := range arr.(*array.Int32).Int32Values() {
values = append(values, &types.Value{Val: &types.Value_Int32Val{Int32Val: v}})
for idx := 0; idx < arr.Len(); idx++ {
if arr.IsNull(idx) {
values = append(values, &types.Value{})
} else {
values = append(values, &types.Value{Val: &types.Value_Int32Val{Int32Val: arr.(*array.Int32).Value(idx)}})
}
}
case arrow.PrimitiveTypes.Int64:
for _, v := range arr.(*array.Int64).Int64Values() {
values = append(values, &types.Value{Val: &types.Value_Int64Val{Int64Val: v}})
for idx := 0; idx < arr.Len(); idx++ {
if arr.IsNull(idx) {
values = append(values, &types.Value{})
} else {
values = append(values, &types.Value{Val: &types.Value_Int64Val{Int64Val: arr.(*array.Int64).Value(idx)}})
}
}
case arrow.PrimitiveTypes.Float32:
for _, v := range arr.(*array.Float32).Float32Values() {
values = append(values, &types.Value{Val: &types.Value_FloatVal{FloatVal: v}})
for idx := 0; idx < arr.Len(); idx++ {
if arr.IsNull(idx) {
values = append(values, &types.Value{})
} else {
values = append(values, &types.Value{Val: &types.Value_FloatVal{FloatVal: arr.(*array.Float32).Value(idx)}})
}
}
case arrow.PrimitiveTypes.Float64:
for _, v := range arr.(*array.Float64).Float64Values() {
values = append(values, &types.Value{Val: &types.Value_DoubleVal{DoubleVal: v}})
for idx := 0; idx < arr.Len(); idx++ {
if arr.IsNull(idx) {
values = append(values, &types.Value{})
} else {
values = append(values, &types.Value{Val: &types.Value_DoubleVal{DoubleVal: arr.(*array.Float64).Value(idx)}})
}
}
case arrow.FixedWidthTypes.Boolean:
for idx := 0; idx < arr.Len(); idx++ {
values = append(values,
&types.Value{Val: &types.Value_BoolVal{BoolVal: arr.(*array.Boolean).Value(idx)}})
if arr.IsNull(idx) {
values = append(values, &types.Value{})
} else {
values = append(values, &types.Value{Val: &types.Value_BoolVal{BoolVal: arr.(*array.Boolean).Value(idx)}})
}
}
case arrow.BinaryTypes.Binary:
for idx := 0; idx < arr.Len(); idx++ {
values = append(values,
&types.Value{Val: &types.Value_BytesVal{BytesVal: arr.(*array.Binary).Value(idx)}})
if arr.IsNull(idx) {
values = append(values, &types.Value{})
} else {
values = append(values, &types.Value{Val: &types.Value_BytesVal{BytesVal: arr.(*array.Binary).Value(idx)}})
}
}
case arrow.BinaryTypes.String:
for idx := 0; idx < arr.Len(); idx++ {
values = append(values,
&types.Value{Val: &types.Value_StringVal{StringVal: arr.(*array.String).Value(idx)}})
if arr.IsNull(idx) {
values = append(values, &types.Value{})
} else {
values = append(values, &types.Value{Val: &types.Value_StringVal{StringVal: arr.(*array.String).Value(idx)}})
}
}
case arrow.FixedWidthTypes.Timestamp_s:
for idx := 0; idx < arr.Len(); idx++ {
values = append(values,
&types.Value{Val: &types.Value_UnixTimestampVal{
UnixTimestampVal: int64(arr.(*array.Timestamp).Value(idx))}})
if arr.IsNull(idx) {
values = append(values, &types.Value{})
} else {
values = append(values, &types.Value{Val: &types.Value_UnixTimestampVal{UnixTimestampVal: int64(arr.(*array.Timestamp).Value(idx))}})
}
}
case arrow.Null:
for idx := 0; idx < arr.Len(); idx++ {
Expand All @@ -306,7 +326,9 @@ func ProtoValuesToArrowArray(protoValues []*types.Value, arrowAllocator memory.A
if err != nil {
return nil, err
}
break
if fieldType != nil {
break
}
}
}

Expand Down
30 changes: 27 additions & 3 deletions go/types/typeconversion_test.go
Original file line number Diff line number Diff line change
@@ -1,27 +1,46 @@
package types

import (
"math"
"testing"
"time"

"github.com/apache/arrow/go/v8/arrow/memory"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"

"github.com/feast-dev/feast/go/protos/feast/types"
)

var nil_or_null_val = &types.Value{}

var (
PROTO_VALUES = [][]*types.Value{
{{Val: nil}},
{{Val: nil}, {Val: nil}},
{nil_or_null_val, nil_or_null_val},
{nil_or_null_val, {Val: nil}},
{{Val: &types.Value_Int32Val{10}}, {Val: nil}, nil_or_null_val, {Val: &types.Value_Int32Val{20}}},
{{Val: &types.Value_Int32Val{10}}, nil_or_null_val},
{nil_or_null_val, {Val: &types.Value_Int32Val{20}}},
{{Val: &types.Value_Int32Val{10}}, {Val: &types.Value_Int32Val{20}}},
{{Val: &types.Value_Int64Val{10}}, nil_or_null_val},
{{Val: &types.Value_Int64Val{10}}, {Val: &types.Value_Int64Val{20}}},
{nil_or_null_val, {Val: &types.Value_FloatVal{2.0}}},
{{Val: &types.Value_FloatVal{1.0}}, {Val: &types.Value_FloatVal{2.0}}},
{{Val: &types.Value_FloatVal{1.0}}, {Val: &types.Value_FloatVal{2.0}}, {Val: &types.Value_FloatVal{float32(math.NaN())}}},
{{Val: &types.Value_DoubleVal{1.0}}, {Val: &types.Value_DoubleVal{2.0}}},
{{Val: &types.Value_DoubleVal{1.0}}, {Val: &types.Value_DoubleVal{2.0}}, {Val: &types.Value_DoubleVal{math.NaN()}}},
{{Val: &types.Value_DoubleVal{1.0}}, nil_or_null_val},
{nil_or_null_val, {Val: &types.Value_StringVal{"bbb"}}},
{{Val: &types.Value_StringVal{"aaa"}}, {Val: &types.Value_StringVal{"bbb"}}},
{{Val: &types.Value_BytesVal{[]byte{1, 2, 3}}}, nil_or_null_val},
{{Val: &types.Value_BytesVal{[]byte{1, 2, 3}}}, {Val: &types.Value_BytesVal{[]byte{4, 5, 6}}}},
{nil_or_null_val, {Val: &types.Value_BoolVal{false}}},
{{Val: &types.Value_BoolVal{true}}, {Val: &types.Value_BoolVal{false}}},
{{Val: &types.Value_UnixTimestampVal{time.Now().Unix()}},
{Val: &types.Value_UnixTimestampVal{time.Now().Unix()}}},
{{Val: &types.Value_UnixTimestampVal{time.Now().Unix()}}, nil_or_null_val},
{{Val: &types.Value_UnixTimestampVal{time.Now().Unix()}}, {Val: &types.Value_UnixTimestampVal{time.Now().Unix()}}},
{{Val: &types.Value_UnixTimestampVal{time.Now().Unix()}}, {Val: &types.Value_UnixTimestampVal{time.Now().Unix()}}, {Val: &types.Value_UnixTimestampVal{-9223372036854775808}}},

{
{Val: &types.Value_Int32ListVal{&types.Int32List{Val: []int32{0, 1, 2}}}},
Expand Down Expand Up @@ -55,6 +74,11 @@ var (
{Val: &types.Value_UnixTimestampListVal{&types.Int64List{Val: []int64{time.Now().Unix()}}}},
{Val: &types.Value_UnixTimestampListVal{&types.Int64List{Val: []int64{time.Now().Unix()}}}},
},
{
{Val: &types.Value_UnixTimestampListVal{&types.Int64List{Val: []int64{time.Now().Unix(), time.Now().Unix()}}}},
{Val: &types.Value_UnixTimestampListVal{&types.Int64List{Val: []int64{time.Now().Unix(), time.Now().Unix()}}}},
{Val: &types.Value_UnixTimestampListVal{&types.Int64List{Val: []int64{-9223372036854775808, time.Now().Unix()}}}},
},
}
)

Expand Down
Loading

0 comments on commit c491e57

Please sign in to comment.