From 19a9a1b76f882f267552ec954c42083bef4c8851 Mon Sep 17 00:00:00 2001 From: Junwei Deng <35031544+TheaperDeng@users.noreply.github.com> Date: Fri, 9 Sep 2022 09:57:26 +0800 Subject: [PATCH] Chronos: fix spark 3.1 bug in xshards unscale (#5689) --- .../bigdl/chronos/data/experimental/test_xshardstsdataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/chronos/test/bigdl/chronos/data/experimental/test_xshardstsdataset.py b/python/chronos/test/bigdl/chronos/data/experimental/test_xshardstsdataset.py index 05dd2958d16..d587c9ad72a 100644 --- a/python/chronos/test/bigdl/chronos/data/experimental/test_xshardstsdataset.py +++ b/python/chronos/test/bigdl/chronos/data/experimental/test_xshardstsdataset.py @@ -217,8 +217,8 @@ def test_xshardstsdataset_scale_unscale(self): scalers = [{0: StandardScaler(), 1: StandardScaler()}] df = pd.read_csv(os.path.join(self.resource_path, "multiple.csv")) for scaler in scalers: - shards_multiple = read_csv(os.path.join(self.resource_path, "multiple.csv")) - shards_multiple_test = read_csv(os.path.join(self.resource_path, "multiple.csv")) + shards_multiple = read_csv(os.path.join(self.resource_path, "multiple.csv"), dtype={"id": np.int64}) + shards_multiple_test = read_csv(os.path.join(self.resource_path, "multiple.csv"), dtype={"id": np.int64}) tsdata = XShardsTSDataset.from_xshards(shards_multiple, dt_col="datetime", target_col="value",