Skip to content

Commit

Permalink
[Bugfix] Command-R Max Model Length (vllm-project#3727)
Browse files Browse the repository at this point in the history
  • Loading branch information
ywang96 authored Mar 29, 2024
1 parent a333ba8 commit 4a86f7a
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,15 +765,20 @@ def _get_and_verify_max_len(
"max_seq_len",
# ChatGLM2
"seq_length",
# Command-R
"model_max_length",
# Others
"max_sequence_length",
"max_seq_length",
"seq_len",
]
max_len_key = None
for key in possible_keys:
max_len_key = getattr(hf_config, key, None)
if max_len_key is not None:
derived_max_model_len = min(derived_max_model_len, max_len_key)
max_len = getattr(hf_config, key, None)
if max_len is not None:
max_len_key = key if max_len < derived_max_model_len \
else max_len_key
derived_max_model_len = min(derived_max_model_len, max_len)
if derived_max_model_len == float("inf"):
if max_model_len is not None:
# If max_model_len is specified, we use it.
Expand All @@ -799,10 +804,18 @@ def _get_and_verify_max_len(
if max_model_len is None:
max_model_len = derived_max_model_len
elif max_model_len > derived_max_model_len:
raise ValueError(
f"User-specified max_model_len ({max_model_len}) is greater than "
f"the derived max_model_len ({max_len_key}={derived_max_model_len}"
" in model's config.json). This may lead to incorrect model "
"outputs or CUDA errors. Make sure the value is correct and "
"within the model context size.")
# Some models might have a separate key for specifying model_max_length
# that will be bigger than derived_max_model_len. We compare user input
# with model_max_length and allow this override when it's smaller.
model_max_length = getattr(hf_config, "model_max_length", None)
if model_max_length is not None and max_model_len <= model_max_length:
pass
else:
raise ValueError(
f"User-specified max_model_len ({max_model_len}) is greater "
"than the derived max_model_len "
f"({max_len_key}={derived_max_model_len} or model_max_length="
f"{model_max_length} in model's config.json). This may lead "
"to incorrect model outputs or CUDA errors. Make sure the "
"value is correct and within the model context size.")
return int(max_model_len)

0 comments on commit 4a86f7a

Please sign in to comment.