Skip to content

Commit

Permalink
[Orca] Refactor tf2 pyspark estimator save model (intel-analytics#5425)
Browse files Browse the repository at this point in the history
* fix save model in pyspark est

* use a function to replace

* use orca.data to replace dllib
  • Loading branch information
sgwhat authored and ForJadeForest committed Sep 20, 2022
1 parent 5ed94c3 commit 43f276a
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 16 deletions.
14 changes: 7 additions & 7 deletions python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@

from bigdl.dllib.utils.common import get_node_and_core_number
from bigdl.dllib.utils.file_utils import enable_multi_fs_load, enable_multi_fs_save, \
get_remote_file_to_local, is_local_path, get_remote_files_with_prefix_to_local, \
append_suffix, put_local_file_to_remote, put_local_files_with_prefix_to_remote

is_local_path, append_suffix
from bigdl.dllib.utils.utils import get_node_ip
from bigdl.orca.data.file import is_file, exists

from bigdl.orca.data.file import is_file, exists, get_remote_file_to_local, \
get_remote_files_with_prefix_to_local, put_local_file_to_remote, \
put_local_files_with_prefix_to_remote
from bigdl.orca.learn.tf2.spark_runner import SparkRunner
from bigdl.orca.learn.utils import find_free_port, find_ip_and_free_port
from bigdl.orca.learn.utils import maybe_dataframe_to_xshards, dataframe_to_xshards, \
Expand Down Expand Up @@ -217,8 +218,7 @@ def transform_func(iter, init_param, param):
try:
temp_dir = tempfile.mkdtemp()
get_remote_file_to_local(os.path.join(self.model_dir, "state.pkl"),
os.path.join(temp_dir, "state.pkl"),
over_write=True)
os.path.join(temp_dir, "state.pkl"))
import pickle
with open(os.path.join(temp_dir, "state.pkl"), 'rb') as f:
state = pickle.load(f)
Expand Down Expand Up @@ -492,7 +492,7 @@ def save(self,
if exists(self._model_saved_path):
model = load_model(self._model_saved_path)
else:
model = self.model_creator(self.config)
model = self.get_model()
# save model
save_model(model, filepath, overwrite=overwrite, include_optimizer=include_optimizer,
save_format=save_format, signatures=signatures, options=options)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def test_checkpoint_model(self):
finally:
shutil.rmtree(temp_dir)

def test_save_load_model(self):
def test_save_load_model_h5(self):
sc = OrcaContext.get_spark_context()
rdd = sc.range(0, 100)
spark = OrcaContext.get_spark_session()
Expand Down Expand Up @@ -527,22 +527,71 @@ def test_save_load_model(self):

print("start saving")
trainer.save(os.path.join(temp_dir, "a.h5"))

res = trainer.evaluate(df, batch_size=4, num_steps=25, feature_cols=["feature"],
label_cols=["label"])
print("validation result: ", res)

before_res = trainer.predict(df, feature_cols=["feature"]).collect()
expect_res = np.concatenate([part["prediction"] for part in before_res])

trainer.load(os.path.join(temp_dir, "a.h5"))
trainer.save(os.path.join(temp_dir, "saved_model"))
trainer.load(os.path.join(temp_dir, "saved_model"))
# continous training
res = trainer.fit(df, epochs=10, batch_size=4, steps_per_epoch=25,

# continous predicting
after_res = trainer.predict(df, feature_cols=["feature"]).collect()
pred_res = np.concatenate([part["prediction"] for part in after_res])

assert np.array_equal(expect_res, pred_res)
finally:
shutil.rmtree(temp_dir)

def test_save_load_model_savemodel(self):
sc = OrcaContext.get_spark_context()
rdd = sc.range(0, 100)
spark = OrcaContext.get_spark_session()

from pyspark.ml.linalg import DenseVector
df = rdd.map(lambda x: (DenseVector(np.random.randn(1, ).astype(np.float)),
int(np.random.randint(0, 2, size=())))).toDF(["feature", "label"])

config = {
"lr": 0.2
}

try:
temp_dir = tempfile.mkdtemp()

trainer = Estimator.from_keras(
model_creator=model_creator,
verbose=True,
config=config,
workers_per_node=2,
backend="spark",
model_dir=temp_dir)

res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25,
feature_cols=["feature"],
label_cols=["label"],
validation_data=df,
validation_steps=1,
initial_epoch=5)
validation_steps=1)

print("start saving")
trainer.save(os.path.join(temp_dir, "saved_model"))

res = trainer.evaluate(df, batch_size=4, num_steps=25, feature_cols=["feature"],
label_cols=["label"])
print("validation result: ", res)

res = trainer.predict(df, feature_cols=["feature"]).collect()
print("predict result: ", res)
before_res = trainer.predict(df, feature_cols=["feature"]).collect()
expect_res = np.concatenate([part["prediction"] for part in before_res])

trainer.load(os.path.join(temp_dir, "saved_model"))

# continous predicting
after_res = trainer.predict(df, feature_cols=["feature"]).collect()
pred_res = np.concatenate([part["prediction"] for part in after_res])

assert np.array_equal(expect_res, pred_res)
finally:
shutil.rmtree(temp_dir)

Expand Down

0 comments on commit 43f276a

Please sign in to comment.