Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix Mistral v0.3 Weight Loading #5005

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/models/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

MODELS = [
"mistralai/Mistral-7B-Instruct-v0.1",
"mistralai/Mistral-7B-Instruct-v0.3",
]


Expand Down
17 changes: 15 additions & 2 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from vllm.model_executor.model_loader.utils import (get_model_architecture,
set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf, filter_files_not_needed_for_inference,
download_safetensors_index_file_from_hf, download_weights_from_hf,
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
Expand Down Expand Up @@ -188,7 +189,19 @@ def _prepare_weights(self, model_name_or_path: str,
use_safetensors = True
break

if not use_safetensors:
if use_safetensors:
# For models like Mistral-7B-Instruct-v0.3
# there are both sharded safetensors files and a consolidated
# safetensors file. Using both breaks.
# Here, we download the `model.safetensors.index.json` and filter
# any files not found in the index.
if not is_local:
download_safetensors_index_file_from_hf(
model_name_or_path, self.load_config.download_dir,
revision)
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder)
else:
hf_weights_files = filter_files_not_needed_for_inference(
hf_weights_files)

Expand Down
64 changes: 63 additions & 1 deletion vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
import huggingface_hub.constants
import numpy as np
import torch
from huggingface_hub import HfFileSystem, snapshot_download
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME

from vllm.config import LoadConfig, ModelConfig
from vllm.logger import init_logger
Expand Down Expand Up @@ -211,6 +212,67 @@ def download_weights_from_hf(
return hf_folder


def download_safetensors_index_file_from_hf(
model_name_or_path: str,
cache_dir: Optional[str],
revision: Optional[str] = None,
) -> None:
"""Download hf safetensors index file from Hugging Face Hub.

Args:
model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model
weights. If None, will use HF defaults.
revision (Optional[str]): The revision of the model.
"""
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
try:
# Download the safetensors index file.
hf_hub_download(
repo_id=model_name_or_path,
filename=SAFE_WEIGHTS_INDEX_NAME,
cache_dir=cache_dir,
revision=revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
)
# If file not found on remote or locally, we should not fail since
# only some models will have SAFE_WEIGHTS_INDEX_NAME.
except huggingface_hub.utils.EntryNotFoundError:
logger.info("No %s found in remote.", SAFE_WEIGHTS_INDEX_NAME)
except huggingface_hub.utils.LocalEntryNotFoundError:
logger.info("No %s found in local cache.", SAFE_WEIGHTS_INDEX_NAME)


# For models like Mistral-7B-v0.3, there are both sharded
# safetensors files and a consolidated safetensors file.
# Passing both of these to the weight loader functionality breaks.
# So, we use the SAFE_WEIGHTS_INDEX_NAME to
# look up which safetensors files should be used.
def filter_duplicate_safetensors_files(hf_weights_files: List[str],
hf_folder: str) -> List[str]:
# model.safetensors.index.json is a mapping from keys in the
# torch state_dict to safetensors file holding that weight.
index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME)
if not os.path.isfile(index_file_name):
return hf_weights_files

# Iterate through the weight_map (weight_name: safetensors files)
# to identify weights that we should use.
with open(index_file_name) as index_file:
weight_map = json.load(index_file)["weight_map"]
weight_files_in_index = set()
for weight_name in weight_map:
weight_files_in_index.add(
os.path.join(hf_folder, weight_map[weight_name]))
# Filter out any fields that are not found in the index file.
hf_weights_files = [
f for f in hf_weights_files if f in weight_files_in_index
]
return hf_weights_files


def filter_files_not_needed_for_inference(
hf_weights_files: List[str]) -> List[str]:
"""
Expand Down
Loading