Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Batched Multi-LoRA inference failure with random length dataset #237

Closed
tae-su-kim opened this issue Sep 4, 2024 · 10 comments · Fixed by #339
Closed

[Bug]: Batched Multi-LoRA inference failure with random length dataset #237

tae-su-kim opened this issue Sep 4, 2024 · 10 comments · Fixed by #339
Assignees
Labels
bug Something isn't working

Comments

@tae-su-kim
Copy link

tae-su-kim commented Sep 4, 2024

Anything you want to discuss about vllm.

Environment:

SynapseAI 1.17.0
vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest
habana_main branch

Current implementation of batched multi-lora suffers from RuntimeError on online serving scenarios (e.g. OpenAI API). This bug can be reproduced with following script:

Server:

VLLM_SKIP_WARMUP=true python -m vllm.entrypoints.openai.api_server     --model /models/Meta-Llama-3-8B-Instruct     --block-size 128   --max-model-len 2048     --max-loras 2     --max-lora-rank 8     --enable-lora     --lora-modules lora-1=/models/Gaudi_LoRA_Llama-3-8B-Instruct lora-2=/models/Gaudi_LoRA_Llama-3-8B-Instruct

Send requests with any dataset and LoRA pattern to the API server with number of requests >= 8. Below is an example command line for our benchmark script (https://github.com/SqueezeBits/vllm-fork/tree/benchmark).

python benchmarks/benchmark_sqzb.py --tokenizer /models/Meta-Llama-3-8B-Instruct/ --dataset benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json --max-input-len 1024 --max-output-len 8 --mimic-throughput-sample --lora-pattern ,lora-
1,lora-2 -n 16

Then, following error occurs:

(Abbreviated)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1544, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/codes/vllm/model_executor/models/llama.py", line 429, in forward
    model_output = self.model(input_ids, positions, kv_caches,
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1535, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1585, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/workspace/codes/vllm/model_executor/models/llama.py", line 314, in forward
    hidden_states = self.get_input_embeddings(input_ids)
  File "/workspace/codes/vllm/model_executor/models/llama.py", line 299, in get_input_embeddings
    return self.embed_tokens(input_ids)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1535, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1585, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/workspace/codes/vllm/lora/layers.py", line 330, in forward
    indices = self.embeddings_indices[1][:embedding_len].view_as(x)
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

This error happens due to the number of tokens in a prefill batch being larger than max_num_batched_tokens. As discussed in PR #109 , current implementation of prefill scheduler may let a prefill batch to have the number of tokens exceeding max_num_batched tokens after padding. While, several indices information for LoRA in LoRAModelManager is designed to support only up to max_num_batched_tokens:

self.base_indices = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device=get_device())
self.sampler_indices = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device=get_device())
self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device=get_device())
self.embeddings_indices = torch.empty(2,
self.max_num_batched_tokens,
dtype=torch.long,
device=get_device())
self.long_lora_indices = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device=get_device())

This causes L330 in vllm/lora/layer.py to fail on view_as(x).

Suggested solutions are either (1) to merge PR #109 or (2) increase the size of embeddings_indices and other indices to the maximum number of padded prefill tokens under max_num_batched_tokens constraint.

@vivekgoe vivekgoe self-assigned this Sep 4, 2024
@vivekgoe
Copy link

vivekgoe commented Sep 4, 2024

@tae-su-kim Please add --max-num-batched-tokens to server command, where value >= max_num_seqs * max_seq_len. E.g. for num_requests = 16 and max-input-len 1024, value >= 16384. With this additional parameter added to server command I was able to execute commands you shared without errors.

For testing larger batch-sizes > 32 you will run into Assert or OOM issues. Please use #223 which has optimizations to enable executing with larger batch-sizes.

@vivekgoe vivekgoe assigned tae-su-kim and unassigned vivekgoe Sep 4, 2024
@tae-su-kim
Copy link
Author

@vivekgoe Thanks for the fast response.
Setting max_num_batched_tokens >= max_num_seqs * max_seq_len can indeed prevent the padding issue, but will cause significant effective prefill throughput degradation in certain cases. e.g., consider 256 prefills with length [128, 128, ..., 128, 128, 2048] batched at the same time. This batch will become [2048, 2048, ..., 2048, 2048] after padding.
If the design choice is to rather keep the scheduler intact, I think max_num_batched_tokens >= max_num_seqs * max_seq_len must be included as an assertion when enable_lora is True (at least).

I will definitely check #223 for further benchmarks. Thanks for the lookahead!

@vivekgoe
Copy link

vivekgoe commented Sep 4, 2024

@tae-su-kim Adding assertion to check max_num_batched_tokens >= max_num_seqs * max_seq_len for enable_lora = True is a good suggestion, will add it. Regarding possibility of change in scheduler to handle this better, we will discuss this and get back.

@vivekgoe vivekgoe added the bug Something isn't working label Sep 4, 2024
@JHLEE17
Copy link

JHLEE17 commented Sep 11, 2024

I encountered the following errors while conducting experiments in the same environment as @tae-su-kim. There are two notable issues: First, errors occur when the number of requests increases. Second, even with the same number of requests, setting max-out-len=1 results in an error.

Server script:

VLLM_SKIP_WARMUP=true python -m vllm.entrypoints.openai.api_server      \
--model /scratch-1/models/Meta-Llama-3.1-8B-Instruct      \
--block-size 128      \
--max-model-len 2048      \
--enable-lora     \
--max-loras 2     \
--max-lora-rank 8     \
--lora-modules lora-1=/scratch-1/models/Gaudi_LoRA_Llama-3-8B-Instruct \
lora-2=/scratch-1/models/Gaudi_LoRA_Llama-3-8B-Instruct     \
--enforce-eager --max-num-seq <max_num_seq>

Client script:

python benchmarks/benchmark_sqzb.py     \
--tokenizer /scratch-1/models/Meta-Llama-3.1-8B-Instruct     \
--dataset /scratch-1/datasets/dynamic_sonnet_llama3/dynamic_sonnet_llama_3_prefix_256_max_1024_1024_sampled.parquet     \
--max-input-len 1024     \
--max-output-len <max_output_len>     \
--lora-pattern ,lora-1,lora-2    \
-n <#requests>
  1. With max-num-seq=128 and max-out-len=1024:
    #requests: 1024 Works fine

  2. With max-num-seq=128 and max-out-len=1:
    #requests: 1024 Error
    #requests: 512 Works fine

  3. With max-num-seq=256 and max-out-len=1024:
    #requests: 1024 Error
    #requests: 200 Error
    #requests: 128 Works fine

Error Message

...
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/utils.py", line 311, in producer
async for item in iterator:
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/engine/async_llm_engine.py", line 785, in generate
async for output in self._process_request(
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/engine/async_llm_engine.py", line 901, in _process_request
raise e
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/engine/async_llm_engine.py", line 897, in _process_request
async for request_output in stream:
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/engine/async_llm_engine.py", line 93, in anext
raise result
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/utils.py", line 311, in producer
async for item in iterator:
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/engine/async_llm_engine.py", line 785, in generate
async for output in self._process_request(
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/engine/async_llm_engine.py", line 901, in _process_request
raise e
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/engine/async_llm_engine.py", line 897, in _process_request
async for request_output in stream:
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/engine/async_llm_engine.py", line 93, in anext
raise result
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/utils.py", line 311, in producer
async for item in iterator:
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/engine/async_llm_engine.py", line 785, in generate
async for output in self._process_request(
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/engine/async_llm_engine.py", line 901, in _process_request
raise e
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/engine/async_llm_engine.py", line 897, in _process_request
async for request_output in stream:
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/engine/async_llm_engine.py", line 93, in anext
raise result
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/engine/async_llm_engine.py", line 46, in _log_task_completion
return_value = task.result()
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/engine/async_llm_engine.py", line 650, in run_engine_loop
result = task.result()
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/engine/async_llm_engine.py", line 593, in engine_step
request_outputs = await self.engine.step_async(virtual_engine)
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/engine/async_llm_engine.py", line 253, in step_async
output = await self.model_executor.execute_model_async(
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/executor/habana_executor.py", line 206, in execute_model_async
output = await make_async(self.driver_worker.execute_model
File "/home/sdp/miniforge3/envs/jongho/lib/python3.10/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/worker/worker_base.py", line 272, in execute_model
output = self.model_runner.execute_model(
File "/home/sdp/miniforge3/envs/jongho/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/worker/habana_model_runner.py", line 1837, in execute_model
output = self.model.sample(
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/worker/habana_model_runner.py", line 298, in sample
return self.model.sample(*args, **kwargs)
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/model_executor/models/llama.py", line 447, in sample
next_tokens = self.sampler(logits, sampling_metadata)
File "/home/sdp/miniforge3/envs/jongho/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1535, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/sdp/miniforge3/envs/jongho/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1585, in _call_impl
result = forward_call(*args, **kwargs)
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/model_executor/layers/sampler.py", line 138, in forward
sample_results, maybe_sampled_tokens_tensor = _sample(
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/model_executor/layers/sampler.py", line 711, in _sample
return _sample_with_torch(
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/model_executor/layers/sampler.py", line 592, in _sample_with_torch
sample_results = _greedy_sample(seq_groups, greedy_samples)
File "/home/sdp/works/jh/sqzb/vllm-fork/vllm/model_executor/layers/sampler.py", line 336, in _greedy_sample
samples_lst = samples.tolist()
RuntimeError: [Rank:0] FATAL ERROR :: MODULE:PT_LAZY Error, ValidateSyncInputTensors tensor_data is empty. Tensorid:4611686018431193860 QueueStatus:ThreadPool m_tasks size: 0 irValue:id_109408485_model/hpu__input
INFO: 127.0.0.1:35830 - "POST /v1/completions HTTP/1.1" 500 Internal Server Error

Any insights or suggestions to resolve these would be greatly appreciated!

@vivekgoe
Copy link

@JHLEE17 Please share the commit which you used for above. Will check and get back.

@JHLEE17
Copy link

JHLEE17 commented Sep 11, 2024

@JHLEE17 Please share the commit which you used for above. Will check and get back.

@vivekgoe Oh, I missed it. I'm currently working with commit 53f96b7

@SanjuCSudhakaran
Copy link

SanjuCSudhakaran commented Sep 13, 2024

Cases 1 and 2 [batch size 128]: Device out-of-memory issue is observed with enforce-eager flag and warmup enabled. We have reported the issue internally and have started looking into it with priority.

If enforce-eager flag is removed and warmup-run is enabled, the test runs without any issues. You could continue your experiments with this configuration.

case 3 [batch size 256]: For higher batch sizes like 256, we have also observed the device out-of-memory issue and debug is in progress.

Note: To get better performance, set max-num-batched-tokens to max-num-seqs * max-model-len.

Command used to run server is shared below. vllm-fork head f858d43

python -m vllm.entrypoints.openai.api_server \
    --model /scratch-1/models/Meta-Llama-3.1-8B-Instruct  \
    --block-size 128 \
    --max-model-len 2048 \
    --enable-lora \
    --max-loras 2 \
    --max-lora-rank 8 \
    --lora-modules lora-1=/scratch-1/models/Gaudi_LoRA_Llama-3-8B-Instruct lora-2=/scratch-1/models/Gaudi_LoRA_Llama-3-8B-Instruct \
    --max-num-seqs 128 \
    --max-num-batched-tokens 262144 #max-num-seqs * max-model-len```

@SanjuCSudhakaran
Copy link

For case 1 and 2, profile-run was under-estimating the expected memory usage. Setting VLLM_PROMPT_BS_BUCKET_MAX to 128 will fix the OOM issue.

Command used to run server in without HPUGraphs is shared below. vllm-fork head 4c1ca3a

VLLM_SKIP_WARMUP=true \
VLLM_PROMPT_BS_BUCKET_MAX=128 \
python -m vllm.entrypoints.openai.api_server \
    --model /scratch-1/models/Meta-Llama-3.1-8B-Instruct  \
    --block-size 128 \
    --max-model-len 2048 \
    --enforce-eager \
    --enable-lora \
    --max-loras 2 \
    --max-lora-rank 8 \
    --lora-modules lora-1=/scratch-1/models/Gaudi_LoRA_Llama-3-8B-Instruct lora-2=/scratch-1/models/Gaudi_LoRA_Llama-3-8B-Instruct \
    --max-num-seqs 128 \
    --max-num-batched-tokens 262144 #max-num-seqs * max-model-len```

@JHLEE17
Copy link

JHLEE17 commented Sep 19, 2024

I’ve successfully tested with --max-input-len 1024 (--max-model-len 2048), and confirmed that it works well not only when VLLM_PROMPT_BS_BUCKET_MAX and --man-num-seqs are set to 128, but also when both are set to 256. Thanks for your help! However, when I increase --max-input-len to 2048 (--max-model-len 3072) and the number of requests exceeds 128, I encounter an error that seems to be OOM-related, similar to what I reported earlier. It seems there may be a memory management issue when using Multi-LoRA. Is there a potential fix for this issue?

You can reproduce the results with these commands. Tested on vllm-fork head 53f96b7 & 35a4a98 (latest)
Server:

VLLM_SKIP_WARMUP=true \
VLLM_PROMPT_BS_BUCKET_MAX=128 \
python -m vllm.entrypoints.openai.api_server \
    --model /scratch-1/models/Meta-Llama-3.1-8B-Instruct  \
    --block-size 128 \
    --max-model-len 3072 \
    --enforce-eager \
    --enable-lora \
    --max-loras 2 \
    --max-lora-rank 8 \
    --lora-modules lora-1=/scratch-1/models/Gaudi_LoRA_Llama-3-8B-Instruct lora-2=/scratch-1/models/Gaudi_LoRA_Llama-3-8B-Instruct \
    --max-num-seqs 128 \
    --max-num-batched-tokens 393216  #max-num-seqs * max-model-len

Client:

python benchmarks/benchmark_sqzb.py     \
--tokenizer /scratch-1/models/Meta-Llama-3.1-8B-Instruct     \
--dataset /scratch-1/datasets/dynamic_sonnet_llama3/dynamic_sonnet_llama_3_prefix_512_max_2048_1024_sampled.parquet     \
--max-input-len 2048     \
--max-output-len 1024     \  # or 1
--lora-pattern ,lora-1,lora-2    \
-n 1024

@SanjuCSudhakaran
Copy link

I could reproduce the device OOM issue locally with the given configuration.

The same issue is observed without LoRA also and the related debug is in progress.

michalkuligowski pushed a commit that referenced this issue Sep 27, 2024
…th LoRA (#343)

This PR has following fixes,

- Increase size of indices tensors used to maintain multi-lora state
information from max_num_batched_tokens to 3*max_num_batched_tokens.
This increase is done to provide buffer for padding done in batch &
sequence dimensions.

- Move logic to remove padding from lora_logits from execute_model()
back to Class LogitsProcessorWithLoRA, this is done to fix race
condition caused by updating multi-lora state information directly.

FIX #237
huijjj pushed a commit to SqueezeBits/vllm-fork that referenced this issue Sep 27, 2024
…th LoRA (HabanaAI#339)

This PR has following fixes,
- Increase size of indices tensors used to maintain multi-lora state
information from max_num_batched_tokens to 3*max_num_batched_tokens.
This increase is done to provide buffer for padding done in batch &
sequence dimensions.
- Move logic to remove padding from lora_logits from execute_model()
back to Class LogitsProcessorWithLoRA, this is done to fix race
condition caused by updating multi-lora state information directly.

FIX HabanaAI#237
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants