Skip to content

Commit

Permalink
Support decimal for Spark DataFrame (#5671)
Browse files Browse the repository at this point in the history
* support decimal

* remove
  • Loading branch information
hkvision authored Sep 8, 2022
1 parent 7aa7617 commit 7ec7356
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
4 changes: 4 additions & 0 deletions python/dllib/src/bigdl/dllib/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ def _is_scalar_type(dtype, accept_str_col=False):
return True
if isinstance(dtype, df_types.TimestampType):
return True
if isinstance(dtype, df_types.DecimalType):
return True
if accept_str_col and isinstance(dtype, df_types.StringType):
return True
return False
Expand All @@ -212,6 +214,8 @@ def convert_for_cols(row, cols):
result.append(np.array(row[name]).astype('datetime64[ns]'))
elif isinstance(feature_type, df_types.IntegerType):
result.append(np.array(row[name]).astype(np.int32))
elif isinstance(feature_type, df_types.DecimalType):
result.append(np.array(row[name]).astype(np.float64))
else:
result.append(np.array(row[name]))
elif isinstance(feature_type, df_types.ArrayType):
Expand Down
29 changes: 29 additions & 0 deletions python/orca/test/bigdl/orca/learn/ray/tf/test_tf_ray_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,35 @@ def test_dataframe(self):
label_cols=["label"])
trainer.predict(df, feature_cols=["feature"]).collect()

def test_dataframe_decimal_input(self):

from pyspark.sql.types import StructType, StructField, IntegerType, FloatType
from pyspark.sql.functions import col
from bigdl.orca import OrcaContext

spark = OrcaContext.get_spark_session()
schema = StructType([
StructField("feature", FloatType(), True),
StructField("label", IntegerType(), True)
])
data = [(30.2222, 1), (40.0, 0), (15.1, 1),
(-2.456, 1), (3.21, 0), (11.28, 1)]
df = spark.createDataFrame(data=data, schema=schema)
df = df.withColumn("feature", col("feature").cast("decimal(38,2)"))

config = {
"lr": 0.8
}
trainer = Estimator.from_keras(
model_creator=model_creator,
verbose=True,
config=config,
workers_per_node=2)

trainer.fit(df, epochs=1, batch_size=4, steps_per_epoch=25,
feature_cols=["feature"],
label_cols=["label"])

def test_dataframe_with_empty_partition(self):
from bigdl.orca import OrcaContext
sc = OrcaContext.get_spark_context()
Expand Down

0 comments on commit 7ec7356

Please sign in to comment.