diff --git a/sdk/python/tests/data/data_creator.py b/sdk/python/tests/data/data_creator.py index e5355b40bb..752eae37c1 100644 --- a/sdk/python/tests/data/data_creator.py +++ b/sdk/python/tests/data/data_creator.py @@ -60,6 +60,13 @@ def get_feature_values_for_dtype( "float": [1.0, None, 3.0, 4.0, 5.0], "string": ["1", None, "3", "4", "5"], "bool": [True, None, False, True, False], + "datetime": [ + datetime(1980, 1, 1), + None, + datetime(1981, 1, 1), + datetime(1982, 1, 1), + datetime(1982, 1, 1), + ], } non_list_val = dtype_map[dtype] if is_list: diff --git a/sdk/python/tests/integration/registration/test_universal_types.py b/sdk/python/tests/integration/registration/test_universal_types.py index c007d56c35..54a7de37ca 100644 --- a/sdk/python/tests/integration/registration/test_universal_types.py +++ b/sdk/python/tests/integration/registration/test_universal_types.py @@ -1,4 +1,5 @@ import logging +import re from dataclasses import dataclass from datetime import datetime, timedelta from typing import List @@ -28,6 +29,7 @@ def populate_test_configs(offline: bool): (ValueType.INT64, "int64"), (ValueType.STRING, "float"), (ValueType.STRING, "bool"), + (ValueType.INT32, "datetime"), ] configs: List[TypeTestConfig] = [] for test_repo_config in FULL_REPO_CONFIGS: @@ -232,6 +234,7 @@ def test_feature_get_online_features_types_match(online_types_test_fixtures): "float": float, "string": str, "bool": bool, + "datetime": int, } expected_dtype = feature_list_dtype_to_expected_online_response_value_type[ config.feature_dtype @@ -258,6 +261,8 @@ def create_feature_view( value_type = ValueType.FLOAT_LIST elif feature_dtype == "bool": value_type = ValueType.BOOL_LIST + elif feature_dtype == "datetime": + value_type = ValueType.UNIX_TIMESTAMP_LIST else: if feature_dtype == "int32": value_type = ValueType.INT32 @@ -267,6 +272,8 @@ def create_feature_view( value_type = ValueType.FLOAT elif feature_dtype == "bool": value_type = ValueType.BOOL + elif feature_dtype == "datetime": + value_type = ValueType.UNIX_TIMESTAMP return driver_feature_view(data_source, name=name, value_type=value_type,) @@ -281,6 +288,7 @@ def assert_expected_historical_feature_types( "float": (pd.api.types.is_float_dtype,), "string": (pd.api.types.is_string_dtype,), "bool": (pd.api.types.is_bool_dtype, pd.api.types.is_object_dtype), + "datetime": (pd.api.types.is_datetime64_any_dtype,), } dtype_checkers = feature_dtype_to_expected_historical_feature_dtype[feature_dtype] assert any( @@ -307,6 +315,7 @@ def assert_feature_list_types( bool, np.bool_, ), # Can be `np.bool_` if from `np.array` rather that `list` + "datetime": np.datetime64, } expected_dtype = feature_list_dtype_to_expected_historical_feature_list_dtype[ feature_dtype @@ -328,22 +337,23 @@ def assert_expected_arrow_types( historical_features_arrow = historical_features.to_arrow() print(historical_features_arrow) feature_list_dtype_to_expected_historical_feature_arrow_type = { - "int32": "int64", - "int64": "int64", - "float": "double", - "string": "string", - "bool": "bool", + "int32": r"int64", + "int64": r"int64", + "float": r"double", + "string": r"string", + "bool": r"bool", + "datetime": r"timestamp\[.+\]", } arrow_type = feature_list_dtype_to_expected_historical_feature_arrow_type[ feature_dtype ] if feature_is_list: - assert ( - str(historical_features_arrow.schema.field_by_name("value").type) - == f"list" + assert re.match( + f"list", + str(historical_features_arrow.schema.field_by_name("value").type), ) else: - assert ( - str(historical_features_arrow.schema.field_by_name("value").type) - == arrow_type + assert re.match( + arrow_type, + str(historical_features_arrow.schema.field_by_name("value").type), )