diff --git a/python/orca/src/bigdl/orca/learn/pytorch/estimator.py b/python/orca/src/bigdl/orca/learn/pytorch/estimator.py index 98e2876c301..0e36e11c7d3 100644 --- a/python/orca/src/bigdl/orca/learn/pytorch/estimator.py +++ b/python/orca/src/bigdl/orca/learn/pytorch/estimator.py @@ -55,6 +55,7 @@ def from_torch(*, config=None, scheduler_step_freq="batch", use_tqdm=False, + workers_per_node=1, backend="horovod"): if backend == "horovod": return PyTorchHorovodEstimatorWrapper(model_creator=model, @@ -65,7 +66,8 @@ def from_torch(*, initialization_hook=initialization_hook, config=config, scheduler_step_freq=scheduler_step_freq, - use_tqdm=use_tqdm) + use_tqdm=use_tqdm, + workers_per_node=workers_per_node) elif backend == "bigdl": return PytorchSparkEstimatorWrapper(model=model, loss=loss, @@ -87,7 +89,8 @@ def __init__(self, initialization_hook=None, config=None, scheduler_step_freq="batch", - use_tqdm=False): + use_tqdm=False, + workers_per_node=1): from zoo.orca.learn.pytorch.pytorch_horovod_estimator import PyTorchHorovodEstimator self.estimator = PyTorchHorovodEstimator(model_creator=model_creator, optimizer_creator=optimizer_creator, @@ -97,7 +100,8 @@ def __init__(self, initialization_hook=initialization_hook, config=config, scheduler_step_freq=scheduler_step_freq, - use_tqdm=use_tqdm) + use_tqdm=use_tqdm, + workers_per_node=workers_per_node) def fit(self, data, epochs=1, profile=False, reduce_results=True, info=None): """