Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't save models. #1541

Closed
theekshanadis opened this issue Aug 18, 2020 · 2 comments
Closed

Can't save models. #1541

theekshanadis opened this issue Aug 18, 2020 · 2 comments

Comments

@theekshanadis
Copy link

🐛 Bug

Can't save models using torch.save() function.

To Reproduce

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = tgm.nn.ChebConv(11, 4, K=2, node_dim=1)

    def forward(self, x, adj):
        x = torch.nn.functional.relu(self.conv1(x,adj))
        x = torch.nn.functional.dropout(x, training=self.training)
        return torch.nn.functional.log_softmax(x.flatten(start_dim=1), dim=1)

x = torch.rand((7,23,11)) # some random dataset [batch_size,nodes,in_feature]
adj = torch.ones(size=(2,10),dtype=torch.int64) # coo matrix

model = Net()
y = model(x,adj)

torch.save(model,'temp')
  1. Just illustrating a simple model with a ChebConv layer. The model works fine.
  2. But, when I try to save the model, it produces an error message.
  3. When I replace the ChebConv layer with a Conv2D layer (native torch), everything works fine.

Traceback (most recent call last):
File "/home/diz/PycharmProjects/BloodP/Graph/py_geo.py", line 89, in
torch.save(model,'temp')
File "/home/diz/.virtualenvs/env2/lib/python3.6/site-packages/torch/serialization.py", line 370, in save
_legacy_save(obj, opened_file, pickle_module, pickle_protocol)
File "/home/diz/.virtualenvs/env2/lib/python3.6/site-packages/torch/serialization.py", line 443, in _legacy_save
pickler.dump(obj)
_pickle.PicklingError: Can't pickle typing.Union[torch.Tensor, NoneType]: it's not the same object as typing.Union

Environment

  • OS: Ubuntu 16.04 LTS
  • Python version: 3.6.1 (also tested in 3.6.4)
  • PyTorch version: '1.5.1+cpu' (also tested in 1.6.0+cpu)
  • CUDA/cuDNN version: -
  • GCC version: 5.4.0
  • Any other relevant information: -

Additional context

@rusty1s
Copy link
Member

rusty1s commented Aug 19, 2020

Hi and thanks for reporting! This is an interesting issue, especially because I'm not able to reproduce it (PyTorch 1.6.0, Python 3.7). This seems related to pickling type information, and there are some issues regarding this, see, e.g., here. I expect it to be a Python 3.6 issue, so maybe you can try to upgrade Python to see if this issue disappears?

@theekshanadis
Copy link
Author

Hi, Thank you for the response. I upgraded my Python version to 3.7.1 and now everything works fine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants