Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Orca] Refactor tf2 pyspark estimator save model #5425

Merged
merged 6 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@

from bigdl.dllib.utils.common import get_node_and_core_number
from bigdl.dllib.utils.file_utils import enable_multi_fs_load, enable_multi_fs_save, \
get_remote_file_to_local, is_local_path, get_remote_files_with_prefix_to_local, \
append_suffix, put_local_file_to_remote, put_local_files_with_prefix_to_remote

is_local_path, append_suffix
from bigdl.dllib.utils.utils import get_node_ip
from bigdl.orca.data.file import is_file, exists

from bigdl.orca.data.file import is_file, exists, get_remote_file_to_local, \
get_remote_files_with_prefix_to_local, put_local_file_to_remote, \
put_local_files_with_prefix_to_remote
from bigdl.orca.learn.tf2.spark_runner import SparkRunner
from bigdl.orca.learn.utils import find_free_port, find_ip_and_free_port
from bigdl.orca.learn.utils import maybe_dataframe_to_xshards, dataframe_to_xshards, \
Expand Down Expand Up @@ -217,8 +218,7 @@ def transform_func(iter, init_param, param):
try:
temp_dir = tempfile.mkdtemp()
get_remote_file_to_local(os.path.join(self.model_dir, "state.pkl"),
os.path.join(temp_dir, "state.pkl"),
over_write=True)
os.path.join(temp_dir, "state.pkl"))
import pickle
with open(os.path.join(temp_dir, "state.pkl"), 'rb') as f:
state = pickle.load(f)
Expand Down Expand Up @@ -492,7 +492,7 @@ def save(self,
if exists(self._model_saved_path):
model = load_model(self._model_saved_path)
else:
model = self.model_creator(self.config)
model = self.get_model()
# save model
save_model(model, filepath, overwrite=overwrite, include_optimizer=include_optimizer,
save_format=save_format, signatures=signatures, options=options)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def test_checkpoint_model(self):
finally:
shutil.rmtree(temp_dir)

def test_save_load_model(self):
def test_save_load_model_h5(self):
sc = OrcaContext.get_spark_context()
rdd = sc.range(0, 100)
spark = OrcaContext.get_spark_session()
Expand Down Expand Up @@ -527,22 +527,71 @@ def test_save_load_model(self):

print("start saving")
trainer.save(os.path.join(temp_dir, "a.h5"))

res = trainer.evaluate(df, batch_size=4, num_steps=25, feature_cols=["feature"],
label_cols=["label"])
print("validation result: ", res)

before_res = trainer.predict(df, feature_cols=["feature"]).collect()
expect_res = np.concatenate([part["prediction"] for part in before_res])

trainer.load(os.path.join(temp_dir, "a.h5"))
trainer.save(os.path.join(temp_dir, "saved_model"))
trainer.load(os.path.join(temp_dir, "saved_model"))
# continous training
res = trainer.fit(df, epochs=10, batch_size=4, steps_per_epoch=25,

# continous predicting
after_res = trainer.predict(df, feature_cols=["feature"]).collect()
pred_res = np.concatenate([part["prediction"] for part in after_res])

assert np.array_equal(expect_res, pred_res)
finally:
shutil.rmtree(temp_dir)

def test_save_load_model_savemodel(self):
sc = OrcaContext.get_spark_context()
rdd = sc.range(0, 100)
spark = OrcaContext.get_spark_session()

from pyspark.ml.linalg import DenseVector
df = rdd.map(lambda x: (DenseVector(np.random.randn(1, ).astype(np.float)),
int(np.random.randint(0, 2, size=())))).toDF(["feature", "label"])

config = {
"lr": 0.2
}

try:
temp_dir = tempfile.mkdtemp()

trainer = Estimator.from_keras(
model_creator=model_creator,
verbose=True,
config=config,
workers_per_node=2,
backend="spark",
model_dir=temp_dir)

res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25,
feature_cols=["feature"],
label_cols=["label"],
validation_data=df,
validation_steps=1,
initial_epoch=5)
validation_steps=1)

print("start saving")
trainer.save(os.path.join(temp_dir, "saved_model"))

res = trainer.evaluate(df, batch_size=4, num_steps=25, feature_cols=["feature"],
label_cols=["label"])
print("validation result: ", res)

res = trainer.predict(df, feature_cols=["feature"]).collect()
print("predict result: ", res)
before_res = trainer.predict(df, feature_cols=["feature"]).collect()
expect_res = np.concatenate([part["prediction"] for part in before_res])

trainer.load(os.path.join(temp_dir, "saved_model"))

# continous predicting
after_res = trainer.predict(df, feature_cols=["feature"]).collect()
pred_res = np.concatenate([part["prediction"] for part in after_res])

assert np.array_equal(expect_res, pred_res)
finally:
shutil.rmtree(temp_dir)

Expand Down