Skip to content
This repository has been archived by the owner on Oct 9, 2024. It is now read-only.

Loading saved model fails #104

Open
VladimirShitov opened this issue Jun 19, 2024 · 3 comments
Open

Loading saved model fails #104

VladimirShitov opened this issue Jun 19, 2024 · 3 comments

Comments

@VladimirShitov
Copy link
Collaborator

Hi! I am following the new tutorial in the scvi-tools documentation. Everything works fine but I want to save the model and later reuse it to avoid training it each time. I do that by running:
model.save("models/mrvi_no_nuissanse", overwrite=True)

Then I try to load the saved model by running either:
model.load("models/mrvi_no_nuissanse")
or:
model.load("models/mrvi_no_nuissanse", adata)

However, in both cases, the following error is raised:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[07_mrvi_analysis.ipynb) Cell 19 line 1
----> [1](07_mrvi_analysis.ipynb#X32sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0) model.load("models/mrvi_no_nuissanse", adata)

File [/lib/python3.10/site-packages/scvi/model/base/_base_model.py:693](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:693), in BaseModelClass.load(cls, dir_path, adata, accelerator, device, prefix, backup_url)
    [680](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:680) load_adata = adata is None
    [681](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:681) _, _, device = parse_device_args(
    [682](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:682)     accelerator=accelerator,
    [683](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:683)     devices=device,
    [684](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:684)     return_device="torch",
    [685](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:685)     validate_single_device=True,
    [686](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:686) )
    [688](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:688) (
    [689](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:689)     attr_dict,
    [690](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:690)     var_names,
    [691](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:691)     model_state_dict,
    [692](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:692)     new_adata,
--> [693](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:693) ) = _load_saved_files(
    [694](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:694)     dir_path,
    [695](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:695)     load_adata,
    [696](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:696)     map_location=device,
    [697](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:697)     prefix=prefix,
    [698](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:698)     backup_url=backup_url,
    [699](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:699) )
    [700](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:700) adata = new_adata if new_adata is not None else adata
    [702](/lib/python3.10/site-packages/scvi/model/base/_base_model.py:702) _validate_var_names(adata, var_names)

File [/lib/python3.10/site-packages/scvi/model/base/_save_load.py:71](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:71), in _load_saved_files(dir_path, load_adata, prefix, map_location, backup_url)
     [69](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:69) try:
     [70](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:70)     _download(backup_url, dir_path, model_file_name)
---> [71](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:71)     model = torch.load(model_path, map_location=map_location)
     [72](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:72) except FileNotFoundError as exc:
     [73](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:73)     raise ValueError(
     [74](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:74)         f"Failed to load model file at {model_path}. "
     [75](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:75)         "If attempting to load a saved model from <v0.15.0, please use the util function "
     [76](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:76)         "`convert_legacy_save` to convert to an updated format."
     [77](/lib/python3.10/site-packages/scvi/model/base/_save_load.py:77)     ) from exc

File [/lib/python3.10/site-packages/torch/serialization.py:1025](/lib/python3.10/site-packages/torch/serialization.py:1025), in load(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)
   [1023](/lib/python3.10/site-packages/torch/serialization.py:1023)             except RuntimeError as e:
   [1024](/lib/python3.10/site-packages/torch/serialization.py:1024)                 raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
-> [1025](/lib/python3.10/site-packages/torch/serialization.py:1025)         return _load(opened_zipfile,
   [1026](/lib/python3.10/site-packages/torch/serialization.py:1026)                      map_location,
   [1027](/lib/python3.10/site-packages/torch/serialization.py:1027)                      pickle_module,
   [1028](/lib/python3.10/site-packages/torch/serialization.py:1028)                      overall_storage=overall_storage,
   [1029](/lib/python3.10/site-packages/torch/serialization.py:1029)                      **pickle_load_args)
   [1030](/lib/python3.10/site-packages/torch/serialization.py:1030) if mmap:
   [1031](/lib/python3.10/site-packages/torch/serialization.py:1031)     f_name = "" if not isinstance(f, str) else f"{f}, "

File [lib/python3.10/site-packages/torch/serialization.py:1442](/lib/python3.10/site-packages/torch/serialization.py:1442), in _load(zip_file, map_location, pickle_module, pickle_file, overall_storage, **pickle_load_args)
   [1439](/lib/python3.10/site-packages/torch/serialization.py:1439)         return super().find_class(mod_name, name)
   [1441](/lib/python3.10/site-packages/torch/serialization.py:1441) # Load the data (which may in turn use `persistent_load` to load tensors)
-> [1442](/lib/python3.10/site-packages/torch/serialization.py:1442) data_file = io.BytesIO(zip_file.get_record(pickle_file))
   [1444](/lib/python3.10/site-packages/torch/serialization.py:1444) unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
   [1445](/lib/python3.10/site-packages/torch/serialization.py:1445) unpickler.persistent_load = persistent_load

RuntimeError: PytorchStreamReader failed locating file data.pkl: file not found

I guess it can be fixed by saving the data along with the model:
model.save("models/mrvi_no_nuissanse", overwrite=True, save_anndata=True)

However, it would require storing a copy of a large dataset without necessity. Could you please provide instructions on how to save and load the model? They would also fit nicely in the tutorial

@VladimirShitov
Copy link
Collaborator Author

VladimirShitov commented Jun 19, 2024

Setting save_anndata=True helped indeed :) But it is rather inefficient when the adata is saved elsewhere. Hopefully, it is possible to do it some other way

@VladimirShitov
Copy link
Collaborator Author

Ok, but then something weird happens... The model appears not to be trained even though "Training status: Trained" is printed. Also, a training epoch starts after loading, and the loss is quite high (as if the model were untrained).

image

@justjhong
Copy link
Collaborator

Hi @VladimirShitov, we found a similar issue in the up-to-date code in scvi-tools (scverse/scvi-tools#2813). Could you try using the version there and see if it addresses this problem?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants