Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add code-revision config argument for Hugging Face Hub #2892

Merged
merged 2 commits into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class ModelConfig:
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id. If unspecified, will use the default
version.
code_revision: The specific revision to use for the model code on
Hugging Face Hub. It can be a branch name, a tag name, or a
commit id. If unspecified, will use the default version.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id. If unspecified, will use
the default version.
Expand All @@ -70,6 +73,7 @@ def __init__(
dtype: Union[str, torch.dtype],
seed: int,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
Expand All @@ -84,6 +88,7 @@ def __init__(
self.load_format = load_format
self.seed = seed
self.revision = revision
self.code_revision = code_revision
self.tokenizer_revision = tokenizer_revision
self.quantization = quantization
self.enforce_eager = enforce_eager
Expand All @@ -103,7 +108,8 @@ def __init__(
self.download_dir = model_path
self.tokenizer = model_path

self.hf_config = get_config(self.model, trust_remote_code, revision)
self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_config,
max_model_len)
Expand Down
21 changes: 14 additions & 7 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class EngineArgs:
max_paddings: int = 256
disable_log_stats: bool = False
revision: Optional[str] = None
code_revision: Optional[str] = None
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None
enforce_eager: bool = False
Expand Down Expand Up @@ -75,6 +76,13 @@ def add_cli_args(
help='the specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
parser.add_argument(
'--code-revision',
type=str,
default=None,
help='the specific revision to use for the model code on '
'Hugging Face Hub. It can be a branch name, a tag name, or a '
'commit id. If unspecified, will use the default version.')
parser.add_argument(
'--tokenizer-revision',
type=str,
Expand Down Expand Up @@ -279,13 +287,12 @@ def create_engine_configs(
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
DeviceConfig, Optional[LoRAConfig]]:
device_config = DeviceConfig(self.device)
model_config = ModelConfig(self.model, self.tokenizer,
self.tokenizer_mode, self.trust_remote_code,
self.download_dir, self.load_format,
self.dtype, self.seed, self.revision,
self.tokenizer_revision, self.max_model_len,
self.quantization, self.enforce_eager,
self.max_context_len_to_capture)
model_config = ModelConfig(
self.model, self.tokenizer, self.tokenizer_mode,
self.trust_remote_code, self.download_dir, self.load_format,
self.dtype, self.seed, self.revision, self.code_revision,
self.tokenizer_revision, self.max_model_len, self.quantization,
self.enforce_eager, self.max_context_len_to_capture)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype,
Expand Down
12 changes: 9 additions & 3 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@

def get_config(model: str,
trust_remote_code: bool,
revision: Optional[str] = None) -> PretrainedConfig:
revision: Optional[str] = None,
code_revision: Optional[str] = None) -> PretrainedConfig:
try:
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision)
model,
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision)
except ValueError as e:
if (not trust_remote_code and
"requires you to execute the configuration file" in str(e)):
Expand All @@ -33,5 +37,7 @@ def get_config(model: str,
raise e
if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model, revision=revision)
config = config_class.from_pretrained(model,
revision=revision,
code_revision=code_revision)
return config
Loading