Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware][Nvidia] Enable support for Pascal GPUs #4290

Closed
wants to merge 0 commits into from

Conversation

cduk
Copy link

@cduk cduk commented Apr 23, 2024

[Hardware][Nvidia] This PR enables support for Pascal generation cards. This is tested as working with P100 and P40 cards.

@cduk cduk changed the title Enable support for Pascal GPUs [Hardware][Nvidia] Enable support for Pascal GPUs Apr 23, 2024
@mgoin
Copy link
Collaborator

mgoin commented Apr 23, 2024

Looks reasonable to me! It would be worth noting how this affects the size of the wheel. FYI the last release's wheel is already reaching ~100MB https://pypi.org/project/vllm/#files

@youkaichao
Copy link
Member

It will definitely enlarge wheel size, so we will not officially support it. We are limited by pypi package size of 100MB: pypi/support#3792 .

@mgoin
Copy link
Collaborator

mgoin commented Apr 23, 2024

Hey @youkaichao, I wasn't aware of this pypi package size limit - I know some packages on pypi that are much larger than 100MB. For instance, how does torch have wheels larger than 700MB? https://pypi.org/project/torch/#files

Screenshot 2024-04-23 at 2 27 56 PM

@youkaichao
Copy link
Member

pytorch has approval from pypi. our request on vllm is not approved. actually no response yet.

@jasonacox
Copy link
Contributor

@cduk Thanks for sharing this. I know Pascal is on the last rung of the CUDA ladder, but it would be good to offer this up as an option for older Pascal GPUs.

Until the wheel size issue is fixed and could reconsider this PR, here is a pascal.sh scripts that edits the files to add Pascal GPU support. It requires the user to build the project, but at least it provides a path.

# Download and add Pascal Support
git clone https://github.com/vllm-project/vllm.git
cd vllm
./pascal.sh

# You can now build from source with Pascal GPU support:
pip install -e .

# or build the Docker image with:
DOCKER_BUILDKIT=1 docker build . --target vllm-openai --tag vllm/vllm-openai

pascal.sh

#!/bin/bash
#
# This script adds Pascal GPU support to vLLM by adding 6.0, 6.1 and 6.2 
# GPU architectures to the build files CMakeLists.txt and Dockerfile
#

# Ask user for confirmation
read -p "This script will add Pascal GPU support to vLLM. Continue? [y/N] " -n 1 -r
echo
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
    echo "Exiting..."
    exit 1
fi
echo
echo "Adding Pascal GPU support..."

# Update CMakeLists.txt and Dockerfile
echo " - Updating CMakeLists.txt"
cuda_supported_archs="6.0;6.1;6.2;7.0;7.5;8.0;8.6;8.9;9.0"
sed -i.orig "s/set(CUDA_SUPPORTED_ARCHS \"7.0;7.5;8.0;8.6;8.9;9.0\")/set(CUDA_SUPPORTED_ARCHS \"$cuda_supported_archs\")/g" CMakeLists.txt

echo " - Updating Dockerfile"
torch_cuda_arch_list="6.0 6.1 6.2 7.0 7.5 8.0 8.6 8.9 9.0+PTX"
sed -i.orig "s/ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'/ARG torch_cuda_arch_list='$torch_cuda_arch_list'/g" Dockerfile

cat <<EOF
You can now build from source with Pascal GPU support:
    pip install -e .
Or build the Docker image with:
    DOCKER_BUILDKIT=1 docker build . --target vllm-openai --tag vllm/vllm-openai
EOF

@cduk
Copy link
Author

cduk commented Apr 24, 2024

Thanks for explaining the wheelsize issue.

@sasha0552
Copy link
Contributor

sasha0552 commented Apr 25, 2024

I really appreciate the effort to add Pascal support to the vLLM main branch. As someone who uses vLLM with Pascal GPUs, I want to add a few comments on this.

First, I think the 6.2 compute capability can be excluded to reduce the wheel size a bit since it's Tegra/Jetson/DRIVE CC. Other Tegra/Jetson/DRIVE CCs like 7.2 and 8.7 are not supported, so 6.2 support can be excluded too.

Support table

image
Source

Secondly, triton (used by pytorch) needs a patch as well.
Sometimes vLLM crashes because of triton, like described in this issue triton-lang/triton#2780 (there is a patch in this issue, but the triton developers didn't bother to apply it, I guess?). I haven't currently determined when they appear, as test batching with the OpenAI client work fine, but using another application causes the crash.
Update: another one showed up: LLVM ERROR: Cannot select: intrinsic %llvm.nvvm.shfl.sync.bfly.i32.

Crash log
INFO 04-25 08:04:31 metrics.py:229] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 2.1 tokens/s, Running: 1 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 1.7%, CPU KV cache usage: 0.0%
INFO 04-25 08:04:36 metrics.py:229] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 2.1 tokens/s, Running: 1 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 1.7%, CPU KV cache usage: 0.0%
INFO 04-25 08:04:49 metrics.py:229] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.4 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%
INFO:     x.x.x.x:xxxxx - "POST /v1/chat/completions HTTP/1.1" 200 OK
Unsupported conversion from f16 to f16

UNREACHABLE executed at /project/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp:816!

P40 performance with vLLM

As a side note, I get ~500 t/s when batching 256 simultaneous requests on a single P40 with Llama 3 8B.
Or 2 t/s per request.
In comparison, I get 500 t/s pp and 15 t/s tg using llama.cpp with fp16 model, and 700 t/s pp & 30 t/s tg with Q8_0 model, but it's not batch processing.
pp = prompt processing, "prompt throughput"
tg = token generation, "generation throughput"

(My P40s is usually powerlimited to 125W, so the screenshot is ~400 t/s).
vLLM on P40

@cduk
Copy link
Author

cduk commented Apr 25, 2024

As a side note, I get ~500 t/s when batching 256 simultaneous requests on a single P40 with Llama 3 8B. Or 2 t/s per request.

What quantization are you using? I get 15 t/s for 7B at both FP16 and Q8 (llama.cpp Q8 performance is higher at 25 tok/s).

vLLM Q4 performance is bad on P40 due to FP16 code paths and P40s terrible FP16 performance - I get maybe 3 tok/s with a 14B Q4 model with vLLM compared to llama.cpp which can get 19 tok/s.

I need to check the code to see if performance can be increased by avoiding FP16 calculations (upscaling to FP32).

EDIT: all performance figures for P40 @ 125W power limit

@sasha0552
Copy link
Contributor

sasha0552 commented Apr 25, 2024

@cduk
If you are talking about vLLM, I am using the unquantized model (fp16). The --dtype float16 and --dtype float32 don't make much difference. I can run vLLM on two GPUs using tensor parallelism in fp32 and get the same ~500 t/s (2 t/s per request) as with --dtype float16 and one GPU. So probably the vLLM code is using fp16 and 500 t/s is fp16 performance (and with fp32 implementation t/s should be much higher), or vLLM is smart enough to automatically cast it to fp32.
Update: when using TP, fp16 is even better than fp32.

All of the following is with powerlimit = 125W unless otherwise noted.

vLLM float16 commandline
export MODEL=".../hub/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e5e23bbe8e749ef0efcf16cad411a7d23bd23298"
export MODEL_NAME="meta-llama/Meta-Llama-3-8B-Instruct"

env CUDA_VISIBLE_DEVICES=1,0                    \
    HF_HUB_DISABLE_TELEMETRY=1                  \
    HF_HUB_OFFLINE=1                            \
      venv/bin/python3                          \
        -m vllm.entrypoints.openai.api_server   \
        --disable-log-requests                  \
        --enable-prefix-caching                 \
        --dtype float16                         \
        --model "$MODEL"                        \
        --served-model-name "$MODEL_NAME"
vLLM float16 performance
INFO 04-25 14:34:55 metrics.py:229] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%
INFO 04-25 14:35:05 metrics.py:229] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%
INFO 04-25 14:35:15 metrics.py:229] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%
INFO 04-25 14:35:21 metrics.py:229] Avg prompt throughput: 215.8 tokens/s, Avg generation throughput: 16.6 tokens/s, Running: 100 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%
INFO 04-25 14:35:27 metrics.py:229] Avg prompt throughput: 378.6 tokens/s, Avg generation throughput: 76.9 tokens/s, Running: 256 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 12.5%, CPU KV cache usage: 0.0%
INFO 04-25 14:35:32 metrics.py:229] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 403.6 tokens/s, Running: 256 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 12.6%, CPU KV cache usage: 0.0%
INFO 04-25 14:35:37 metrics.py:229] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 393.7 tokens/s, Running: 256 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 12.6%, CPU KV cache usage: 0.0%

400 t/s

vLLM float16 commandline, tp=2
export MODEL=".../hub/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e5e23bbe8e749ef0efcf16cad411a7d23bd23298"
export MODEL_NAME="meta-llama/Meta-Llama-3-8B-Instruct"

env CUDA_VISIBLE_DEVICES=1,0                    \
    HF_HUB_DISABLE_TELEMETRY=1                  \
    HF_HUB_OFFLINE=1                            \
      venv/bin/python3                          \
        -m vllm.entrypoints.openai.api_server   \
        --disable-log-requests                  \
        --enable-prefix-caching                 \
        --dtype float16                         \
        --model "$MODEL"                        \
        --served-model-name "$MODEL_NAME"       \
        --tensor-parallel-size 2
vLLM float16 performance, tp=2
INFO 04-25 14:53:01 metrics.py:229] Avg prompt throughput: 11.6 tokens/s, Avg generation throughput: 0.9 tokens/s, Running: 6 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%
INFO 04-25 14:53:06 metrics.py:229] Avg prompt throughput: 647.9 tokens/s, Avg generation throughput: 100.9 tokens/s, Running: 256 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 2.1%, CPU KV cache usage: 0.0%
INFO 04-25 14:53:11 metrics.py:229] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 618.4 tokens/s, Running: 256 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 2.2%, CPU KV cache usage: 0.0%

650! t/s

vLLM float16 performance, tp=2, powerlimit = 250W
INFO 04-25 14:57:27 metrics.py:229] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%
INFO 04-25 14:57:32 metrics.py:229] Avg prompt throughput: 7.5 tokens/s, Avg generation throughput: 0.6 tokens/s, Running: 3 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%
INFO 04-25 14:57:37 metrics.py:229] Avg prompt throughput: 622.6 tokens/s, Avg generation throughput: 241.7 tokens/s, Running: 256 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 2.1%, CPU KV cache usage: 0.0%
INFO 04-25 14:57:43 metrics.py:229] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 696.1 tokens/s, Running: 256 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 2.1%, CPU KV cache usage: 0.0%

700! t/s

vLLM float32 commandline, tp=2
export MODEL=".../hub/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e5e23bbe8e749ef0efcf16cad411a7d23bd23298"
export MODEL_NAME="meta-llama/Meta-Llama-3-8B-Instruct"

env CUDA_VISIBLE_DEVICES=1,0                    \
    HF_HUB_DISABLE_TELEMETRY=1                  \
    HF_HUB_OFFLINE=1                            \
      venv/bin/python3                          \
        -m vllm.entrypoints.openai.api_server   \
        --disable-log-requests                  \
        --enable-prefix-caching                 \
        --dtype float32                         \
        --model "$MODEL"                        \
        --served-model-name "$MODEL_NAME"       \
        --tensor-parallel-size 2
vLLM float32 performance, tp=2
INFO 04-25 14:49:00 metrics.py:229] Avg prompt throughput: 5.9 tokens/s, Avg generation throughput: 0.5 tokens/s, Running: 5 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.1%, CPU KV cache usage: 0.0%
INFO 04-25 14:49:05 metrics.py:229] Avg prompt throughput: 642.1 tokens/s, Avg generation throughput: 49.4 tokens/s, Running: 256 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.1%, CPU KV cache usage: 0.0%
INFO 04-25 14:49:10 metrics.py:229] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 499.8 tokens/s, Running: 256 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 13.1%, CPU KV cache usage: 0.0%
INFO 04-25 14:49:16 metrics.py:229] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 489.1 tokens/s, Running: 256 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 13.2%, CPU KV cache usage: 0.0%

500 t/s

Client
import time
import threading
import random
import string
from openai import OpenAI

client = OpenAI(
    base_url="http://x.x.x.x:8000/v1",
    api_key="token-abc123",
)

def func(name):
  start_time = time.time()

  client.chat.completions.create(
    model="meta-llama/Meta-Llama-3-8B-Instruct",
    messages=[
      {"role": "user", "content": "Hello!"}
    ]
  )

  end_time = time.time()

  print(f"[{name}] Received response in {end_time - start_time}s")

def main():
  for i in range(256):
    threading.Thread(target=func, args=(f"Thread-{i}",)).start()

if __name__ == "__main__":
  main()

Speaking of llama.cpp, fp16 (right after conversion) and Q8_0. For example, Q8_0:

llama.cpp Q8_0 performance
% CUDA_VISIBLE_DEVICES=1 build/bin/llama-bench --model .../Meta-Llama-3-8B-Instruct-Q8_0.gguf
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:   yes
ggml_cuda_init: CUDA_USE_TENSOR_CORES: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: Tesla P40, compute capability 6.1, VMM: yes
| model                          |       size |     params | backend    | ngl | test       |              t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------- | ---------------: |
| llama 8B Q8_0                  |   7.95 GiB |     8.03 B | CUDA       |  99 | pp 512     |  561.38 ± 109.23 |
| llama 8B Q8_0                  |   7.95 GiB |     8.03 B | CUDA       |  99 | tg 128     |     31.17 ± 0.08 |

build: 784e11de (2725)

@sasha0552
Copy link
Contributor

I found that --enable-prefix-caching causes triton errors such as UNREACHABLE executed: Unsupported conversion from f16 to f16 and LLVM ERROR: Cannot select: intrinsic %llvm.nvvm.shfl.sync.bfly.i32.
Can anyone reproduce this? As far as I know, the following conditions must be met:

  1. Pascal GPU
  2. --enable-prefix-caching
  3. 2000+ token prompt

Backtrace of second error (not very useful, but shows that's it's related to triton): sasha0552/triton#1
Should I create a separate issue about this or is this not planned to be fixed on unsupported platforms?

Also @cduk, in case you didn't notice, you closed the PR with that force push because the branches diverged. You should create a new PR, I think (as I know github doesn't allow you to reopen a PR in this case, even if you force push it again with the correct base branch).

@jasonacox
Copy link
Contributor

jasonacox commented Apr 27, 2024

I can open a new PR as a placeholder in the hopes that the >100MB request is granted. I'll only add 6.0 and 6.1, as you mention @cduk . Also, I do see that pytorch is now only supporting sm_60.

>>> torch.__version__
'2.2.1+cu121'
>>> torch.cuda.torch.cuda.get_arch_list()
['sm_50', 'sm_60', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90']
>>> 

Pascal Architecture

  • (+) SM60 or SM_60, compute_60 – Quadro GP100, Tesla P100, DGX-1 (Generic Pascal)
  • (+) SM61 or SM_61, compute_61– GTX 1080, GTX 1070, GTX 1060, GTX 1050, GTX 1030 (GP108), GT 1010 (GP108) Titan Xp, Tesla P40, Tesla P4, Discrete GPU on the NVIDIA Drive PX2
  • (-) SM62 or SM_62, compute_62 – Integrated GPU on the NVIDIA Drive PX2, Tegra (Jetson) TX2

@sasha0552 I'm running on 4 x P100's on CUDA 12.2 with context prompts up to 24k tokens without issue (could get closer to 32k but haven't tried). Average TPS across concurrent 10 threads: 208.3 - Individual Threads: Min TPS: 20.8, Max TPS: 20.8 (#963)

I'm using docker images and running with:

# build
DOCKER_BUILDKIT=1 docker build . --target vllm-openai --tag vllm-openai --no-cache

 # run
docker run -d \
    --shm-size=10.24gb \
    --gpus '"device=0,1,2,3"' \
    -v /data/models:/root/.cache/huggingface \
    --env "HF_TOKEN=xyz" \
    -p 8000:8000 \
    --restart unless-stopped \
    --name vllm-openai \
    vllm-openai \
    --host 0.0.0.0 \
    --model=mistralai/Mistral-7B-Instruct-v0.1 \
    --enforce-eager \
    --dtype=float \
    --gpu-memory-utilization 0.95 \
    --tensor-parallel-size=4

@sasha0552
Copy link
Contributor

sasha0552 commented Apr 27, 2024

@jasonacox you don't run with the --enable-prefix-caching option, which is causing problems (at least for me). As I understand it, this option caches prompts in the kv cache, so that subsequent requests can be processed without prompt processing "prompt throughput" (useful for iterative large prompts).
Also, 20 t/s per request is very impressive. Can you also test 256 requests in parallel? I've provided the script I'm using here in the "Client" spoiler.
I'll look at the code later to try to get this performance on the P40s.

@jasonacox
Copy link
Contributor

jasonacox commented Apr 27, 2024

you don't run with the --enable-prefix-caching option

You're right! I missed that, I'm sorry about that. I have limited memory headroom on these P100's so I wasn't optimizing for cache, just concurrent sessions. Also, I'm running Mistral vs Llama-3. I modified your script to try to get the tokens per second without having to look at the vLLM logs/stats and iterate through different concurrent threads. Please check my math.

EDIT: Updated the docker run and start the load test with 1 thread:

docker run -d \
    --gpus all \
    --shm-size=10.24gb \
    -v /data/models:/root/.cache/huggingface \
    -p 8000:8000 \
    --restart unless-stopped \
    --name $CONTAINER \
    vllm-openai \
    --host 0.0.0.0 \
    --model=mistralai/Mistral-7B-Instruct-v0.1 \
    --dtype=float \
    --worker-use-ray \
    --gpu-memory-utilization 0.95 \
    --disable-log-stats \
    --tensor-parallel-size=4
loadtest.py

import time
import threading
from openai import OpenAI

# Constants
SHORT_REPORT = True

# Globals
stats = {}
client = OpenAI(
    base_url="http://localhost:8000/v1",
    api_key="token-abc123",
)

def func(name):
    start_time = time.time()

    response = client.chat.completions.create(
        model="mistralai/Mistral-7B-Instruct-v0.1",
        messages=[
            {"role": "user", "content": "Hello!"}
        ]
    )

    end_time = time.time()
    response_time = end_time - start_time
    stats[name] = {}
    stats[name]["response_time"] = response_time
    stats[name]["tokens"] = int(response.usage.completion_tokens)
    print(f" - [{name}] Received response in {end_time - start_time}s")

def main(num_threads=256):
    for i in range(num_threads):
        threading.Thread(target=func, args=(f"Thread-{i}",)).start()
    # Wait for threads to finish
    for thread in threading.enumerate():
        if thread != threading.current_thread():
            thread.join()

report = {}
if __name__ == "__main__":
    print("Starting load test...")
    for i in (1, 8, 16, 32, 64, 128, 256, 512):
        print(f"Running {i} threads...")
        main_start = time.time()
        main(i)
        main_end = time.time()

        # Compute Stats
        total_response_time = sum(stats[name]['response_time'] for name in stats)
        total_tokens = sum(stats[name]['tokens'] for name in stats)
        average_response_time = total_response_time / len(stats)
        tokens_per_second = total_tokens / (main_end - main_start)
        tokens_per_thread = total_tokens / len(stats)
        report[i] = f"Total TPS: {tokens_per_second:.2f} - Average Thread TPS: {tokens_per_thread / average_response_time:.2f}"
        print("")
        
    print("Load test complete.")
    print("Results:")
    for threads, result in report.items():
        print(f"Threads: {threads} - {result}")
    print("Done.")

On the 4 x Tesla P100 (250W) system (no cache):

Threads: 1 - Total TPS: 28.76 - Average Thread TPS: 28.77
Threads: 8 - Total TPS: 124.51 - Average Thread TPS: 20.55
Threads: 16 - Total TPS: 175.89 - Average Thread TPS: 21.20
Threads: 32 - Total TPS: 229.26 - Average Thread TPS: 15.53
Threads: 64 - Total TPS: 256.15 - Average Thread TPS: 7.66
Threads: 128 - Total TPS: 309.02 - Average Thread TPS: 4.49
Threads: 256 - Total TPS: 358.97 - Average Thread TPS: 2.58
Threads: 512 - Total TPS: 308.25 - Average Thread TPS: 1.47

Graph The peak TPS may show up if I added a few more data points (more threads). The interesting part is the per-thread TPS that seems to find optimal performance at 16 concurrent threads and starts to dip after that. image

@sasha0552
Copy link
Contributor

sasha0552 commented Apr 28, 2024

@jasonacox
As I understand it, --enable-prefix-caching does not increase memory consumption. It just stores the prompts in unused space in the kv cache, so that they can be reused when needed (or simply discarded if there is no room for a new prompt). So can you try this option and let me know if there are any crashes? I'll submit an issue if this can be reproduced on the P100 as well.

As for your code, it looks good. I downloaded mistral to test it on the P40.

1x Tesla P40 (250W, `--dtype float16`):

Threads: 8 - Total TPS: 25.63 - Average Thread TPS: 4.96
Threads: 16 - Total TPS: 20.62 - Average Thread TPS: 6.76
Threads: 32 - Total TPS: 65.73 - Average Thread TPS: 4.99
Threads: 64 - Total TPS: 70.26 - Average Thread TPS: 3.74
Threads: 128 - Total TPS: 161.11 - Average Thread TPS: 2.38
Threads: 256 - Total TPS: 166.03 - Average Thread TPS: 1.25
Threads: 512 - Total TPS: 231.96 - Average Thread TPS: 0.75

2x Tesla P40 (250W, `--dtype float32 --tensor-parallel-size 2`):

Threads: 8 - Total TPS: 41.81 - Average Thread TPS: 7.81
Threads: 16 - Total TPS: 79.21 - Average Thread TPS: 7.43
Threads: 32 - Total TPS: 82.54 - Average Thread TPS: 6.15
Threads: 64 - Total TPS: 145.50 - Average Thread TPS: 4.39
Threads: 128 - Total TPS: 162.96 - Average Thread TPS: 2.34
Threads: 256 - Total TPS: 237.33 - Average Thread TPS: 1.49
Threads: 512 - Total TPS: 219.46 - Average Thread TPS: 0.90

2x Tesla P40 (250W, `--dtype float16 --tensor-parallel-size 2`):

Threads: 1 - Total TPS: 17.15 - Average Thread TPS: 17.16
Threads: 8 - Total TPS: 54.83 - Average Thread TPS: 8.76
Threads: 16 - Total TPS: 72.03 - Average Thread TPS: 7.65
Threads: 32 - Total TPS: 109.18 - Average Thread TPS: 7.80
Threads: 64 - Total TPS: 124.72 - Average Thread TPS: 4.64
Threads: 128 - Total TPS: 215.55 - Average Thread TPS: 3.06
Threads: 256 - Total TPS: 287.08 - Average Thread TPS: 1.74
Threads: 512 - Total TPS: 306.58 - Average Thread TPS: 1.02

Not too bad, I guess (considering you have twice as much computing power). vLLM (or maybe pytorch) definitely considers the P40 capabilities.

Because otherwise those numbers would be at least 64 times worse than yours.

image
image

@jasonacox
Copy link
Contributor

I'm getting an error trying to activate --enable-prefix-caching:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/openai/api_server.py", line 159, in <module>
    engine = AsyncLLMEngine.from_engine_args(
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 361, in from_engine_args
    engine = cls(
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 319, in __init__
    self.engine = self._init_engine(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 437, in _init_engine
    return engine_class(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 214, in __init__
    self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
  File "/usr/local/lib/python3.10/dist-packages/vllm/core/scheduler.py", line 265, in __init__
    self.block_manager = BlockSpaceManagerImpl(
  File "/usr/local/lib/python3.10/dist-packages/vllm/core/block_manager_v1.py", line 218, in __init__
    raise NotImplementedError(
NotImplementedError: Sliding window is not allowed with prefix caching enabled!

It seems to be an issue with mistral: #3781 - Are you using Llama-3?

@sasha0552
Copy link
Contributor

Yes, I am using Llama-3. Can you test it or mistralai/Mistral-7B-Instruct-v0.2? In v0.2, there is no sliding window, instead there is native 32k context.

@jasonacox
Copy link
Contributor

@sasha0552 Does mistralai/Mistral-7B-Instruct-v0.2 need a chat template or different stop token in vlllm? I'm getting non-stop generation for each test.

@sasha0552
Copy link
Contributor

sasha0552 commented Apr 28, 2024

There is a slight difference in the chat template between versions, but I don't think it matters.
image
You can try to copy v0.1 chat template anyway.
What kind of tokens are you getting? Is there [INST] or EOS? Anyway, as a test you can just limit the number of tokens output with max_tokens. You can do that like this:

    response = client.chat.completions.create(
        model="mistralai/Mistral-7B-Instruct-v0.1",
        messages=[
            {"role": "user", "content": "Hello!"}
        ],
        max_tokens=128,
    )

@jasonacox
Copy link
Contributor

jasonacox commented Apr 28, 2024

You're right. It is working as it should, just incredibly chatty! Holy smokes is that annoying!! 😂 The max_tokens make it tolerable.

  • No errors - I am using --enable-prefix-caching but I'm not seeing any errors. I scanned the log for "Unsupported" and "fp16" just in case I missed anything unusual.
  • Results with dtype=float32
    Threads: 1 - Total TPS: 36.28 - Average Thread TPS: 36.28
    Threads: 8 - Total TPS: 183.36 - Average Thread TPS: 23.95
    Threads: 16 - Total TPS: 306.30 - Average Thread TPS: 21.07
    Threads: 32 - Total TPS: 559.59 - Average Thread TPS: 18.11
    Threads: 64 - Total TPS: 682.53 - Average Thread TPS: 10.95
    Threads: 128 - Total TPS: 567.58 - Average Thread TPS: 5.70
    Threads: 256 - Total TPS: 567.70 - Average Thread TPS: 3.40
    Threads: 512 - Total TPS: 574.43 - Average Thread TPS: 1.95
  • Results with dtype=float16
    Threads: 1 - Total TPS: 42.93 - Average Thread TPS: 42.95
    Threads: 8 - Total TPS: 122.13 - Average Thread TPS: 15.81
    Threads: 16 - Total TPS: 223.63 - Average Thread TPS: 14.22
    Threads: 32 - Total TPS: 554.17 - Average Thread TPS: 17.65
    Threads: 64 - Total TPS: 723.73 - Average Thread TPS: 11.63
    Threads: 128 - Total TPS: 590.22 - Average Thread TPS: 5.92
    Threads: 256 - Total TPS: 641.44 - Average Thread TPS: 3.77
    Threads: 512 - Total TPS: 599.08 - Average Thread TPS: 2.11
docker command
LLM=mistralai/Mistral-7B-Instruct-v0.2
MODEL=mistralai/Mistral-7B-Instruct-v0.2
docker run -d \
    --gpus all \
    --shm-size=10.24gb \
    -v /data/models:/root/.cache/huggingface \
    -p 8000:8000 \
    --env "HF_TOKEN=x" \
    --restart unless-stopped \
    --name vllm-openai \
    vllm-openai \
    --host 0.0.0.0 \
    --model=$MODEL \
    --served-model-name $LLM \
    --dtype=float16 \
    --enable-prefix-caching \
    --gpu-memory-utilization 0.95 \
    --disable-log-stats \
    --worker-use-ray \
    --tensor-parallel-size=4

@sasha0552
Copy link
Contributor

sasha0552 commented Apr 28, 2024

Can you try a long prompt (2000+ tokens)? No need for batch processing or performance measurement, just send a same long prompt two times (w/o parallelism, just one after the other). In my case it causes a crash. And thank you for patience and testing.

@sasha0552
Copy link
Contributor

sasha0552 commented Apr 28, 2024

I accidentally got a stack trace from python, it's happening there:

_fwd_kernel[grid](

Something wrong with this function:
@triton.jit
def _fwd_kernel(

I'll look into this tomorrow and send a PR if I fix it (I have no experience with pytorch/triton, so this should be a good exercise for me). First of all, I will find the exact LoC that generates this intrinsic.

Stacktrace
LLVM ERROR: Cannot select: intrinsic %llvm.nvvm.shfl.sync.bfly.i32
*** SIGABRT received at time=1714270500 on cpu 5 ***
...
Fatal Python error: Aborted

Stack (most recent call first):
  File "/mnt/ml/vllm/venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 200 in llir_to_ptx
  File "/mnt/ml/vllm/venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 381 in <lambda>
  File "/mnt/ml/vllm/venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 543 in compile
  File "/mnt/ml/vllm/venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 532 in run
  File "/mnt/ml/vllm/vllm/attention/ops/prefix_prefill.py", line 708 in context_attention_fwd
  File "/mnt/ml/vllm/venv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115 in decorate_context
  File "/mnt/ml/vllm/vllm/attention/ops/paged_attn.py", line 177 in forward_prefix
  File "/mnt/ml/vllm/vllm/attention/backends/xformers.py", line 237 in forward
  File "/mnt/ml/vllm/vllm/attention/layer.py", line 48 in forward
  File "/mnt/ml/vllm/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520 in _call_impl
  File "/mnt/ml/vllm/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511 in _wrapped_call_impl
  File "/mnt/ml/vllm/vllm/model_executor/models/llama.py", line 166 in forward
  ...

@jasonacox
Copy link
Contributor

Wow! Nice @sasha0552 - You may want to open a new Issue for this instead of us working it in this poor zombie PR. 😉

@sasha0552
Copy link
Contributor

@jasonacox I've created an issue - #4438.

@elabz
Copy link

elabz commented Jun 7, 2024

What ended up happening with Pascal support, anyone knows? Is it still deemed not worthy? Just making that assumption given that it's not working in 0.4.3

@sasha0552
Copy link
Contributor

@elabz It's still not merged, the new PR is #4409.

@cduk
Copy link
Author

cduk commented Jun 8, 2024

@elabz It is not yet merged, but hopefully with the PyPI decision, it can be merged soon. Until then, I have a Pascal enabled repo here which supports it: https://github.com/cduk/vllm-pascal

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants