From db8a50b55db3a7c97e0795fcec60857e3bee8ad5 Mon Sep 17 00:00:00 2001 From: sgwhat Date: Tue, 16 Aug 2022 15:39:23 +0800 Subject: [PATCH 1/5] fix save model in pyspark est --- python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py b/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py index 28888073555..1bdf53f6776 100644 --- a/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py +++ b/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py @@ -482,10 +482,8 @@ def save(self, saving to SavedModel. """ # get current model - if exists(self._model_saved_path): - model = load_model(self._model_saved_path) - else: - model = self.model_creator(self.config) + model = self.model_creator(self.config) + model.set_weights(self.model_weights) # save model save_model(model, filepath, overwrite=overwrite, include_optimizer=include_optimizer, save_format=save_format, signatures=signatures, options=options) From 207f68efb151669d708c27f73ae78917b2214ea3 Mon Sep 17 00:00:00 2001 From: sgwhat Date: Tue, 16 Aug 2022 15:41:09 +0800 Subject: [PATCH 2/5] use a function to replace --- python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py b/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py index 1bdf53f6776..bd6d91b66d0 100644 --- a/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py +++ b/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py @@ -482,8 +482,7 @@ def save(self, saving to SavedModel. """ # get current model - model = self.model_creator(self.config) - model.set_weights(self.model_weights) + model = self.get_model() # save model save_model(model, filepath, overwrite=overwrite, include_optimizer=include_optimizer, save_format=save_format, signatures=signatures, options=options) From 836248d41e88c1dd76a69ad939224535147d2aeb Mon Sep 17 00:00:00 2001 From: sgwhat Date: Wed, 17 Aug 2022 10:11:18 +0800 Subject: [PATCH 3/5] add ut --- .../learn/ray/tf/test_tf_spark_estimator.py | 67 ++++++++++++++++--- 1 file changed, 58 insertions(+), 9 deletions(-) diff --git a/python/orca/test/bigdl/orca/learn/ray/tf/test_tf_spark_estimator.py b/python/orca/test/bigdl/orca/learn/ray/tf/test_tf_spark_estimator.py index 2d7a3d83034..a7bf62c26f8 100644 --- a/python/orca/test/bigdl/orca/learn/ray/tf/test_tf_spark_estimator.py +++ b/python/orca/test/bigdl/orca/learn/ray/tf/test_tf_spark_estimator.py @@ -496,7 +496,7 @@ def test_checkpoint_model(self): finally: shutil.rmtree(temp_dir) - def test_save_load_model(self): + def test_save_load_model_h5(self): sc = OrcaContext.get_spark_context() rdd = sc.range(0, 100) spark = OrcaContext.get_spark_session() @@ -528,22 +528,71 @@ def test_save_load_model(self): print("start saving") trainer.save(os.path.join(temp_dir, "a.h5")) + + res = trainer.evaluate(df, batch_size=4, num_steps=25, feature_cols=["feature"], + label_cols=["label"]) + print("validation result: ", res) + + before_res = trainer.predict(df, feature_cols=["feature"]).collect() + expect_res = np.concatenate([part["prediction"] for part in before_res]) + trainer.load(os.path.join(temp_dir, "a.h5")) - trainer.save(os.path.join(temp_dir, "saved_model")) - trainer.load(os.path.join(temp_dir, "saved_model")) - # continous training - res = trainer.fit(df, epochs=10, batch_size=4, steps_per_epoch=25, + + # continous predicting + after_res = trainer.predict(df, feature_cols=["feature"]).collect() + pred_res = np.concatenate([part["prediction"] for part in after_res]) + + assert np.array_equal(expect_res, pred_res) + finally: + shutil.rmtree(temp_dir) + + def test_save_load_model_savemodel(self): + sc = OrcaContext.get_spark_context() + rdd = sc.range(0, 100) + spark = OrcaContext.get_spark_session() + + from pyspark.ml.linalg import DenseVector + df = rdd.map(lambda x: (DenseVector(np.random.randn(1, ).astype(np.float)), + int(np.random.randint(0, 2, size=())))).toDF(["feature", "label"]) + + config = { + "lr": 0.2 + } + + try: + temp_dir = tempfile.mkdtemp() + + trainer = Estimator.from_keras( + model_creator=model_creator, + verbose=True, + config=config, + workers_per_node=2, + backend="spark", + model_dir=temp_dir) + + res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25, feature_cols=["feature"], label_cols=["label"], validation_data=df, - validation_steps=1, - initial_epoch=5) + validation_steps=1) + + print("start saving") + trainer.save(os.path.join(temp_dir, "saved_model")) + res = trainer.evaluate(df, batch_size=4, num_steps=25, feature_cols=["feature"], label_cols=["label"]) print("validation result: ", res) - res = trainer.predict(df, feature_cols=["feature"]).collect() - print("predict result: ", res) + before_res = trainer.predict(df, feature_cols=["feature"]).collect() + expect_res = np.concatenate([part["prediction"] for part in before_res]) + + trainer.load(os.path.join(temp_dir, "saved_model")) + + # continous predicting + after_res = trainer.predict(df, feature_cols=["feature"]).collect() + pred_res = np.concatenate([part["prediction"] for part in after_res]) + + assert np.array_equal(expect_res, pred_res) finally: shutil.rmtree(temp_dir) From d2f857a41fd13db4b803e03d29b39b6e83903d54 Mon Sep 17 00:00:00 2001 From: sgwhat Date: Fri, 19 Aug 2022 17:20:27 +0800 Subject: [PATCH 4/5] use orca.data to replace dllib --- .../src/bigdl/orca/learn/tf2/pyspark_estimator.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py b/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py index bd6d91b66d0..9a465d9b15a 100644 --- a/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py +++ b/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py @@ -28,11 +28,12 @@ from bigdl.dllib.utils.common import get_node_and_core_number from bigdl.dllib.utils.file_utils import enable_multi_fs_load, enable_multi_fs_save, \ - get_remote_file_to_local, is_local_path, get_remote_files_with_prefix_to_local, \ - append_suffix, put_local_file_to_remote, put_local_files_with_prefix_to_remote - + is_local_path, append_suffix from bigdl.dllib.utils.utils import get_node_ip -from bigdl.orca.data.file import is_file, exists + +from bigdl.orca.data.file import is_file, exists, get_remote_file_to_local, \ + get_remote_files_with_prefix_to_local, put_local_file_to_remote, \ + put_local_files_with_prefix_to_remote from bigdl.orca.learn.tf2.spark_runner import SparkRunner from bigdl.orca.learn.utils import find_free_port, find_ip_and_free_port from bigdl.orca.learn.utils import maybe_dataframe_to_xshards, dataframe_to_xshards, \ @@ -482,7 +483,10 @@ def save(self, saving to SavedModel. """ # get current model - model = self.get_model() + if exists(self._model_saved_path): + model = load_model(self._model_saved_path) + else: + model = self.get_model() # save model save_model(model, filepath, overwrite=overwrite, include_optimizer=include_optimizer, save_format=save_format, signatures=signatures, options=options) From a2a7668256f7cd1896c27ff033853a9398b21527 Mon Sep 17 00:00:00 2001 From: sgwhat Date: Mon, 22 Aug 2022 11:22:24 +0800 Subject: [PATCH 5/5] update --- python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py b/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py index 9a465d9b15a..295a16fd697 100644 --- a/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py +++ b/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py @@ -213,8 +213,7 @@ def transform_func(iter, init_param, param): try: temp_dir = tempfile.mkdtemp() get_remote_file_to_local(os.path.join(self.model_dir, "state.pkl"), - os.path.join(temp_dir, "state.pkl"), - over_write=True) + os.path.join(temp_dir, "state.pkl")) import pickle with open(os.path.join(temp_dir, "state.pkl"), 'rb') as f: state = pickle.load(f)