From c2ecf64d7c6306e1cfcd18e05226267128911802 Mon Sep 17 00:00:00 2001 From: dding3 Date: Wed, 14 Sep 2022 09:30:27 -0700 Subject: [PATCH] support standard scaler for shards (#5716) * support standard scaler --- .../orca/src/bigdl/orca/data/transformer.py | 42 +++++++++++++++++++ .../bigdl/orca/data/test_spark_backend.py | 14 +++++++ 2 files changed, 56 insertions(+) diff --git a/python/orca/src/bigdl/orca/data/transformer.py b/python/orca/src/bigdl/orca/data/transformer.py index 7db51994868..b5236ba19f1 100644 --- a/python/orca/src/bigdl/orca/data/transformer.py +++ b/python/orca/src/bigdl/orca/data/transformer.py @@ -21,6 +21,7 @@ from bigdl.orca.data import SparkXShards from bigdl.orca import OrcaContext from pyspark.ml.feature import MinMaxScaler as SparkMinMaxScaler +from pyspark.ml.feature import StandardScaler as SparkStandardScaler from pyspark.ml.feature import VectorAssembler as SparkVectorAssembler from pyspark.ml import Pipeline as SparkPipeline @@ -354,3 +355,44 @@ def transform(self, shard): scaledData = self.scalerModel.transform(df) data_shards = spark_df_to_pd_sparkxshards(scaledData) return data_shards + + +class StandardScaler: + def __init__(self, withMean=False, withStd=True, inputCol=None, outputCol=None): + self.withMean = withMean + self.withStd = withStd + self.inputCol = inputCol + self.outputCol = outputCol + self.scaler = None + self.scalerModel = None + if inputCol: + self.__createScaler__() + + def __createScaler__(self): + invalidInputError(self.inputCol, "inputColumn cannot be empty") + invalidInputError(self.outputCol, "outputColumn cannot be empty") + + vecOutputCol = str(uuid.uuid1()) + "x_vec" + assembler = SparkVectorAssembler(inputCols=[self.inputCol], outputCol=vecOutputCol) + scaler = SparkStandardScaler(withMean=self.withMean, withStd=self.withStd, + inputCol=vecOutputCol, outputCol=self.outputCol) + self.scaler = SparkPipeline(stages=[assembler, scaler]) + + def setInputOutputCol(self, inputCol, outputCol): + self.inputCol = inputCol + self.outputCol = outputCol + self.__createScaler__() + + def fit_transform(self, shard): + df = shard.to_spark_df() + self.scalerModel = self.scaler.fit(df) + scaledData = self.scalerModel.transform(df) + data_shards = spark_df_to_pd_sparkxshards(scaledData) + return data_shards + + def transform(self, shard): + invalidInputError(self.scalerModel, "Please call fit_transform first") + df = shard.to_spark_df() + scaledData = self.scalerModel.transform(df) + data_shards = spark_df_to_pd_sparkxshards(scaledData) + return data_shards diff --git a/python/orca/test/bigdl/orca/data/test_spark_backend.py b/python/orca/test/bigdl/orca/data/test_spark_backend.py index ebae7b9ddc6..d3723b3667e 100644 --- a/python/orca/test/bigdl/orca/data/test_spark_backend.py +++ b/python/orca/test/bigdl/orca/data/test_spark_backend.py @@ -26,6 +26,7 @@ from bigdl.dllib.nncontext import * from bigdl.orca.data.image import write_tfrecord, read_tfrecord from bigdl.orca.data.utils import * +from bigdl.orca.data.transformer import * class TestSparkBackend(TestCase): @@ -231,6 +232,19 @@ def test_spark_df_to_shards(self): df = spark.read.csv(file_path) data_shards = spark_df_to_pd_sparkxshards(df) + def test_minmaxscale_shards(self): + file_path = os.path.join(self.resource_path, "orca/data/csv") + data_shard = bigdl.orca.data.pandas.read_csv(file_path) + scale = MinMaxScaler(inputCol=["sale_price"], outputCol="sale_price_scaled") + transformed_data_shard = scale.fit_transform(data_shard) + + def test_standardscale_shards(self): + file_path = os.path.join(self.resource_path, "orca/data/csv") + + data_shard = bigdl.orca.data.pandas.read_csv(file_path) + scale = StandardScaler(inputCol="sale_price", outputCol="sale_price_scaled") + transformed_data_shard = scale.fit_transform(data_shard) + if __name__ == "__main__": pytest.main([__file__])