Skip to content

Commit

Permalink
[Orca] Refacor model_dir as optional in pytorch pyspark estimator (#…
Browse files Browse the repository at this point in the history
…5738)

* refactor model_dir as optional in pytorch estimator

* update ut to test non model_dir
  • Loading branch information
sgwhat authored Sep 14, 2022
1 parent 8817715 commit affe548
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,7 @@ def __init__(
invalidInputError(False,
"If a loss_creator is not provided, you must "
"provide a custom training operator.")
if not model_dir:
invalidInputError(False,
"Please specify model directory when using spark backend")

self.model_dir = model_dir

self.model_creator = model_creator
Expand Down Expand Up @@ -298,8 +296,12 @@ def transform_func(iter, init_param, param):
res = self.workerRDD.barrier().mapPartitions(
lambda iter: transform_func(iter, init_params, params)).collect()

self.state_dict = PyTorchPySparkEstimator._get_state_dict_from_remote(self.model_dir)
worker_stats = res
if self.model_dir is not None:
self.state_dict = PyTorchPySparkEstimator._get_state_dict_from_remote(self.model_dir)
worker_stats = res
else:
self.state_dict = res[0]
worker_stats = res[1]

epoch_stats = list(map(list, zip(*worker_stats)))
if reduce_results:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def __init__(self,
self.mode = mode
self.backend = backend
self.cluster_info = cluster_info
invalidInputError(model_dir, "model_dir cannot be null")
self.model_dir = model_dir
self.log_to_driver = log_to_driver

Expand Down Expand Up @@ -142,9 +141,13 @@ def train_epochs(self, data_creator, epochs=1, batch_size=32, profile=False,
LogMonitor.stop_log_monitor(self.log_path, self.logger_thread, self.thread_stop)

if self.rank == 0:
save_pkl(state_dict, os.path.join(self.model_dir, "state.pkl"))
if self.model_dir is not None:
save_pkl(state_dict, os.path.join(self.model_dir, "state.pkl"))

return [stats_list]
if self.model_dir is not None:
return [stats_list]
else:
return state_dict, [stats_list]

def validate(self, data_creator, batch_size=32, num_steps=None, profile=False,
info=None, wrap_dataloader=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def get_optimizer(model, config):


def get_estimator(workers_per_node=1, model_fn=get_model, sync_stats=False,
log_level=logging.INFO, model_dir=None):
log_level=logging.INFO):
estimator = Estimator.from_torch(model=model_fn,
optimizer=get_optimizer,
loss=nn.BCELoss(),
Expand All @@ -183,8 +183,7 @@ def get_estimator(workers_per_node=1, model_fn=get_model, sync_stats=False,
workers_per_node=workers_per_node,
backend="spark",
sync_stats=sync_stats,
log_level=log_level,
model_dir=model_dir)
log_level=log_level)
return estimator


Expand All @@ -196,7 +195,7 @@ def tearDown(self) -> None:
shutil.rmtree(self.model_dir)

def test_data_creator_convergence(self):
estimator = get_estimator(workers_per_node=2, model_dir=self.model_dir)
estimator = get_estimator(workers_per_node=2)
start_val_stats = estimator.evaluate(val_data_loader, batch_size=64)
print(start_val_stats)
train_stats = estimator.fit(train_data_loader, epochs=4, batch_size=128,
Expand All @@ -218,7 +217,7 @@ def test_data_creator_convergence(self):
def test_spark_xshards(self):
from bigdl.dllib.nncontext import init_nncontext
from bigdl.orca.data import SparkXShards
estimator = get_estimator(workers_per_node=1, model_dir=self.model_dir)
estimator = get_estimator(workers_per_node=1)
sc = init_nncontext()
x_rdd = sc.parallelize(np.random.rand(4000, 1, 50).astype(np.float32))
# torch 1.7.1+ requires target size same as output size, which is (batch, 1)
Expand Down Expand Up @@ -248,7 +247,7 @@ def test_dataframe_train_eval(self):

df = spark.createDataFrame(data=data, schema=schema)

estimator = get_estimator(workers_per_node=2, model_dir=self.model_dir)
estimator = get_estimator(workers_per_node=2)
estimator.fit(df, batch_size=4, epochs=2,
validation_data=df,
feature_cols=["feature"],
Expand All @@ -273,7 +272,7 @@ def test_dataframe_shard_size_train_eval(self):
])
df = spark.createDataFrame(data=data, schema=schema)

estimator = get_estimator(workers_per_node=2, model_dir=self.model_dir)
estimator = get_estimator(workers_per_node=2)
estimator.fit(df, batch_size=4, epochs=2,
feature_cols=["feature"],
label_cols=["label"])
Expand All @@ -295,7 +294,7 @@ def test_partition_num_less_than_workers(self):

df = spark.createDataFrame(data=data, schema=schema)

estimator = get_estimator(workers_per_node=2, model_dir=self.model_dir)
estimator = get_estimator(workers_per_node=2)
assert df.rdd.getNumPartitions() < estimator.num_workers

estimator.fit(df, batch_size=4, epochs=2,
Expand All @@ -316,8 +315,7 @@ def test_dataframe_predict(self):
).toDF(["feature", "label"])

estimator = get_estimator(workers_per_node=2,
model_fn=lambda config: IdentityNet(),
model_dir=self.model_dir)
model_fn=lambda config: IdentityNet())
result = estimator.predict(df, batch_size=4,
feature_cols=["feature"])
expr = "sum(cast(feature <> to_array(prediction) as int)) as error"
Expand All @@ -331,8 +329,7 @@ def test_xshards_predict_save_load(self):
shards = SparkXShards(shards)

estimator = get_estimator(workers_per_node=2,
model_fn=lambda config: IdentityNet(),
model_dir=self.model_dir)
model_fn=lambda config: IdentityNet())
result_shards = estimator.predict(shards, batch_size=4)
result_before = np.concatenate([shard["prediction"] for shard in result_shards.collect()])
expected_result = np.concatenate([shard["x"] for shard in result_shards.collect()])
Expand Down Expand Up @@ -388,8 +385,7 @@ def test_multiple_inputs_model(self):
df = spark.createDataFrame(data=data, schema=schema)

estimator = get_estimator(workers_per_node=2,
model_fn=lambda config: MultiInputNet(),
model_dir=self.model_dir)
model_fn=lambda config: MultiInputNet())
estimator.fit(df, batch_size=4, epochs=2,
validation_data=df,
feature_cols=["f1", "f2"],
Expand Down Expand Up @@ -438,8 +434,7 @@ def get_optimizer(model, config):
config={},
workers_per_node=2,
backend="spark",
sync_stats=False,
model_dir=self.model_dir)
sync_stats=False)

stats = estimator.fit(df, batch_size=4, epochs=2,
validation_data=df,
Expand All @@ -466,8 +461,7 @@ def test_checkpoint_callback(self):
df = spark.createDataFrame(data=data, schema=schema)
df = df.cache()

estimator = get_estimator(workers_per_node=2, model_dir=self.model_dir,
log_level=logging.DEBUG)
estimator = get_estimator(workers_per_node=2, log_level=logging.DEBUG)

callbacks = [
ModelCheckpoint(filepath=os.path.join(self.model_dir, "test-{epoch}"),
Expand All @@ -487,8 +481,8 @@ def test_checkpoint_callback(self):
latest_checkpoint_path = Estimator.latest_checkpoint(self.model_dir)
assert os.path.isfile(latest_checkpoint_path)
estimator.shutdown()
new_estimator = get_estimator(workers_per_node=2, model_dir=self.model_dir,
log_level=logging.DEBUG)

new_estimator = get_estimator(workers_per_node=2, log_level=logging.DEBUG)
new_estimator.load_checkpoint(latest_checkpoint_path)
eval_after = new_estimator.evaluate(df, batch_size=4,
feature_cols=["feature"],
Expand All @@ -514,8 +508,7 @@ def test_manual_ckpt(self):
df = spark.createDataFrame(data=data, schema=schema)
df = df.cache()

estimator = get_estimator(workers_per_node=2, model_dir=self.model_dir,
log_level=logging.DEBUG)
estimator = get_estimator(workers_per_node=2, log_level=logging.DEBUG)
estimator.fit(df, batch_size=4, epochs=epochs,
feature_cols=["feature"],
label_cols=["label"])
Expand All @@ -528,8 +521,7 @@ def test_manual_ckpt(self):
ckpt_file = os.path.join(temp_dir, "manual.ckpt")
estimator.save_checkpoint(ckpt_file)
estimator.shutdown()
new_estimator = get_estimator(workers_per_node=2, model_dir=self.model_dir,
log_level=logging.DEBUG)
new_estimator = get_estimator(workers_per_node=2, log_level=logging.DEBUG)
new_estimator.load_checkpoint(ckpt_file)
eval_after = new_estimator.evaluate(df, batch_size=4,
feature_cols=["feature"],
Expand All @@ -540,7 +532,7 @@ def test_manual_ckpt(self):
shutil.rmtree(temp_dir)

def test_custom_callback(self):
estimator = get_estimator(workers_per_node=2, model_dir=self.model_dir)
estimator = get_estimator(workers_per_node=2)
callbacks = [CustomCallback()]
estimator.fit(train_data_loader, epochs=4, batch_size=128,
validation_data=val_data_loader, callbacks=callbacks)
Expand Down

0 comments on commit affe548

Please sign in to comment.