From 3fa7e5a71ae1bf4ef16bba026e738059fdf316a3 Mon Sep 17 00:00:00 2001 From: Kai Huang Date: Wed, 7 Sep 2022 20:35:01 +0800 Subject: [PATCH 1/2] support decimal --- python/dllib/src/bigdl/dllib/utils/utils.py | 4 +++ .../learn/ray/tf/test_tf_ray_estimator.py | 31 +++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/python/dllib/src/bigdl/dllib/utils/utils.py b/python/dllib/src/bigdl/dllib/utils/utils.py index c26bbe988b1..256949b2b4d 100644 --- a/python/dllib/src/bigdl/dllib/utils/utils.py +++ b/python/dllib/src/bigdl/dllib/utils/utils.py @@ -194,6 +194,8 @@ def _is_scalar_type(dtype, accept_str_col=False): return True if isinstance(dtype, df_types.TimestampType): return True + if isinstance(dtype, df_types.DecimalType): + return True if accept_str_col and isinstance(dtype, df_types.StringType): return True return False @@ -212,6 +214,8 @@ def convert_for_cols(row, cols): result.append(np.array(row[name]).astype('datetime64[ns]')) elif isinstance(feature_type, df_types.IntegerType): result.append(np.array(row[name]).astype(np.int32)) + elif isinstance(feature_type, df_types.DecimalType): + result.append(np.array(row[name]).astype(np.float64)) else: result.append(np.array(row[name])) elif isinstance(feature_type, df_types.ArrayType): diff --git a/python/orca/test/bigdl/orca/learn/ray/tf/test_tf_ray_estimator.py b/python/orca/test/bigdl/orca/learn/ray/tf/test_tf_ray_estimator.py index 624e1ccb064..36162725343 100644 --- a/python/orca/test/bigdl/orca/learn/ray/tf/test_tf_ray_estimator.py +++ b/python/orca/test/bigdl/orca/learn/ray/tf/test_tf_ray_estimator.py @@ -348,6 +348,37 @@ def test_dataframe(self): label_cols=["label"]) trainer.predict(df, feature_cols=["feature"]).collect() + def test_dataframe_decimal_input(self): + + from pyspark.sql.types import StructType, StructField, IntegerType, FloatType + from pyspark.sql.functions import col + from bigdl.orca import OrcaContext + + spark = OrcaContext.get_spark_session() + schema = StructType([ + StructField("feature", FloatType(), True), + StructField("label", IntegerType(), True) + ]) + data = [(30.2222, 1), (40.0, 0), (15.1, 1), + (-2.456, 1), (3.21, 0), (11.28, 1)] + df = spark.createDataFrame(data=data, schema=schema) + df = df.withColumn("feature", col("feature").cast("decimal(38,2)")) + res = df.collect() + features = [x[0] for x in res] + + config = { + "lr": 0.8 + } + trainer = Estimator.from_keras( + model_creator=model_creator, + verbose=True, + config=config, + workers_per_node=2) + + trainer.fit(df, epochs=1, batch_size=4, steps_per_epoch=25, + feature_cols=["feature"], + label_cols=["label"]) + def test_dataframe_with_empty_partition(self): from bigdl.orca import OrcaContext sc = OrcaContext.get_spark_context() From 72692d6c775844d5f083234b8deaa82bc7b89c59 Mon Sep 17 00:00:00 2001 From: Kai Huang Date: Wed, 7 Sep 2022 20:41:17 +0800 Subject: [PATCH 2/2] remove --- .../orca/test/bigdl/orca/learn/ray/tf/test_tf_ray_estimator.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/orca/test/bigdl/orca/learn/ray/tf/test_tf_ray_estimator.py b/python/orca/test/bigdl/orca/learn/ray/tf/test_tf_ray_estimator.py index 36162725343..71e82a08639 100644 --- a/python/orca/test/bigdl/orca/learn/ray/tf/test_tf_ray_estimator.py +++ b/python/orca/test/bigdl/orca/learn/ray/tf/test_tf_ray_estimator.py @@ -363,8 +363,6 @@ def test_dataframe_decimal_input(self): (-2.456, 1), (3.21, 0), (11.28, 1)] df = spark.createDataFrame(data=data, schema=schema) df = df.withColumn("feature", col("feature").cast("decimal(38,2)")) - res = df.collect() - features = [x[0] for x in res] config = { "lr": 0.8