Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] remove peft as dependency for prompt models #8162

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1558,14 +1558,6 @@ class PromptAdapterConfig:
prompt_adapter_dtype: Optional[torch.dtype] = None

def __post_init__(self):
library_name = 'peft'
try:
__import__(library_name)
except ImportError as e:
raise ImportError(
f"'{library_name}' is not installed for prompt adapter support."
f"Please install it using 'pip install {library_name}'."
) from e

if self.max_prompt_adapters < 1:
raise ValueError(f"max_prompt_adapters "
Expand Down
2 changes: 1 addition & 1 deletion vllm/prompt_adapter/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vllm.prompt_adapter.layers import (
VocabParallelEmbeddingWithPromptAdapter) # yapf: disable
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.utils import load_peft_weights

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -90,7 +91,6 @@ def from_local_checkpoint(
config: PromptAdapterConfig,
device: str = "cuda",
) -> "PromptAdapterModel":
from peft.utils import load_peft_weights

if num_virtual_tokens > config.max_prompt_adapter_token:
raise ValueError(
Expand Down
93 changes: 93 additions & 0 deletions vllm/prompt_adapter/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# code borrowed from: https://github.com/huggingface/peft/blob/v0.12.0/src/peft/utils/save_and_load.py#L420

import os
from typing import Optional

import torch
from huggingface_hub import file_exists, hf_hub_download
from huggingface_hub.utils import EntryNotFoundError
from safetensors.torch import load_file as safe_load_file

prashantgupta24 marked this conversation as resolved.
Show resolved Hide resolved
WEIGHTS_NAME = "adapter_model.bin"
SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors"


# Get current device name based on available devices
def infer_device() -> str:
if torch.cuda.is_available():
return "cuda"
return "cpu"


def load_peft_weights(model_id: str,
device: Optional[str] = None,
**hf_hub_download_kwargs) -> dict:
r"""
A helper method to load the PEFT weights from the HuggingFace Hub or locally

Args:
model_id (`str`):
The local path to the adapter weights or the name of the adapter to
load from the HuggingFace Hub.
device (`str`):
The device to load the weights onto.
hf_hub_download_kwargs (`dict`):
Additional arguments to pass to the `hf_hub_download` method when
loading from the HuggingFace Hub.
"""
path = (os.path.join(model_id, hf_hub_download_kwargs["subfolder"])
if hf_hub_download_kwargs.get("subfolder", None) is not None else
model_id)

if device is None:
device = infer_device()

if os.path.exists(os.path.join(path, SAFETENSORS_WEIGHTS_NAME)):
filename = os.path.join(path, SAFETENSORS_WEIGHTS_NAME)
use_safetensors = True
elif os.path.exists(os.path.join(path, WEIGHTS_NAME)):
filename = os.path.join(path, WEIGHTS_NAME)
use_safetensors = False
else:
token = hf_hub_download_kwargs.get("token", None)
if token is None:
token = hf_hub_download_kwargs.get("use_auth_token", None)

hub_filename = (os.path.join(hf_hub_download_kwargs["subfolder"],
SAFETENSORS_WEIGHTS_NAME)
if hf_hub_download_kwargs.get("subfolder", None)
is not None else SAFETENSORS_WEIGHTS_NAME)
has_remote_safetensors_file = file_exists(
repo_id=model_id,
filename=hub_filename,
revision=hf_hub_download_kwargs.get("revision", None),
repo_type=hf_hub_download_kwargs.get("repo_type", None),
token=token,
)
use_safetensors = has_remote_safetensors_file

if has_remote_safetensors_file:
# Priority 1: load safetensors weights
filename = hf_hub_download(
model_id,
SAFETENSORS_WEIGHTS_NAME,
**hf_hub_download_kwargs,
)
else:
try:
filename = hf_hub_download(model_id, WEIGHTS_NAME,
**hf_hub_download_kwargs)
except EntryNotFoundError:
raise ValueError( # noqa: B904
f"Can't find weights for {model_id} in {model_id} or \
in the Hugging Face Hub. "
f"Please check that the file {WEIGHTS_NAME} or \
{SAFETENSORS_WEIGHTS_NAME} is present at {model_id}.")

if use_safetensors:
adapters_weights = safe_load_file(filename, device=device)
else:
adapters_weights = torch.load(filename,
map_location=torch.device(device))

return adapters_weights
Loading