From 84ce18eb00af4c89a25648faa0b88bb1e2f5428e Mon Sep 17 00:00:00 2001 From: Kai Huang Date: Wed, 28 Oct 2020 15:03:49 +0800 Subject: [PATCH 1/2] loss instance --- .../zoo/orca/learn/ray/pytorch/test_pytorch_estimator.py | 6 +++--- .../zoo/orca/learn/ray/pytorch/test_pytorch_trainer.py | 2 +- .../zoo/examples/orca/learn/horovod/pytorch_estimator.py | 2 +- pyzoo/zoo/orca/learn/pytorch/torch_runner.py | 8 +++++--- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/pyzoo/test/zoo/orca/learn/ray/pytorch/test_pytorch_estimator.py b/pyzoo/test/zoo/orca/learn/ray/pytorch/test_pytorch_estimator.py index 63e72eb65e6..10cc93d2590 100644 --- a/pyzoo/test/zoo/orca/learn/ray/pytorch/test_pytorch_estimator.py +++ b/pyzoo/test/zoo/orca/learn/ray/pytorch/test_pytorch_estimator.py @@ -32,7 +32,7 @@ def test_train(self): estimator = Estimator.from_torch( model=model_creator, optimizer=optimizer_creator, - loss=nn.MSELoss, + loss=nn.MSELoss(), scheduler_creator=scheduler_creator, config={ "lr": 1e-2, # used in optimizer_creator @@ -56,7 +56,7 @@ def test_save_and_restore(self): estimator1 = Estimator.from_torch( model=model_creator, optimizer=optimizer_creator, - loss=nn.MSELoss, + loss=nn.MSELoss(), scheduler_creator=scheduler_creator, config={ "lr": 1e-2, # used in optimizer_creator @@ -75,7 +75,7 @@ def test_save_and_restore(self): estimator2 = Estimator.from_torch( model=model_creator, optimizer=optimizer_creator, - loss=nn.MSELoss, + loss=nn.MSELoss(), scheduler_creator=scheduler_creator, config={ "lr": 1e-2, # used in optimizer_creator diff --git a/pyzoo/test/zoo/orca/learn/ray/pytorch/test_pytorch_trainer.py b/pyzoo/test/zoo/orca/learn/ray/pytorch/test_pytorch_trainer.py index 87e01342987..6570fe30e91 100644 --- a/pyzoo/test/zoo/orca/learn/ray/pytorch/test_pytorch_trainer.py +++ b/pyzoo/test/zoo/orca/learn/ray/pytorch/test_pytorch_trainer.py @@ -69,7 +69,7 @@ class TestPyTorchTrainer(TestCase): def test_linear(self): estimator = Estimator.from_torch(model=get_model, optimizer=get_optimizer, - loss=nn.MSELoss, + loss=nn.MSELoss(), config={"lr": 1e-2, "hidden_size": 1, "batch_size": 128}, backend="pytorch") diff --git a/pyzoo/zoo/examples/orca/learn/horovod/pytorch_estimator.py b/pyzoo/zoo/examples/orca/learn/horovod/pytorch_estimator.py index a2801c686f3..cbd4b8161cc 100644 --- a/pyzoo/zoo/examples/orca/learn/horovod/pytorch_estimator.py +++ b/pyzoo/zoo/examples/orca/learn/horovod/pytorch_estimator.py @@ -80,7 +80,7 @@ def train_example(workers_per_node): estimator = Estimator.from_torch( model=model_creator, optimizer=optimizer_creator, - loss=nn.MSELoss, + loss=nn.MSELoss(), scheduler_creator=scheduler_creator, workers_per_node=workers_per_node, config={ diff --git a/pyzoo/zoo/orca/learn/pytorch/torch_runner.py b/pyzoo/zoo/orca/learn/pytorch/torch_runner.py index f4bf90e6dc4..c14a7fa1fd2 100644 --- a/pyzoo/zoo/orca/learn/pytorch/torch_runner.py +++ b/pyzoo/zoo/orca/learn/pytorch/torch_runner.py @@ -95,10 +95,12 @@ def _create_loss(self): if not self.loss_creator: return logger.debug("Creating loss.") - if inspect.isclass(self.loss_creator) and issubclass( - self.loss_creator, torch.nn.modules.loss._Loss): - self.criterion = self.loss_creator() + if isinstance(self.loss_creator, torch.nn.modules.loss._Loss): + self.criterion = self.loss_creator else: + import types + assert isinstance(self.loss_creator, types.FunctionType), \ + "Must provide a torch loss instance or a loss_creator function" self.criterion = self.loss_creator(self.config) def _create_schedulers_if_available(self): From 430c4d4f0f3d138aff9cdad3194c5fec97423a58 Mon Sep 17 00:00:00 2001 From: Kai Huang Date: Wed, 28 Oct 2020 15:09:26 +0800 Subject: [PATCH 2/2] update --- pyzoo/zoo/orca/learn/pytorch/pytorch_ray_estimator.py | 6 ++++-- pyzoo/zoo/orca/learn/pytorch/torch_runner.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pyzoo/zoo/orca/learn/pytorch/pytorch_ray_estimator.py b/pyzoo/zoo/orca/learn/pytorch/pytorch_ray_estimator.py index 4487d11ca1e..da2d4e51cef 100644 --- a/pyzoo/zoo/orca/learn/pytorch/pytorch_ray_estimator.py +++ b/pyzoo/zoo/orca/learn/pytorch/pytorch_ray_estimator.py @@ -69,9 +69,11 @@ def __init__( # todo remove ray_ctx to run on workers ray_ctx = RayContext.get() - if not (callable(model_creator) and callable(optimizer_creator)): + import types + if not (isinstance(model_creator, types.FunctionType) and + isinstance(optimizer_creator, types.FunctionType)): # Torch model is also callable. raise ValueError( - "Must provide a callable model_creator and optimizer_creator") + "Must provide a function for both model_creator and optimizer_creator") self.model_creator = model_creator self.optimizer_creator = optimizer_creator diff --git a/pyzoo/zoo/orca/learn/pytorch/torch_runner.py b/pyzoo/zoo/orca/learn/pytorch/torch_runner.py index c14a7fa1fd2..65635ed1b26 100644 --- a/pyzoo/zoo/orca/learn/pytorch/torch_runner.py +++ b/pyzoo/zoo/orca/learn/pytorch/torch_runner.py @@ -97,7 +97,7 @@ def _create_loss(self): logger.debug("Creating loss.") if isinstance(self.loss_creator, torch.nn.modules.loss._Loss): self.criterion = self.loss_creator - else: + else: # Torch loss is also callable. import types assert isinstance(self.loss_creator, types.FunctionType), \ "Must provide a torch loss instance or a loss_creator function"