You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
save/load does not work anymore when using TCNModel with weight_norm=True, Pytorch ask for model.state_dict()
probably related to this update
Replaced the deprecated torch.nn.utils.weight_norm function with torch.nn.utils.parametrizations.weight_norm. #2593 by Saeed Foroutan.
The text was updated successfully, but these errors were encountered:
Hi @forklife, I could not reproduce this issue with below code:
With darts 0.31.0:
import numpy as np
from darts.models import TCNModel
from darts.datasets import AirPassengersDataset
series = AirPassengersDataset().load().astype(np.float32)
model = TCNModel(12, 11, weight_norm=True)
model.fit(series, epochs=1)
model.save("model.pt")
With darts 0.32.0 (works without issues):
from darts.models import TCNModel
model = TCNModel.load("model.pt")
model.predict(n=11)
Could you provide a minimal reproducible example?
Also, what usually fixes cross-version loading issues is to load the model's weights only in the newer version, and storing it again.
See our torch model guide that describes how to load the weights only.
In Darts 0.32.0 (example for loading from a manual save)
from darts.models import TCNModel
# create model with identical input parameters that were used in version `0.31.0`
model = TCNModel(12, 11, weight_norm=True)
# load from a manual save
model.load_weights("model.pt")
# manually store it again
model.save("model.pt")
After that, the you should be able load the model normally without any issues model.load("model.pt").
Describe the bug
save/load does not work anymore when using TCNModel with weight_norm=True, Pytorch ask for model.state_dict()
probably related to this update
Replaced the deprecated torch.nn.utils.weight_norm function with torch.nn.utils.parametrizations.weight_norm. #2593 by Saeed Foroutan.
The text was updated successfully, but these errors were encountered: