From c4396118a99392a7b284b635b80486c94e0c86c8 Mon Sep 17 00:00:00 2001 From: Oleksii Moskalenko Date: Thu, 21 Apr 2022 21:07:49 +0300 Subject: [PATCH] fix: Use timestamp type when converting unixtimestamp feature type to arrow (#2593) * use timestamp Signed-off-by: pyalex * add timezone to type definition Signed-off-by: pyalex --- go/types/typeconversion.go | 24 ++--- .../embedded_go/online_features_service.py | 80 ++--------------- sdk/python/feast/embedded_go/type_map.py | 88 +++++++++++++++++++ 3 files changed, 109 insertions(+), 83 deletions(-) create mode 100644 sdk/python/feast/embedded_go/type_map.py diff --git a/go/types/typeconversion.go b/go/types/typeconversion.go index f42356b454..416eb2ac27 100644 --- a/go/types/typeconversion.go +++ b/go/types/typeconversion.go @@ -40,9 +40,9 @@ func ProtoTypeToArrowType(sample *types.Value) (arrow.DataType, error) { case *types.Value_DoubleListVal: return arrow.ListOf(arrow.PrimitiveTypes.Float64), nil case *types.Value_UnixTimestampVal: - return arrow.FixedWidthTypes.Time32s, nil + return arrow.FixedWidthTypes.Timestamp_s, nil case *types.Value_UnixTimestampListVal: - return arrow.ListOf(arrow.FixedWidthTypes.Time32s), nil + return arrow.ListOf(arrow.FixedWidthTypes.Timestamp_s), nil default: return nil, fmt.Errorf("unsupported proto type in proto to arrow conversion: %s", sample.Val) @@ -80,9 +80,9 @@ func ValueTypeEnumToArrowType(t types.ValueType_Enum) (arrow.DataType, error) { case types.ValueType_DOUBLE_LIST: return arrow.ListOf(arrow.PrimitiveTypes.Float64), nil case types.ValueType_UNIX_TIMESTAMP: - return arrow.FixedWidthTypes.Time32s, nil + return arrow.FixedWidthTypes.Timestamp_s, nil case types.ValueType_UNIX_TIMESTAMP_LIST: - return arrow.ListOf(arrow.FixedWidthTypes.Time32s), nil + return arrow.ListOf(arrow.FixedWidthTypes.Timestamp_s), nil default: return nil, fmt.Errorf("unsupported value type enum in enum to arrow type conversion: %s", t) @@ -119,9 +119,9 @@ func copyProtoValuesToArrowArray(builder array.Builder, values []*types.Value) e for _, v := range values { fieldBuilder.Append(v.GetDoubleVal()) } - case *array.Time32Builder: + case *array.TimestampBuilder: for _, v := range values { - fieldBuilder.Append(arrow.Time32(v.GetUnixTimestampVal())) + fieldBuilder.Append(arrow.Timestamp(v.GetUnixTimestampVal())) } case *array.ListBuilder: for _, list := range values { @@ -157,9 +157,9 @@ func copyProtoValuesToArrowArray(builder array.Builder, values []*types.Value) e for _, v := range list.GetDoubleListVal().GetVal() { valueBuilder.Append(v) } - case *array.Time32Builder: + case *array.TimestampBuilder: for _, v := range list.GetUnixTimestampListVal().GetVal() { - valueBuilder.Append(arrow.Time32(v)) + valueBuilder.Append(arrow.Timestamp(v)) } } } @@ -227,10 +227,10 @@ func ArrowValuesToProtoValues(arr arrow.Array) ([]*types.Value, error) { } values = append(values, &types.Value{Val: &types.Value_BoolListVal{BoolListVal: &types.BoolList{Val: vals}}}) - case arrow.FixedWidthTypes.Time32s: + case arrow.FixedWidthTypes.Timestamp_s: vals := make([]int64, int(offsets[idx])-pos) for j := pos; j < int(offsets[idx]); j++ { - vals[j-pos] = int64(listValues.(*array.Time32).Value(j)) + vals[j-pos] = int64(listValues.(*array.Timestamp).Value(j)) } values = append(values, @@ -278,11 +278,11 @@ func ArrowValuesToProtoValues(arr arrow.Array) ([]*types.Value, error) { values = append(values, &types.Value{Val: &types.Value_StringVal{StringVal: arr.(*array.String).Value(idx)}}) } - case arrow.FixedWidthTypes.Time32s: + case arrow.FixedWidthTypes.Timestamp_s: for idx := 0; idx < arr.Len(); idx++ { values = append(values, &types.Value{Val: &types.Value_UnixTimestampVal{ - UnixTimestampVal: int64(arr.(*array.Time32).Value(idx))}}) + UnixTimestampVal: int64(arr.(*array.Timestamp).Value(idx))}}) } default: return nil, fmt.Errorf("unsupported arrow to proto conversion for type %s", arr.DataType()) diff --git a/sdk/python/feast/embedded_go/online_features_service.py b/sdk/python/feast/embedded_go/online_features_service.py index a007cf8272..410af1d8fe 100644 --- a/sdk/python/feast/embedded_go/online_features_service.py +++ b/sdk/python/feast/embedded_go/online_features_service.py @@ -14,59 +14,17 @@ from feast.protos.feast.serving.ServingService_pb2 import GetOnlineFeaturesResponse from feast.protos.feast.types import Value_pb2 from feast.repo_config import RepoConfig +from feast.types import from_value_type from feast.value_type import ValueType from .lib.embedded import DataTable, NewOnlineFeatureService, OnlineFeatureServiceConfig from .lib.go import Slice_string +from .type_map import FEAST_TYPE_TO_ARROW_TYPE, arrow_array_to_array_of_proto if TYPE_CHECKING: from feast.feature_store import FeatureStore -ARROW_TYPE_TO_PROTO_FIELD = { - pa.int32(): "int32_val", - pa.int64(): "int64_val", - pa.float32(): "float_val", - pa.float64(): "double_val", - pa.bool_(): "bool_val", - pa.string(): "string_val", - pa.binary(): "bytes_val", - pa.time32("s"): "unix_timestamp_val", -} - -ARROW_LIST_TYPE_TO_PROTO_FIELD = { - pa.int32(): "int32_list_val", - pa.int64(): "int64_list_val", - pa.float32(): "float_list_val", - pa.float64(): "double_list_val", - pa.bool_(): "bool_list_val", - pa.string(): "string_list_val", - pa.binary(): "bytes_list_val", - pa.time32("s"): "unix_timestamp_list_val", -} - -ARROW_LIST_TYPE_TO_PROTO_LIST_CLASS = { - pa.int32(): Value_pb2.Int32List, - pa.int64(): Value_pb2.Int64List, - pa.float32(): Value_pb2.FloatList, - pa.float64(): Value_pb2.DoubleList, - pa.bool_(): Value_pb2.BoolList, - pa.string(): Value_pb2.StringList, - pa.binary(): Value_pb2.BytesList, - pa.time32("s"): Value_pb2.Int64List, -} - -# used for entity types only -PROTO_TYPE_TO_ARROW_TYPE = { - ValueType.INT32: pa.int32(), - ValueType.INT64: pa.int64(), - ValueType.FLOAT: pa.float32(), - ValueType.DOUBLE: pa.float64(), - ValueType.STRING: pa.string(), - ValueType.BYTES: pa.binary(), -} - - class EmbeddedOnlineFeatureServer: def __init__( self, repo_path: str, repo_config: RepoConfig, feature_store: "FeatureStore" @@ -179,8 +137,10 @@ def _to_arrow(value, type_hint: Optional[ValueType]) -> pa.Array: if isinstance(value, Value_pb2.RepeatedValue): _proto_to_arrow(value) - if type_hint in PROTO_TYPE_TO_ARROW_TYPE: - return pa.array(value, PROTO_TYPE_TO_ARROW_TYPE[type_hint]) + if type_hint: + feast_type = from_value_type(type_hint) + if feast_type in FEAST_TYPE_TO_ARROW_TYPE: + return pa.array(value, FEAST_TYPE_TO_ARROW_TYPE[feast_type]) return pa.array(value) @@ -263,31 +223,9 @@ def record_batch_to_online_response(record_batch): [Value_pb2.Value()] * len(record_batch.columns[idx]) ) else: - if isinstance(field.type, pa.ListType): - proto_list_class = ARROW_LIST_TYPE_TO_PROTO_LIST_CLASS[ - field.type.value_type - ] - proto_field_name = ARROW_LIST_TYPE_TO_PROTO_FIELD[field.type.value_type] - - column = record_batch.columns[idx] - if field.type.value_type == pa.time32("s"): - column = column.cast(pa.list_(pa.int32())) - - for v in column.tolist(): - feature_vector.values.append( - Value_pb2.Value(**{proto_field_name: proto_list_class(val=v)}) - ) - else: - proto_field_name = ARROW_TYPE_TO_PROTO_FIELD[field.type] - - column = record_batch.columns[idx] - if field.type == pa.time32("s"): - column = column.cast(pa.int32()) - - for v in column.tolist(): - feature_vector.values.append( - Value_pb2.Value(**{proto_field_name: v}) - ) + feature_vector.values.extend( + arrow_array_to_array_of_proto(field.type, record_batch.columns[idx]) + ) resp.results.append(feature_vector) resp.metadata.feature_names.val.append(field.name) diff --git a/sdk/python/feast/embedded_go/type_map.py b/sdk/python/feast/embedded_go/type_map.py new file mode 100644 index 0000000000..e70dc3be86 --- /dev/null +++ b/sdk/python/feast/embedded_go/type_map.py @@ -0,0 +1,88 @@ +from typing import List + +import pyarrow as pa +import pytz + +from feast.protos.feast.types import Value_pb2 +from feast.types import Array, PrimitiveFeastType + +PA_TIMESTAMP_TYPE = pa.timestamp("s", tz=pytz.UTC) + +ARROW_TYPE_TO_PROTO_FIELD = { + pa.int32(): "int32_val", + pa.int64(): "int64_val", + pa.float32(): "float_val", + pa.float64(): "double_val", + pa.bool_(): "bool_val", + pa.string(): "string_val", + pa.binary(): "bytes_val", + PA_TIMESTAMP_TYPE: "unix_timestamp_val", +} + +ARROW_LIST_TYPE_TO_PROTO_FIELD = { + pa.int32(): "int32_list_val", + pa.int64(): "int64_list_val", + pa.float32(): "float_list_val", + pa.float64(): "double_list_val", + pa.bool_(): "bool_list_val", + pa.string(): "string_list_val", + pa.binary(): "bytes_list_val", + PA_TIMESTAMP_TYPE: "unix_timestamp_list_val", +} + +ARROW_LIST_TYPE_TO_PROTO_LIST_CLASS = { + pa.int32(): Value_pb2.Int32List, + pa.int64(): Value_pb2.Int64List, + pa.float32(): Value_pb2.FloatList, + pa.float64(): Value_pb2.DoubleList, + pa.bool_(): Value_pb2.BoolList, + pa.string(): Value_pb2.StringList, + pa.binary(): Value_pb2.BytesList, + PA_TIMESTAMP_TYPE: Value_pb2.Int64List, +} + +FEAST_TYPE_TO_ARROW_TYPE = { + PrimitiveFeastType.INT32: pa.int32(), + PrimitiveFeastType.INT64: pa.int64(), + PrimitiveFeastType.FLOAT32: pa.float32(), + PrimitiveFeastType.FLOAT64: pa.float64(), + PrimitiveFeastType.STRING: pa.string(), + PrimitiveFeastType.BYTES: pa.binary(), + PrimitiveFeastType.BOOL: pa.bool_(), + PrimitiveFeastType.UNIX_TIMESTAMP: pa.timestamp("s"), + Array(PrimitiveFeastType.INT32): pa.list_(pa.int32()), + Array(PrimitiveFeastType.INT64): pa.list_(pa.int64()), + Array(PrimitiveFeastType.FLOAT32): pa.list_(pa.float32()), + Array(PrimitiveFeastType.FLOAT64): pa.list_(pa.float64()), + Array(PrimitiveFeastType.STRING): pa.list_(pa.string()), + Array(PrimitiveFeastType.BYTES): pa.list_(pa.binary()), + Array(PrimitiveFeastType.BOOL): pa.list_(pa.bool_()), + Array(PrimitiveFeastType.UNIX_TIMESTAMP): pa.list_(pa.timestamp("s")), +} + + +def arrow_array_to_array_of_proto( + arrow_type: pa.DataType, arrow_array: pa.Array +) -> List[Value_pb2.Value]: + values = [] + if isinstance(arrow_type, pa.ListType): + proto_list_class = ARROW_LIST_TYPE_TO_PROTO_LIST_CLASS[arrow_type.value_type] + proto_field_name = ARROW_LIST_TYPE_TO_PROTO_FIELD[arrow_type.value_type] + + if arrow_type.value_type == PA_TIMESTAMP_TYPE: + arrow_array = arrow_array.cast(pa.list_(pa.int64())) + + for v in arrow_array.tolist(): + values.append( + Value_pb2.Value(**{proto_field_name: proto_list_class(val=v)}) + ) + else: + proto_field_name = ARROW_TYPE_TO_PROTO_FIELD[arrow_type] + + if arrow_type == PA_TIMESTAMP_TYPE: + arrow_array = arrow_array.cast(pa.int64()) + + for v in arrow_array.tolist(): + values.append(Value_pb2.Value(**{proto_field_name: v})) + + return values