Skip to content

Commit

Permalink
[Orca] Refactor model_dir as an option in tf2 pyspark estimator (in…
Browse files Browse the repository at this point in the history
…tel-analytics#5541)

* refactor model_dir as an option

* modify ut to test non model_dir

* update coding format

* update return results

* add support in load api

* move model_dir from ut
  • Loading branch information
sgwhat authored and ForJadeForest committed Sep 20, 2022
1 parent e87a3b4 commit 6fe518a
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 81 deletions.
3 changes: 0 additions & 3 deletions python/orca/src/bigdl/orca/learn/tf2/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,6 @@ def from_keras(*,
if cpu_binding:
invalidInputError(False,
"cpu_binding should not be True when using spark backend")
if not model_dir:
invalidInputError(False,
"Please specify model directory when using spark backend")
from bigdl.orca.learn.tf2.pyspark_estimator import SparkTFEstimator
return SparkTFEstimator(model_creator=model_creator,
config=config, verbose=verbose,
Expand Down
13 changes: 9 additions & 4 deletions python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ def transform_func(iter, init_param, param):
res = self.workerRDD.barrier().mapPartitions(
lambda iter: transform_func(iter, init_params, params)).collect()

if self.model_dir:
if self.model_dir is not None:
result = res
try:
temp_dir = tempfile.mkdtemp()
get_remote_file_to_local(os.path.join(self.model_dir, "state.pkl"),
Expand All @@ -225,8 +226,11 @@ def transform_func(iter, init_param, param):
self.model_weights = state['weights']
finally:
shutil.rmtree(temp_dir)
else:
result = res[0]
self.model_weights = res[1]

return res[0]
return result[0]

def evaluate(self, data, batch_size=32, num_steps=None, verbose=1,
sample_weight=None, callbacks=None, data_config=None,
Expand Down Expand Up @@ -489,7 +493,7 @@ def save(self,
saving to SavedModel.
"""
# get current model
if exists(self._model_saved_path):
if self.model_dir is not None and exists(self._model_saved_path):
model = load_model(self._model_saved_path)
else:
model = self.get_model()
Expand All @@ -513,7 +517,8 @@ def load(self, filepath, custom_objects=None, compile=True):
model = load_model(filepath, custom_objects=custom_objects, compile=compile)
self.model_weights = model.get_weights()
# update remote model
save_model(model, self._model_saved_path, save_format="h5", filemode=0o666)
if self.model_dir is not None:
save_model(model, self._model_saved_path, save_format="h5", filemode=0o666)

def get_model(self):
"""
Expand Down
12 changes: 8 additions & 4 deletions python/orca/src/bigdl/orca/learn/tf2/spark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def distributed_train_func(self, data_creator, config, epochs=1, verbose=1,
runs a training epoch and updates the model parameters
"""
with self.strategy.scope():
if exists(self._model_saved_path):
if self.model_dir is not None and exists(self._model_saved_path):
# for continous training
model = load_model(self._model_saved_path)
else:
Expand Down Expand Up @@ -336,7 +336,6 @@ def step(self, data_creator, epochs=1, batch_size=32, verbose=1,
validation_steps=validation_steps,
validation_freq=validation_freq
)
weights = model.get_weights()
if history is None:
stats = {}
else:
Expand All @@ -345,14 +344,19 @@ def step(self, data_creator, epochs=1, batch_size=32, verbose=1,
if self.model_dir is not None:
save_model(model, self._model_saved_path, save_format="h5")
model_state = {
"weights": weights,
"weights": model.get_weights(),
"optimizer_weights": model.optimizer.get_weights()
}
save_pkl(model_state, os.path.join(self.model_dir, "state.pkl"))
else:
weights = model.get_weights()

if self.need_to_log_to_driver:
LogMonitor.stop_log_monitor(self.log_path, self.logger_thread, self.thread_stop)
return [stats]
if self.model_dir is not None:
return [stats]
else:
return [stats], weights
else:
temp_dir = tempfile.mkdtemp()
try:
Expand Down
119 changes: 49 additions & 70 deletions python/orca/test/bigdl/orca/learn/ray/tf/test_tf_spark_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def test_dataframe(self):
verbose=True,
config=config,
workers_per_node=2,
backend="spark",
model_dir=temp_dir)
backend="spark")

res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25,
feature_cols=["feature"],
Expand Down Expand Up @@ -126,8 +125,7 @@ def test_dataframe_with_empty_partition(self):
verbose=True,
config=config,
workers_per_node=3,
backend="spark",
model_dir=temp_dir)
backend="spark")

res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25,
feature_cols=["feature"],
Expand Down Expand Up @@ -178,8 +176,7 @@ def model_creator(config):
verbose=True,
config=config,
workers_per_node=1,
backend="spark",
model_dir=temp_dir)
backend="spark")

res = trainer.fit(data=xshards, epochs=5, batch_size=4, steps_per_epoch=25,
feature_cols=["user", "item"], label_cols=["label"])
Expand Down Expand Up @@ -214,8 +211,7 @@ def test_checkpoint_weights(self):
verbose=True,
config=config,
workers_per_node=2,
backend="spark",
model_dir=temp_dir)
backend="spark")

callbacks = [
tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(temp_dir, "ckpt_{epoch}"),
Expand Down Expand Up @@ -260,8 +256,7 @@ def test_checkpoint_weights_h5(self):
verbose=True,
config=config,
workers_per_node=2,
backend="spark",
model_dir=temp_dir)
backend="spark")

callbacks = [
tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(temp_dir, "ckpt_weights.h5"),
Expand Down Expand Up @@ -303,35 +298,29 @@ def test_dataframe_shard_size(self):
"lr": 0.2
}

try:
temp_dir = tempfile.mkdtemp()

trainer = Estimator.from_keras(
model_creator=model_creator,
verbose=True,
config=config,
workers_per_node=2,
backend="spark",
model_dir=temp_dir)

res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25,
validation_data=val_df,
validation_steps=2,
feature_cols=["feature"],
label_cols=["label"])

res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25,
feature_cols=["feature"],
label_cols=["label"])

res = trainer.evaluate(val_df, batch_size=4, num_steps=25, feature_cols=["feature"],
label_cols=["label"])
print("validation result: ", res)

res = trainer.predict(df, feature_cols=["feature"]).collect()
print("predict result: ", res)
finally:
shutil.rmtree(temp_dir)
trainer = Estimator.from_keras(
model_creator=model_creator,
verbose=True,
config=config,
workers_per_node=2,
backend="spark")

res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25,
validation_data=val_df,
validation_steps=2,
feature_cols=["feature"],
label_cols=["label"])

res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25,
feature_cols=["feature"],
label_cols=["label"])

res = trainer.evaluate(val_df, batch_size=4, num_steps=25, feature_cols=["feature"],
label_cols=["label"])
print("validation result: ", res)

res = trainer.predict(df, feature_cols=["feature"]).collect()
print("predict result: ", res)
OrcaContext._shard_size = None

def test_dataframe_different_train_val(self):
Expand All @@ -351,32 +340,26 @@ def test_dataframe_different_train_val(self):
"lr": 0.2
}

try:
temp_dir = tempfile.mkdtemp()
trainer = Estimator.from_keras(
model_creator=model_creator,
verbose=True,
config=config,
workers_per_node=2,
backend="spark")

trainer = Estimator.from_keras(
model_creator=model_creator,
verbose=True,
config=config,
workers_per_node=2,
backend="spark",
model_dir=temp_dir)
res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25,
validation_data=val_df,
validation_steps=2,
feature_cols=["feature"],
label_cols=["label"])

res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25,
validation_data=val_df,
validation_steps=2,
feature_cols=["feature"],
label_cols=["label"])

res = trainer.evaluate(val_df, batch_size=4, num_steps=25, feature_cols=["feature"],
label_cols=["label"])
print("validation result: ", res)
res = trainer.evaluate(val_df, batch_size=4, num_steps=25, feature_cols=["feature"],
label_cols=["label"])
print("validation result: ", res)

res = trainer.predict(df, feature_cols=["feature"]).collect()
print("predict result: ", res)
trainer.shutdown()
finally:
shutil.rmtree(temp_dir)
res = trainer.predict(df, feature_cols=["feature"]).collect()
print("predict result: ", res)
trainer.shutdown()

def test_tensorboard(self):
sc = OrcaContext.get_spark_context()
Expand All @@ -399,8 +382,7 @@ def test_tensorboard(self):
verbose=True,
config=config,
workers_per_node=2,
backend="spark",
model_dir=temp_dir)
backend="spark")

callbacks = [
tf.keras.callbacks.TensorBoard(log_dir=os.path.join(temp_dir, "train_log"),
Expand Down Expand Up @@ -460,8 +442,7 @@ def test_checkpoint_model(self):
verbose=True,
config=config,
workers_per_node=2,
backend="spark",
model_dir=temp_dir)
backend="spark")

callbacks = [
tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(temp_dir, "ckpt_{epoch}"),
Expand Down Expand Up @@ -517,8 +498,7 @@ def test_save_load_model_h5(self):
verbose=True,
config=config,
workers_per_node=2,
backend="spark",
model_dir=temp_dir)
backend="spark")

res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25,
feature_cols=["feature"],
Expand Down Expand Up @@ -567,8 +547,7 @@ def test_save_load_model_savemodel(self):
verbose=True,
config=config,
workers_per_node=2,
backend="spark",
model_dir=temp_dir)
backend="spark")

res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25,
feature_cols=["feature"],
Expand Down

0 comments on commit 6fe518a

Please sign in to comment.