From de4a1cfa0831be203fd3ffeabb5f8843be455eee Mon Sep 17 00:00:00 2001 From: Xin Qiu Date: Thu, 13 Aug 2020 05:52:30 +0800 Subject: [PATCH] fix to pytorch (#2716) --- .../orca/src/bigdl/orca/torch/torch_model.py | 2 +- .../orca/test/bigdl/orca/torch/test_torch.py | 24 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/python/orca/src/bigdl/orca/torch/torch_model.py b/python/orca/src/bigdl/orca/torch/torch_model.py index 020fb4058b2..cbf8e39ba10 100644 --- a/python/orca/src/bigdl/orca/torch/torch_model.py +++ b/python/orca/src/bigdl/orca/torch/torch_model.py @@ -68,7 +68,7 @@ def to_pytorch(self): # set weights m = CloudPickleSerializer.loads(CloudPickleSerializer, self.module_bytes) w = torch.Tensor(new_weight[0]) - torch.nn.utils.vector_to_parameters(w, m.parameters()) + torch.nn.utils.vector_to_parameters(w, trainable_param(m)) # set named buffers new_extra_params = callZooFunc(self.bigdl_type, "getModuleExtraParameters", self.value) diff --git a/python/orca/test/bigdl/orca/torch/test_torch.py b/python/orca/test/bigdl/orca/torch/test_torch.py index 883ec6fb20c..7faa9c819e4 100644 --- a/python/orca/test/bigdl/orca/torch/test_torch.py +++ b/python/orca/test/bigdl/orca/torch/test_torch.py @@ -120,6 +120,30 @@ def forward(self, x): exported_model = az_model.to_pytorch() assert len((list(exported_model.named_buffers()))) != 0 + def test_freezed_model_to_pytorch(self): + class SimpleTorchModel(nn.Module): + def __init__(self): + super(SimpleTorchModel, self).__init__() + self.dense1 = nn.Linear(2, 4) + self.bn1 = torch.nn.BatchNorm1d(4) + self.dense2 = nn.Linear(4, 1) + for p in self.dense1.parameters(): + p.requires_grad_(False) + + def forward(self, x): + x = self.dense1(x) + x = self.bn1(x) + x = torch.sigmoid(self.dense2(x)) + return x + + torch_model = SimpleTorchModel() + az_model = TorchModel.from_pytorch(torch_model) + dummy_input = torch.ones(16, 2) + zoo_result = az_model.forward(dummy_input.numpy()) + + exported_model = az_model.to_pytorch() + assert len((list(exported_model.named_buffers()))) != 0 + def test_train_model_with_bn(self): class SimpleTorchModel(nn.Module): def __init__(self):