Skip to content

Commit

Permalink
expose workers_per_node in pytorch estimator (intel-analytics#2763)
Browse files Browse the repository at this point in the history
* expose workers_per_node

* remove import ray to fix jenkins random fail
  • Loading branch information
shanyu-sys committed Aug 25, 2020
1 parent 7cac9fa commit 6cac410
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions python/orca/example/learn/horovod/pytorch_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ def validation_data_creator(config):
return validation_loader


def train_example():
def train_example(workers_per_node):
estimator = Estimator.from_torch(
model=model_creator,
optimizer=optimizer_creator,
loss=nn.MSELoss,
scheduler_creator=scheduler_creator,
workers_per_node=workers_per_node,
config={
"lr": 1e-2, # used in optimizer_creator
"hidden_size": 1, # used in model_creator
Expand All @@ -96,7 +97,7 @@ def train_example():
print("validation stats: {}".format(val_stats))

# retrieve the model
model = estimator.estimator.get_model()
model = estimator.get_model()
print("trained weight: % .2f, bias: % .2f" % (
model.weight.item(), model.bias.item()))

Expand Down Expand Up @@ -128,6 +129,10 @@ def train_example():
parser.add_argument("--object_store_memory", type=str, default="4g",
help="The memory to store data on local."
"You can change it depending on your own cluster setting.")
parser.add_argument("--workers_per_node", type=int, default=1,
help="The number of workers to run on each node")
parser.add_argument("--local_cores", type=int, default=4,
help="The number of cores while running on local mode")

args = parser.parse_args()
if args.hadoop_conf:
Expand All @@ -145,9 +150,9 @@ def train_example():
object_store_memory=args.object_store_memory)
ray_ctx.init()
else:
sc = init_spark_on_local()
sc = init_spark_on_local(cores=args.local_cores)
ray_ctx = RayContext(
sc=sc,
object_store_memory=args.object_store_memory)
ray_ctx.init()
train_example()
train_example(workers_per_node=args.workers_per_node)

0 comments on commit 6cac410

Please sign in to comment.