Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Orca] Refactor model_dir as an option in tf2 pyspark estimator #5541

Merged
merged 13 commits into from
Sep 1, 2022
3 changes: 0 additions & 3 deletions python/orca/src/bigdl/orca/learn/tf2/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,6 @@ def from_keras(*,
if cpu_binding:
invalidInputError(False,
"cpu_binding should not be True when using spark backend")
if not model_dir:
invalidInputError(False,
"Please specify model directory when using spark backend")
from bigdl.orca.learn.tf2.pyspark_estimator import SparkTFEstimator
return SparkTFEstimator(model_creator=model_creator,
config=config, verbose=verbose,
Expand Down
13 changes: 9 additions & 4 deletions python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ def transform_func(iter, init_param, param):
res = self.workerRDD.barrier().mapPartitions(
lambda iter: transform_func(iter, init_params, params)).collect()

if self.model_dir:
if self.model_dir is not None:
result = res
try:
temp_dir = tempfile.mkdtemp()
get_remote_file_to_local(os.path.join(self.model_dir, "state.pkl"),
Expand All @@ -225,8 +226,11 @@ def transform_func(iter, init_param, param):
self.model_weights = state['weights']
finally:
shutil.rmtree(temp_dir)
else:
result = res[0]
self.model_weights = res[1]

return res[0]
return result[0]

def evaluate(self, data, batch_size=32, num_steps=None, verbose=1,
sample_weight=None, callbacks=None, data_config=None,
Expand Down Expand Up @@ -489,7 +493,7 @@ def save(self,
saving to SavedModel.
"""
# get current model
if exists(self._model_saved_path):
if self.model_dir is not None and exists(self._model_saved_path):
model = load_model(self._model_saved_path)
else:
model = self.get_model()
Expand All @@ -513,7 +517,8 @@ def load(self, filepath, custom_objects=None, compile=True):
model = load_model(filepath, custom_objects=custom_objects, compile=compile)
self.model_weights = model.get_weights()
# update remote model
save_model(model, self._model_saved_path, save_format="h5", filemode=0o666)
if self.model_dir is not None:
save_model(model, self._model_saved_path, save_format="h5", filemode=0o666)

def get_model(self):
"""
Expand Down
12 changes: 8 additions & 4 deletions python/orca/src/bigdl/orca/learn/tf2/spark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def distributed_train_func(self, data_creator, config, epochs=1, verbose=1,
runs a training epoch and updates the model parameters
"""
with self.strategy.scope():
if exists(self._model_saved_path):
if self.model_dir is not None and exists(self._model_saved_path):
# for continous training
model = load_model(self._model_saved_path)
else:
Expand Down Expand Up @@ -336,7 +336,6 @@ def step(self, data_creator, epochs=1, batch_size=32, verbose=1,
validation_steps=validation_steps,
validation_freq=validation_freq
)
weights = model.get_weights()
if history is None:
stats = {}
else:
Expand All @@ -345,14 +344,19 @@ def step(self, data_creator, epochs=1, batch_size=32, verbose=1,
if self.model_dir is not None:
save_model(model, self._model_saved_path, save_format="h5")
model_state = {
"weights": weights,
"weights": model.get_weights(),
"optimizer_weights": model.optimizer.get_weights()
}
save_pkl(model_state, os.path.join(self.model_dir, "state.pkl"))
else:
weights = model.get_weights()

if self.need_to_log_to_driver:
LogMonitor.stop_log_monitor(self.log_path, self.logger_thread, self.thread_stop)
return [stats]
if self.model_dir is not None:
return [stats]
else:
return [stats], weights
else:
temp_dir = tempfile.mkdtemp()
try:
Expand Down
119 changes: 49 additions & 70 deletions python/orca/test/bigdl/orca/learn/ray/tf/test_tf_spark_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def test_dataframe(self):
verbose=True,
config=config,
workers_per_node=2,
backend="spark",
model_dir=temp_dir)
backend="spark")

res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25,
feature_cols=["feature"],
Expand Down Expand Up @@ -126,8 +125,7 @@ def test_dataframe_with_empty_partition(self):
verbose=True,
config=config,
workers_per_node=3,
backend="spark",
model_dir=temp_dir)
backend="spark")

res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25,
feature_cols=["feature"],
Expand Down Expand Up @@ -178,8 +176,7 @@ def model_creator(config):
verbose=True,
config=config,
workers_per_node=1,
backend="spark",
model_dir=temp_dir)
backend="spark")

res = trainer.fit(data=xshards, epochs=5, batch_size=4, steps_per_epoch=25,
feature_cols=["user", "item"], label_cols=["label"])
Expand Down Expand Up @@ -214,8 +211,7 @@ def test_checkpoint_weights(self):
verbose=True,
config=config,
workers_per_node=2,
backend="spark",
model_dir=temp_dir)
backend="spark")

callbacks = [
tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(temp_dir, "ckpt_{epoch}"),
Expand Down Expand Up @@ -260,8 +256,7 @@ def test_checkpoint_weights_h5(self):
verbose=True,
config=config,
workers_per_node=2,
backend="spark",
model_dir=temp_dir)
backend="spark")

callbacks = [
tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(temp_dir, "ckpt_weights.h5"),
Expand Down Expand Up @@ -303,35 +298,29 @@ def test_dataframe_shard_size(self):
"lr": 0.2
}

try:
temp_dir = tempfile.mkdtemp()

trainer = Estimator.from_keras(
model_creator=model_creator,
verbose=True,
config=config,
workers_per_node=2,
backend="spark",
model_dir=temp_dir)

res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25,
validation_data=val_df,
validation_steps=2,
feature_cols=["feature"],
label_cols=["label"])

res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25,
feature_cols=["feature"],
label_cols=["label"])

res = trainer.evaluate(val_df, batch_size=4, num_steps=25, feature_cols=["feature"],
label_cols=["label"])
print("validation result: ", res)

res = trainer.predict(df, feature_cols=["feature"]).collect()
print("predict result: ", res)
finally:
shutil.rmtree(temp_dir)
trainer = Estimator.from_keras(
model_creator=model_creator,
verbose=True,
config=config,
workers_per_node=2,
backend="spark")

res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25,
validation_data=val_df,
validation_steps=2,
feature_cols=["feature"],
label_cols=["label"])

res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25,
feature_cols=["feature"],
label_cols=["label"])

res = trainer.evaluate(val_df, batch_size=4, num_steps=25, feature_cols=["feature"],
label_cols=["label"])
print("validation result: ", res)

res = trainer.predict(df, feature_cols=["feature"]).collect()
print("predict result: ", res)
OrcaContext._shard_size = None

def test_dataframe_different_train_val(self):
Expand All @@ -351,32 +340,26 @@ def test_dataframe_different_train_val(self):
"lr": 0.2
}

try:
temp_dir = tempfile.mkdtemp()
trainer = Estimator.from_keras(
model_creator=model_creator,
verbose=True,
config=config,
workers_per_node=2,
backend="spark")

trainer = Estimator.from_keras(
model_creator=model_creator,
verbose=True,
config=config,
workers_per_node=2,
backend="spark",
model_dir=temp_dir)
res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25,
validation_data=val_df,
validation_steps=2,
feature_cols=["feature"],
label_cols=["label"])

res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25,
validation_data=val_df,
validation_steps=2,
feature_cols=["feature"],
label_cols=["label"])

res = trainer.evaluate(val_df, batch_size=4, num_steps=25, feature_cols=["feature"],
label_cols=["label"])
print("validation result: ", res)
res = trainer.evaluate(val_df, batch_size=4, num_steps=25, feature_cols=["feature"],
label_cols=["label"])
print("validation result: ", res)

res = trainer.predict(df, feature_cols=["feature"]).collect()
print("predict result: ", res)
trainer.shutdown()
finally:
shutil.rmtree(temp_dir)
res = trainer.predict(df, feature_cols=["feature"]).collect()
print("predict result: ", res)
trainer.shutdown()

def test_tensorboard(self):
sc = OrcaContext.get_spark_context()
Expand All @@ -399,8 +382,7 @@ def test_tensorboard(self):
verbose=True,
config=config,
workers_per_node=2,
backend="spark",
model_dir=temp_dir)
backend="spark")

callbacks = [
tf.keras.callbacks.TensorBoard(log_dir=os.path.join(temp_dir, "train_log"),
Expand Down Expand Up @@ -460,8 +442,7 @@ def test_checkpoint_model(self):
verbose=True,
config=config,
workers_per_node=2,
backend="spark",
model_dir=temp_dir)
backend="spark")

callbacks = [
tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(temp_dir, "ckpt_{epoch}"),
Expand Down Expand Up @@ -517,8 +498,7 @@ def test_save_load_model_h5(self):
verbose=True,
config=config,
workers_per_node=2,
backend="spark",
model_dir=temp_dir)
backend="spark")

res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25,
feature_cols=["feature"],
Expand Down Expand Up @@ -567,8 +547,7 @@ def test_save_load_model_savemodel(self):
verbose=True,
config=config,
workers_per_node=2,
backend="spark",
model_dir=temp_dir)
backend="spark")

res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25,
feature_cols=["feature"],
Expand Down