Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] - TCNModel - save/load does not work anymore - darts [0.32.0] #2630

Open
forklife opened this issue Dec 22, 2024 · 1 comment
Open
Labels
bug Something isn't working triage Issue waiting for triaging

Comments

@forklife
Copy link

Describe the bug
save/load does not work anymore when using TCNModel with weight_norm=True, Pytorch ask for model.state_dict()

probably related to this update
Replaced the deprecated torch.nn.utils.weight_norm function with torch.nn.utils.parametrizations.weight_norm. #2593 by Saeed Foroutan.

@forklife forklife added bug Something isn't working triage Issue waiting for triaging labels Dec 22, 2024
@dennisbader
Copy link
Collaborator

dennisbader commented Dec 22, 2024

Hi @forklife, I could not reproduce this issue with below code:

With darts 0.31.0:

import numpy as np

from darts.models import TCNModel
from darts.datasets import AirPassengersDataset

series = AirPassengersDataset().load().astype(np.float32)
model = TCNModel(12, 11, weight_norm=True)
model.fit(series, epochs=1)
model.save("model.pt")

With darts 0.32.0 (works without issues):

from darts.models import TCNModel
model = TCNModel.load("model.pt")
model.predict(n=11)

Could you provide a minimal reproducible example?

Also, what usually fixes cross-version loading issues is to load the model's weights only in the newer version, and storing it again.

See our torch model guide that describes how to load the weights only.

In Darts 0.32.0 (example for loading from a manual save)

from darts.models import TCNModel

# create model with identical input parameters that were used in version `0.31.0`
model = TCNModel(12, 11, weight_norm=True)

# load from a manual save
model.load_weights("model.pt")

# manually store it again
model.save("model.pt")

After that, the you should be able load the model normally without any issues model.load("model.pt").

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working triage Issue waiting for triaging
Projects
None yet
Development

No branches or pull requests

2 participants