diff --git a/vllm/config.py b/vllm/config.py index 265cfa56c04fc..62f1d70079648 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -765,15 +765,20 @@ def _get_and_verify_max_len( "max_seq_len", # ChatGLM2 "seq_length", + # Command-R + "model_max_length", # Others "max_sequence_length", "max_seq_length", "seq_len", ] + max_len_key = None for key in possible_keys: - max_len_key = getattr(hf_config, key, None) - if max_len_key is not None: - derived_max_model_len = min(derived_max_model_len, max_len_key) + max_len = getattr(hf_config, key, None) + if max_len is not None: + max_len_key = key if max_len < derived_max_model_len \ + else max_len_key + derived_max_model_len = min(derived_max_model_len, max_len) if derived_max_model_len == float("inf"): if max_model_len is not None: # If max_model_len is specified, we use it. @@ -799,10 +804,18 @@ def _get_and_verify_max_len( if max_model_len is None: max_model_len = derived_max_model_len elif max_model_len > derived_max_model_len: - raise ValueError( - f"User-specified max_model_len ({max_model_len}) is greater than " - f"the derived max_model_len ({max_len_key}={derived_max_model_len}" - " in model's config.json). This may lead to incorrect model " - "outputs or CUDA errors. Make sure the value is correct and " - "within the model context size.") + # Some models might have a separate key for specifying model_max_length + # that will be bigger than derived_max_model_len. We compare user input + # with model_max_length and allow this override when it's smaller. + model_max_length = getattr(hf_config, "model_max_length", None) + if model_max_length is not None and max_model_len <= model_max_length: + pass + else: + raise ValueError( + f"User-specified max_model_len ({max_model_len}) is greater " + "than the derived max_model_len " + f"({max_len_key}={derived_max_model_len} or model_max_length=" + f"{model_max_length} in model's config.json). This may lead " + "to incorrect model outputs or CUDA errors. Make sure the " + "value is correct and within the model context size.") return int(max_model_len)