Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expose workers_per_node in pytorch estimator #2763

Merged
merged 2 commits into from
Aug 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyzoo/zoo/automl/regression/xgbregressor_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import zipfile
import os
import shutil
import ray

from zoo.automl.search.abstract import *
from zoo.automl.search.RayTuneSearchEngine import RayTuneSearchEngine
Expand Down
13 changes: 9 additions & 4 deletions pyzoo/zoo/examples/orca/learn/horovod/pytorch_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ def validation_data_creator(config):
return validation_loader


def train_example():
def train_example(workers_per_node):
estimator = Estimator.from_torch(
model=model_creator,
optimizer=optimizer_creator,
loss=nn.MSELoss,
scheduler_creator=scheduler_creator,
workers_per_node=workers_per_node,
config={
"lr": 1e-2, # used in optimizer_creator
"hidden_size": 1, # used in model_creator
Expand All @@ -96,7 +97,7 @@ def train_example():
print("validation stats: {}".format(val_stats))

# retrieve the model
model = estimator.estimator.get_model()
model = estimator.get_model()
print("trained weight: % .2f, bias: % .2f" % (
model.weight.item(), model.bias.item()))

Expand Down Expand Up @@ -128,6 +129,10 @@ def train_example():
parser.add_argument("--object_store_memory", type=str, default="4g",
help="The memory to store data on local."
"You can change it depending on your own cluster setting.")
parser.add_argument("--workers_per_node", type=int, default=1,
help="The number of workers to run on each node")
parser.add_argument("--local_cores", type=int, default=4,
help="The number of cores while running on local mode")

args = parser.parse_args()
if args.hadoop_conf:
Expand All @@ -145,9 +150,9 @@ def train_example():
object_store_memory=args.object_store_memory)
ray_ctx.init()
else:
sc = init_spark_on_local()
sc = init_spark_on_local(cores=args.local_cores)
ray_ctx = RayContext(
sc=sc,
object_store_memory=args.object_store_memory)
ray_ctx.init()
train_example()
train_example(workers_per_node=args.workers_per_node)
10 changes: 7 additions & 3 deletions pyzoo/zoo/orca/learn/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def from_torch(*,
config=None,
scheduler_step_freq="batch",
use_tqdm=False,
workers_per_node=1,
backend="horovod"):
if backend == "horovod":
return PyTorchHorovodEstimatorWrapper(model_creator=model,
Expand All @@ -65,7 +66,8 @@ def from_torch(*,
initialization_hook=initialization_hook,
config=config,
scheduler_step_freq=scheduler_step_freq,
use_tqdm=use_tqdm)
use_tqdm=use_tqdm,
workers_per_node=workers_per_node)
elif backend == "bigdl":
return PytorchSparkEstimatorWrapper(model=model,
loss=loss,
Expand All @@ -87,7 +89,8 @@ def __init__(self,
initialization_hook=None,
config=None,
scheduler_step_freq="batch",
use_tqdm=False):
use_tqdm=False,
workers_per_node=1):
from zoo.orca.learn.pytorch.pytorch_horovod_estimator import PyTorchHorovodEstimator
self.estimator = PyTorchHorovodEstimator(model_creator=model_creator,
optimizer_creator=optimizer_creator,
Expand All @@ -97,7 +100,8 @@ def __init__(self,
initialization_hook=initialization_hook,
config=config,
scheduler_step_freq=scheduler_step_freq,
use_tqdm=use_tqdm)
use_tqdm=use_tqdm,
workers_per_node=workers_per_node)

def fit(self, data, epochs=1, profile=False, reduce_results=True, info=None):
"""
Expand Down