Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AssertionError: Prefix caching is currently not supported with sliding window attention #3355

Closed
Chenghao-Jia opened this issue Mar 12, 2024 · 9 comments · Fixed by #3373
Closed
Assignees

Comments

@Chenghao-Jia
Copy link

code

# 导入模块
from vllm import LLM, SamplingParams
import json
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

model = "/home/chenghao/workspace/OralCalculation/model/1-grade/Qwen1.5-7B-Chat-4bit/SFT_24-03-13"

llm = LLM(model=model, enforce_eager=True, trust_remote_code=True, gpu_memory_utilization=0.95, dtype='auto', quantization='gptq', max_model_len=9000)

# 定义数据和参数
sampling_params = SamplingParams(temperature=0.8)
file = "/home/chenghao/workspace/OralCalculation/data/1-grade/OralCalculation_1_new_test.json"
# for file in os.listdir(files):
with open(file, 'r', encoding='utf-8') as f:
    datas = json.loads(f.read())
    prompts = [data['system'] + '\n\n' + data['instruction'] + '\n' + data['input'] for data in datas]
    labels = [data['output'] for data in datas]
    
    prefix = datas[0]['system'] + '\n\n' + datas[0]['instruction'] + '\n'
    # -1,因为连接提示时最后一个标记可能会更改。
    prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1
    print(prefix_pos)
    
    # 推理
    # llm.generate 调用将对所有提示进行批处理,并在资源允许的情况下立即发送批处理。
    # 前缀只会在第一批处理完成后才会被缓存,因此我们需要调用一次generate来计算前缀并缓存。
    outputs = llm.generate(prompts[0],
                        sampling_params,
                        prefix_pos=[prefix_pos])

    # 后续批次可以利用缓存的前缀
    outputs = llm.generate(prompts,
                        sampling_params,
                        prefix_pos=[prefix_pos] * len(prompts))
    
    preds = [output.outputs[0].text for output in outputs]

error

(vllm) chenghao@auc:~/workspace/OralCalculation/vllm$ python predict.py
WARNING 03-12 22:15:40 config.py:193] gptq quantization is not fully optimized yet. The speed can be slower than non-quantized models.
INFO 03-12 22:15:40 llm_engine.py:87] Initializing an LLM engine with config: model='/home/chenghao/workspace/OralCalculation/model/1-grade/Qwen1.5-7B-Chat-4bit/SFT_24-03-13', tokenizer='/home/chenghao/workspace/OralCalculation/model/1-grade/Qwen1.5-7B-Chat-4bit/SFT_24-03-13', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=9000, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=gptq, enforce_eager=True, kv_cache_dtype=auto, device_config=cuda, seed=0)
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
INFO 03-12 22:15:45 llm_engine.py:357] # GPU blocks: 1927, # CPU blocks: 512
546
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  4.03it/s]
Processed prompts:   0%|                                                                                        | 0/289 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/sshfs/chenghao/OralCalculation/vllm/predict.py", line 37, in <module>
    outputs = llm.generate(prompts,
  File "/home/chenghao/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 182, in generate
    return self._run_engine(use_tqdm)
  File "/home/chenghao/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 208, in _run_engine
    step_outputs = self.llm_engine.step()
  File "/home/chenghao/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 838, in step
    all_outputs = self._run_workers(
  File "/home/chenghao/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 1041, in _run_workers
    driver_worker_output = getattr(self.driver_worker,
  File "/home/chenghao/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/chenghao/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/worker/worker.py", line 223, in execute_model
    output = self.model_runner.execute_model(seq_group_metadata_list,
  File "/home/chenghao/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/chenghao/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 571, in execute_model
    lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list)
  File "/home/chenghao/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 490, in prepare_input_tensors
    lora_requests) = self._prepare_prompt(seq_group_metadata_list)
  File "/home/chenghao/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 193, in _prepare_prompt
    assert prefix_len == 0, (
AssertionError: Prefix caching is currently not supported with sliding window attention
@Chenghao-Jia
Copy link
Author

But,when I just predict two example, it can finfish

import time
import datetime
import os

from vllm import LLM
from vllm import SamplingParams


def test_prefix(llm = None, sampling_params=None, prompts=None, prompt_token_ids=None, prefix_len=None, save_file=None, detile=True):
    assert prompts != None or prompt_token_ids != None, f"prompt and tokens can't both be None"
    if isinstance(prefix_len, int):
        prefix_len = [prefix_len]
        assert len(prompts) % len(prefix_len) == 0, f"len of prompts must be multiple of len of prefix_len"
    print("------start generating------")
    start_time = time.time()
    # whether use Prefix
    if prefix_len != None:
        # start inference
        if prompt_token_ids != None:
            outputs = llm.generate(prompt_token_ids=prompt_token_ids,
                                   sampling_params=sampling_params,
                                   prefix_pos=prefix_len * (len(prompts) // len(prefix_len)))
        else:
            outputs = llm.generate(prompts=prompts,
                                   sampling_params=sampling_params,
                                   prefix_pos=prefix_len * (len(prompts) // len(prefix_len)))
            print("len(prompts) // len(prefix_len):" + str(len(prompts) // len(prefix_len)))
            print("prefix_len:" + str(prefix_len))
            print("prefix_pos:" + str(prefix_len * (len(prompts) // len(prefix_len))))
    else:
        outputs = llm.generate(prompts, sampling_params=sampling_params)

    end_time = time.time()
    print(f"cost time {end_time - start_time}")

    if save_file != None:
        print("saving output......")
        for index, output in enumerate(outputs):
            if detile == True:
                print(output, file=save_file)
            else:
                print(output.outputs[0].text, file=save_file)
        print(f"output saved in {save_file.name} {datetime.datetime.now()}")


# 你需要对下面这些参数进行改变
# set gpus
os.environ['CUDA_VISIBLE_DEVICES']="3"
tensor_parallel_size = len(os.getenv('CUDA_VISIBLE_DEVICES').split(','))
# set inference model
# 这里需要换成你的模型位置
model = "/home/chenghao/workspace/OralCalculation/model/1-grade/Qwen1.5-1.8B-Chat/SFT_24-03-12"
# Create an LLM.
llm = LLM(model=model, tokenizer_mode='auto', trust_remote_code=True, tensor_parallel_size=tensor_parallel_size)
# get prompts
prompts = ["这是一个 Prefix 功能使用的示例,因为 Prefix 的存储以物理块为单位,所以 Prompt 的长度需要至少大于等于一个物理块,这是第一句话",
           "这是一个 Prefix 功能使用的示例,因为 Prefix 的存储以物理块为单位,所以 Prompt 的长度需要至少大于等于一个物理块,这是第二句话"]
# prompt_token_ids = llm.tokenizer(prompts)
# set SamplingParams
sampling_params = SamplingParams(temperature=0)

# prefix_len 是与 prompts 等长的 list,表示对应 prompts 的 prefix 长度,没有设为 None
with open("output.txt", 'w') as f:
    test_prefix(llm=llm, prompts=prompts, prefix_len=[16, 32], save_file=f, sampling_params=sampling_params, detile=False)
(vllm) chenghao@auc:~/workspace/OralCalculation/vllm$ python ~/workspace/test.py
INFO 03-12 22:48:32 llm_engine.py:87] Initializing an LLM engine with config: model='/home/chenghao/workspace/OralCalculation/model/1-grade/Qwen1.5-1.8B-Chat/SFT_24-03-12', tokenizer='/home/chenghao/workspace/OralCalculation/model/1-grade/Qwen1.5-1.8B-Chat/SFT_24-03-12', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, seed=0)
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
INFO 03-12 22:48:37 llm_engine.py:357] # GPU blocks: 5475, # CPU blocks: 1365
INFO 03-12 22:48:38 model_runner.py:684] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 03-12 22:48:38 model_runner.py:688] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 03-12 22:48:43 model_runner.py:756] Graph capturing finished in 5 secs.
------start generating------
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 17.86it/s]
len(prompts) // len(prefix_len):1
prefix_len:[16, 32]
prefix_pos:[16, 32]
cost time 0.12456130981445312
saving output......
output saved in output.txt 2024-03-12 22:48:43.667264

@Limingxing00
Copy link

I met the same problem. Besides, I want to confirm two questions

  1. Whether prefix cache can speed up decoder-only frameworks, such as Qwen-1.5.
  2. If I want to further speed up Qwen-1.5, do you have any other suggestions?

@carrey-feng
Copy link

I met the same problem. Besides, I want to confirm two questions

  1. Whether prefix cache can speed up decoder-only frameworks, such as Qwen-1.5.
  2. If I want to further speed up Qwen-1.5, do you have any other suggestions?

My scene has sped up by more than four times. my batchsize is 100

@carrey-feng
Copy link

carrey-feng commented Mar 13, 2024

Yesterday, I encountered and resolved the same problem as well. However, I only submitted the PR today, not realizing someone had already preemptively fixed it ahead of time. This community is incredibly active.

@Limingxing00
Copy link

I met the same problem. Besides, I want to confirm two questions

  1. Whether prefix cache can speed up decoder-only frameworks, such as Qwen-1.5.
  2. If I want to further speed up Qwen-1.5, do you have any other suggestions?

My scene has sped up by more than four times. my batchsize is 100

Thanks for your reply.
The batchsize is 1 in my setting, and my experiment has sped up by two times with vllm. But Prefix caching doesn't work for me. Did Prefix caching work for you? And How do you use it? Can you show your related code?

@cadedaniel
Copy link
Collaborator

We don't have prefix caching support yet for sliding window -- either #3377 or #3373 can fix the failure here, we can add support for sliding window + prefix caching in the future

@cadedaniel cadedaniel self-assigned this Mar 13, 2024
@carrey-feng
Copy link

I met the same problem. Besides, I want to confirm two questions

  1. Whether prefix cache can speed up decoder-only frameworks, such as Qwen-1.5.
  2. If I want to further speed up Qwen-1.5, do you have any other suggestions?

My scene has sped up by more than four times. my batchsize is 100

Thanks for your reply. The batchsize is 1 in my setting, and my experiment has sped up by two times with vllm. But Prefix caching doesn't work for me. Did Prefix caching work for you? And How do you use it? Can you show your related code?

Just add more data to the prompts list. The first prompt in the prompts list will be quite slow, but subsequent ones will be much faster, which is not entirely consistent with the description in the official example code comments. It seems that the warm-up during the first run did not take effect as illustrated in the example.

@chenxu2048
Copy link
Contributor

Hi, @Chenghao-Jia, @Limingxing00, @a516072575.

FYI, as a temporary fix, we can modify the config.json in Qwen1.5 and set "sliding_window" to null before PRs mereged.

{
  "architectures": [
    "Qwen2ForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "eos_token_id": 151645,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "max_position_embeddings": 32768,
  "max_window_layers": 28,
  "model_type": "qwen2",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 32,
  "rms_norm_eps": 1e-06,
  "rope_theta": 1000000.0,
  "sliding_window": null, // replace with null here
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.37.0",
  "use_cache": true,
  "use_sliding_window": false, // turn off sliding_window
  "vocab_size": 151936
}

Qwen2Config has non-None default value for sliding_window, and we can not remove "sliding_window" in config.json directly.

@mridulBanik112
Copy link

hello @chenxu2048, what's the path of this config.json file?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants