Skip to content

Commit

Permalink
Orca PyTorch Estimator input torch loss instance (#3003)
Browse files Browse the repository at this point in the history
* loss instance

* update
  • Loading branch information
hkvision authored and yangw1234 committed Sep 27, 2021
1 parent e50ffab commit 43cf9ff
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@ def __init__(

# todo remove ray_ctx to run on workers
ray_ctx = RayContext.get()
if not (callable(model_creator) and callable(optimizer_creator)):
import types
if not (isinstance(model_creator, types.FunctionType) and
isinstance(optimizer_creator, types.FunctionType)): # Torch model is also callable.
raise ValueError(
"Must provide a callable model_creator and optimizer_creator")
"Must provide a function for both model_creator and optimizer_creator")

self.model_creator = model_creator
self.optimizer_creator = optimizer_creator
Expand Down
10 changes: 6 additions & 4 deletions python/orca/src/bigdl/orca/learn/pytorch/torch_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,12 @@ def _create_loss(self):
if not self.loss_creator:
return
logger.debug("Creating loss.")
if inspect.isclass(self.loss_creator) and issubclass(
self.loss_creator, torch.nn.modules.loss._Loss):
self.criterion = self.loss_creator()
else:
if isinstance(self.loss_creator, torch.nn.modules.loss._Loss):
self.criterion = self.loss_creator
else: # Torch loss is also callable.
import types
assert isinstance(self.loss_creator, types.FunctionType), \
"Must provide a torch loss instance or a loss_creator function"
self.criterion = self.loss_creator(self.config)

def _create_schedulers_if_available(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_train(self):
estimator = Estimator.from_torch(
model=model_creator,
optimizer=optimizer_creator,
loss=nn.MSELoss,
loss=nn.MSELoss(),
scheduler_creator=scheduler_creator,
config={
"lr": 1e-2, # used in optimizer_creator
Expand All @@ -56,7 +56,7 @@ def test_save_and_restore(self):
estimator1 = Estimator.from_torch(
model=model_creator,
optimizer=optimizer_creator,
loss=nn.MSELoss,
loss=nn.MSELoss(),
scheduler_creator=scheduler_creator,
config={
"lr": 1e-2, # used in optimizer_creator
Expand All @@ -75,7 +75,7 @@ def test_save_and_restore(self):
estimator2 = Estimator.from_torch(
model=model_creator,
optimizer=optimizer_creator,
loss=nn.MSELoss,
loss=nn.MSELoss(),
scheduler_creator=scheduler_creator,
config={
"lr": 1e-2, # used in optimizer_creator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class TestPyTorchTrainer(TestCase):
def test_linear(self):
estimator = Estimator.from_torch(model=get_model,
optimizer=get_optimizer,
loss=nn.MSELoss,
loss=nn.MSELoss(),
config={"lr": 1e-2, "hidden_size": 1,
"batch_size": 128},
backend="pytorch")
Expand Down

0 comments on commit 43cf9ff

Please sign in to comment.