Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: CohereForAI/c4ai-command-r-v01 : ValueError: User-specified max_model_len (131072) is greater than the derived max_model_len (None=8192 in model's config.json). This may lead to incorrect model outputs or CUDA errors. Make sure the value is correct and within the model > #3676

Closed
pseudotensor opened this issue Mar 28, 2024 · 10 comments · Fixed by #3727
Labels
bug Something isn't working

Comments

@pseudotensor
Copy link

Your current environment

Head of main after various cohere updates/fixes.

Issues:

sudo apt update
sudo apt install libnccl2 libnccl-dev

wget https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run
sudo sh cuda_12.1.0_530.30.02_linux.run
sudo chmod -R a+rwx /usr/local/

export CUDA_HOME=/usr/local/cuda-12.1
export PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cu121"

conda create -n vllm_cuda12.1 -y
conda activate vllm_cuda12.1
conda install python=3.10 -y

pip install git+https://github.com/vllm-project/vllm.git
pip install hf_transfer
pip install tiktoken accelerate flash_attn

export HF_HUB_ENABLE_HF_TRANSFER=1
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/lib64:$HOME/extras/CUPTI/lib64
export PATH=$PATH:$CUDA_HOME/bin

export CUDA_VISIBLE_DEVICES="0,1"
python -m vllm.entrypoints.openai.api_server --port=5005 --host=0.0.0.0 --model CohereForAI/c4ai-command-r-v01 --seed 1234 --tensor-parallel-size=2 --max-num-batched-tokens=131072 --max-log-len=100  --max-model-len 131072
# --trust-remote-code

have to comment out trust-remote-code due to a bug in their model that has a PR for registration of the model name that isn't merged yet.

🐛 Describe the bug

INFO 03-28 08:14:18 api_server.py:147] vLLM API server version 0.3.3
INFO 03-28 08:14:18 api_server.py:148] args: Namespace(host='0.0.0.0', port=5005, uvicorn_log_level='info', allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, served_model_name=Non>
Traceback (most recent call last):
  File "/home/fsuser/miniconda3/envs/vllm_cuda12.1/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/fsuser/miniconda3/envs/vllm_cuda12.1/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/fsuser/miniconda3/envs/vllm_cuda12.1/lib/python3.10/site-packages/vllm/entrypoints/openai/api_server.py", line 156, in <module>
    engine = AsyncLLMEngine.from_engine_args(engine_args)
  File "/home/fsuser/miniconda3/envs/vllm_cuda12.1/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 327, in from_engine_args
    engine_configs = engine_args.create_engine_configs()
  File "/home/fsuser/miniconda3/envs/vllm_cuda12.1/lib/python3.10/site-packages/vllm/engine/arg_utils.py", line 362, in create_engine_configs
    model_config = ModelConfig(
  File "/home/fsuser/miniconda3/envs/vllm_cuda12.1/lib/python3.10/site-packages/vllm/config.py", line 124, in __init__
    self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
  File "/home/fsuser/miniconda3/envs/vllm_cuda12.1/lib/python3.10/site-packages/vllm/config.py", line 791, in _get_and_verify_max_len
    raise ValueError(
ValueError: User-specified max_model_len (131072) is greater than the derived max_model_len (None=8192 in model's config.json). This may lead to incorrect model outputs or CUDA errors. Make sure the value is correct and within the model >
@pseudotensor pseudotensor added the bug Something isn't working label Mar 28, 2024
@youkaichao
Copy link
Member

cc @zeppombal

@mwbyeon
Copy link

mwbyeon commented Mar 28, 2024

https://huggingface.co/CohereForAI/c4ai-command-r-v01/blob/9c33b0976099d0f406f0d007613676fe42b78e3b/config.json#L16

derived_max_model_len is limited to 8192 because max_position_embeddings is 8192 in config.json

vllm/vllm/config.py

Lines 749 to 772 in 14ccd94

def _get_and_verify_max_len(
hf_config: PretrainedConfig,
max_model_len: Optional[int],
) -> int:
"""Get and verify the model's maximum length."""
derived_max_model_len = float("inf")
possible_keys = [
# OPT
"max_position_embeddings",
# GPT-2
"n_positions",
# MPT
"max_seq_len",
# ChatGLM2
"seq_length",
# Others
"max_sequence_length",
"max_seq_length",
"seq_len",
]
for key in possible_keys:
max_len_key = getattr(hf_config, key, None)
if max_len_key is not None:
derived_max_model_len = min(derived_max_model_len, max_len_key)

If you want to change max_model_len to 131072, you must also change max_position_embeddings to the same value.

But I can't think of a good way to solve this.

@pseudotensor
Copy link
Author

But I thought they use rope scaling, so should be accounted for in vllm. Shouldn't change embedding size I'd think.

@mwbyeon
Copy link

mwbyeon commented Mar 28, 2024

@pseudotensor that makes sense.

@ywang96
Copy link
Member

ywang96 commented Mar 29, 2024

@pseudotensor If you look at the code right below - it'll scale the derived max model length by the factor if it exists in the config.json, but people may not want to used a rope-scaled model all the time.

vllm/vllm/config.py

Lines 786 to 796 in 14ccd94

rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None:
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "yarn":
derived_max_model_len = rope_scaling[
"original_max_position_embeddings"]
derived_max_model_len *= scaling_factor
if max_model_len is None:
max_model_len = derived_max_model_len

@saurabhdash What's your opinion on this? I guess we should add args on EngineArgs to allow custom rope settings, correct? Please ignore - I misunderstood the issue here.

@pseudotensor
Copy link
Author

But less than the maximum should be controlled by passing model_max_length to vLLM. I'm not aware of any other model that fails int his way with rope scaling.

@ywang96
Copy link
Member

ywang96 commented Mar 29, 2024

But less than the maximum should be controlled by passing model_max_length to vLLM. I'm not aware of any other model that fails int his way with rope scaling.

You're right - I did a bit of more research and found https://huggingface.co/CohereForAI/c4ai-command-r-v01/discussions/12, and it seems that the "model_max_length" is added after discussion in this thread

This is indeed a bug we should fix then - although going forward I'm sure if we should just take model_max_length is the truth context window whenever we see it in the config. I will make a PR for this.

@pseudotensor
Copy link
Author

At least from that discussion, a default of low is ok, but shouldn't require editing the model config.json in order to go up to its maximum.

Although normally I think vLLM always does maximum by default unless make smaller, different than what that person said HF does.

@ywang96
Copy link
Member

ywang96 commented Mar 29, 2024

@pseudotensor Yep - I had the same thoughts in #3727. Please take a look, thanks!

@saurabhdash
Copy link

The max_model_length was added to support llama.cpp (ggerganov/llama.cpp#6033 (comment))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants