Skip to content

Commit

Permalink
optimized topp/topk calculation (#195)
Browse files Browse the repository at this point in the history
## One line description

Use topk instead of sort for topp/topk calculation under certain
conditions (scalar value of p and k).

## Details

Instead of using `k` for topk, we use `_padded_k`, which is strictly
larger than k and monotonically non decreasing.

We need/use `_padded_k > k` for cases where the smallest value of the
topk=k values has some values beyond k, (for example for
[9,8,8,8,7,7,7], with k=3, we have [9,8,8,8], which is 4 instead of 3
values),

To prevent excessive recompilations, anytime we require an expansion of
`_padded_k` we increment with a fixed constant `_increment` (usually
>1), to have a bucketed approach to prevent multiple shapes


### Basic outline

1. perform topk with `_padded_k`
2. find the "kth" value in each row (smallest number that will be in
topk), this is variable `num_duplicates_of_smallest_of_topk`
3. find maximum of number of duplicates, this variable is
`max_num_duplicates_of_smallest_of_topk`
4. check if `_padded_k` is big enough to contain
`max_num_duplicates_of_smallest_of_topk`. if not, then expand
`_padded_k`, and redo the topk again with expanded `_padded_k`
6. maskout the values that are extra in `_padded_k`
7. move to doing topp


## Perf benefit

### Using benchmark_throughput.py

To check benefit of this PR, make following change in
`benchmark_throughput.py`:
```
diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py
index ff33e3dc..3383dea8 100644
--- a/benchmarks/benchmark_throughput.py
+++ b/benchmarks/benchmark_throughput.py
@@ -116,8 +116,9 @@ def run_vllm(
         sampling_params.append(
             SamplingParams(
                 n=n,
-                temperature=0.0 if use_beam_search else 1.0,
-                top_p=1.0,
+                temperature=1.0,  #0.0 if use_beam_search else 1.0,
+                top_p=0.95,
+                top_k=20,
                 use_beam_search=use_beam_search,
                 ignore_eos=True,
                 max_tokens=output_len,

 ```


`VLLM_SKIP_WARMUP=true VLLM_GRAPH_RESERVED_MEM=0.2 VLLM_GRAPH_PROMPT_RATIO=0.8 VLLM_DECODE_BS_BUCKET_MIN=1 VLLM_DECODE_BLOCK_BUCKET_STEP=64 VLLM_DECODE_BLOCK_BUCKET_MIN=64 python benchmark_throughput.py --model /root/sasarkar/llama3-8b/ --device hpu --seed 2024 --backend vllm --num-prompts 100 --dtype bfloat16 --input-len=256 --output-len=512`

in the numbers below there is a **49%** increase in thruput in the case with warmup, and **30%** increase in thruput in the case without warmup


#### with opt + warmup

Processed prompts: 100%|█████████████████████████████████████████████████████████████████████| 100/100 [00:22<00:00,  4.37it/s, est. speed input: 1119.66 toks/s, output: 2239.33 toks/s]
Throughput: 4.37 requests/s, 3354.58 tokens/s


#### with opt + skip warmup

Processed prompts: 100%|██████████████████████████████████████████████████████████████████████| 100/100 [00:46<00:00,  2.17it/s, est. speed input: 556.32 toks/s, output: 1112.63 toks/s]
Throughput: 2.17 requests/s, 1667.89 tokens/s


#### without opt + warmup

Processed prompts: 100%|██████████████████████████████████████████████████████████████████████| 100/100 [00:34<00:00,  2.93it/s, est. speed input: 749.24 toks/s, output: 1498.48 toks/s]
Throughput: 2.92 requests/s, 2245.74 tokens/s


#### without opt + skip warmup

Processed prompts: 100%|███████████████████████████████████████████████████████████████████████| 100/100 [00:59<00:00,  1.67it/s, est. speed input: 428.49 toks/s, output: 856.99 toks/s]
Throughput: 1.67 requests/s, 1284.85 tokens/s

### Using server Client
(Data collected by Peter)
[baseline](https://github.com/HabanaAI/vllm-fork/commits/a7763a7a76b4531ed7907549724df2949d9225bf/)
all numbers on 1.17-495
third column [branch ](https://github.com/HabanaAI/vllm-fork/commits/ae_benchmark_9_10_24/)

| model | TP | baseline HPU thruput   | baseline HPU + this PR thruput | baseline HPU + this PR + other opt | 
| -------- | ------- | ------- | ------- | ------- |
| llama3 8b | 1 | 950  | 1296    | 1306 | 
| llama3 8b | 4 | 1347  | 1969    | 2077 | 
| llama3 70b | 4 | 368  | 394    | 394 | 
| qwen 72b | 4 | 731  | 726    | 815 |


### Without delayed sampling 
On habana_main f858d43
```VLLM_GRAPH_RESERVED_MEM=0.2 VLLM_GRAPH_PROMPT_RATIO=0.8
VLLM_DECODE_BS_BUCKET_MIN=1 VLLM_DECODE_BLOCK_BUCKET_STEP=64
VLLM_DECODE_BLOCK_BUCKET_MIN=64 python benchmark_throughput.py --model
/root/sasarkar/llama3-8b/ --device hpu --seed 2024 --backend vllm
--num-prompts 100 --dtype bfloat16 --input-len=256 --output-len=512```

Without change
Throughput: 3.32 requests/s, 2550.85 tokens/s

With change:
Throughput: 5.17 requests/s, 3967.58 tokens/s




## Extra Notes
1. Works only for "scalar" case, though it might be possible to extend
the basic idea (topk instead of sort) for vector case as well. (Outline
of this is: find max k in topk vector, then perform topk using that,
etc. needs some bucketing possibly to prevent dyn shapes etc)
2. Need an additional check in `_init_sampling_tensors` to determine if
its scalar case. This has a minor perf hit. ideally if someone could
tell us that its a scalar from the top itself...
3. Some tradeoffs can be made, where we use a sufficiently large
padded_k (which is still smaller than vocab size) from the beginning,
and hope that every case lands within that bucket. Cases that wont land
are expected to be very, very rare. For example if padded_k = max(2 * k,
100) is used, and k = say 50, then we need the smallest of the topk
value to repeat 50 times with same probability, which is exceedingly
unlikely. If we trade off this mathematical improbability, then we can
do with only 1 topk op, which might be faster
4. There is a `fliplr` in the code, which could be removed, if we can
compute reverse cumsum. however the formula for reverse cumsum as
expressed [here ](pytorch/pytorch#33520), ` x
+ torch.sum(x, dim=1, keepdims=True) - torch.cumsum(x, dim=1)` is
numerically unstable, because of the addition/subtraction. It works well
enough on ints and large numbers, but not on small probability values.
5. The value of `k` affects the gains we might get from this. For
example in the expt shown above, with k=20, thruput increases from
1284.85 to 1667.89 (30% gain). But if k = 2000, instead of 20,
throughput increases from 1127.34 to 1289.26 (14% gain). Thus the gain %
might decrease with increasing k, as asymptotically topk would probably
converges to sort's performance for large k. However practically k is
pretty small.
6. For larger models, the gains may be less, as they are more device
bound probably
7. Cumsum may be taking long. Maybe try below. [Initial
try](b392ff8)
```
import torch
y = torch.tensor([[1,2,3], [4,5,6]])
mask1 = torch.tensor([[[1,0,0],[1,1,0],[1,1,1]], [[1,0,0],[1,1,0],[1,1,1]]])
torch.sum(y.unsqueeze(1)*mask1,2)
```
or
```
F.conv1d(torch.tensor([[[0,0,0,0,1,2,3,4,5]], [[0,0,0,0,6,7,8,9,10.0]]]), torch.ones([1,1,5], dtype=torch.float32))
```
FIX #xxxx (*link existing issues this PR will resolve*)

**BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE
DESCRIPTION ABOVE**

---

<details>
<!-- inside this <details> section, markdown rendering does not work, so
we use raw html here. -->
<summary><b> PR Checklist (Click to Expand) </b></summary>

<p>Thank you for your contribution to vLLM! Before submitting the pull
request, please ensure the PR meets the following criteria. This helps
vLLM maintain the code quality and improve the efficiency of the review
process.</p>

<h3>PR Title and Classification</h3>
<p>Only specific types of PRs will be reviewed. The PR title is prefixed
appropriately to indicate the type of change. Please use one of the
following:</p>
<ul>
    <li><code>[Bugfix]</code> for bug fixes.</li>
<li><code>[CI/Build]</code> for build or continuous integration
improvements.</li>
<li><code>[Doc]</code> for documentation fixes and improvements.</li>
<li><code>[Model]</code> for adding a new model or improving an existing
model. Model name should appear in the title.</li>
<li><code>[Frontend]</code> For changes on the vLLM frontend (e.g.,
OpenAI API server, <code>LLM</code> class, etc.) </li>
<li><code>[Kernel]</code> for changes affecting CUDA kernels or other
compute kernels.</li>
<li><code>[Core]</code> for changes in the core vLLM logic (e.g.,
<code>LLMEngine</code>, <code>AsyncLLMEngine</code>,
<code>Scheduler</code>, etc.)</li>
<li><code>[Hardware][Vendor]</code> for hardware-specific changes.
Vendor name should appear in the prefix (e.g.,
<code>[Hardware][AMD]</code>).</li>
<li><code>[Misc]</code> for PRs that do not fit the above categories.
Please use this sparingly.</li>
</ul>
<p><strong>Note:</strong> If the PR spans more than one category, please
include all relevant prefixes.</p>

<h3>Code Quality</h3>

<p>The PR need to meet the following code quality standards:</p>

<ul>
<li>We adhere to <a
href="https://google.github.io/styleguide/pyguide.html">Google Python
style guide</a> and <a
href="https://google.github.io/styleguide/cppguide.html">Google C++
style guide</a>.</li>
<li>Pass all linter checks. Please use <a
href="https://github.com/vllm-project/vllm/blob/main/format.sh"><code>format.sh</code></a>
to format your code.</li>
<li>The code need to be well-documented to ensure future contributors
can easily understand the code.</li>
<li>Include sufficient tests to ensure the project to stay correct and
robust. This includes both unit tests and integration tests.</li>
<li>Please add documentation to <code>docs/source/</code> if the PR
modifies the user-facing behaviors of vLLM. It helps vLLM user
understand and utilize the new features or changes.</li>
</ul>

<h3>Notes for Large Changes</h3>
<p>Please keep the changes as concise as possible. For major
architectural changes (>500 LOC excluding kernel/data/config/test), we
would expect a GitHub issue (RFC) discussing the technical design and
justification. Otherwise, we will tag it with <code>rfc-required</code>
and might not go through the PR.</p>

<h3>What to Expect for the Reviews</h3>

<p>The goal of the vLLM team is to be a <i>transparent reviewing
machine</i>. We would like to make the review process transparent and
efficient and make sure no contributor feel confused or frustrated.
However, the vLLM team is small, so we need to prioritize some PRs over
others. Here is what you can expect from the review process: </p>

<ul>
<li> After the PR is submitted, the PR will be assigned to a reviewer.
Every reviewer will pick up the PRs based on their expertise and
availability.</li>
<li> After the PR is assigned, the reviewer will provide status update
every 2-3 days. If the PR is not reviewed within 7 days, please feel
free to ping the reviewer or the vLLM team.</li>
<li> After the review, the reviewer will put an <code>
action-required</code> label on the PR if there are changes required.
The contributor should address the comments and ping the reviewer to
re-review the PR.</li>
<li> Please respond to all comments within a reasonable time frame. If a
comment isn't clear or you disagree with a suggestion, feel free to ask
for clarification or discuss the suggestion.
 </li>
</ul>

<h3>Thank You</h3>

<p> Finally, thank you for taking the time to read these guidelines and
for your interest in contributing to vLLM. Your contributions make vLLM
a great tool for everyone! </p>


</details>
  • Loading branch information
michalkuligowski authored Sep 17, 2024
2 parents f4ac1f9 + 2ab316d commit 4c1ca3a
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 3 deletions.
61 changes: 60 additions & 1 deletion tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from transformers import GenerationConfig, GenerationMixin

from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.sampler import ApplyToppTopkScalar, Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
Expand Down Expand Up @@ -700,3 +700,62 @@ def test_sampling_params(sampling_params: List[SamplingParams]):

assert tokens1[0] == tokens2[1]
assert tokens1[1] == tokens2[0]


def test_topk_topk_scalar():
obj1 = ApplyToppTopkScalar(2)
assert ApplyToppTopkScalar._padded_k == 0
x = torch.tensor([[9, 9, 8, 8, 8, 8, 7, 7, 7.0],
[10, 10, 9, 9, 9, 8, 5, 5, 5]])

retval1 = obj1(x, p=0.9, k=5)
ninf = -float("inf")
expected1 = torch.tensor([[9., 9., 8., 8., 8., 8., ninf, ninf, ninf],
[10., 10., 9., 9., 9., ninf, ninf, ninf, ninf]])
assert torch.all(retval1 == expected1).item()
assert ApplyToppTopkScalar._padded_k == 9

obj2 = ApplyToppTopkScalar(2)
assert obj2._padded_k == 9

x = torch.tensor([[2, 2, 9, 9, 2, 2, 1, 1, 1.0],
[10, 9, 9, 5, 9, 9, 5, 9, 10]])
retval2 = obj2(x, p=0.9, k=5)
expected2 = torch.tensor(
[[ninf, ninf, 9., 9., ninf, ninf, ninf, ninf, ninf],
[10., ninf, 9., ninf, 9., 9., ninf, 9., 10.]])
assert torch.all(retval2 == expected2).item()
assert obj2._padded_k == 9

retval3 = obj2(x, p=1.0, k=5)
expected3 = torch.tensor([[2., 2., 9., 9., 2., 2., ninf, ninf, ninf],
[10., 9., 9., ninf, 9., 9., ninf, 9., 10.]])

assert torch.all(retval3 == expected3).item()

# this should not be done in general, doing it here for testing purposes
ApplyToppTopkScalar._padded_k = 0
x = torch.tensor([[1, 1, 1, 9, 8, 1, 1, 1, 1.0],
[2, 1, 2, 2, 1, 1, 1, 1, 1]])
obj3 = ApplyToppTopkScalar(2)
retval4 = obj3(x, p=0.9, k=2)
expected4 = torch.tensor(
[[ninf, ninf, ninf, 9., 8., ninf, ninf, ninf, ninf],
[2., ninf, 2., 2., ninf, ninf, ninf, ninf, ninf]])
assert torch.all(retval4 == expected4).item()
assert obj3._padded_k == 4
y = torch.tensor([[8, 8, 8, 9, 8, 1, 1, 1, 1.0],
[2, 1, 2, 2, 1, 1, 1, 1, 1]])
retval5 = obj3(y, p=0.9, k=2)
assert obj3._padded_k == 8
expected5 = torch.tensor([[8., 8., 8., 9., 8., ninf, ninf, ninf, ninf],
[2., ninf, 2., 2., ninf, ninf, ninf, ninf,
ninf]])
assert torch.all(retval5 == expected5).item()
y = torch.tensor([[8, 8, 8, 9, 8, 8, 1, 1, 1.0],
[2, 1, 2, 2, 3, 1, 1, 1, 1]])
retval6 = obj3(y, p=0.9, k=2)
expected6 = torch.tensor([[8., 8., 8., 9., 8., 8., ninf, ninf, ninf],
[2., ninf, 2., 2., 3., ninf, ninf, ninf, ninf]])
assert torch.all(retval6 == expected6).item()
assert obj3._padded_k == 8
112 changes: 110 additions & 2 deletions vllm/model_executor/layers/sampler.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""A layer that samples the next tokens from the model's outputs."""
import itertools
import math
from math import inf
from typing import Dict, List, Optional, Tuple

Expand Down Expand Up @@ -77,6 +78,13 @@ def _init_sampling_tensors(
self._do_penalties = do_penalties
self._do_top_p_top_k = do_top_p_top_k
self._do_min_p = do_min_p
self._top_p_scalar = sampling_tensors.top_ps[0].item()
self._top_k_scalar = sampling_tensors.top_ks[0].item()
scalar_p = torch.all(sampling_tensors.top_ps == self._top_p_scalar)
scalar_k = torch.all(sampling_tensors.top_ks == self._top_k_scalar)
self._scalar_p_and_k = (scalar_p and scalar_k).item()
if self._scalar_p_and_k and self._do_top_p_top_k:
self._apply_top_k_top_p_opt = ApplyToppTopkScalar(5)

def forward(
self,
Expand Down Expand Up @@ -122,8 +130,13 @@ def forward(
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))

if do_top_p_top_k:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)
if self._scalar_p_and_k:
logits = self._apply_top_k_top_p_opt(logits,
self._top_p_scalar,
self._top_k_scalar)
else:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)

if do_min_p:
logits = _apply_min_p(logits, sampling_tensors.min_ps)
Expand Down Expand Up @@ -198,6 +211,101 @@ def _get_bin_counts_and_mask(
return bin_counts, mask


class ApplyToppTopkScalar():
"""
The original implementation of _apply_top_k_top_p is more general
as it uses vector topp, topk
However in a lot of cases, topp and topk is same for all batch elements
For such "scalar" topp, topk cases, we can use this class
The main optimizations in this class is:
Use topk instead of sort, which is much faster especially for small k.
However just using topk might not suffice in cases as shown below
Consider a tensor: 9 9 8 8 8 8 7 7 7
Topk, with k=5, on this yields 9 9 8 8 8
The value "8" is on the boundary, hence the last "8" gets snipped off
However the original implementation accepts all the "8"s,
so it should output:
9 9 8 8 8 8 (6 values, even though k=5)
To ensure these semantics, we perform topk with _padded_k elements
If we find more boundary elements left over,
then we keep incrementing _padded_k
and in future calls use the expanded value of __padded_k
The increments to _padded_k should be done
with value > 1 to prevent excessive recompilations
due to dynamic shapes (the output shape of the topk)
The main logic of this is in __call__
This is a class instead of a function, just to keep track of
the monotonic non-decreasing state _padded_k
"""
_padded_k = 0

def __init__(self, increment: int):
self._increment = increment

def __call__(self, logits: torch.Tensor, p: float, k: int):
if k > ApplyToppTopkScalar._padded_k:
ApplyToppTopkScalar._padded_k = min(k + self._increment,
logits.shape[1])

vals, idx = torch.topk(logits, k=ApplyToppTopkScalar._padded_k, \
dim=1, sorted=True)

# this "if" checks if we have bucketed so much that
# we have padded k upto shape of logits
if ApplyToppTopkScalar._padded_k != logits.shape[1]:
smallest_of_top_k = vals[:, k - 1]
num_duplicates_of_smallest_of_topk = torch.sum(
logits == smallest_of_top_k.unsqueeze(1), 1)
max_num_duplicates_of_smallest_of_topk = torch.max(
num_duplicates_of_smallest_of_topk).item()

# there are n repeats for a border
# (border meaning the smallest value of the top k).
# we do not know if only 1 or 2 or (n-1)
# of them lie outside the kth border,
# so we choose to conservatively increase by n-1
# when num_duplicates > _padded_k - k
if max_num_duplicates_of_smallest_of_topk - 1 > (
ApplyToppTopkScalar._padded_k - k):
incr = int(
math.ceil((max_num_duplicates_of_smallest_of_topk - 1) /
self._increment) * self._increment)
# this while loop should be traversed at most twice,
# because we dont increment by self._increment and retry
# instead we compute incr in one go
ApplyToppTopkScalar._padded_k = min(
ApplyToppTopkScalar._padded_k + incr, logits.shape[1])

# recompute topk with expanded padded_k
vals, idx = torch.topk(logits, \
k=ApplyToppTopkScalar._padded_k, \
dim=1, sorted=True)

idx = torch.fliplr(idx)
vals = torch.fliplr(vals)

top_k_smallest_val_idx = vals.size(1) - k
top_k_mask = vals[:, top_k_smallest_val_idx].unsqueeze(1)
top_k_mask = vals < top_k_mask
vals.masked_fill_(top_k_mask, -float("inf"))

probs_sort = vals.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum <= (1 - p)
top_p_mask[:, -1] = False
vals.masked_fill_(top_p_mask, -float("inf"))

new_logits = torch.full(logits.shape,
-float("inf"),
device=logits.device)
new_logits.scatter_(1, idx, vals.to(new_logits.dtype))

return new_logits


def _apply_min_tokens_penalty(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
Expand Down

0 comments on commit 4c1ca3a

Please sign in to comment.