diff --git a/dictionary_learning/dictionary.py b/dictionary_learning/dictionary.py index cf90f67..b817917 100644 --- a/dictionary_learning/dictionary.py +++ b/dictionary_learning/dictionary.py @@ -49,7 +49,7 @@ def from_pretrained( device: Device to load the model to **kwargs: Additional arguments passed to loading function """ - model = PyTorchModelHubMixin.from_pretrained(cls, path, **kwargs) + model = super(Dictionary, cls).from_pretrained(path, **kwargs) if device is not None: model.to(device) if dtype is not None: