From 7ec735654652a6ac3a923bde8325de6f01c50f87 Mon Sep 17 00:00:00 2001 From: Kai Huang Date: Thu, 8 Sep 2022 11:01:54 +0800 Subject: [PATCH] Support decimal for Spark DataFrame (#5671) * support decimal * remove --- python/dllib/src/bigdl/dllib/utils/utils.py | 4 +++ .../learn/ray/tf/test_tf_ray_estimator.py | 29 +++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/python/dllib/src/bigdl/dllib/utils/utils.py b/python/dllib/src/bigdl/dllib/utils/utils.py index c26bbe988b1..256949b2b4d 100644 --- a/python/dllib/src/bigdl/dllib/utils/utils.py +++ b/python/dllib/src/bigdl/dllib/utils/utils.py @@ -194,6 +194,8 @@ def _is_scalar_type(dtype, accept_str_col=False): return True if isinstance(dtype, df_types.TimestampType): return True + if isinstance(dtype, df_types.DecimalType): + return True if accept_str_col and isinstance(dtype, df_types.StringType): return True return False @@ -212,6 +214,8 @@ def convert_for_cols(row, cols): result.append(np.array(row[name]).astype('datetime64[ns]')) elif isinstance(feature_type, df_types.IntegerType): result.append(np.array(row[name]).astype(np.int32)) + elif isinstance(feature_type, df_types.DecimalType): + result.append(np.array(row[name]).astype(np.float64)) else: result.append(np.array(row[name])) elif isinstance(feature_type, df_types.ArrayType): diff --git a/python/orca/test/bigdl/orca/learn/ray/tf/test_tf_ray_estimator.py b/python/orca/test/bigdl/orca/learn/ray/tf/test_tf_ray_estimator.py index 624e1ccb064..71e82a08639 100644 --- a/python/orca/test/bigdl/orca/learn/ray/tf/test_tf_ray_estimator.py +++ b/python/orca/test/bigdl/orca/learn/ray/tf/test_tf_ray_estimator.py @@ -348,6 +348,35 @@ def test_dataframe(self): label_cols=["label"]) trainer.predict(df, feature_cols=["feature"]).collect() + def test_dataframe_decimal_input(self): + + from pyspark.sql.types import StructType, StructField, IntegerType, FloatType + from pyspark.sql.functions import col + from bigdl.orca import OrcaContext + + spark = OrcaContext.get_spark_session() + schema = StructType([ + StructField("feature", FloatType(), True), + StructField("label", IntegerType(), True) + ]) + data = [(30.2222, 1), (40.0, 0), (15.1, 1), + (-2.456, 1), (3.21, 0), (11.28, 1)] + df = spark.createDataFrame(data=data, schema=schema) + df = df.withColumn("feature", col("feature").cast("decimal(38,2)")) + + config = { + "lr": 0.8 + } + trainer = Estimator.from_keras( + model_creator=model_creator, + verbose=True, + config=config, + workers_per_node=2) + + trainer.fit(df, epochs=1, batch_size=4, steps_per_epoch=25, + feature_cols=["feature"], + label_cols=["label"]) + def test_dataframe_with_empty_partition(self): from bigdl.orca import OrcaContext sc = OrcaContext.get_spark_context()