From 6fe518acc8970352e7cf90c7f126b0206f5b0b16 Mon Sep 17 00:00:00 2001 From: SONG Ge <38711238+sgwhat@users.noreply.github.com> Date: Thu, 1 Sep 2022 10:48:55 +0800 Subject: [PATCH] [Orca] Refactor `model_dir` as an option in tf2 pyspark estimator (#5541) * refactor model_dir as an option * modify ut to test non model_dir * update coding format * update return results * add support in load api * move model_dir from ut --- .../src/bigdl/orca/learn/tf2/estimator.py | 3 - .../bigdl/orca/learn/tf2/pyspark_estimator.py | 13 +- .../src/bigdl/orca/learn/tf2/spark_runner.py | 12 +- .../learn/ray/tf/test_tf_spark_estimator.py | 119 ++++++++---------- 4 files changed, 66 insertions(+), 81 deletions(-) diff --git a/python/orca/src/bigdl/orca/learn/tf2/estimator.py b/python/orca/src/bigdl/orca/learn/tf2/estimator.py index 9d6ad451cbe..226aa9505f1 100644 --- a/python/orca/src/bigdl/orca/learn/tf2/estimator.py +++ b/python/orca/src/bigdl/orca/learn/tf2/estimator.py @@ -72,9 +72,6 @@ def from_keras(*, if cpu_binding: invalidInputError(False, "cpu_binding should not be True when using spark backend") - if not model_dir: - invalidInputError(False, - "Please specify model directory when using spark backend") from bigdl.orca.learn.tf2.pyspark_estimator import SparkTFEstimator return SparkTFEstimator(model_creator=model_creator, config=config, verbose=verbose, diff --git a/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py b/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py index 213d930def3..34dda1ca363 100644 --- a/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py +++ b/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py @@ -214,7 +214,8 @@ def transform_func(iter, init_param, param): res = self.workerRDD.barrier().mapPartitions( lambda iter: transform_func(iter, init_params, params)).collect() - if self.model_dir: + if self.model_dir is not None: + result = res try: temp_dir = tempfile.mkdtemp() get_remote_file_to_local(os.path.join(self.model_dir, "state.pkl"), @@ -225,8 +226,11 @@ def transform_func(iter, init_param, param): self.model_weights = state['weights'] finally: shutil.rmtree(temp_dir) + else: + result = res[0] + self.model_weights = res[1] - return res[0] + return result[0] def evaluate(self, data, batch_size=32, num_steps=None, verbose=1, sample_weight=None, callbacks=None, data_config=None, @@ -489,7 +493,7 @@ def save(self, saving to SavedModel. """ # get current model - if exists(self._model_saved_path): + if self.model_dir is not None and exists(self._model_saved_path): model = load_model(self._model_saved_path) else: model = self.get_model() @@ -513,7 +517,8 @@ def load(self, filepath, custom_objects=None, compile=True): model = load_model(filepath, custom_objects=custom_objects, compile=compile) self.model_weights = model.get_weights() # update remote model - save_model(model, self._model_saved_path, save_format="h5", filemode=0o666) + if self.model_dir is not None: + save_model(model, self._model_saved_path, save_format="h5", filemode=0o666) def get_model(self): """ diff --git a/python/orca/src/bigdl/orca/learn/tf2/spark_runner.py b/python/orca/src/bigdl/orca/learn/tf2/spark_runner.py index 2277e73cc0a..040529a195b 100644 --- a/python/orca/src/bigdl/orca/learn/tf2/spark_runner.py +++ b/python/orca/src/bigdl/orca/learn/tf2/spark_runner.py @@ -271,7 +271,7 @@ def distributed_train_func(self, data_creator, config, epochs=1, verbose=1, runs a training epoch and updates the model parameters """ with self.strategy.scope(): - if exists(self._model_saved_path): + if self.model_dir is not None and exists(self._model_saved_path): # for continous training model = load_model(self._model_saved_path) else: @@ -336,7 +336,6 @@ def step(self, data_creator, epochs=1, batch_size=32, verbose=1, validation_steps=validation_steps, validation_freq=validation_freq ) - weights = model.get_weights() if history is None: stats = {} else: @@ -345,14 +344,19 @@ def step(self, data_creator, epochs=1, batch_size=32, verbose=1, if self.model_dir is not None: save_model(model, self._model_saved_path, save_format="h5") model_state = { - "weights": weights, + "weights": model.get_weights(), "optimizer_weights": model.optimizer.get_weights() } save_pkl(model_state, os.path.join(self.model_dir, "state.pkl")) + else: + weights = model.get_weights() if self.need_to_log_to_driver: LogMonitor.stop_log_monitor(self.log_path, self.logger_thread, self.thread_stop) - return [stats] + if self.model_dir is not None: + return [stats] + else: + return [stats], weights else: temp_dir = tempfile.mkdtemp() try: diff --git a/python/orca/test/bigdl/orca/learn/ray/tf/test_tf_spark_estimator.py b/python/orca/test/bigdl/orca/learn/ray/tf/test_tf_spark_estimator.py index b463c590d8f..5945b731043 100644 --- a/python/orca/test/bigdl/orca/learn/ray/tf/test_tf_spark_estimator.py +++ b/python/orca/test/bigdl/orca/learn/ray/tf/test_tf_spark_estimator.py @@ -80,8 +80,7 @@ def test_dataframe(self): verbose=True, config=config, workers_per_node=2, - backend="spark", - model_dir=temp_dir) + backend="spark") res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25, feature_cols=["feature"], @@ -126,8 +125,7 @@ def test_dataframe_with_empty_partition(self): verbose=True, config=config, workers_per_node=3, - backend="spark", - model_dir=temp_dir) + backend="spark") res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25, feature_cols=["feature"], @@ -178,8 +176,7 @@ def model_creator(config): verbose=True, config=config, workers_per_node=1, - backend="spark", - model_dir=temp_dir) + backend="spark") res = trainer.fit(data=xshards, epochs=5, batch_size=4, steps_per_epoch=25, feature_cols=["user", "item"], label_cols=["label"]) @@ -214,8 +211,7 @@ def test_checkpoint_weights(self): verbose=True, config=config, workers_per_node=2, - backend="spark", - model_dir=temp_dir) + backend="spark") callbacks = [ tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(temp_dir, "ckpt_{epoch}"), @@ -260,8 +256,7 @@ def test_checkpoint_weights_h5(self): verbose=True, config=config, workers_per_node=2, - backend="spark", - model_dir=temp_dir) + backend="spark") callbacks = [ tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(temp_dir, "ckpt_weights.h5"), @@ -303,35 +298,29 @@ def test_dataframe_shard_size(self): "lr": 0.2 } - try: - temp_dir = tempfile.mkdtemp() - - trainer = Estimator.from_keras( - model_creator=model_creator, - verbose=True, - config=config, - workers_per_node=2, - backend="spark", - model_dir=temp_dir) - - res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25, - validation_data=val_df, - validation_steps=2, - feature_cols=["feature"], - label_cols=["label"]) - - res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25, - feature_cols=["feature"], - label_cols=["label"]) - - res = trainer.evaluate(val_df, batch_size=4, num_steps=25, feature_cols=["feature"], - label_cols=["label"]) - print("validation result: ", res) - - res = trainer.predict(df, feature_cols=["feature"]).collect() - print("predict result: ", res) - finally: - shutil.rmtree(temp_dir) + trainer = Estimator.from_keras( + model_creator=model_creator, + verbose=True, + config=config, + workers_per_node=2, + backend="spark") + + res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25, + validation_data=val_df, + validation_steps=2, + feature_cols=["feature"], + label_cols=["label"]) + + res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25, + feature_cols=["feature"], + label_cols=["label"]) + + res = trainer.evaluate(val_df, batch_size=4, num_steps=25, feature_cols=["feature"], + label_cols=["label"]) + print("validation result: ", res) + + res = trainer.predict(df, feature_cols=["feature"]).collect() + print("predict result: ", res) OrcaContext._shard_size = None def test_dataframe_different_train_val(self): @@ -351,32 +340,26 @@ def test_dataframe_different_train_val(self): "lr": 0.2 } - try: - temp_dir = tempfile.mkdtemp() + trainer = Estimator.from_keras( + model_creator=model_creator, + verbose=True, + config=config, + workers_per_node=2, + backend="spark") - trainer = Estimator.from_keras( - model_creator=model_creator, - verbose=True, - config=config, - workers_per_node=2, - backend="spark", - model_dir=temp_dir) + res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25, + validation_data=val_df, + validation_steps=2, + feature_cols=["feature"], + label_cols=["label"]) - res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25, - validation_data=val_df, - validation_steps=2, - feature_cols=["feature"], - label_cols=["label"]) - - res = trainer.evaluate(val_df, batch_size=4, num_steps=25, feature_cols=["feature"], - label_cols=["label"]) - print("validation result: ", res) + res = trainer.evaluate(val_df, batch_size=4, num_steps=25, feature_cols=["feature"], + label_cols=["label"]) + print("validation result: ", res) - res = trainer.predict(df, feature_cols=["feature"]).collect() - print("predict result: ", res) - trainer.shutdown() - finally: - shutil.rmtree(temp_dir) + res = trainer.predict(df, feature_cols=["feature"]).collect() + print("predict result: ", res) + trainer.shutdown() def test_tensorboard(self): sc = OrcaContext.get_spark_context() @@ -399,8 +382,7 @@ def test_tensorboard(self): verbose=True, config=config, workers_per_node=2, - backend="spark", - model_dir=temp_dir) + backend="spark") callbacks = [ tf.keras.callbacks.TensorBoard(log_dir=os.path.join(temp_dir, "train_log"), @@ -460,8 +442,7 @@ def test_checkpoint_model(self): verbose=True, config=config, workers_per_node=2, - backend="spark", - model_dir=temp_dir) + backend="spark") callbacks = [ tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(temp_dir, "ckpt_{epoch}"), @@ -517,8 +498,7 @@ def test_save_load_model_h5(self): verbose=True, config=config, workers_per_node=2, - backend="spark", - model_dir=temp_dir) + backend="spark") res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25, feature_cols=["feature"], @@ -567,8 +547,7 @@ def test_save_load_model_savemodel(self): verbose=True, config=config, workers_per_node=2, - backend="spark", - model_dir=temp_dir) + backend="spark") res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25, feature_cols=["feature"],