Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to load LoRA fine-tuned LLM from HF (AssertionError) #3404

Open
oscar-martin opened this issue Mar 14, 2024 · 6 comments
Open

Unable to load LoRA fine-tuned LLM from HF (AssertionError) #3404

oscar-martin opened this issue Mar 14, 2024 · 6 comments
Labels

Comments

@oscar-martin
Copy link

oscar-martin commented Mar 14, 2024

Following the docs about Using LoRA Adapters, I am finding an assert problem. My code:

from huggingface_hub import snapshot_download
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest

lora_path = snapshot_download(repo_id="<my repo id>")

llm = LLM(
        model="mistralai/Mistral-7B-v0.1",
        tokenizer="<my tokenizer>",
        enable_lora=True)

sampling_params = SamplingParams(
    temperature=0,
    max_tokens=256,
    stop=["<|endcontext|>"]
)

prompts = [
    "<|begincontext|><|user|>I'm hungry. Find places to eat please.<|system|>Sure thing. Which city would you like to eat in?<|user|>Let's go with Foster City please.<|system|>Sure. What kind of food are you hungry for?<|user|>Spicy Indian sound really good.<|system|>One moment. I found a great restaurant called Pastries N Chaat in Foster City.<|user|>Give me other suggestions as well<|system|>How about, Tabla Indian Restaurant in Foster City?<|user|>Can you find out if they are average priced?<|system|>sure. The price range would be inexpensive.<|user|>Perfect. That works<|system|>Should I reserve for you?<|beginlastuserutterance|>Yes, go ahead and do that.<|endlastuserutterance|><|endcontext|>"
]

outputs = llm.generate(
    prompts,
    sampling_params,
    lora_request=LoRARequest("lora_adapter", 1, lora_path)
)
print(outputs)

The error:

...
INFO 03-14 11:33:38 model_runner.py:756] Graph capturing finished in 7 secs.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Processed prompts:   0%|                                                                                                                                                                                                                 | 0/1 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/home/ubuntu/vllm/.venv/lib/python3.10/site-packages/vllm/lora/worker_manager.py", line 139, in _load_lora
    lora = self._lora_model_cls.from_local_checkpoint(
  File "/home/ubuntu/vllm/.venv/lib/python3.10/site-packages/vllm/lora/models.py", line 227, in from_local_checkpoint
    return cls.from_lora_tensors(
  File "/home/ubuntu/vllm/.venv/lib/python3.10/site-packages/vllm/lora/models.py", line 148, in from_lora_tensors
    module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
  File "/home/ubuntu/vllm/.venv/lib/python3.10/site-packages/vllm/lora/utils.py", line 33, in parse_fine_tuned_lora_name
    assert parts[-2] == "lora_A" or parts[-2] == "lora_B"
AssertionError

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/ubuntu/vllm/vllm-lora-check.py", line 22, in <module>
    outputs = llm.generate(
  File "/home/ubuntu/vllm/.venv/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 182, in generate
    return self._run_engine(use_tqdm)
  File "/home/ubuntu/vllm/.venv/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 208, in _run_engine
    step_outputs = self.llm_engine.step()
  File "/home/ubuntu/vllm/.venv/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 838, in step
    all_outputs = self._run_workers(
  File "/home/ubuntu/vllm/.venv/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 1041, in _run_workers
    driver_worker_output = getattr(self.driver_worker,
  File "/home/ubuntu/vllm/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/vllm/.venv/lib/python3.10/site-packages/vllm/worker/worker.py", line 223, in execute_model
    output = self.model_runner.execute_model(seq_group_metadata_list,
  File "/home/ubuntu/vllm/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/vllm/.venv/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 574, in execute_model
    self.set_active_loras(lora_requests, lora_mapping)
  File "/home/ubuntu/vllm/.venv/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 660, in set_active_loras
    self.lora_manager.set_active_loras(lora_requests, lora_mapping)
  File "/home/ubuntu/vllm/.venv/lib/python3.10/site-packages/vllm/lora/worker_manager.py", line 112, in set_active_loras
    self._apply_loras(lora_requests)
  File "/home/ubuntu/vllm/.venv/lib/python3.10/site-packages/vllm/lora/worker_manager.py", line 224, in _apply_loras
    self.add_lora(lora)
  File "/home/ubuntu/vllm/.venv/lib/python3.10/site-packages/vllm/lora/worker_manager.py", line 231, in add_lora
    lora = self._load_lora(lora_request)
  File "/home/ubuntu/vllm/.venv/lib/python3.10/site-packages/vllm/lora/worker_manager.py", line 150, in _load_lora
    raise RuntimeError(
RuntimeError: Loading lora /home/ubuntu/.cache/huggingface/hub/models--bla-bla/snapshots/316a6e3610eedf49d5cb04b4670942f425401ee9 failed

By comparing the model found in the aforementioned documentation, I realized my model is "exporting" a couple of tensors (found in the adapter_model.safetensors file) that are not expected by vLLM code to be there, namely:

  • base_model.model.lm_head.base_layer.weight, and
  • base_model.model.model.embed_tokens.base_layer.weight.

This code will crash if weight-named tensors are not "coming" from lora (by looking at the tensor name).

In the model used for the documentation, all tensors contain 'lora' in their names.

I am pretty new to this and followed this fine-tuning guide.

The question is how can I "fix" this issue? Is the problem related to the fine-tuning guide? Maybe because the LoRAConfig is not correct or because the way the model is persisted. Is it instead related to vLLM?

Thanks!

@geknow
Copy link

geknow commented Mar 18, 2024

I also encountered this problem. The name looks like
image

@sagar-deepscribe
Copy link

I also encountered the same error. This happens because (#2816) the peft library saves the base embedding layers as well when save() is called - https://github.com/huggingface/peft/blob/8dd45b75d7eabe7ee94ecb6a19d552f2aa5e98c6/src/peft/utils/save_and_load.py#L175.
This is not supported in vllm apparently. If you are not training with new special tokens and your base embeddings have not updated - you can just remove the base layer weights. I used the following code:

lora_path = 'YOUR_ADAPTER_PATH'
import safetensors.torch
tensors =  safetensors.torch.load_file(lora_path)

nonlora_keys = []
for k in list(tensors.keys()):
    if "lora" not in k:
        nonlora_keys.append(k)

print(nonlora_keys) # just take a look what they are

for k in nonlora_keys:
    del tensors[k]

safetensors.torch.save_file(tensors, 'NEW_ADAPTER_PATH')

@oscar-martin
Copy link
Author

Thanks @sagar-deepscribe! In my case, I need new special tokens but this is good stuff for me to learn.

@tsvisab
Copy link

tsvisab commented Mar 31, 2024

I also encountered the same error. This happens because (#2816) the peft library saves the base embedding layers as well when save() is called - https://github.com/huggingface/peft/blob/8dd45b75d7eabe7ee94ecb6a19d552f2aa5e98c6/src/peft/utils/save_and_load.py#L175. This is not supported in vllm apparently. If you are not training with new special tokens and your base embeddings have not updated - you can just remove the base layer weights. I used the following code:

lora_path = 'YOUR_ADAPTER_PATH'
import safetensors.torch
tensors =  safetensors.torch.load_file(lora_path)

nonlora_keys = []
for k in list(tensors.keys()):
    if "lora" not in k:
        nonlora_keys.append(k)

print(nonlora_keys) # just take a look what they are

for k in nonlora_keys:
    del tensors[k]

safetensors.torch.save_file(tensors, 'NEW_ADAPTER_PATH')

The adapter path is usually a file name "adapter_model.safetensors"

@stas00
Copy link
Contributor

stas00 commented May 3, 2024

thanks, @sagar-deepscribe - that was helpful.

Here is your code slightly edited and w/ copy-n-paste instructions to run:

cat << EOT > vllm-lora-convert.py
import sys
import safetensors.torch

src, dst = sys.argv[-2:]

tensors = safetensors.torch.load_file(f"{src}/adapter_model.safetensors")

non_lora_keys = [k for k in tensors.keys() if "lora" not in k]

print("splitting non-lora keys into a separate file")
print("lora keys: ", tensors.keys())
print("non-lora keys: ", non_lora_keys)

non_lora_tensors = {k:tensors.pop(k) for k in non_lora_keys}

safetensors.torch.save_file(tensors, f"{dst}/adapter_model.safetensors")
safetensors.torch.save_file(non_lora_tensors, f"{dst}/rest.safetensors")
EOT
dir=unwrapped_model # edit to the dir with lora weights and config files
cp -r $dir $dir-vllm
python vllm-lora-convert.py $dir $dir-vllm

Copy link

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!

@github-actions github-actions bot added the stale label Oct 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants