You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = tgm.nn.ChebConv(11, 4, K=2, node_dim=1)
def forward(self, x, adj):
x = torch.nn.functional.relu(self.conv1(x,adj))
x = torch.nn.functional.dropout(x, training=self.training)
return torch.nn.functional.log_softmax(x.flatten(start_dim=1), dim=1)
x = torch.rand((7,23,11)) # some random dataset [batch_size,nodes,in_feature]
adj = torch.ones(size=(2,10),dtype=torch.int64) # coo matrix
model = Net()
y = model(x,adj)
torch.save(model,'temp')
Just illustrating a simple model with a ChebConv layer. The model works fine.
But, when I try to save the model, it produces an error message.
When I replace the ChebConv layer with a Conv2D layer (native torch), everything works fine.
Traceback (most recent call last):
File "/home/diz/PycharmProjects/BloodP/Graph/py_geo.py", line 89, in
torch.save(model,'temp')
File "/home/diz/.virtualenvs/env2/lib/python3.6/site-packages/torch/serialization.py", line 370, in save
_legacy_save(obj, opened_file, pickle_module, pickle_protocol)
File "/home/diz/.virtualenvs/env2/lib/python3.6/site-packages/torch/serialization.py", line 443, in _legacy_save
pickler.dump(obj)
_pickle.PicklingError: Can't pickle typing.Union[torch.Tensor, NoneType]: it's not the same object as typing.Union
Environment
OS: Ubuntu 16.04 LTS
Python version: 3.6.1 (also tested in 3.6.4)
PyTorch version: '1.5.1+cpu' (also tested in 1.6.0+cpu)
CUDA/cuDNN version: -
GCC version: 5.4.0
Any other relevant information: -
Additional context
The text was updated successfully, but these errors were encountered:
Hi and thanks for reporting! This is an interesting issue, especially because I'm not able to reproduce it (PyTorch 1.6.0, Python 3.7). This seems related to pickling type information, and there are some issues regarding this, see, e.g., here. I expect it to be a Python 3.6 issue, so maybe you can try to upgrade Python to see if this issue disappears?
🐛 Bug
Can't save models using torch.save() function.
To Reproduce
Traceback (most recent call last):
File "/home/diz/PycharmProjects/BloodP/Graph/py_geo.py", line 89, in
torch.save(model,'temp')
File "/home/diz/.virtualenvs/env2/lib/python3.6/site-packages/torch/serialization.py", line 370, in save
_legacy_save(obj, opened_file, pickle_module, pickle_protocol)
File "/home/diz/.virtualenvs/env2/lib/python3.6/site-packages/torch/serialization.py", line 443, in _legacy_save
pickler.dump(obj)
_pickle.PicklingError: Can't pickle typing.Union[torch.Tensor, NoneType]: it's not the same object as typing.Union
Environment
Additional context
The text was updated successfully, but these errors were encountered: