From 80e254834de9c3c34eaca02d8880e952b3daf344 Mon Sep 17 00:00:00 2001 From: James Whedbee Date: Wed, 20 Mar 2024 16:05:03 -0500 Subject: [PATCH 01/10] [Bugfix] Fix ROCm support in CMakeLists.txt (#3534) --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 150fcebeb8878..66842e6845edd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,7 +51,7 @@ append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path") # Ensure the 'nvcc' command is in the PATH find_program(NVCC_EXECUTABLE nvcc) -if (NOT NVCC_EXECUTABLE) +if (CUDA_FOUND AND NOT NVCC_EXECUTABLE) message(FATAL_ERROR "nvcc not found") endif() From 426ec4ec6711b4180538cd56b9f6b856e5276a1f Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 20 Mar 2024 14:45:08 -0700 Subject: [PATCH 02/10] [1/n] Triton sampling kernel (#3186) Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> --- tests/kernels/test_rand.py | 51 +++ tests/kernels/test_sampler.py | 196 ++++++++++ tests/samplers/test_sampler.py | 6 +- vllm/model_executor/layers/ops/__init__.py | 0 vllm/model_executor/layers/ops/rand.py | 157 ++++++++ vllm/model_executor/layers/ops/sample.py | 405 +++++++++++++++++++++ vllm/model_executor/layers/sampler.py | 109 +++++- vllm/model_executor/sampling_metadata.py | 129 ++++++- vllm/sequence.py | 3 + vllm/worker/model_runner.py | 40 +- 10 files changed, 1072 insertions(+), 24 deletions(-) create mode 100644 tests/kernels/test_rand.py create mode 100644 tests/kernels/test_sampler.py create mode 100644 vllm/model_executor/layers/ops/__init__.py create mode 100644 vllm/model_executor/layers/ops/rand.py create mode 100644 vllm/model_executor/layers/ops/sample.py diff --git a/tests/kernels/test_rand.py b/tests/kernels/test_rand.py new file mode 100644 index 0000000000000..3b9d0d732acf5 --- /dev/null +++ b/tests/kernels/test_rand.py @@ -0,0 +1,51 @@ +import torch +import pytest +import random + +from vllm.model_executor.layers.ops.rand import seeded_uniform +from vllm.model_executor.utils import set_random_seed + + +@pytest.mark.parametrize("dtype", + [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("use_3d", [True, False]) +def test_seeded_uniform(dtype: torch.dtype, use_3d: bool): + device = "cuda" + for seed in range(512): + set_random_seed(seed) + rows = random.randint(1, 512) + cols = random.randint(1, 64000) + if use_3d: + third_dim = random.randint(2, 10) + dims = [rows, third_dim, cols] + else: + dims = [rows, cols] + seeds = torch.randint(torch.iinfo(torch.long).min, + torch.iinfo(torch.long).max, (rows, ), + device=device) + + # Test that the same seed produces the same output + out = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device) + out2 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device) + torch.testing.assert_close(out, out2) + # del to save memory + del out2 + + out3 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device) + torch.testing.assert_close(out, out3) + # del to save memory + del out3 + + # Initialize out tensor with garbage to ensure that it is overwritten + out_with_tensor = seeded_uniform( + *dims, + out=torch.full( + (*dims, ), + -1, + dtype=dtype, + device=device, + ), + seeds=seeds, + dtype=dtype, + ) + torch.testing.assert_close(out, out_with_tensor) diff --git a/tests/kernels/test_sampler.py b/tests/kernels/test_sampler.py new file mode 100644 index 0000000000000..5f8c51fb074f4 --- /dev/null +++ b/tests/kernels/test_sampler.py @@ -0,0 +1,196 @@ +import gc + +import torch +import pytest +import triton +import triton.language as tl + +from vllm.model_executor.layers.ops.sample import ( + _uniform_to_exponential, sample, get_num_triton_sampler_splits, + MAX_TRITON_N_COLS) +from vllm.model_executor.utils import set_random_seed +from vllm.model_executor.sampling_metadata import SamplingTensors + +SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size +MULTI_SPLIT_VOCAB_SIZE = MAX_TRITON_N_COLS + 100 + + +@pytest.fixture(autouse=True) +def _cleanup(): + yield + gc.collect() + torch.cuda.empty_cache() + + +@triton.jit +def _uniform_to_exponential_kernel(input, output, n: tl.constexpr): + idx = tl.arange(0, n) + x = tl.load(input + idx) + y = _uniform_to_exponential(x) + tl.store(output + idx, y) + + +def test_uniform_to_exponential(): + """Test that we can convert uniform to exponential without div by 0.""" + input = torch.tensor([0.0, 1.0 - torch.finfo(torch.float32).eps], + dtype=torch.float32, + device="cuda") + output = torch.zeros(input.shape, dtype=torch.float32, device="cuda") + _uniform_to_exponential_kernel[(1, )](input, output, 2) + assert torch.all(torch.isfinite(output)) + assert torch.all(output > 0) + assert torch.all(torch.isfinite(torch.full_like(output, 1.0) / output)) + + +@pytest.mark.parametrize("random_sampling", [True, False, "mixed"]) +@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5]) +@pytest.mark.parametrize("modify_greedy_probs", [True, False]) +@pytest.mark.parametrize("seed", [1337]) +@pytest.mark.parametrize("vocab_size", + [SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE]) +@pytest.mark.parametrize("save_logprobs", [True, False]) +def test_sample_decoding_only(random_sampling, max_best_of, + modify_greedy_probs, seed, vocab_size, + save_logprobs): + set_random_seed(seed) + bs = 8 + probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda") + for i in range(bs): + probs[i, i * (vocab_size // bs)] = 1.0 + logprobs = torch.rand_like(probs) + sample_indices = torch.arange(bs, dtype=torch.long, device="cuda") + n_splits = get_num_triton_sampler_splits(probs.shape[1]) + if random_sampling == "mixed": + random_sampling_mask = (torch.rand( + (1, bs), device="cuda") < 0.5).expand(n_splits, bs) + elif random_sampling: + random_sampling_mask = torch.ones((n_splits, bs), + dtype=torch.bool, + device="cuda") + else: + random_sampling_mask = torch.zeros((n_splits, bs), + dtype=torch.bool, + device="cuda") + + seeds = torch.randint(1, + torch.iinfo(torch.long).max, (n_splits, bs), + device="cuda").mul_(random_sampling_mask) + sampled_tokens, sampled_logprobs, sampled_modified_probs = sample( + probs=probs, + logprobs=logprobs, + sample_indices=sample_indices, + seeds=seeds, + max_best_of=max_best_of, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=save_logprobs, + _save_modified_probs=True) + assert sampled_tokens.shape == (bs, max_best_of) + for i in range(bs): + assert torch.all(sampled_tokens[i] == i * (vocab_size // bs)) + request_uses_random_sampling = random_sampling_mask[0, i] + if modify_greedy_probs and not request_uses_random_sampling: + # If we are modifying greedy probs and the request is greedy, + # we want to make sure the probs tensor is modified in place + assert torch.allclose( + probs[i][sampled_tokens[i]], + torch.full_like(probs[i][sampled_tokens[i]], 1.0)) + assert torch.sum(probs[i]) == 1.0 + assert torch.allclose( + sampled_modified_probs[i][0], + torch.full_like(sampled_modified_probs[i][0], 1.0)) + elif request_uses_random_sampling: + # If the request is random, we want to make sure + # sampled_modified_probs tensor has noise added + # (and thus is different from probs tensor) + assert not torch.allclose(sampled_modified_probs[i][0], + probs[i][sampled_tokens[i]]) + elif not request_uses_random_sampling: + # If the request is greedy and we are not modifying greedy probs, + # we want to make sure sampled_modified_probs tensor is the same as + # the probs tensor. + assert torch.allclose(sampled_modified_probs[i][0], + probs[i][sampled_tokens[i]]) + + if save_logprobs: + assert sampled_logprobs.shape == (bs, max_best_of) + for i in range(bs): + for best_of in range(max_best_of): + assert torch.all(sampled_logprobs[i] == logprobs[i][ + sampled_tokens[i, best_of]]) + else: + assert sampled_logprobs is None + + +@pytest.mark.parametrize("random_sampling", [True, False, "mixed"]) +@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5]) +@pytest.mark.parametrize("modify_greedy_probs", [True, False]) +@pytest.mark.parametrize("seed", [1337]) +@pytest.mark.parametrize("vocab_size", + [SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE]) +def test_sample_prompt_logprobs(random_sampling, max_best_of, + modify_greedy_probs, seed, vocab_size): + set_random_seed(seed) + prompt_sizes = [16, 32, 64, 128] * 2 + samples = 8 + bs = samples + sum(prompt_sizes) + probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda") + for i in range(bs): + probs[i, i * (vocab_size // bs)] = 1.0 + logprobs = torch.rand_like(probs) + sample_indices = torch.tensor(prompt_sizes, + dtype=torch.long, + device="cuda").cumsum_(0) + n_splits = get_num_triton_sampler_splits(probs.shape[1]) + if random_sampling == "mixed": + random_sampling_mask = torch.rand( + (n_splits, samples), device="cuda") < 0.5 + elif random_sampling: + random_sampling_mask = torch.ones((n_splits, samples), + dtype=torch.bool, + device="cuda") + else: + random_sampling_mask = torch.zeros((n_splits, samples), + dtype=torch.bool, + device="cuda") + + seeds = torch.randint(1, + torch.iinfo(torch.long).max, (n_splits, samples), + device="cuda").mul_(random_sampling_mask) + sampled_tokens, sampled_logprobs, _ = sample( + probs=probs, + logprobs=logprobs, + sample_indices=sample_indices, + seeds=seeds, + max_best_of=max_best_of, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=True) + assert sampled_tokens.shape == (samples, max_best_of) + assert sampled_logprobs.shape == (samples, max_best_of) + for i, t in enumerate(sample_indices): + assert torch.all(sampled_tokens[i] == t * (vocab_size // bs)) + for best_of in range(max_best_of): + assert torch.all(sampled_logprobs[i] == logprobs[sample_indices[i]] + [sampled_tokens[i, best_of]]) + + +@pytest.mark.parametrize("seed", list(range(16))) +def test_get_sequence_seeds(seed): + """Ensure that we get a different child seed from base + seed + extra entropy""" + starting_seed = seed + seq_seed = None + extra_entropy = 1 + for i in range(512): + new_seq_seed = SamplingTensors._get_sequence_seeds(starting_seed, + i, + seeds_to_generate=1, + is_greedy=False)[0] + new_seq_seed_extra_entropy = SamplingTensors._get_sequence_seeds( + starting_seed, + i, + extra_entropy, + seeds_to_generate=1, + is_greedy=False)[0] + assert new_seq_seed_extra_entropy != new_seq_seed + assert seq_seed != new_seq_seed + seq_seed = new_seq_seed diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 1bc8703d1a8e0..b0c6e1c09eebc 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -302,11 +302,11 @@ def test_sampler_logits_processors(seed: int, device: str): batch_size = random.randint(1, 256) input_tensor, _, sampler, model_runner = _prepare_test(batch_size) - # This sample logits processor gives infinite score to the i-th token, + # This sample logits processor gives maximum score to the i-th token, # where i is the length of the input sequence. # We therefore expect the output token sequence to be [0, 1, 2, ...] def pick_ith(token_ids, logits): - logits[len(token_ids)] = float("inf") + logits[len(token_ids)] = torch.finfo(logits.dtype).max return logits seq_group_metadata_list = [] @@ -385,7 +385,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): sample_probs = None - def mock_sample(probs, logprobs, sampling_metadata): + def mock_sample(probs, *args, **kwargs): nonlocal sample_probs sample_probs = probs return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs] diff --git a/vllm/model_executor/layers/ops/__init__.py b/vllm/model_executor/layers/ops/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/model_executor/layers/ops/rand.py b/vllm/model_executor/layers/ops/rand.py new file mode 100644 index 0000000000000..5b4b7a153351f --- /dev/null +++ b/vllm/model_executor/layers/ops/rand.py @@ -0,0 +1,157 @@ +import torch +import triton +import triton.language as tl + +from typing import Optional, Union + + +def seeded_uniform( + *size, + seeds: torch.Tensor, + out: Optional[torch.Tensor] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[torch.device, str]] = None, + pin_memory: Optional[bool] = False, +) -> torch.Tensor: + """Similar to torch.rand, but allows for seeds to be set per row. + + seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d. + If it is 3d, the additional seeds needed will be derived automatically + in a deterministic fashion: + [ + row 0: [columns_with_seed_0], [columns_with_seed0^1], ... + ] + """ + n_dims = len(size) + + if n_dims > 3: + raise ValueError("seeded_uniform only supports up to 3D tensors") + + if out is None: + out = torch.empty(*size, + dtype=dtype, + device=device, + pin_memory=pin_memory) + elif out.shape != size: + raise ValueError("shape of out and size must be the same") + + if n_dims == 3: + n_rows, n_3d, n_cols = out.shape + stride_row = out.stride(0) + stride_3d = out.stride(1) + elif n_dims == 2: + n_rows, n_cols = out.shape + n_3d = 1 + stride_row = out.stride(0) + stride_3d = 1 + else: + n_cols = out.shape[0] + n_rows = 1 + n_3d = 1 + stride_row = 1 + stride_3d = 1 + + if seeds.ndim != 1: + raise ValueError("seeds must be a 1D tensor") + + if seeds.numel() != n_rows: + raise ValueError( + "seeds must have the same number of elements as out has rows") + + # The philox PRNG Triton uses generates 4 random numbers at once. + # Therefore, the most efficient use of it is to divide the + # block size by 4, and then save the generated random numbers to + # each of the 4 slices of the tensor. + full_block_size = triton.next_power_of_2(n_cols) + philox_block_size = max(full_block_size // 4, 1) + n_slices = full_block_size // philox_block_size + num_warps = 4 + # Manual tuning. This seems to give best performance on A100 for + # simple kernels like this. + if philox_block_size >= 8192: + num_warps = 32 + elif philox_block_size >= 4096: + num_warps = 16 + elif philox_block_size >= 2048: + num_warps = 8 + + _seeded_uniform_triton[(n_rows, n_3d)]( + out, + seeds, + stride_row, + stride_3d, + seeds.stride(0), + n_rows, + n_3d, + n_cols, + n_slices=n_slices, + num_warps=num_warps, + block_size=philox_block_size, + ) + return out + + +@triton.jit +def _seeded_uniform_triton( + out_ptr: torch.Tensor, + seed_ptr: torch.Tensor, + out_row_stride: int, + out_3d_stride: int, + seed_row_stride: int, + n_rows: int, + n_3d: int, + n_cols: int, + n_slices: tl.constexpr, + block_size: tl.constexpr, +): + """ + Generate a random float32 number in [0, 1) for each element in the output + tensor. The random numbers in a row generated using the seed for that row. + + Args: + out_ptr: The output tensor. + seed_ptr: The per-row seeds to use for random number generation. + out_row_stride: The stride between rows of the output tensor. + out_3d_stride: The stride between 3D slices of the output tensor. + seed_row_stride: The stride between rows of the seed tensor. + n_rows: The number of rows in the output tensor. + n_3d: The size of second dimension of the output tensor, + if output tensor is 3D. + n_cols: The number of columns in the output tensor. + n_slices: The number of philox outputs to use. + """ + tl.static_assert(n_slices > 0 and n_slices <= 4, "0 < n_slices <= 4") + + # Get the row index. + row_idx = tl.program_id(axis=0) + three_d_idx = tl.program_id(axis=1) + + philox_offsets = tl.arange(0, block_size) + # Get the seed for the current element. + seed = tl.load(seed_ptr + row_idx * seed_row_stride) + if three_d_idx > 0: + seed ^= three_d_idx + # Generate random numbers in [0, 1). + out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets) + + output_row_start_ptr = (out_ptr + row_idx * out_row_stride + + three_d_idx * out_3d_stride) + out1_offsets = philox_offsets + tl.store(output_row_start_ptr + out1_offsets, + out1, + mask=out1_offsets < n_cols) + if n_slices > 1: + out2_offsets = tl.arange(block_size, block_size * 2) + tl.store(output_row_start_ptr + out2_offsets, + out2, + mask=out2_offsets < n_cols) + if n_slices > 2: + out3_offsets = tl.arange(block_size * 2, block_size * 3) + tl.store(output_row_start_ptr + out3_offsets, + out3, + mask=out3_offsets < n_cols) + if n_slices > 3: + out4_offsets = tl.arange(block_size * 3, block_size * 4) + tl.store(output_row_start_ptr + out4_offsets, + out4, + mask=out4_offsets < n_cols) diff --git a/vllm/model_executor/layers/ops/sample.py b/vllm/model_executor/layers/ops/sample.py new file mode 100644 index 0000000000000..0077317282204 --- /dev/null +++ b/vllm/model_executor/layers/ops/sample.py @@ -0,0 +1,405 @@ +import math +from typing import Tuple, Optional + +import torch +import triton +import triton.language as tl + +from vllm.model_executor.layers.ops.rand import seeded_uniform + +_EPS = 1e-6 + +# This is a hardcoded limit in Triton (max block size). +MAX_TRITON_N_COLS = 131072 + + +def get_num_triton_sampler_splits(n_cols: int) -> int: + """Get the number of splits to use for Triton sampling. + + Triton has a limit on the number of columns it can handle, so we need to + split the tensor and call the kernel multiple times if it's too large. + """ + return math.ceil(n_cols / MAX_TRITON_N_COLS) + + +def _multi_split_sample( + probs: torch.Tensor, + seeds: torch.Tensor, + n_splits: int, + sampled_tokens_size: Tuple[int, int], + sampled_logprobs_size: Tuple[int, int], + sample_indices: torch.Tensor, + *, + logprobs: Optional[torch.Tensor] = None, + modify_greedy_probs: bool = False, + save_logprobs: bool = False, +): + """Sample tokens where vocab size is split into multiple parts + (too large for Triton otherwise).""" + assert seeds.ndim == 2 and seeds.shape[0] == n_splits + split_probs = probs.tensor_split(n_splits, 1) + split_logprobs = logprobs.tensor_split(n_splits, 1) + sampled_tokens_tmp = [ + torch.empty(sampled_tokens_size, dtype=torch.long, device=probs.device) + for _ in range(n_splits) + ] + sampled_logprobs_tmp = [ + torch.empty(sampled_logprobs_size, + dtype=probs.dtype, + device=probs.device) for _ in range(n_splits) + ] + # We are purposefuly using sampled_tokens_size as we need to always + # save modified probs in this case. + sampled_modified_probs_tmp = [ + torch.empty(sampled_tokens_size, + dtype=probs.dtype, + device=probs.device) for _ in range(n_splits) + ] + for i in range(n_splits): + n_samples = sample_indices.shape[0] + n_cols = split_probs[i].shape[1] + n_best = sampled_tokens_tmp[i].shape[1] + uniform_noise = seeded_uniform(n_samples, + n_best, + n_cols, + seeds=seeds[i].flatten(), + device=split_probs[i].device, + dtype=split_probs[i].dtype) + # TODO(yard1): See if we can remove the contiguous() calls. + # Will need kernel support. + _sample( + split_probs[i].contiguous(), + split_logprobs[i].contiguous(), + sample_indices, + sampled_tokens_tmp[i], + sampled_logprobs_tmp[i], + sampled_modified_probs_tmp[i], + seeds[i], + uniform_noise, + modify_greedy_probs=False, + save_logprobs=save_logprobs, + save_modified_probs=True, + ) + if i > 0: + # Add offset to sampled tokens + sampled_tokens_tmp[i].add_(i * split_probs[i - 1].shape[1]) + sampled_tokens = torch.stack(sampled_tokens_tmp) + sampled_modified_probs = torch.stack(sampled_modified_probs_tmp) + # Reduce the results from the splits. + sampled_modified_probs, indices = torch.max(sampled_modified_probs, + dim=0, + keepdim=True) + sampled_tokens = sampled_tokens.gather(0, indices).squeeze(0) + if save_logprobs: + sampled_logprobs = torch.stack(sampled_logprobs_tmp) + sampled_logprobs = sampled_logprobs.gather(0, indices).squeeze(0) + else: + sampled_logprobs = None + sampled_modified_probs = sampled_modified_probs.squeeze(0) + + if modify_greedy_probs: + # We need to modify the greedy probs for the sampled tokens. + # We can't do this in the kernel as we need to know the + # sampled tokens. + probs.fill_(0.0) + probs.scatter_(1, sampled_tokens, 1.0) + + return (sampled_tokens, sampled_logprobs, sampled_modified_probs) + + +def sample( + probs: torch.Tensor, + seeds: torch.Tensor, + *, + max_best_of: int = 1, + sample_indices: Optional[torch.Tensor] = None, + logprobs: Optional[torch.Tensor] = None, + modify_greedy_probs: bool = False, + save_logprobs: bool = False, + _save_modified_probs: bool = False, # pylint: disable=invalid-name +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """Sample tokens from probs. with per-sequence seeds. + + Can sample from a subset of sequences through sample_indices. + + Args: + probs: Probabilities to sample from. + shape = [batch_size, vocab_size] + seeds: Per-sequence seed values. + shape = [n, math.ceil(vocab_size / MAX_TRITON_N_COLS)] + max_best_of: Number of samples to generate per sequence. + Sequence seed will be incremented by 1 each time. + sample_indices: Indices of sequences to sample from. + If not provided, will sample from all sequences. + shape = [n] + logprobs: Log-probabilities of the sampled tokens. + Only used for saving the logprobs if save_logprobs is True. + shape = [batch_size, vocab_size] + modify_greedy_probs: Whether to modify the greedy probabilities + for speculative sampling (sampled token = 1.0, + everything else = 0.0). + save_logprobs: Whether to save the log-probabilities of the + sampled tokens to a tensor. + _save_modified_probs: Whether to save the modified probabilities + (including gumbel noise) of the sampled tokens to a tensor. + DOES NOT include the modification done by modify_greedy_probs + (because we want to use the unmodified probs to pick the best + split in case of multi-split sampling). + This is exposed only for testing. + + Returns: + sampled_tokens: shape = [n, max_best_of] + sampled_logprobs: shape = [n, max_best_of] if save_logprobs else None + sampled_modified_probs: shape = [n, max_best_of] + if save_modified_probs else None + """ + if sample_indices is None: + sample_indices = torch.arange(0, probs.shape[0], device=probs.device) + + sampled_tokens_size = (sample_indices.size(0), max_best_of) + if save_logprobs: + if logprobs is None: + raise ValueError( + "logprobs tensor must be provided if save_logprobs is True") + sampled_logprobs_size = sampled_tokens_size + else: + # Empty tensors to invoke the kernel + sampled_logprobs_size = (0, 0) + logprobs = probs + + if _save_modified_probs: + sampled_modified_probs_size = sampled_tokens_size + else: + # Empty tensors to invoke the kernel + sampled_modified_probs_size = (0, 0) + + # If the number of columns in probs is too large for Triton to handle, + # we split the tensor and sample from each split separately, and then + # do an argmax+gather to combine the results. + n_splits = get_num_triton_sampler_splits(probs.shape[1]) + if n_splits > 1: + (sampled_tokens, sampled_logprobs, + sampled_modified_probs) = _multi_split_sample( + probs, + seeds, + n_splits, + sampled_tokens_size, + sampled_logprobs_size, + sample_indices, + logprobs=logprobs, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=save_logprobs) + else: + sampled_tokens = torch.empty(sampled_tokens_size, + dtype=torch.long, + device=probs.device) + sampled_logprobs = torch.empty(sampled_logprobs_size, + dtype=probs.dtype, + device=probs.device) + sampled_modified_probs = torch.empty(sampled_modified_probs_size, + dtype=probs.dtype, + device=probs.device) + n_samples = sample_indices.shape[0] + n_cols = probs.shape[1] + uniform_noise = seeded_uniform(n_samples, + max_best_of, + n_cols, + seeds=seeds.flatten(), + device=probs.device, + dtype=probs.dtype) + + _sample( + probs, + logprobs, + sample_indices, + sampled_tokens, + sampled_logprobs, + sampled_modified_probs, + seeds, + uniform_noise, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=save_logprobs, + save_modified_probs=_save_modified_probs, + ) + return (sampled_tokens, sampled_logprobs if save_logprobs else None, + sampled_modified_probs if _save_modified_probs else None) + + +def _sample(probs: torch.Tensor, + logprobs: torch.Tensor, + sample_indices: torch.Tensor, + output_samples: torch.Tensor, + output_logprobs: torch.Tensor, + output_modified_probs: torch.Tensor, + seeds: torch.Tensor, + uniform_noise: torch.Tensor, + *, + modify_greedy_probs: bool = False, + save_logprobs: bool = True, + save_modified_probs: bool = False) -> torch.Tensor: + """Sample tokens from probs. + + Args: + probs [batch_size, vocab_size]: probs to sample from. + logprobs [batch_size, vocab_size]: logprobs (used when + save_logprobsis True). + sample_indices [n]: Indices of the samples to use for each row of probs. + output_samples [n, n_best]: Output tensor to store samples in. + output_logprobs [n, n_best]: Output tensor to store logprobs in. + output_modified_probs [n, n_best]: Output tensor to store + probs of chosen tokens in (modified with noise). + seeds [n]: Seeds to use for sampling. If the seed is 0, we use + greedy sampling. Note this is ONLY used for determining + whether to use random sampling or not. The actual random + noise should be passed as uniform_noise. + uniform_noise [batch_size, n_best, vocab_size]: Uniform + noise to use for random sampling (will be converted + to exponential gumbel noise by the kernel). + modify_greedy_probs: If True, we modify the probs tensor in-place + to encode the sampling method used for each row. This is used + in speculative decoding. Only applies in greedy decoding. + save_logprobs: If True, we save the logprobs of the sampled tokens + in the output_logprobs tensor. + save_modified_probs: If True, we save the modified probs (with noise) + of the sampled tokens in the output_modified_probs tensor. + DOES NOT include the modification done by modify_greedy_probs + (because we want to use the unmodified probs to pick the best + split in case of multi-split sampling). + """ + n_samples = sample_indices.shape[0] + n_cols = probs.shape[1] + n_best = output_samples.shape[1] if len(output_samples.shape) > 1 else 1 + + # The block size is the smallest power of two greater than the number of + # columns in probs + block_size = triton.next_power_of_2(n_cols) + num_warps = 4 + # Manual tuning. This seems to give best performance on A100 for + # simple kernels like this. + if block_size >= 8192: + num_warps = 32 + elif block_size >= 4096: + num_warps = 16 + elif block_size >= 2048: + num_warps = 8 + + # Enqueue kernel. The 1D launch grid is simple: we have one kernel + # instance per row of the probs matrix + _sample_triton[(n_samples, n_best)]( + sample_indices, + output_samples, + output_logprobs, + output_modified_probs, + probs, + logprobs, + seeds, + uniform_noise, + output_samples.stride(0), + probs.stride(0), + uniform_noise.stride(0), + uniform_noise.stride(1) if n_best > 1 else 1, + n_samples, + n_cols, + n_best, + num_warps=num_warps, + block_size=block_size, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=save_logprobs, + save_modified_probs=save_modified_probs, + ) + return output_samples, output_logprobs, output_modified_probs + + +@triton.jit +def _uniform_to_exponential(uniform_noise): + """Convert uniform samples to exponential samples.""" + # tl.rand returns values in [0, 1), so we clamp lower bound + # to _EPS to avoid log(0) and thus division by 0 later + lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype) + uniform_noise = tl.maximum(uniform_noise, lb) + # Use the inversion method to turn uniform samples + # into exponential samples + exponential_noise = -tl.log(uniform_noise) + return exponential_noise + + +@triton.jit +def _sample_triton( + sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor, + output_logprobs_ptr: torch.Tensor, + output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor, + logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor, + uniform_noise_ptr: torch.Tensor, output_row_stride: int, + probs_row_stride: int, uniform_noise_row_stride: int, + uniform_noise_best_stride: int, n_samples: int, n_cols: int, + n_best: int, block_size: tl.constexpr, + modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr, + save_modified_probs: tl.constexpr): + # The rows are independent, so we parallelize across those + sample_idx = tl.program_id(0) + best_idx = tl.program_id(1) + + # Load the row index from DRAM + row_idx = tl.load(sample_indices_ptr + sample_idx) + seed = tl.load(seeds_ptr + sample_idx) + uses_random_sampling = seed != 0 + + # The stride represents how much we need to increase the + # pointer to advance 1 row + row_start_ptr = probs_ptr + row_idx * probs_row_stride + + # The block size is the next power of two greater than n_cols, + # so we can fit each row in a single block + col_offsets = tl.arange(0, block_size) + + # Load the row into SRAM, using a mask since block_size may be > than n_cols + row = tl.load(row_start_ptr + col_offsets, + mask=col_offsets < n_cols, + other=float("-inf")) + + if uses_random_sampling: + uniform_noise_start_ptr = (uniform_noise_ptr + + sample_idx * uniform_noise_row_stride + + best_idx * uniform_noise_best_stride) + uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets, + mask=col_offsets < n_cols, + other=0.5) + exponential_noise = _uniform_to_exponential(uniform_noise) + row /= exponential_noise + + sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True) + # clamp sampled token to n_cols - 1 + # this should not be necessary, but we do it + # just in case + if sampled_token >= n_cols: + sampled_token = n_cols - 1 + # Write back output to DRAM + output_row_start_ptr = (output_ptr + sample_idx * output_row_stride + + best_idx) + tl.store(output_row_start_ptr, sampled_token) + + if modify_greedy_probs: # noqa + if not uses_random_sampling: + # Set the probability of the sampled token to 1, all other + # tokens to zero. This is used in speculative decoding where + # the sampling method must be encoded within the sampled + # probability distributions. + row = tl.where(col_offsets == sampled_token, 1.0, 0.0) + tl.store(row_start_ptr + col_offsets, + row, + mask=col_offsets < n_cols) + + if save_modified_probs: + output_row_start_ptr = (output_modified_probs_ptr + + sample_idx * output_row_stride + best_idx) + tl.store(output_row_start_ptr, sampled_value) + + if save_logprobs: + # Load the row into SRAM, using a mask since block_size + # may be > than n_cols + sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride + + sampled_token) + # Write back output to DRAM + output_row_start_ptr = (output_logprobs_ptr + + sample_idx * output_row_stride + best_idx) + tl.store(output_row_start_ptr, sampled_logprob) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 4377b845df628..1fab1e734e1d7 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -12,6 +12,7 @@ from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs, SamplerOutput, SequenceData, SequenceGroupOutput, SequenceOutput) +from vllm.model_executor.layers.ops.sample import (sample as sample_triton) from vllm.utils import is_neuron @@ -114,7 +115,8 @@ def forward( logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) # Sample the next tokens. - sample_results = _sample(probs, logprobs, sampling_metadata) + sample_results = _sample(probs, logprobs, sampling_metadata, + sampling_tensors) # Get the logprobs query results. prompt_logprobs, sample_logprobs = _get_logprobs( logprobs, sampling_metadata, sample_results) @@ -375,7 +377,7 @@ def _multinomial( return probs.div_(q).argmax(dim=1).view(-1, num_samples) -def _sample( +def _sample_with_torch( probs: torch.Tensor, logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, @@ -394,7 +396,7 @@ def _sample( # Counterintiutively, having two loops here is actually faster. # The first loop can run without waiting on GPU<->CPU sync. for sampling_type in SamplingType: - sample_indices = categorized_sample_indices[sampling_type] + sample_indices = categorized_sample_indices[sampling_type][:, 0] num_tokens = len(sample_indices) if num_tokens == 0: continue @@ -407,17 +409,19 @@ def _sample( greedy_samples = torch.argmax(logprobs[sample_indices.long()], dim=-1) elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): - max_best_of = 1 + max_best_of_in_batch = 1 for seq_group, is_prompt in zip(seq_groups, is_prompts): if is_prompt: _, sampling_params = seq_group - max_best_of = max(max_best_of, sampling_params.best_of) + max_best_of_in_batch = max(max_best_of_in_batch, + sampling_params.best_of) seeded_args = {} if sampling_type == SamplingType.RANDOM else { "seq_groups": seq_groups, "generators": sampling_metadata.generators, } multinomial_samples[sampling_type] = _multinomial( - probs[sample_indices.long()], max_best_of, **seeded_args) + probs[sample_indices.long()], max_best_of_in_batch, + **seeded_args) elif sampling_type == SamplingType.BEAM: beam_search_logprobs = logprobs[sample_indices] else: @@ -448,6 +452,99 @@ def _sample( return sample_results +def _sample_with_triton_kernel( + probs: torch.Tensor, + logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, + sampling_tensors: SamplingTensors, +) -> List[Tuple[List[int], List[int]]]: + categorized_seq_group_ids = {t: [] for t in SamplingType} + categorized_sample_indices = sampling_metadata.categorized_sample_indices + for i, seq_group in enumerate(sampling_metadata.seq_groups): + _, sampling_params = seq_group + sampling_type = sampling_params.sampling_type + categorized_seq_group_ids[sampling_type].append(i) + + sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} + sample_metadata = {} + max_best_of_in_batch = 1 + + # Counterintiutively, having two loops here is actually faster. + # The first loop can run without waiting on GPU<->CPU sync. + for sampling_type in SamplingType: + sample_indices = categorized_sample_indices[sampling_type][:, 0] + sampled_token_indices = categorized_sample_indices[sampling_type][:, 1] + num_tokens = len(sample_indices) + if num_tokens == 0: + continue + seq_group_ids = categorized_seq_group_ids[sampling_type] + seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids] + is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] + sample_metadata[sampling_type] = (seq_group_ids, seq_groups, + is_prompts, sample_indices, + sampled_token_indices) + if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM, + SamplingType.RANDOM_SEED): + for seq_group, is_prompt in zip(seq_groups, is_prompts): + if is_prompt: + _, sampling_params = seq_group + max_best_of_in_batch = max(max_best_of_in_batch, + sampling_params.best_of) + elif sampling_type == SamplingType.BEAM: + beam_search_logprobs = logprobs[sample_indices] + else: + raise ValueError(f"Unsupported sampling type: {sampling_type}") + + sampled_tokens, _, _ = sample_triton( + probs=probs, + seeds=sampling_tensors.sampling_seeds, + max_best_of=max_best_of_in_batch, + sample_indices=sampling_tensors.sample_indices, + logprobs=logprobs, + # don't save logprobs because we have logic for that below + # TODO: use this instead of the CPU-based logic below + save_logprobs=False, + ) + + # GPU<->CPU sync happens in the loop below. + + for sampling_type in SamplingType: + if sampling_type not in sample_metadata: + continue + (seq_group_ids, seq_groups, is_prompts, sample_indices, + sampled_token_indices) = sample_metadata[sampling_type] + if sampling_type == SamplingType.GREEDY: + sample_results = _greedy_sample( + seq_groups, sampled_tokens[sampled_token_indices][:, 0]) + elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): + sample_results = _random_sample( + seq_groups, is_prompts, sampled_tokens[sampled_token_indices]) + elif sampling_type == SamplingType.BEAM: + sample_results = _beam_search_sample(seq_groups, is_prompts, + sampling_metadata.seq_data, + beam_search_logprobs) + sample_results_dict.update(zip(seq_group_ids, sample_results)) + + sample_results = [ + sample_results_dict[i] + for i in range(len(sampling_metadata.seq_groups)) + ] + return sample_results + + +def _sample( + probs: torch.Tensor, + logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, + sampling_tensors: SamplingTensors, +) -> List[Tuple[List[int], List[int]]]: + return _sample_with_torch(probs, logprobs, sampling_metadata) + + # TODO: Enable once Triton kernel & associated code is faster. + # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata, + # sampling_tensors) + + def _get_logprobs( logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index b23f0170a6ca5..7d08feb3fee1c 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -2,12 +2,16 @@ from typing import Dict, List, Optional, Tuple import torch +import random from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SequenceData from vllm.utils import in_wsl, is_neuron +from vllm.model_executor.layers.ops.sample import ( + get_num_triton_sampler_splits) _SAMPLING_EPS = 1e-5 +_SEED_0_REPLACEMENT = 3403598558 class SamplingMetadata: @@ -67,14 +71,28 @@ class SamplingTensors: presence_penalties: torch.Tensor frequency_penalties: torch.Tensor repetition_penalties: torch.Tensor + sampling_seeds: torch.Tensor + sample_indices: torch.Tensor + extra_seeds: Optional[torch.Tensor] prompt_tokens: torch.Tensor output_tokens: torch.Tensor @classmethod def from_sampling_metadata( - cls, sampling_metadata: "SamplingMetadata", vocab_size: int, - device: torch.device, - dtype: torch.dtype) -> Tuple["SamplingTensors", bool, bool, bool]: + cls, + sampling_metadata: "SamplingMetadata", + vocab_size: int, + device: torch.device, + dtype: torch.dtype, + *, + extra_seeds_to_generate: int = 0, + extra_entropy: Optional[Tuple[int, ...]] = None + ) -> Tuple["SamplingTensors", bool, bool, bool]: + """ + extra_seeds_to_generate: extra seeds to generate using the + user-defined seed for each sequence. + extra_entropy: extra entropy to use when generating seeds. + """ prompt_tokens: List[List[int]] = [] output_tokens: List[List[int]] = [] top_ks: List[int] = [] @@ -84,9 +102,18 @@ def from_sampling_metadata( presence_penalties: List[float] = [] frequency_penalties: List[float] = [] repetition_penalties: List[float] = [] + sampling_seeds: List[int] = [] + sample_indices: List[int] = [] + prompt_best_of: List[int] = [] do_penalties = False do_top_p_top_k = False do_min_p = False + + # We need one base seed per Triton slice. + seeds_to_generate = (extra_seeds_to_generate + + get_num_triton_sampler_splits(vocab_size)) + + sample_indices_start_idx = 0 for i, seq_group in enumerate(sampling_metadata.seq_groups): seq_ids, sampling_params = seq_group temperature = sampling_params.temperature @@ -95,6 +122,10 @@ def from_sampling_metadata( r = sampling_params.repetition_penalty top_p = sampling_params.top_p min_p = sampling_params.min_p + seed = sampling_params.seed + + is_greedy = sampling_params.sampling_type == SamplingType.GREEDY + # k should not be greater than the vocab size. top_k = min(sampling_params.top_k, vocab_size) top_k = vocab_size if top_k == -1 else top_k @@ -112,6 +143,7 @@ def from_sampling_metadata( or abs(f) >= _SAMPLING_EPS or abs(r - 1.0) >= _SAMPLING_EPS): do_penalties = True + if (i < sampling_metadata.num_prompts and sampling_params.prompt_logprobs is not None): # For tokens in the prompt that we only need to get @@ -138,10 +170,34 @@ def from_sampling_metadata( frequency_penalties += [f] * len(seq_ids) repetition_penalties += [r] * len(seq_ids) + is_prompt = i < sampling_metadata.num_prompts + if is_prompt: + prompt_best_of.append(sampling_params.best_of) + prompt_len = sampling_metadata.prompt_lens[i] + + if sampling_params.prompt_logprobs is not None: + # NOTE: the sampling position is the last token + # in the prompt + sample_indices_start_idx += prompt_len - 1 + for seq_id in seq_ids: + seq_data = sampling_metadata.seq_data[seq_id] + extra_entropy = extra_entropy or () + seq_seeds = cls._get_sequence_seeds( + seed, + seq_data.get_len(), + *extra_entropy, + seq_id, + seeds_to_generate=seeds_to_generate, + is_greedy=is_greedy) + sampling_seeds.append(seq_seeds) + sample_indices.append(sample_indices_start_idx) + sample_indices_start_idx += 1 + sampling_tensors = SamplingTensors.from_lists( temperatures, top_ps, top_ks, min_ps, presence_penalties, - frequency_penalties, repetition_penalties, prompt_tokens, - output_tokens, vocab_size, device, dtype) + frequency_penalties, repetition_penalties, sampling_seeds, + sample_indices, prompt_tokens, output_tokens, vocab_size, + extra_seeds_to_generate, device, dtype) return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) @classmethod @@ -150,9 +206,10 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], presence_penalties: List[float], frequency_penalties: List[float], repetition_penalties: List[float], + sampling_seeds: List[int], sample_indices: List[int], prompt_tokens: List[List[int]], output_tokens: List[List[int]], vocab_size: int, - device: torch.device, + extra_seeds_to_generate: int, device: torch.device, dtype: torch.dtype) -> "SamplingTensors": # Note that the performance will be very bad without # pinned memory. @@ -210,6 +267,12 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], dtype=torch.int, pin_memory=pin_memory, ) + sample_indices_t = torch.tensor( + sample_indices, + device="cpu", + dtype=torch.long, + pin_memory=pin_memory, + ) prompt_tensor = torch.tensor( prompt_padded_tokens, device="cpu", @@ -222,8 +285,28 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], dtype=torch.long, pin_memory=pin_memory, ) + # need to transpose and make contiguous to + # copy the tensor correctly. + # [batch_size, n_seeds] -> [n_seeds, batch_size] + sampling_seeds_t = torch.tensor( + sampling_seeds, + device="cpu", + dtype=torch.long, + pin_memory=pin_memory, + ).T.contiguous() + # Because the memory is pinned, we can do non-blocking # transfer to device. + + # How many seeds the sample operation itself will need. + num_base_seeds = sampling_seeds_t.shape[0] - extra_seeds_to_generate + sampling_seeds_gpu = sampling_seeds_t.to(device=device, + non_blocking=True) + extra_seeds_gpu = sampling_seeds_gpu[num_base_seeds:] + if not extra_seeds_gpu.numel(): + extra_seeds_gpu = None + sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds] + return cls( temperatures=temperatures_t.to(device=device, non_blocking=True), top_ps=top_ps_t.to(device=device, non_blocking=True), @@ -237,4 +320,38 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], non_blocking=True), prompt_tokens=prompt_tensor.to(device=device, non_blocking=True), output_tokens=output_tensor.to(device=device, non_blocking=True), + sampling_seeds=sampling_seeds_gpu, + sample_indices=sample_indices_t.to(device=device, + non_blocking=True), + extra_seeds=extra_seeds_gpu, ) + + @staticmethod + def _get_sequence_seeds( + seed: int, + *extra_entropy: int, + seeds_to_generate: int, + is_greedy: bool, + ): + """Get `seeds_to_generate` child seeds from `seed` and extra entropy.""" + if not is_greedy: + if seed is None: + randint_fn = random.randint + else: + generator = random.Random(str((seed, ) + extra_entropy)) + randint_fn = generator.randint + lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max + # If the user/random sets seed = 0 but request should + # have sampling, we need to change it to something + # else. We use a constant in that case. + # This way we don't need to create and load a bool + # matrix in the sampling kernel, which reduces CPU + # overhead and latency. + seq_seeds = [ + randint_fn(lo, hi) or _SEED_0_REPLACEMENT + for _ in range(seeds_to_generate) + ] + else: + # For the kernel, seed == 0 means greedy decoding. + seq_seeds = [0] * seeds_to_generate + return seq_seeds diff --git a/vllm/sequence.py b/vllm/sequence.py index 4a002edaf580f..ff96dd306791c 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -242,6 +242,9 @@ def get_output_len(self) -> int: def get_token_ids(self) -> List[int]: return self.data.get_token_ids() + def get_prompt_token_ids(self) -> List[int]: + return self.data.get_prompt_token_ids() + def get_last_token_id(self) -> int: return self.data.get_last_token_id() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 27213887ed265..7e25311fa2268 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -408,6 +408,7 @@ def _prepare_sample( selected_token_start_idx = 0 categorized_sample_indices = {t: [] for t in SamplingType} categorized_sample_indices_start_idx = 0 + categorized_sampled_token_indices_start_idx = 0 pin_memory = not self.in_wsl and not self.device_config.is_neuron max_subquery_len = max(subquery_lens) if subquery_lens else 1 @@ -425,9 +426,12 @@ def _prepare_sample( categorized_sample_indices_start_idx += subquery_len - 1 categorized_sample_indices[ - sampling_params.sampling_type].append( - categorized_sample_indices_start_idx) + sampling_params.sampling_type].append([ + categorized_sample_indices_start_idx, + categorized_sampled_token_indices_start_idx + ]) categorized_sample_indices_start_idx += 1 + categorized_sampled_token_indices_start_idx += 1 if sampling_params.prompt_logprobs is not None: selected_token_indices.extend( @@ -449,9 +453,17 @@ def _prepare_sample( categorized_sample_indices[ sampling_params.sampling_type].extend( - range(categorized_sample_indices_start_idx, - categorized_sample_indices_start_idx + num_seqs)) + zip( + range( + categorized_sample_indices_start_idx, + categorized_sample_indices_start_idx + + num_seqs), + range( + categorized_sampled_token_indices_start_idx, + categorized_sampled_token_indices_start_idx + + num_seqs))) categorized_sample_indices_start_idx += num_seqs + categorized_sampled_token_indices_start_idx += num_seqs if sampling_params.seed is not None: generators.append(seq_group_metadata.state.generator) @@ -459,12 +471,14 @@ def _prepare_sample( selected_token_indices = _async_h2d(selected_token_indices, dtype=torch.long, target_device=self.device, - pin_memory=pin_memory) + pin_memory=not self.in_wsl) + categorized_sample_indices = { - t: _async_h2d(seq_ids, - dtype=torch.int, - target_device=self.device, - pin_memory=pin_memory) + t: _maybe_expand_dim( + _async_h2d(seq_ids, + dtype=torch.int, + target_device=self.device, + pin_memory=pin_memory), 2, 2) for t, seq_ids in categorized_sample_indices.items() } @@ -884,3 +898,11 @@ def _async_h2d( ) -> torch.Tensor: t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu") return t.to(device=target_device, non_blocking=True) + + +def _maybe_expand_dim(tensor: torch.Tensor, + target_dims: int, + size: int = 1) -> torch.Tensor: + if tensor.ndim < target_dims: + tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim))) + return tensor From 6e435de766c7749b214b637ac58570a221006c95 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Thu, 21 Mar 2024 06:46:05 +0900 Subject: [PATCH 03/10] [1/n][Chunked Prefill] Refactor input query shapes (#3236) --- .buildkite/test-pipeline.yaml | 4 +- .../test_basic_correctness.py | 4 +- tests/core/test_scheduler.py | 18 +- tests/lora/test_worker.py | 2 +- tests/spec_decode/test_multi_step_worker.py | 4 +- tests/worker/test_model_runner.py | 161 +++++++++++- vllm/config.py | 3 - vllm/core/scheduler.py | 13 +- vllm/engine/arg_utils.py | 8 +- vllm/engine/llm_engine.py | 1 - vllm/model_executor/input_metadata.py | 82 +++++- vllm/model_executor/layers/activation.py | 4 +- .../layers/attention/attention.py | 3 +- .../layers/attention/backends/flash_attn.py | 46 +++- .../layers/attention/backends/xformers.py | 232 ++++++++++------- .../layers/attention/ops/paged_attn.py | 9 +- vllm/model_executor/layers/sampler.py | 1 - vllm/worker/model_runner.py | 239 +++++++++++------- 18 files changed, 575 insertions(+), 259 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 6ae351130f203..17f4c33670821 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -47,7 +47,7 @@ steps: - pytest -v -s prefix_caching - label: Samplers Test - command: pytest -v -s samplers --forked + command: pytest -v -s samplers - label: Worker Test command: pytest -v -s worker @@ -56,7 +56,7 @@ steps: command: pytest -v -s spec_decode - label: LoRA Test %N - command: pytest -v -s lora --forked --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 4 - label: Metrics Test diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index fe67e0f2f4808..da0176306b4ee 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -13,6 +13,7 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("enforce_eager", [False, True]) def test_models( hf_runner, vllm_runner, @@ -20,12 +21,13 @@ def test_models( model: str, dtype: str, max_tokens: int, + enforce_eager: bool, ) -> None: hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner(model, dtype=dtype) + vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index ebfeb8ba04812..397101fa86104 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -10,7 +10,7 @@ def test_scheduler_add_seq_group(): block_size = 4 - scheduler_config = SchedulerConfig(100, 64, 1, 256) + scheduler_config = SchedulerConfig(100, 64, 1) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 4 cache_config.num_gpu_blocks = 4 @@ -26,7 +26,7 @@ def test_scheduler_add_seq_group(): def test_scheduler_abort_seq_group(): block_size = 4 - scheduler_config = SchedulerConfig(100, 64, 1, 256) + scheduler_config = SchedulerConfig(100, 64, 1) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 4 cache_config.num_gpu_blocks = 4 @@ -50,7 +50,7 @@ def test_scheduler_schedule_simple(): block_size = 4 num_seq_group = 4 max_model_len = 16 - scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len, 256) + scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -64,10 +64,10 @@ def test_scheduler_schedule_simple(): running.append(seq_group) # Schedule seq groups prompts. + num_tokens = block_size * num_seq_group seq_group_meta, out = scheduler.schedule() assert set(out.scheduled_seq_groups) == set(running) - assert out.num_batched_tokens == num_seq_group * seq_group.get_seqs( - )[0].get_len() + assert out.num_batched_tokens == num_tokens assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) assert len(seq_group_meta) == num_seq_group @@ -84,7 +84,7 @@ def test_scheduler_schedule_simple(): def test_scheduler_schedule_preempt_abort(): block_size = 4 max_model_len = 16 - scheduler_config = SchedulerConfig(64, 2, max_model_len, 256) + scheduler_config = SchedulerConfig(64, 2, max_model_len) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 2 cache_config.num_gpu_blocks = 2 @@ -99,7 +99,7 @@ def test_scheduler_schedule_preempt_abort(): # Schedule seq groups prompts. seq_group_meta, out = scheduler.schedule() assert out.scheduled_seq_groups == [seq_group_a, seq_group_b] - assert out.num_batched_tokens == seq_group_a.get_seqs()[0].get_len() * 2 + assert out.num_batched_tokens == block_size * 2 # seq_a and seq_b assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) assert len(seq_group_meta) == 2 @@ -124,7 +124,7 @@ def test_scheduler_schedule_preempt_abort(): scheduler.abort_seq_group("1") seq_group_meta, out = scheduler.schedule() assert out.scheduled_seq_groups == [seq_group_b] - assert out.num_batched_tokens == seq_group_b.get_seqs()[0].get_len() + assert out.num_batched_tokens == 5 # 4 prompt + 1 generation. assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) assert len(seq_group_meta) == 1 @@ -136,7 +136,7 @@ def test_scheduler_max_seqs(): num_seq_group = 4 max_seq_group = 2 max_model_len = 16 - scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len, 256) + scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 31a7c716afbf2..e4538de35169b 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -25,7 +25,7 @@ def test_worker_apply_lora(sql_lora_files): revision=None, ), parallel_config=ParallelConfig(1, 1, False), - scheduler_config=SchedulerConfig(32, 32, 32, 256), + scheduler_config=SchedulerConfig(32, 32, 32), device_config=DeviceConfig("cuda"), local_rank=0, rank=0, diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index 45b43ec59ee8f..5f788549d44d0 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -92,8 +92,8 @@ def test_same_output_for_single_step(): num_gpu_blocks, seed, ) - multi_step_worker.model_runner = worker.model_runner - multi_step_worker.cache_engine = worker.cache_engine + # multi_step_worker.model_runner = worker.model_runner + # multi_step_worker.cache_engine = worker.cache_engine num_steps = 1 diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index f44895a728c7e..44b22c2bd8a21 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -1,8 +1,13 @@ import random import torch +from vllm.config import ModelConfig from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.worker.model_runner import ModelRunner +from vllm.worker.model_runner import ModelRunner, _BATCH_SIZE_ALIGNMENT + + +def get_aligned_size(batch_size: int, alignment: int): + return ((batch_size + alignment - 1) // alignment * alignment) def test_prepare_prompt(): @@ -12,6 +17,7 @@ def test_prepare_prompt(): batch_size = random.randint(1, 256) prompt_lens = [] seq_group_metadata_list = [] + block_tables = {0: [1]} for i in range(batch_size): # make sure all tokens fit into one block prompt_len = i % (model_runner.block_size - 1) + 1 @@ -23,26 +29,165 @@ def test_prepare_prompt(): is_prompt=True, seq_data={0: SequenceData(seq_data)}, sampling_params=SamplingParams(temperature=0), - block_tables={0: [1]}, + block_tables=block_tables, )) expected_selected_token_indices = [] selected_token_start_idx = 0 - max_seq_len = max(prompt_lens) for prompt_len in prompt_lens: expected_selected_token_indices.append(selected_token_start_idx + prompt_len - 1) - selected_token_start_idx += max_seq_len - input_tokens, input_positions, _, return_prompt_lens, _, _, _, _ = ( - model_runner._prepare_prompt(seq_group_metadata_list)) + selected_token_start_idx += prompt_len + (input_tokens, input_positions, input_metadata, return_prompt_lens, _, _, + _, _) = (model_runner._prepare_prompt(seq_group_metadata_list)) assert return_prompt_lens == prompt_lens + + # Verify input metadata is correct for prompts. + device = model_runner.device + assert input_metadata.is_prompt is True + assert torch.allclose(input_metadata.prompt_lens_tensor, + torch.tensor(prompt_lens, device=device)) + assert input_metadata.prompt_lens == prompt_lens + assert input_metadata.num_prompt_tokens == sum(prompt_lens) + assert input_metadata.num_generation_tokens == 0 + assert input_metadata.max_seq_len == max(prompt_lens) + + # Test subquery start locs. + start_idx = 0 + start_loc = [start_idx] + for prompt_len in prompt_lens: + start_idx += prompt_len + start_loc.append(start_idx) + assert torch.allclose( + input_metadata.subquery_start_loc, + torch.tensor(start_loc, dtype=torch.int32, device=device)) + + # Test seq start locs. Note that for normal prefill it is + # equivalent to subquery_start_loc. + start_idx = 0 + seq_start_loc = [start_idx] + for prompt_len in prompt_lens: + start_idx += prompt_len + seq_start_loc.append(start_idx) + + assert torch.allclose( + input_metadata.seq_start_loc, + torch.tensor(start_loc, dtype=torch.int32, device=device)) + assert input_metadata.max_context_len is None + assert torch.allclose( + input_metadata.context_lens, + torch.zeros(input_metadata.context_lens.shape[0], + dtype=torch.int, + device=device)) + + expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))], + dtype=torch.int32, + device=model_runner.device) + assert torch.allclose(input_metadata.block_tables, expected) + # Cuda graph should not be used for prerill. + assert input_metadata.use_cuda_graph is False + assert input_metadata.kv_cache_dtype == "auto" + + assert input_tokens.shape == (sum(prompt_lens), ) + assert input_positions.shape == (sum(prompt_lens), ) + torch.testing.assert_close(input_tokens, input_positions) + sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) - assert input_tokens.shape == (batch_size, max_seq_len) - assert input_positions.shape == (batch_size, max_seq_len) + assert input_tokens.shape == (sum(prompt_lens), ) + assert input_positions.shape == (sum(prompt_lens), ) + actual = sampling_metadata.selected_token_indices + expected = torch.tensor(expected_selected_token_indices, + device=actual.device, + dtype=actual.dtype) + torch.testing.assert_close(actual, expected) + torch.testing.assert_close(input_tokens, input_positions) + + actual = sampling_metadata.selected_token_indices + expected = torch.tensor(expected_selected_token_indices, + device=actual.device, + dtype=actual.dtype) + torch.testing.assert_close(actual, expected) + + +def test_prepare_decode_cuda_graph(): + model_config = ModelConfig( + "facebook/opt-125m", + "facebook/opt-125m", + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None, + enforce_eager=False, + ) + model_runner = ModelRunner(model_config, None, None, None, None) + model_runner.set_block_size(16) + + batch_size = random.randint(1, 256) + prompt_lens = [] + seq_group_metadata_list = [] + for i in range(batch_size): + # make sure all tokens fit into one block + prompt_len = i % (model_runner.block_size - 1) + 1 + prompt_lens.append(prompt_len) + seq_data = list(range(prompt_len)) + seq_group_metadata_list.append( + SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=False, + seq_data={0: SequenceData(seq_data)}, + sampling_params=SamplingParams(temperature=0), + block_tables={0: [1]}, + )) + + input_tokens, input_positions, input_metadata, _, _, _ = ( + model_runner._prepare_decode(seq_group_metadata_list)) + + # Verify input metadata is correct for prompts. + device = model_runner.device + assert input_metadata.is_prompt is False + assert input_metadata.prompt_lens is None + assert input_metadata.num_prompt_tokens == 0 + assert input_metadata.num_generation_tokens == (get_aligned_size( + len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT)) + assert input_metadata.max_seq_len is None + assert input_metadata.subquery_start_loc is None + assert input_metadata.seq_start_loc is None + assert input_metadata.max_context_len == max(prompt_lens) + assert torch.allclose( + input_metadata.context_lens[:len(prompt_lens)], + torch.tensor(prompt_lens, dtype=torch.int, device=device)) + + # block table's first index corresponds to each batch, meaning in + # decoding it is each token. + assert input_metadata.block_tables.shape[0] == len(input_tokens) + # Block table's second dim correspondsd to each token's block number. + # It is padded up to + assert input_metadata.block_tables.shape[1] == ( + model_runner.get_max_block_per_batch()) + # Cuda graph should not be used for prerill. + assert input_metadata.use_cuda_graph is True + assert input_metadata.kv_cache_dtype == "auto" + + assert input_tokens.shape == (get_aligned_size( + len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), ) + assert input_positions.shape == (get_aligned_size( + len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), ) torch.testing.assert_close(input_tokens, input_positions) + # Verify Sampling + expected_selected_token_indices = [] + selected_token_start_idx = 0 + for prompt_len in prompt_lens: + expected_selected_token_indices.append(selected_token_start_idx) + selected_token_start_idx += 1 + sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, + prompt_lens, + subquery_lens=prompt_lens) actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, device=actual.device, diff --git a/vllm/config.py b/vllm/config.py index 51ae66e2375ab..b769ecdce8808 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -535,7 +535,6 @@ class SchedulerConfig: iteration. max_model_len: Maximum length of a sequence (including prompt and generated text). - max_paddings: Maximum number of paddings to be added to a batch. """ def __init__( @@ -543,7 +542,6 @@ def __init__( max_num_batched_tokens: Optional[int], max_num_seqs: int, max_model_len: int, - max_paddings: int, ) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens @@ -553,7 +551,6 @@ def __init__( self.max_num_batched_tokens = max(max_model_len, 2048) self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len - self.max_paddings = max_paddings self._verify_args() def _verify_args(self) -> None: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c3f93a2928df5..be55e8520a55f 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -173,12 +173,12 @@ def _schedule(self) -> SchedulerOutputs: curr_loras = set( seq_group.lora_int_id for seq_group in self.running) if self.lora_enabled else None - seq_lens: List[int] = [] # Optimization: We do not sort the waiting queue since the preempted # sequence groups are added to the front and the new sequence groups # are added to the back. leftover_waiting_sequences = deque() + num_batched_tokens = 0 while self.waiting: seq_group = self.waiting[0] waiting_seqs = seq_group.get_seqs( @@ -223,8 +223,7 @@ def _schedule(self) -> SchedulerOutputs: continue # If the number of batched tokens exceeds the limit, stop. - new_seq_lens = seq_lens + [num_prompt_tokens] - num_batched_tokens = len(new_seq_lens) * max(new_seq_lens) + num_batched_tokens += num_prompt_tokens if (num_batched_tokens > self.scheduler_config.max_num_batched_tokens): break @@ -236,11 +235,6 @@ def _schedule(self) -> SchedulerOutputs: self.scheduler_config.max_num_seqs): break - num_paddings = num_batched_tokens - sum(new_seq_lens) - if num_paddings > self.scheduler_config.max_paddings: - break - seq_lens = new_seq_lens - if lora_int_id > 0: curr_loras.add(lora_int_id) self.waiting.popleft() @@ -255,8 +249,7 @@ def _schedule(self) -> SchedulerOutputs: scheduler_outputs = SchedulerOutputs( scheduled_seq_groups=scheduled, prompt_run=True, - num_batched_tokens=len(seq_lens) * - max(seq_lens) if seq_lens else 0, + num_batched_tokens=num_batched_tokens, blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3e146d2e6c0c4..94c80f4284067 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -31,7 +31,6 @@ class EngineArgs: gpu_memory_utilization: float = 0.90 max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 - max_paddings: int = 256 max_logprobs: int = 5 # OpenAI default value disable_log_stats: bool = False revision: Optional[str] = None @@ -213,10 +212,6 @@ def add_cli_args( type=int, default=EngineArgs.max_num_seqs, help='maximum number of sequences per iteration') - parser.add_argument('--max-paddings', - type=int, - default=EngineArgs.max_paddings, - help='maximum number of paddings in a batch') parser.add_argument( '--max-logprobs', type=int, @@ -347,8 +342,7 @@ def create_engine_configs( ), self.ray_workers_use_nsight) scheduler_config = SchedulerConfig(self.max_num_batched_tokens, self.max_num_seqs, - model_config.max_model_len, - self.max_paddings) + model_config.max_model_len) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 71798ab7d17c0..2280481cca9cb 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -561,7 +561,6 @@ def _process_model_outputs( # Log stats. if self.log_stats: self.stat_logger.log(self._get_stats(scheduler_outputs)) - return request_outputs def step(self) -> List[RequestOutput]: diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index 01bba70ac10a8..35245865fb1b1 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -1,36 +1,92 @@ from dataclasses import dataclass, fields -from typing import Optional, Any, Dict +from typing import Optional, List, Any, Dict import torch +from xformers.ops.fmha.attn_bias import AttentionBias @dataclass class InputMetadata: """Metadata for input sequences. Used in PagedAttention. - Args: - prompt_lens: Lengths of prompts. - slot_mapping: The address to write the new KV to of each token. - max_context_len: The maximum context length. - context_lens: the length of attention context for each sequence. - block_tables: The block tables. (Seq id -> list of physical block) - kv_cache_dtype: Data type to store kv cache. + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. """ - + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. is_prompt: bool + # (num_tokens,). The indices of the token slots that input tokens will be + # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size + # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot + # in block 0, and 1st slot in block 1, respectively. slot_mapping: torch.Tensor - prompt_lens: Optional[torch.Tensor] - max_seq_len: Optional[int] - start_loc: Optional[torch.Tensor] + # (batch_size,). The prompt length per sequence. None if it is a decoding. + prompt_lens: Optional[List[int]] + # prompt_lens stored as a tensor. + prompt_lens_tensor: Optional[torch.Tensor] + # The number of prompt tokens. Doesn't include padding. + num_prompt_tokens: int + # The number of generation tokens. Doesn't include padding. + num_generation_tokens: int + """ + Definition of context_len, subquery_len, and seqlen. + |---------- N-1 iteration --------| + |---------------- N iteration ---------------------| + |- tokenA -|......................|-- newTokens ---| + |---------- context_len ----------| + |-------------------- seqlen ----------------------| + |- subquery_len -| + + WARNING: context_len has different definition depending on if it is + prefill vs decoding. When it is prefill, it doesn't include new + tokens. When it is for decoding, it includes a new token. + """ + + # Maximum subquery length in the batch. + max_subquery_len: Optional[int] + # Maximum context length in the batch. max_context_len: Optional[int] + # FIXME: It is for flash attn. + # Maximum sequence length in the batch. + max_seq_len: Optional[int] + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + subquery_start_loc: Optional[torch.Tensor] + # FIXME: It is for flash attn. + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + # (batch_size,). The length of context (tokens stored in KV cache) per + # sequence. WARNING: When it is a prefill request, it doesn't include new + # tokens. When it is for decoding, it includes a new token. context_lens: Optional[torch.Tensor] + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. block_tables: Optional[torch.Tensor] + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. use_cuda_graph: bool kv_cache_dtype: str def __post_init__(self): + # Set during the execution of the first attention op. + # It is a list because it is needed to set per prompt + # when alibi slopes is used. It is because of the limitation + # from xformer API. # will not appear in the __repr__ and __init__ - self.attn_bias = None + self.attn_bias: Optional[List[AttentionBias]] = None + + # Cuda graph is only used for decoding now. + if self.use_cuda_graph: + assert self.num_prompt_tokens == 0 def asdict_zerocopy(self) -> Dict[str, Any]: """Similar to dataclasses.asdict, but avoids deepcopying.""" diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 3eb73ee109f50..f569a5a49cbdf 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -20,8 +20,8 @@ class SiluAndMul(nn.Module): The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2. Shapes: - x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d) - return: (batch_size, seq_len, d) or (num_tokens, d) + x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) + return: (num_tokens, d) or (batch_size, seq_len, d) """ def _forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 4b63b9eaf59a7..ae598b029a007 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -17,11 +17,12 @@ class Attention(nn.Module): This class takes query, key, and value tensors as input. The input tensors can either contain prompt tokens or generation tokens. + The class does the following: 1. Store the input key and value tensors in the KV cache. 2. Perform (multi-head/multi-query/grouped-query) attention. - 3. Return the output tensor. + 3. Output the output tensor. """ def __init__( diff --git a/vllm/model_executor/layers/attention/backends/flash_attn.py b/vllm/model_executor/layers/attention/backends/flash_attn.py index 58ccd461b993e..9ce5851f3650d 100644 --- a/vllm/model_executor/layers/attention/backends/flash_attn.py +++ b/vllm/model_executor/layers/attention/backends/flash_attn.py @@ -1,7 +1,7 @@ """Attention layer with Flash and PagedAttention.""" from typing import List, Optional -from flash_attn import flash_attn_func +from flash_attn import flash_attn_varlen_func import torch from vllm.model_executor.input_metadata import InputMetadata @@ -10,6 +10,21 @@ class FlashAttentionBackend: + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + + Otherwise, the layout is as follows: + |<------------------ num_generation_tokens (M) ----------------->| + |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + """ def __init__( self, @@ -52,18 +67,18 @@ def forward( """Forward pass with FlashAttention and PagedAttention. Args: - query: shape = [batch_size, seq_len, num_heads * head_size] - key: shape = [batch_size, seq_len, num_kv_heads * head_size] - value: shape = [batch_size, seq_len, num_kv_heads * head_size] + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] key_cache: shape = [num_blocks, num_kv_heads, head_size/x, block_size, x] value_cache: shape = [num_blocks, num_kv_heads, head_size, block_size] input_metadata: metadata for the inputs. Returns: - shape = [batch_size, seq_len, num_heads * head_size] + shape = [num_tokens, num_heads * head_size] """ - batch_size, seq_len, hidden_size = query.shape + num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) @@ -82,13 +97,16 @@ def forward( if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): # normal attention - query = query.unflatten(0, (batch_size, seq_len)) - key = key.unflatten(0, (batch_size, seq_len)) - value = value.unflatten(0, (batch_size, seq_len)) - output = flash_attn_func( - query, - key, - value, + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=input_metadata.seq_start_loc, + cu_seqlens_k=input_metadata.seq_start_loc, + max_seqlen_q=input_metadata.max_seq_len, + max_seqlen_k=input_metadata.max_seq_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, @@ -118,4 +136,4 @@ def forward( ) # Reshape the output tensor. - return output.view(batch_size, seq_len, hidden_size) + return output.view(num_tokens, hidden_size) diff --git a/vllm/model_executor/layers/attention/backends/xformers.py b/vllm/model_executor/layers/attention/backends/xformers.py index bad2a648b6703..f0ef9fac9aaa4 100644 --- a/vllm/model_executor/layers/attention/backends/xformers.py +++ b/vllm/model_executor/layers/attention/backends/xformers.py @@ -14,6 +14,21 @@ class XFormersBackend: + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prompt_tokens --------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1--->| + + Otherwise, the layout is as follows: + |<------------------ num_generation_tokens (M) ----------------->| + |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + """ def __init__( self, @@ -55,19 +70,18 @@ def forward( """Forward pass with xFormers and PagedAttention. Args: - query: shape = [batch_size, seq_len, num_heads * head_size] - key: shape = [batch_size, seq_len, num_kv_heads * head_size] - value: shape = [batch_size, seq_len, num_kv_heads * head_size] + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] key_cache: shape = [num_blocks, num_kv_heads, head_size/x, block_size, x] value_cache: shape = [num_blocks, num_kv_heads, head_size, block_size] input_metadata: metadata for the inputs. Returns: - shape = [batch_size, seq_len, num_heads * head_size] + shape = [num_tokens, num_heads * head_size] """ - batch_size, seq_len, hidden_size = query.shape - # Reshape the query, key, and value tensors. + num_tokens, hidden_size = query.shape query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) @@ -82,9 +96,10 @@ def forward( if input_metadata.is_prompt: # Prompt run. + # key_cache and value_cache are None when it is a profiling run. + # block tables are empty if the prompt has never been computed. if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): - # normal attention if self.num_kv_heads != self.num_heads: # As of Nov 2023, xformers only supports MHA. For MQA/GQA, # project the key and value tensors to the desired number of @@ -103,61 +118,33 @@ def forward( self.num_queries_per_kv, value.shape[-1]) - # Set attention bias if not provided. This typically happens at - # the very attention layer of every iteration. - # FIXME(woosuk): This is a hack. - if input_metadata.attn_bias is None: - if self.alibi_slopes is None: - attn_bias = BlockDiagonalCausalMask.from_seqlens( - [seq_len] * batch_size) - if self.sliding_window is not None: - attn_bias = attn_bias.make_local_attention( - self.sliding_window) - input_metadata.attn_bias = attn_bias - else: - input_metadata.attn_bias = _make_alibi_bias( - self.alibi_slopes, self.num_kv_heads, batch_size, - seq_len, query.dtype) - if self.use_ref_attention: - output = _ref_masked_attention( - query, - key, - value, - self.num_heads, - self.num_kv_heads, - self.head_size, - self.scale, - ) + print("ref attention used.") + output = torch.empty_like(query) + start = 0 + for _, prompt_len in enumerate(input_metadata.prompt_lens): + end = start + prompt_len + out = _ref_masked_attention( + query[None, start:end], + key[None, start:end], + value[None, start:end], + self.num_heads, + self.num_kv_heads, + self.head_size, + self.scale, + ) + # TODO(woosuk): Unnecessary copy. Optimize. + output[start:end].copy_(out) + start += prompt_len + # Using view got RuntimeError: view size is not compatible # with input tensor's size and stride (at least one # dimension spans across two contiguous subspaces). # Use reshape instead. - return output.reshape(batch_size, seq_len, hidden_size) - - # TODO(woosuk): Too many view operations. Let's try to reduce - # them in the future for code readability. - if self.alibi_slopes is None: - query = query.unsqueeze(0) - key = key.unsqueeze(0) - value = value.unsqueeze(0) - else: - query = query.unflatten(0, (batch_size, seq_len)) - key = key.unflatten(0, (batch_size, seq_len)) - value = value.unflatten(0, (batch_size, seq_len)) - - out = xops.memory_efficient_attention_forward( - query, - key, - value, - attn_bias=input_metadata.attn_bias, - p=0.0, - scale=self.scale, - op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if - (is_hip()) else None, - ) - output = out.view_as(query) + return output.reshape(num_tokens, hidden_size) + output = self._run_memory_efficient_xformer_forward( + query, key, value, input_metadata) else: # prefix-enabled attention output = PagedAttentionImpl.forward_prefix( @@ -182,41 +169,117 @@ def forward( ) # Reshape the output tensor. - return output.view(batch_size, seq_len, hidden_size) + return output.view(-1, self.num_heads * self.head_size) + + def _run_memory_efficient_xformer_forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + """Attention for 1D query of multiple prompts. Multiple prompt + tokens are flattened in to `query` input. + + Args: + output: shape = [num_prompt_tokens, num_heads, head_size] + query: shape = [num_prompt_tokens, num_heads, head_size] + key: shape = [num_prompt_tokens, num_kv_heads, head_size] + value: shape = [num_prompt_tokens, num_kv_heads, head_size] + input_metadata: metadata for paged attention. + """ + # Set attention bias if not provided. This typically happens at + # the very attention layer of every iteration. + # FIXME(woosuk): This is a hack. + if input_metadata.attn_bias is None: + if self.alibi_slopes is None: + attn_bias = BlockDiagonalCausalMask.from_seqlens( + input_metadata.prompt_lens) + if self.sliding_window is not None: + attn_bias = attn_bias.make_local_attention( + self.sliding_window) + input_metadata.attn_bias = [attn_bias] + else: + input_metadata.attn_bias = _make_alibi_bias( + self.alibi_slopes, self.num_kv_heads, query.dtype, + input_metadata) + + op = xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if ( + is_hip()) else None + # No alibi slopes. + # TODO(woosuk): Too many view operations. Let's try to reduce + # them in the future for code readability. + if self.alibi_slopes is None: + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=input_metadata.attn_bias[0], + p=0.0, + scale=self.scale, + op=op) + + return out.view_as(query) + + # Attention with alibi slopes. + # FIXME(woosuk): Because xformers does not support dynamic sequence + # lengths with custom attention bias, we process each prompt one by + # one. This is inefficient, especially when we have many short prompts. + output = torch.empty_like(query) + start = 0 + for i, prompt_len in enumerate(input_metadata.prompt_lens): + end = start + prompt_len + out = xops.memory_efficient_attention_forward( + query[None, start:end], + key[None, start:end], + value[None, start:end], + attn_bias=input_metadata.attn_bias[i], + p=0.0, + scale=self.scale, + op=op) + # TODO(woosuk): Unnecessary copy. Optimize. + output[start:end].copy_(out.squeeze(0)) + start += prompt_len + return output def _make_alibi_bias( alibi_slopes: torch.Tensor, num_kv_heads: int, - batch_size: int, - seq_len: int, dtype: torch.dtype, + input_metadata: InputMetadata, ) -> LowerTriangularMaskWithTensorBias: - bias = torch.arange(seq_len, dtype=dtype) - # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(prompt_len, 1)` - # here. We find that both biases give the same results, but - # the bias below more accurately follows the original ALiBi - # paper. - bias = bias[None, :] - bias[:, None] - - # When using custom attention bias, xformers requires the bias to - # be sliced from a tensor whose length is a multiple of 8. - padded_len = (seq_len + 7) // 8 * 8 - num_heads = alibi_slopes.shape[0] - bias = torch.empty( - batch_size, - num_heads, - seq_len, - padded_len, - device=alibi_slopes.device, - dtype=dtype, - )[:, :, :, :seq_len].copy_(bias) - bias.mul_(alibi_slopes[:, None, None]) - if num_heads != num_kv_heads: - bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) - attn_bias = LowerTriangularMaskWithTensorBias(bias) - return attn_bias + attn_biases = [] + for prompt_len in input_metadata.prompt_lens: + bias = torch.arange(prompt_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(prompt_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + # Calculate a matrix where each element represents ith element- jth + # element. + bias = bias[None, :] - bias[:, None] + + padded_len = (prompt_len + 7) // 8 * 8 + num_heads = alibi_slopes.shape[0] + bias = torch.empty( + 1, # batch size + num_heads, + prompt_len, + padded_len, + device=alibi_slopes.device, + dtype=dtype, + )[:, :, :, :prompt_len].copy_(bias) + bias.mul_(alibi_slopes[:, None, None]) + if num_heads != num_kv_heads: + bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) + attn_biases.append(LowerTriangularMaskWithTensorBias(bias)) + + return attn_biases def _check_use_ref_attention() -> bool: @@ -239,7 +302,6 @@ def _ref_masked_attention( query = query.view(-1, num_heads, head_size) key = key.view(-1, num_kv_heads, head_size) value = value.view(-1, num_kv_heads, head_size) - seq_len, _, _ = query.shape attn_mask = torch.triu(torch.ones(seq_len, seq_len, diff --git a/vllm/model_executor/layers/attention/ops/paged_attn.py b/vllm/model_executor/layers/attention/ops/paged_attn.py index c5a9618c2395b..3105ba37b9832 100644 --- a/vllm/model_executor/layers/attention/ops/paged_attn.py +++ b/vllm/model_executor/layers/attention/ops/paged_attn.py @@ -128,11 +128,12 @@ def forward_prefix( output, key_cache, value_cache, - input_metadata.block_tables, # [BS, max_block_per_request] - input_metadata.start_loc, - input_metadata.prompt_lens, + input_metadata.block_tables, + # subquery_start_loc is (batch_size + 1,) + input_metadata.subquery_start_loc[:-1], + input_metadata.prompt_lens_tensor, input_metadata.context_lens, - input_metadata.max_seq_len, + input_metadata.max_subquery_len, alibi_slopes, ) return output diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 1fab1e734e1d7..ac8336ca0f9ad 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -128,7 +128,6 @@ def _prune_hidden_states( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) return hidden_states.index_select(0, sampling_metadata.selected_token_indices) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 7e25311fa2268..cfccbbb20adc5 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -28,9 +28,12 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] _PAD_SLOT_ID = -1 LORA_WARMUP_RANK = 8 -# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. +_BATCH_SIZE_ALIGNMENT = 8 +# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. # NOTE: _get_graph_batch_size needs to be updated if this list is changed. -_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] +_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ + _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) +] class ModelRunner: @@ -107,8 +110,7 @@ def load_model(self) -> None: ), "Model does not have embedding_padding_modules" self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens + - self.scheduler_config.max_paddings, self.vocab_size, + self.scheduler_config.max_num_batched_tokens, self.vocab_size, self.lora_config, self.device, self.model.embedding_modules, self.model.embedding_padding_modules) self.model = self.lora_manager.create_lora_manager(self.model) @@ -116,10 +118,13 @@ def load_model(self) -> None: def set_block_size(self, block_size: int) -> None: self.block_size = block_size - max_num_blocks = (self.max_context_len_to_capture + block_size - - 1) // block_size self.graph_block_tables = np.zeros( - (max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32) + (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), + dtype=np.int32) + + def get_max_block_per_batch(self) -> int: + block_size = self.block_size + return (self.max_context_len_to_capture + block_size - 1) // block_size def _prepare_prompt( self, @@ -127,9 +132,9 @@ def _prepare_prompt( ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], List[int], List[int], Set[LoRARequest]]: assert len(seq_group_metadata_list) > 0 - input_tokens: List[List[int]] = [] - input_positions: List[List[int]] = [] - slot_mapping: List[List[int]] = [] + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] lora_index_mapping: List[int] = [] lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() @@ -158,16 +163,18 @@ def _prepare_prompt( computed_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[computed_len:] prefix_block_tables.append(computed_block_nums) + context_len = computed_len else: prefix_block_tables.append([]) + context_len = 0 # actual prompt lens - context_lens.append(computed_len) + context_lens.append(context_len) subquery_lens.append(prompt_len - computed_len) - input_tokens.append(prompt_tokens) + input_tokens.extend(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.append( + input_positions.extend( list(range(computed_len, computed_len + len(prompt_tokens)))) lora_id = seq_group_metadata.lora_int_id @@ -175,7 +182,7 @@ def _prepare_prompt( if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) - lora_index_mapping.append([lora_id] * (prompt_len - computed_len)) + lora_index_mapping += [lora_id] * (prompt_len - computed_len) lora_prompt_mapping.extend( [lora_id] * (prompt_len - computed_len @@ -184,11 +191,10 @@ def _prepare_prompt( if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. - slot_mapping.append([_PAD_SLOT_ID] * prompt_len) + slot_mapping.extend([_PAD_SLOT_ID] * prompt_len) continue # Compute the slot mapping. - slot_mapping.append([]) block_table = seq_group_metadata.block_tables[seq_id] # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, # where start_idx is max(0, prompt_len - sliding_window). @@ -203,35 +209,30 @@ def _prepare_prompt( start_idx = max(0, prompt_len - self.sliding_window) for i in range(computed_len, prompt_len): if i < start_idx: - slot_mapping[-1].append(_PAD_SLOT_ID) + slot_mapping.append(_PAD_SLOT_ID) continue block_number = block_table[i // self.block_size] block_offset = i % self.block_size slot = block_number * self.block_size + block_offset - slot_mapping[-1].append(slot) - - max_prompt_len = max(subquery_lens) - assert max_prompt_len > 0 - input_tokens = _make_tensor_with_pad(input_tokens, - max_prompt_len, - pad=0, - dtype=torch.long, - device=self.device) - input_positions = _make_tensor_with_pad(input_positions, - max_prompt_len, - pad=0, - dtype=torch.long, - device=self.device) - slot_mapping = _make_tensor_with_pad(slot_mapping, - max_prompt_len, - pad=_PAD_SLOT_ID, - dtype=torch.long, - device=self.device) - lora_index_mapping = [ - _pad_to_max(mapping, max_prompt_len, pad=0) - for mapping in lora_index_mapping - ] + slot_mapping.append(slot) + + max_subquery_len = max(subquery_lens) + max_seq_len = max(prompt_lens) + num_prompt_tokens = len(input_tokens) + assert max_subquery_len > 0 + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + lora_index_mapping = lora_index_mapping + context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, device=self.device) @@ -244,22 +245,45 @@ def _prepare_prompt( dtype=torch.int, device=self.device, ) - start_loc_tensor = torch.arange(0, - len(prompt_lens) * max_prompt_len, - max_prompt_len, - dtype=torch.long, - device=self.device) + + # Query length can be shorter than key (i.e., prompt) when prefill + # is chunked or prefix cached. + subquery_lens_tensor = torch.tensor(subquery_lens, + dtype=torch.long, + device=self.device) + subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) + prompt_lens_tensor = torch.tensor(prompt_lens, dtype=torch.long, device=self.device) + seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) + + torch.cumsum(subquery_lens_tensor, + dim=0, + dtype=subquery_start_loc.dtype, + out=subquery_start_loc[1:]) + + torch.cumsum(prompt_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) input_metadata = InputMetadata( is_prompt=True, slot_mapping=slot_mapping, - prompt_lens=prompt_lens_tensor, - max_seq_len=max_prompt_len, - start_loc=start_loc_tensor, + prompt_lens=prompt_lens, + prompt_lens_tensor=prompt_lens_tensor, + num_prompt_tokens=num_prompt_tokens, + num_generation_tokens=0, + max_subquery_len=max_subquery_len, max_context_len=None, + max_seq_len=max_seq_len, + subquery_start_loc=subquery_start_loc, + seq_start_loc=seq_start_loc, context_lens=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, @@ -275,9 +299,9 @@ def _prepare_decode( ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], Set[LoRARequest]]: assert len(seq_group_metadata_list) > 0 - input_tokens: List[List[int]] = [] - input_positions: List[List[int]] = [] - slot_mapping: List[List[int]] = [] + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] context_lens: List[int] = [] block_tables: List[List[int]] = [] lora_index_mapping: List[int] = [] @@ -296,11 +320,11 @@ def _prepare_decode( for seq_id in seq_ids: seq_data = seq_group_metadata.seq_data[seq_id] generation_token = seq_data.get_last_token_id() - input_tokens.append([generation_token]) + input_tokens.append(generation_token) seq_len = seq_data.get_len() position = seq_len - 1 - input_positions.append([position]) + input_positions.append(position) context_len = seq_len if self.sliding_window is None else min( seq_len, self.sliding_window) @@ -310,8 +334,8 @@ def _prepare_decode( block_number = block_table[position // self.block_size] block_offset = position % self.block_size slot = block_number * self.block_size + block_offset - slot_mapping.append([slot]) - lora_index_mapping.append([lora_id]) + slot_mapping.append(slot) + lora_index_mapping.append(lora_id) lora_prompt_mapping.append(lora_id) if self.sliding_window is not None: @@ -320,6 +344,9 @@ def _prepare_decode( block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) + # vLLM uses cuda graph only for decoding requests. + # See `capture_model` API for more details. + # For decoding requests, batch_size == input_tokens. batch_size = len(input_tokens) max_context_len = max(context_lens) use_captured_graph = ( @@ -327,38 +354,37 @@ def _prepare_decode( and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] and max_context_len <= self.max_context_len_to_capture) if use_captured_graph: - # Pad the input tokens, positions, and slot mapping to match the - # batch size of the captured graph. graph_batch_size = _get_graph_batch_size(batch_size) assert graph_batch_size >= batch_size for _ in range(graph_batch_size - batch_size): - input_tokens.append([]) - input_positions.append([]) - slot_mapping.append([]) + input_tokens.append(0) + input_positions.append(0) + slot_mapping.append(_PAD_SLOT_ID) context_lens.append(1) block_tables.append([]) + lora_index_mapping.append(0) batch_size = graph_batch_size - input_tokens = _make_tensor_with_pad(input_tokens, - max_len=1, - pad=0, - dtype=torch.long, - device=self.device) - input_positions = _make_tensor_with_pad(input_positions, - max_len=1, - pad=0, - dtype=torch.long, - device=self.device) - slot_mapping = _make_tensor_with_pad(slot_mapping, - max_len=1, - pad=_PAD_SLOT_ID, - dtype=torch.long, - device=self.device) + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) context_lens = torch.tensor(context_lens, dtype=torch.int, device=self.device) if use_captured_graph: + # When using cuda-graph all these tensors should be + # padded. + assert context_lens.shape[0] == input_tokens.shape[0] + assert context_lens.shape[0] == input_positions.shape[0] + assert context_lens.shape[0] == slot_mapping.shape[0] + # The shape of graph_block_tables is # [max batch size, max context len // block size]. input_block_tables = self.graph_block_tables[:batch_size] @@ -377,17 +403,18 @@ def _prepare_decode( device=self.device, ) - lora_index_mapping = [ - _pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping - ] - input_metadata = InputMetadata( is_prompt=False, slot_mapping=slot_mapping, prompt_lens=None, - max_seq_len=None, - start_loc=None, + prompt_lens_tensor=None, + num_prompt_tokens=0, + num_generation_tokens=len(input_tokens), + max_subquery_len=None, max_context_len=max_context_len, + max_seq_len=None, + subquery_start_loc=None, + seq_start_loc=None, context_lens=context_lens, block_tables=block_tables, use_cuda_graph=use_captured_graph, @@ -411,7 +438,6 @@ def _prepare_sample( categorized_sampled_token_indices_start_idx = 0 pin_memory = not self.in_wsl and not self.device_config.is_neuron - max_subquery_len = max(subquery_lens) if subquery_lens else 1 for i, seq_group_metadata in enumerate(seq_group_metadata_list): seq_ids = list(seq_group_metadata.seq_data.keys()) sampling_params = seq_group_metadata.sampling_params @@ -439,7 +465,7 @@ def _prepare_sample( selected_token_start_idx + subquery_len - 1)) selected_token_indices.append(selected_token_start_idx + subquery_len - 1) - selected_token_start_idx += max_subquery_len + selected_token_start_idx += subquery_len if sampling_params.seed is not None: seq_group_metadata.state.generator = torch.Generator( @@ -521,11 +547,8 @@ def prepare_input_tensors( subquery_lens) if self.lora_config: - flat_lora_index_mapping = [ - item for sublist in lora_index_mapping for item in sublist - ] lora_mapping = LoRAMapping( - flat_lora_index_mapping, + lora_index_mapping, lora_prompt_mapping, ) else: @@ -679,6 +702,18 @@ def list_loras(self) -> Set[int]: @torch.inference_mode() def capture_model(self, kv_caches: List[KVCache]) -> None: + """Cuda graph capture a model. + + Note that CUDA graph's performance gain is negligible if number + of batched tokens are larger than 200. And since CUDA graph + requires fixed sized tensors, supporting large/variable batch + size requires high GPU memory overhead. Thus, vLLM only captures + decoding requests. Mixed batch (chunked prefill + decoding) or + prefill requests are not captured. + + Since it is used for decoding-only, it assumes there's only 1 token + per sequence in the batch. + """ # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never # deleted before the CUDA graphs. self.cupy_nccl_backend = cupy_utils.get_nccl_backend() @@ -697,10 +732,9 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: # Prepare dummy inputs. These will be reused for all batch sizes. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) - input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda() - input_positions = torch.zeros(max_batch_size, 1, - dtype=torch.long).cuda() - slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda() + input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() + input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() + slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() slot_mapping.fill_(_PAD_SLOT_ID) context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() @@ -726,9 +760,14 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: is_prompt=False, slot_mapping=slot_mapping[:batch_size], prompt_lens=None, - max_seq_len=None, - start_loc=None, + prompt_lens_tensor=None, + num_prompt_tokens=0, + num_generation_tokens=batch_size, + max_subquery_len=None, max_context_len=self.max_context_len_to_capture, + max_seq_len=None, + subquery_start_loc=None, + seq_start_loc=None, context_lens=context_lens[:batch_size], block_tables=block_tables[:batch_size], use_cuda_graph=True, @@ -845,7 +884,6 @@ def forward( non_blocking=True) self.input_buffers["block_tables"].copy_(input_metadata.block_tables, non_blocking=True) - # Run the graph. self.graph.replay() @@ -877,17 +915,28 @@ def _make_tensor_with_pad( dtype: torch.dtype, device: Optional[Union[str, torch.device]], ) -> torch.Tensor: + """Make a padded tensor of a 2D inputs. + + The padding is applied to the end of each inner list until it reaches + `max_len`. + """ padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x] return torch.tensor(padded_x, dtype=dtype, device=device) def _get_graph_batch_size(batch_size: int) -> int: + """Returns the padded batch size given actual batch size. + + Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT, + 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT... + """ if batch_size <= 2: return batch_size elif batch_size <= 4: return 4 else: - return (batch_size + 7) // 8 * 8 + return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // + _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) def _async_h2d( From f1c0fc391909e55fce5f109893f3c483f69a091f Mon Sep 17 00:00:00 2001 From: Roy Date: Thu, 21 Mar 2024 07:25:01 +0800 Subject: [PATCH 04/10] Migrate `logits` computation and gather to `model_runner` (#3233) --- .buildkite/test-pipeline.yaml | 3 + tests/lora/conftest.py | 7 +- tests/lora/test_layers.py | 66 ++++++----- tests/samplers/test_sampler.py | 95 ++++------------ tests/test_logits_processor.py | 94 ++++++++++++++++ vllm/lora/layers.py | 20 ++-- vllm/lora/models.py | 13 ++- .../model_executor/layers/logits_processor.py | 106 ++++++++++++++++++ vllm/model_executor/layers/sampler.py | 81 +------------ vllm/model_executor/models/baichuan.py | 15 ++- vllm/model_executor/models/bloom.py | 15 ++- vllm/model_executor/models/chatglm.py | 15 ++- vllm/model_executor/models/deepseek.py | 15 ++- vllm/model_executor/models/falcon.py | 15 ++- vllm/model_executor/models/gemma.py | 15 ++- vllm/model_executor/models/gpt2.py | 14 ++- vllm/model_executor/models/gpt_bigcode.py | 15 ++- vllm/model_executor/models/gpt_j.py | 15 ++- vllm/model_executor/models/gpt_neox.py | 15 ++- vllm/model_executor/models/internlm2.py | 15 ++- vllm/model_executor/models/llama.py | 18 ++- vllm/model_executor/models/mixtral.py | 16 ++- vllm/model_executor/models/mixtral_quant.py | 15 ++- vllm/model_executor/models/mpt.py | 15 ++- vllm/model_executor/models/neuron/llama.py | 15 ++- vllm/model_executor/models/neuron/mistral.py | 15 ++- vllm/model_executor/models/olmo.py | 15 ++- vllm/model_executor/models/opt.py | 15 ++- vllm/model_executor/models/orion.py | 15 ++- vllm/model_executor/models/phi.py | 16 ++- vllm/model_executor/models/qwen.py | 15 ++- vllm/model_executor/models/qwen2.py | 24 ++-- vllm/model_executor/models/stablelm.py | 15 ++- vllm/model_executor/models/starcoder2.py | 16 ++- vllm/worker/model_runner.py | 9 +- 35 files changed, 577 insertions(+), 306 deletions(-) create mode 100644 tests/test_logits_processor.py create mode 100644 vllm/model_executor/layers/logits_processor.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 17f4c33670821..6d052d0f7f4a4 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -49,6 +49,9 @@ steps: - label: Samplers Test command: pytest -v -s samplers +- label: LogitsProcessor Test + command: pytest -v -s test_logits_processor.py + - label: Worker Test command: pytest -v -s worker diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 30a8ad03c8ada..38560c251696a 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -13,6 +13,7 @@ import vllm from vllm.config import LoRAConfig from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.model_loader import get_model from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, @@ -85,7 +86,8 @@ def dummy_model() -> nn.Module: ("outact", nn.Sigmoid()), # Special handling for lm_head & sampler ("lm_head", ParallelLMHead(512, 10)), - ("sampler", Sampler(512)) + ("logits_processor", LogitsProcessor(512)), + ("sampler", Sampler()) ])) model.config = MagicMock() return model @@ -110,7 +112,8 @@ def dummy_model_gate_up() -> nn.Module: ("outact", nn.Sigmoid()), # Special handling for lm_head & sampler ("lm_head", ParallelLMHead(512, 10)), - ("sampler", Sampler(512)) + ("logits_processor", LogitsProcessor(512)), + ("sampler", Sampler()) ])) model.config = MagicMock() return model diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 46f054c5b84ef..7dfc3952016f5 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -13,14 +13,14 @@ QKVParallelLinearWithLora, VocabParallelEmbeddingWithLoRA, RowParallelLinearWithLoRA, - SamplerWithLoRA, + LogitsProcessorWithLoRA, LoRAMapping, BaseLayerWithLoRA, ) from vllm.lora.models import (LoRALayerWeights, convert_mapping, PackedLoRALayerWeights) from vllm.config import LoRAConfig -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear, @@ -394,7 +394,7 @@ def create_random_embedding_layer(): @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_lm_head_sampler(dist_init, num_loras, device) -> None: +def test_lm_head_logits_processor(dist_init, num_loras, device) -> None: torch.set_default_device(device) max_loras = 8 @@ -402,28 +402,29 @@ def test_lm_head_sampler(dist_init, num_loras, device) -> None: max_lora_rank=8, lora_dtype=torch.float16) - def create_random_sampler_layer(): + def _pretest(): linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size, 1024, 32000) linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data[:, 32000:] = 0 - sampler = Sampler(32000 + lora_config.lora_extra_vocab_size, 32000) - lora_sampler = SamplerWithLoRA(sampler, 1024, linear.weight.dtype, - linear.weight.device) - lora_sampler.create_lora_weights(max_loras, lora_config) + logits_processor = LogitsProcessor( + 32000 + lora_config.lora_extra_vocab_size, 32000) + lora_logits_processor = LogitsProcessorWithLoRA( + logits_processor, 1024, linear.weight.dtype, linear.weight.device) + lora_logits_processor.create_lora_weights(max_loras, lora_config) - return linear, sampler, lora_sampler + return linear, logits_processor, lora_logits_processor for i in range(10): set_random_seed(i) id_to_index = get_random_id_to_index(num_loras, max_loras) - linear, sampler, lora_sampler = create_random_sampler_layer() + linear, logits_processor, lora_logits_processor = _pretest() # NOTE: all the generated loras share the same embeddings tensor. lora_dict, _ = populate_loras( id_to_index, - layer=lora_sampler, + layer=lora_logits_processor, layer_weights=linear.weight, generate_embeddings_tensor=1024, ) @@ -447,34 +448,37 @@ def create_random_sampler_layer(): 32000, lora_config.lora_extra_vocab_size, ) - lora_sampler.set_mapping(*mapping_info, ) + lora_logits_processor.set_mapping(*mapping_info, ) - lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs), - embedding=linear.weight, - embedding_bias=None) + lora_result = lora_logits_processor._get_logits( + hidden_states=torch.cat(inputs), + embedding=linear.weight, + embedding_bias=None) original_weight = linear.weight.clone() - linear.weight[sampler.org_vocab_size:sampler.org_vocab_size + + linear.weight[logits_processor. + org_vocab_size:logits_processor.org_vocab_size + embeddings_tensor_len] = embeddings_tensor - sampler.org_vocab_size = 32000 + lora_config.lora_extra_vocab_size + logits_processor.org_vocab_size = (32000 + + lora_config.lora_extra_vocab_size) expected_results = [] for input_, lora_id in zip(inputs, prompt_mapping): lora = lora_dict[lora_id] - result = sampler._get_logits(hidden_states=input_, - embedding=linear.weight, - embedding_bias=None) + result = logits_processor._get_logits(hidden_states=input_, + embedding=linear.weight, + embedding_bias=None) result[:, 32000 + embeddings_tensor_len:] = float("-inf") result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling expected_results.append(result) expected_result = torch.cat(expected_results) - sampler.org_vocab_size = 32000 + logits_processor.org_vocab_size = 32000 # Check that resetting the lora weights succeeds for slot_idx in range(max_loras): - lora_sampler.reset_lora(slot_idx) + lora_logits_processor.reset_lora(slot_idx) inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=[0], @@ -488,14 +492,16 @@ def create_random_sampler_layer(): mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, 32000, lora_config.lora_extra_vocab_size) - lora_sampler.set_mapping(*mapping_info, ) - - lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs), - embedding=original_weight, - embedding_bias=None)[:, :32000] - expected_result = sampler._get_logits(hidden_states=torch.cat(inputs), - embedding=original_weight, - embedding_bias=None) + lora_logits_processor.set_mapping(*mapping_info, ) + + lora_result = lora_logits_processor._get_logits( + hidden_states=torch.cat(inputs), + embedding=original_weight, + embedding_bias=None)[:, :32000] + expected_result = logits_processor._get_logits( + hidden_states=torch.cat(inputs), + embedding=original_weight, + embedding_bias=None) rtol, atol = TOLERANCES[lora_result.dtype] assert torch.allclose(lora_result, diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index b0c6e1c09eebc..92aec831d02e2 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -15,17 +15,12 @@ class MockLogitsSampler(Sampler): - def __init__(self, vocab_size: int, fake_logits: torch.Tensor): - super().__init__(vocab_size=vocab_size) + def __init__(self, fake_logits: torch.Tensor): + super().__init__() self.fake_logits = fake_logits def forward(self, *args, **kwargs): - with patch( - "vllm.model_executor.layers.sampler._prune_hidden_states", - lambda x, y: x), patch( - "vllm.model_executor.layers.sampler.Sampler._get_logits", - lambda *args, **kwargs: self.fake_logits): - return super().forward(*args, **kwargs) + return super().forward(*args, **kwargs) def _prepare_test( @@ -36,7 +31,7 @@ def _prepare_test( fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=input_tensor.dtype) - sampler = MockLogitsSampler(32000, fake_logits) + sampler = MockLogitsSampler(fake_logits) model_runner = ModelRunner(None, None, None, None, None) return input_tensor, fake_logits, sampler, model_runner @@ -70,9 +65,7 @@ def _do_sample( sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) - return sampler(embedding=None, - hidden_states=input_tensor, - sampling_metadata=sampling_metadata) + return sampler(logits=input_tensor, sampling_metadata=sampling_metadata) @pytest.mark.parametrize("seed", RANDOM_SEEDS) @@ -85,8 +78,8 @@ def test_sampler_all_greedy(seed: int, device: str): batch_size) sampling_params = SamplingParams(temperature=0) - sampler_output = _do_sample(batch_size, input_tensor, sampler, - model_runner, sampling_params) + sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, + sampling_params) expected = torch.argmax(fake_logits, dim=-1) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: @@ -111,8 +104,8 @@ def test_sampler_all_random(seed: int, device: str): temperature=1.0, n=random.randint(1, 10), ) - sampler_output = _do_sample(batch_size, input_tensor, sampler, - model_runner, sampling_params) + sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, + sampling_params) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: @@ -127,8 +120,7 @@ def test_sampler_all_random_seed(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) + _, fake_logits, sampler, model_runner = _prepare_test(batch_size) for i in range(batch_size): fake_logits[i, i] = 1e2 @@ -138,8 +130,8 @@ def test_sampler_all_random_seed(seed: int, device: str): n=random.randint(1, 10), seed=random.randint(0, 10000), ) - sampler_output = _do_sample(batch_size, input_tensor, sampler, - model_runner, sampling_params) + sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, + sampling_params) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: @@ -154,18 +146,17 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) + _, fake_logits, sampler, model_runner = _prepare_test(batch_size) sampling_params = SamplingParams( temperature=1.0, n=random.randint(1, 10), seed=random.randint(0, 10000), ) - first_sampler_output = _do_sample(batch_size, input_tensor, sampler, + first_sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params) - second_sampler_output = _do_sample(batch_size, input_tensor, sampler, + second_sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params) assert first_sampler_output == second_sampler_output @@ -179,15 +170,14 @@ def test_sampler_all_beam(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, _, sampler, model_runner = _prepare_test(batch_size) + _, fake_logits, sampler, model_runner = _prepare_test(batch_size) sampling_params = SamplingParams( temperature=0, best_of=2, use_beam_search=True, ) - _do_sample(batch_size, input_tensor, sampler, model_runner, - sampling_params) + _do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params) # no assertion here as I am not sure how to determine whether # the outputs are expected - in other words, this just tests # whether there are no exceptions in the sampler @@ -246,8 +236,7 @@ def test_sampler_mixed(seed: int, device: str): def test_sampling(model_runner: ModelRunner): sampling_metadata = model_runner._prepare_sample( seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) - sampler_output = sampler(embedding=None, - hidden_states=input_tensor, + sampler_output = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) for i, (sequence_output, metadata) in enumerate( @@ -294,48 +283,6 @@ def test_sampling(model_runner: ModelRunner): del model_runner -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_sampler_logits_processors(seed: int, device: str): - set_random_seed(seed) - torch.set_default_device(device) - batch_size = random.randint(1, 256) - input_tensor, _, sampler, model_runner = _prepare_test(batch_size) - - # This sample logits processor gives maximum score to the i-th token, - # where i is the length of the input sequence. - # We therefore expect the output token sequence to be [0, 1, 2, ...] - def pick_ith(token_ids, logits): - logits[len(token_ids)] = torch.finfo(logits.dtype).max - return logits - - seq_group_metadata_list = [] - prompt_lens = [] - for i in range(batch_size): - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: SequenceData([1, 2, 3])}, - sampling_params=SamplingParams(temperature=0, - logits_processors=[pick_ith]), - block_tables={0: [1]}, - )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - - sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens) - sampler_output = sampler(embedding=None, - hidden_states=input_tensor, - sampling_metadata=sampling_metadata) - for _, sequence_output in enumerate(sampler_output): - for idx, nth_output in enumerate(sequence_output.samples): - assert nth_output.output_token == idx - - del model_runner - - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_sampler_top_k_top_p(seed: int, device: str): @@ -352,7 +299,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): size=(batch_size, vocab_size), device=input_tensor.device, dtype=input_tensor.dtype) - sampler = MockLogitsSampler(32000, fake_logits) + sampler = MockLogitsSampler(fake_logits) model_runner = ModelRunner(None, None, None, None, None) generation_model = GenerationMixin() @@ -391,9 +338,7 @@ def mock_sample(probs, *args, **kwargs): return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs] with patch("vllm.model_executor.layers.sampler._sample", mock_sample): - sampler(embedding=None, - hidden_states=input_tensor, - sampling_metadata=sampling_metadata) + sampler(logits=fake_logits, sampling_metadata=sampling_metadata) hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone()) hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float) assert torch.allclose(hf_probs, sample_probs, atol=1e-5) diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py new file mode 100644 index 0000000000000..fe321520114f7 --- /dev/null +++ b/tests/test_logits_processor.py @@ -0,0 +1,94 @@ +import random +from typing import Tuple +from unittest.mock import patch + +import pytest +import torch + +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.utils import set_random_seed +from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata +from vllm.worker.model_runner import ModelRunner + + +class MockLogitsProcessor(LogitsProcessor): + + def __init__(self, vocab_size: int, scale: float, + fake_logits: torch.Tensor): + super().__init__(vocab_size=vocab_size, scale=scale) + self.fake_logits = fake_logits.clone() + + def forward(self, *args, **kwargs): + with patch( + "vllm.model_executor.layers.logits_processor._prune_hidden_states", + lambda x, y: x + ), patch( + "vllm.model_executor.layers.logits_processor.LogitsProcessor._get_logits", + lambda *args, **kwargs: self.fake_logits): + return super().forward(*args, **kwargs) + + +def _prepare_test( + batch_size: int +) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor, ModelRunner]: + vocab_size = 32000 + input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) + fake_logits = torch.full((batch_size, vocab_size), + 1e-2, + dtype=input_tensor.dtype) + logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits) + model_runner = ModelRunner(None, None, None, None, None) + return input_tensor, fake_logits, logits_processor, model_runner + + +RANDOM_SEEDS = list(range(128)) +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + + +@pytest.mark.parametrize("seed", RANDOM_SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_logits_processors(seed: int, device: str): + set_random_seed(seed) + torch.set_default_device(device) + batch_size = random.randint(1, 256) + input_tensor, fake_logits, logits_processor, model_runner = _prepare_test( + batch_size) + + # This sample logits processor gives infinite score to the i-th token, + # where i is the length of the input sequence. + # We therefore expect the output token sequence to be [0, 1, 2, ...] + def pick_ith(token_ids, logits): + logits[len(token_ids)] = float("inf") + return logits + + seq_group_metadata_list = [] + prompt_lens = [] + for i in range(batch_size): + seq_group_metadata_list.append( + SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: SequenceData([1, 2, 3])}, + sampling_params=SamplingParams(temperature=0, + logits_processors=[pick_ith]), + block_tables={0: [1]}, + )) + prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + + sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, + prompt_lens, + subquery_lens=prompt_lens) + logits_processor_output = logits_processor( + embedding=None, + hidden_states=input_tensor, + sampling_metadata=sampling_metadata) + + assert torch.isinf(logits_processor_output[:, 0]).all() + + fake_logits *= logits_processor.scale + assert torch.allclose(logits_processor_output[:, 1], fake_logits[:, 1], + 1e-4) + + del model_runner diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 99e6cdeee6364..f6cd1390d4bce 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -10,7 +10,6 @@ from vllm.config import LoRAConfig from vllm.lora.punica import add_lora, add_lora_slice, bgmv -from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, @@ -20,6 +19,7 @@ RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( @@ -783,11 +783,11 @@ def weight(self): return self.base_layer.weight -class SamplerWithLoRA(BaseLayerWithLoRA): +class LogitsProcessorWithLoRA(BaseLayerWithLoRA): def __init__( self, - base_layer: Sampler, + base_layer: LogitsProcessor, hidden_size: int, dtype: torch.dtype, device: torch.device, @@ -806,6 +806,10 @@ def logits_as_hidden_states(self): def vocab_size(self): return self.base_layer.vocab_size + @property + def scale(self): + return self.base_layer.scale + @property def org_vocab_size(self): return self.base_layer.org_vocab_size @@ -968,14 +972,14 @@ def from_layer( return layer -def from_layer_sampler( - layer: Sampler, +def from_layer_logits_processor( + layer: LogitsProcessor, lm_head: ParallelLMHead, max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None, -) -> SamplerWithLoRA: - ret = SamplerWithLoRA(layer, lm_head.embedding_dim, lm_head.weight.dtype, - lm_head.weight.device) +) -> LogitsProcessorWithLoRA: + ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim, + lm_head.weight.dtype, lm_head.weight.device) ret.create_lora_weights(max_loras, lora_config, model_config) return ret diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 6fe07b69b3203..d1bac7617e1d4 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -14,7 +14,7 @@ from vllm.utils import LRUCache, in_wsl from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, from_layer, - from_layer_sampler) + from_layer_logits_processor) from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule @@ -421,11 +421,14 @@ def _create_lora_modules(self): self.model.config)) # (yard1): TODO make this more robust if "lm_head" in module_name: - sampler_module = self.model.get_submodule("sampler") + logits_processor_module = self.model.get_submodule( + "logits_processor") new_module = replace_submodule( - self.model, "sampler", - from_layer_sampler(sampler_module, module, self.lora_slots, - self.lora_config, self.model.config)) + self.model, "logits_processor", + from_layer_logits_processor(logits_processor_module, + module, self.lora_slots, + self.lora_config, + self.model.config)) self.register_module(module_name, new_module) self._register_packed_modules(module_name) new_module.set_mapping(self.base_indices, self.sampler_indices, diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py new file mode 100644 index 0000000000000..baa113c342c28 --- /dev/null +++ b/vllm/model_executor/layers/logits_processor.py @@ -0,0 +1,106 @@ +"""A layer that compute logits from hidden_stats.""" +from typing import Optional + +import torch +import torch.nn as nn + +from vllm.utils import is_neuron + +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_gather) +from vllm.model_executor.sampling_metadata import SamplingMetadata + + +class LogitsProcessor(nn.Module): + """Process logits and apply logits processors from sampling metadata. + + This layer does the following: + 1. Gather logits from model hidden_states. + 2. Scale logits if needed. + 3. Apply logits processors (if any). + """ + + def __init__(self, + vocab_size: int, + org_vocab_size: Optional[int] = None, + scale: Optional[float] = 1.0) -> None: + """ + Args: + scale: A scaling factor to apply to the logits. + """ + super().__init__() + self.scale = scale + self.vocab_size = vocab_size + # Transformers-neuronx generate outputs as logits directly. + self.logits_as_hidden_states = is_neuron() + # original vocabulary size (without LoRA). + self.org_vocab_size = org_vocab_size or vocab_size + + def forward( + self, + embedding: torch.Tensor, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + embedding_bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if self.logits_as_hidden_states: + logits = hidden_states + else: + hidden_states = _prune_hidden_states(hidden_states, + sampling_metadata) + + # Get the logits for the next tokens. + logits = self._get_logits(hidden_states, embedding, embedding_bias) + + if logits is not None: + logits *= self.scale + + # Apply logits processors (if any). + logits = _apply_logits_processors(logits, sampling_metadata) + + return logits + + def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, + embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: + # Get the logits for the next tokens. + logits = torch.matmul(hidden_states, embedding.t()) + if embedding_bias is not None: + logits += embedding_bias + logits = tensor_model_parallel_gather(logits) + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[:, :self.org_vocab_size] + return logits + + +def _prune_hidden_states( + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + return hidden_states.index_select(0, + sampling_metadata.selected_token_indices) + + +def _apply_logits_processors( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + logits_row_idx = 0 + found_logits_processors = False + for seq_ids, sampling_params in sampling_metadata.seq_groups: + logits_processors = sampling_params.logits_processors + if logits_processors: + found_logits_processors = True + for seq_id in seq_ids: + logits_row = logits[logits_row_idx] + token_ids = sampling_metadata.seq_data[seq_id].output_token_ids + for logits_processor in logits_processors: + logits_row = logits_processor(token_ids, logits_row) + logits[logits_row_idx] = logits_row + logits_row_idx += 1 + else: + logits_row_idx += len(seq_ids) + if found_logits_processors: + assert logits_row_idx == logits.shape[0] + return logits diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index ac8336ca0f9ad..63e494586efb5 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -4,8 +4,6 @@ import torch import torch.nn as nn -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_gather) from vllm.model_executor.sampling_metadata import (SamplingMetadata, SamplingTensors) from vllm.sampling_params import SamplingParams, SamplingType @@ -13,7 +11,6 @@ SamplerOutput, SequenceData, SequenceGroupOutput, SequenceOutput) from vllm.model_executor.layers.ops.sample import (sample as sample_triton) -from vllm.utils import is_neuron class Sampler(nn.Module): @@ -31,58 +28,14 @@ class Sampler(nn.Module): parameters (e.g., sampling method, temperature, top-p, top-k, etc.). """ - def __init__(self, - vocab_size: int, - org_vocab_size: Optional[int] = None) -> None: - super().__init__() - self.vocab_size = vocab_size - # Transformers-neuronx generate outputs as logits directly. - self.logits_as_hidden_states = is_neuron() - # original vocabulary size (without LoRA). - self.org_vocab_size = org_vocab_size or vocab_size - - def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, - embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: - # Get the logits for the next tokens. - logits = torch.matmul(hidden_states, embedding.t()) - if embedding_bias is not None: - logits += embedding_bias - logits = tensor_model_parallel_gather(logits) - # Remove paddings in vocab (if any). - if logits is not None: - logits = logits[:, :self.org_vocab_size] - return logits - def forward( self, - embedding: torch.Tensor, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, - embedding_bias: Optional[torch.Tensor] = None, ) -> Optional[SamplerOutput]: - # Get the hidden states that we use for sampling. - if self.logits_as_hidden_states: - logits = hidden_states - else: - hidden_states = _prune_hidden_states(hidden_states, - sampling_metadata) - - # Get the logits for the next tokens. - logits = self._get_logits(hidden_states, embedding, embedding_bias) - - # Only perform sampling in the driver worker. - # Note: `_get_logits` is still distributed across TP workers because - # the `embedding` weight is distributed across TP workers. - # TODO(zhuohan): Change the get_logits part to a separate stage. - if not sampling_metadata.perform_sampling: - return None - assert logits is not None _, vocab_size = logits.shape - # Apply logits processors (if any). - logits = _apply_logits_processors(logits, sampling_metadata) - # Prepare sampling tensors with pinned memory to avoid blocking. (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) = SamplingTensors.from_sampling_metadata( @@ -124,14 +77,6 @@ def forward( prompt_logprobs, sample_logprobs) -def _prune_hidden_states( - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> torch.Tensor: - return hidden_states.index_select(0, - sampling_metadata.selected_token_indices) - - def _get_bin_counts_and_mask( tokens: torch.Tensor, vocab_size: int, @@ -149,30 +94,6 @@ def _get_bin_counts_and_mask( return bin_counts, mask -def _apply_logits_processors( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> torch.Tensor: - logits_row_idx = 0 - found_logits_processors = False - for seq_ids, sampling_params in sampling_metadata.seq_groups: - logits_processors = sampling_params.logits_processors - if logits_processors: - found_logits_processors = True - for seq_id in seq_ids: - logits_row = logits[logits_row_idx] - token_ids = sampling_metadata.seq_data[seq_id].output_token_ids - for logits_processor in logits_processors: - logits_row = logits_processor(token_ids, logits_row) - logits[logits_row_idx] = logits_row - logits_row_idx += 1 - else: - logits_row_idx += len(seq_ids) - if found_logits_processors: - assert logits_row_idx == logits.shape[0] - return logits - - def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, output_tokens_tensor: torch.Tensor, presence_penalties: torch.Tensor, diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index cbf472750e294..968b9ebba87b2 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -34,6 +34,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -295,7 +296,8 @@ def __init__(self, self.linear_method = linear_method self.model = BaiChuanModel(config, position_embedding, linear_method) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -308,13 +310,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 0548b2b140b1b..851c475206661 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -30,6 +30,7 @@ LinearMethodBase, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -273,7 +274,8 @@ def __init__( self.linear_method = linear_method self.transformer = BloomModel(config, linear_method) self.lm_head_weight = self.transformer.word_embeddings.weight - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -286,13 +288,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 1c5dcfacaff2b..15e7de03b61f1 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -17,6 +17,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -332,7 +333,8 @@ def __init__( self.linear_method = linear_method self.transformer = ChatGLMModel(config, linear_method) self.lm_head_weight = self.transformer.output_layer.weight - self.sampler = Sampler(config.padded_vocab_size) + self.logits_processor = LogitsProcessor(config.padded_vocab_size) + self.sampler = Sampler() def forward( self, @@ -345,13 +347,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 13c080cb02774..eff93e706f5dc 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -38,6 +38,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -372,7 +373,8 @@ def __init__( self.linear_method = linear_method self.model = DeepseekModel(config, linear_method) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -385,13 +387,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: Optional[torch.Tensor], + logits: Optional[torch.Tensor], sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 3c148be5b10f4..7626dbe62293f 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -34,6 +34,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -373,7 +374,8 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -390,13 +392,18 @@ def forward( ) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 386a36cf492d6..fd3dbe798cd8e 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -30,6 +30,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -281,7 +282,8 @@ def __init__( self.config = config self.linear_method = linear_method self.model = GemmaModel(config, linear_method) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -295,13 +297,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.model.embed_tokens.weight, + hidden_states, sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.model.embed_tokens.weight, - hidden_states, sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 3f7b21e5a4133..263727cac19ff 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -30,6 +30,7 @@ LinearMethodBase, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -216,7 +217,8 @@ def __init__( self.linear_method = linear_method self.transformer = GPT2Model(config, linear_method) self.lm_head_weight = self.transformer.wte.weight - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -229,12 +231,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, hidden_states, + next_tokens = self.sampler(self.lm_head_weight, logits, sampling_metadata) return next_tokens diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 5c30d47d93e36..65caabae60daa 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -31,6 +31,7 @@ LinearMethodBase, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -237,7 +238,8 @@ def __init__( self.linear_method = linear_method self.transformer = GPTBigCodeModel(config, linear_method) self.lm_head_weight = self.transformer.wte.weight - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -250,13 +252,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 93dce7b67a7a5..c956a12f3e46e 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -30,6 +30,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -224,7 +225,8 @@ def __init__( config.n_embd, bias=True, ) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -237,13 +239,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata, self.lm_head.bias) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata, self.lm_head.bias) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 98107350e60b9..db2173936e7d9 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -30,6 +30,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -238,7 +239,8 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -251,13 +253,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.embed_out.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.embed_out.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 7b2215ef4bda5..93026fc01f0f0 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -14,6 +14,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -250,7 +251,8 @@ def __init__( self.linear_method = linear_method self.model = InternLM2Model(config, linear_method) self.output = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -263,13 +265,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.output.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.output.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 4c163dfdab537..757b75129845c 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -37,6 +37,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) @@ -325,7 +326,11 @@ def __init__( # compatibility if not lora_config else lora_config.lora_vocab_padding_size, ) - self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.sampler = Sampler() def forward( self, @@ -338,13 +343,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index d47834e519697..68a3a298444ae 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -37,6 +37,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) @@ -369,7 +370,9 @@ def __init__( # compatibility if not lora_config else lora_config.lora_vocab_padding_size, ) - self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -382,13 +385,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: Optional[torch.Tensor], + logits: Optional[torch.Tensor], sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 25c7f1978c0dc..b4dfc439d50e9 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -39,6 +39,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -344,7 +345,8 @@ def __init__( self.linear_method = linear_method self.model = MixtralModel(config, linear_method) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -357,13 +359,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: Optional[torch.Tensor], + logits: Optional[torch.Tensor], sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 16ecac3d0529a..7a2568817858c 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -13,6 +13,7 @@ LinearMethodBase, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -259,7 +260,8 @@ def __init__( self.transformer = MPTModel(config, linear_method) self.lm_head_weight = self.transformer.wte.weight - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -272,13 +274,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/neuron/llama.py b/vllm/model_executor/models/neuron/llama.py index e2856da99d9b1..32c43c4944fac 100644 --- a/vllm/model_executor/models/neuron/llama.py +++ b/vllm/model_executor/models/neuron/llama.py @@ -7,6 +7,7 @@ from transformers import LlamaConfig from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput @@ -25,7 +26,8 @@ def __init__( self.config = config self.linear_method = linear_method self.model = None - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -45,13 +47,18 @@ def forward( start_ids=seq_ids.flatten()) return logits + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.model.chkpt_model.lm_head, + hidden_states, sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.model.chkpt_model.lm_head, - hidden_states, sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/neuron/mistral.py b/vllm/model_executor/models/neuron/mistral.py index a302cce30abab..24fc0fa0aacab 100755 --- a/vllm/model_executor/models/neuron/mistral.py +++ b/vllm/model_executor/models/neuron/mistral.py @@ -6,6 +6,7 @@ from transformers import MistralConfig from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput @@ -26,7 +27,8 @@ def __init__( self.linear_method = linear_method self.model = None self.lm_head = None - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -48,13 +50,18 @@ def forward( start_ids=seq_ids) return logits + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.model.chkpt_model.lm_head, + hidden_states, sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.model.chkpt_model.lm_head, - hidden_states, sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 2b0a420e82faf..19f2be6da8ed3 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -51,6 +51,7 @@ RowParallelLinear, ) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -336,7 +337,8 @@ def __init__(self, self.lm_head_weight = (self.model.transformer.wte.weight if config.weight_tying else self.model.transformer.ff_out.weight) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -353,13 +355,18 @@ def forward( ) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights( diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 782f43ce265bd..a12f63b58f52b 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -31,6 +31,7 @@ QKVParallelLinear, ReplicatedLinear, RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -292,7 +293,8 @@ def __init__( self.linear_method = linear_method self.model = OPTModel(config, linear_method) self.lm_head_weight = self.model.decoder.embed_tokens.weight - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -305,13 +307,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 6039b1cdc3534..86428e320e0f7 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -18,6 +18,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -256,7 +257,8 @@ def __init__( self.linear_method = linear_method self.model = OrionModel(config, linear_method) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -269,13 +271,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 039dc7a9b7675..ef70c823dc905 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -49,6 +49,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -240,7 +241,8 @@ def __init__(self, self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, bias=True) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -254,14 +256,18 @@ def forward( return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata, self.lm_head.bias) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - head = self.lm_head - next_tokens = self.sampler(head.weight, hidden_states, - sampling_metadata, head.bias) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index d4d5a4e8bb9a5..61ac2c6c605c6 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -19,6 +19,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -230,7 +231,8 @@ def __init__( self.linear_method = linear_method self.transformer = QWenModel(config, linear_method) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -243,13 +245,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 12e0feddcb7f1..6698f01b7c701 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -37,6 +37,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -300,11 +301,15 @@ def __init__( self.linear_method = linear_method self.model = Qwen2Model(config, linear_method) - if not config.tie_word_embeddings: + if config.tie_word_embeddings: + self.lm_head_weight = self.model.embed_tokens.weight + else: self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.lm_head_weight = self.lm_head.weight - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -317,17 +322,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - if self.config.tie_word_embeddings: - lm_head_weight = self.model.embed_tokens.weight - else: - lm_head_weight = self.lm_head.weight - next_tokens = self.sampler(lm_head_weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index c66f327beee7a..7624ca89ee670 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -33,6 +33,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -238,7 +239,8 @@ def __init__( self.linear_method = linear_method self.model = StableLMEpochModel(config, linear_method) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -251,13 +253,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index cfbb1bdb7909e..e418951a633ab 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -32,6 +32,7 @@ LinearMethodBase, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) @@ -254,7 +255,9 @@ def __init__(self, padding_size=DEFAULT_VOCAB_PADDING_SIZE, ) self.lm_head_weight = self.lm_head.weight - self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -267,13 +270,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: Optional[torch.Tensor], + logits: Optional[torch.Tensor], sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index cfccbbb20adc5..347b9380f1113 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -613,9 +613,16 @@ def execute_model( input_metadata=input_metadata, ) + # Compute the logits. + logits = self.model.compute_logits(hidden_states, sampling_metadata) + + # Only perform sampling in the driver worker. + if not sampling_metadata.perform_sampling: + return None + # Sample the next token. output = self.model.sample( - hidden_states=hidden_states, + logits=logits, sampling_metadata=sampling_metadata, ) return output From 523e30ea0c5abcb447763dcd9a77b54d5c5f3239 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Wed, 20 Mar 2024 17:59:52 -0700 Subject: [PATCH 05/10] [BugFix] Hot fix in setup.py for neuron build (#3537) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 67575a0e04bf0..47cac5996f816 100644 --- a/setup.py +++ b/setup.py @@ -168,7 +168,7 @@ def build_extensions(self) -> None: def _is_cuda() -> bool: - return torch.version.cuda is not None + return torch.version.cuda is not None and not _is_neuron() def _is_hip() -> bool: From 6ebd02bdef1eb08f9a7a11253a26cd49b5fb6d2d Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 21 Mar 2024 07:20:04 +0100 Subject: [PATCH 06/10] [PREFIX CACHING FOLLOW UP] OrderedDict-based evictor (#3431) Co-authored-by: rsnm2 Co-authored-by: Luka --- vllm/core/evictor.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/vllm/core/evictor.py b/vllm/core/evictor.py index 9f401cba3fbea..92515468a8a1f 100644 --- a/vllm/core/evictor.py +++ b/vllm/core/evictor.py @@ -1,5 +1,5 @@ import enum -from typing import Dict +from typing import OrderedDict from abc import ABC, abstractmethod, abstractproperty from vllm.block import PhysicalTokenBlock @@ -58,27 +58,26 @@ class LRUEvictor(Evictor): """ def __init__(self): - self.free_table: Dict[int, PhysicalTokenBlock] = {} + self.free_table: OrderedDict[int, PhysicalTokenBlock] = OrderedDict() def __contains__(self, block_hash: int) -> bool: return block_hash in self.free_table - # TODO: The performance of this evict function can be optimized further. def evict(self) -> PhysicalTokenBlock: if len(self.free_table) == 0: raise ValueError("No usable cache memory left") - free_blocks = self.free_table.values() - # Get evicted block - evicted_block: PhysicalTokenBlock = next(iter(free_blocks)) - - for block in free_blocks: - if (block.last_accessed < evicted_block.last_accessed - or block.last_accessed == evicted_block.last_accessed and - block.num_hashed_tokens > evicted_block.num_hashed_tokens): + evicted_block = next(iter(self.free_table.values())) + # The blocks with the lowest timestamps should be placed consecutively + # at the start of OrderedDict. Loop through all these blocks to + # find the one with maximum number of hashed tokens. + for _, block in self.free_table.items(): + if evicted_block.last_accessed < block.last_accessed: + break + if evicted_block.num_hashed_tokens < block.num_hashed_tokens: evicted_block = block - del self.free_table[evicted_block.block_hash] + self.free_table.pop(evicted_block.block_hash) evicted_block.computed = False return evicted_block @@ -91,7 +90,7 @@ def remove(self, block_hash: int) -> PhysicalTokenBlock: raise ValueError( "Attempting to remove block that's not in the evictor") block: PhysicalTokenBlock = self.free_table[block_hash] - del self.free_table[block_hash] + self.free_table.pop(block_hash) return block @property From 3bbff9e5ab964cf04897cebfc5e886a1113fef01 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Thu, 21 Mar 2024 17:49:06 +0900 Subject: [PATCH 07/10] Fix 1D query issue from `_prune_hidden_states` (#3539) --- vllm/model_executor/layers/logits_processor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index baa113c342c28..e9d2a2708c1bb 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -77,7 +77,6 @@ def _prune_hidden_states( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) return hidden_states.index_select(0, sampling_metadata.selected_token_indices) From 4c07dd28c0ef8642735222e077935b55f4c98017 Mon Sep 17 00:00:00 2001 From: Lalit Pradhan <136452006+grandiose-pizza@users.noreply.github.com> Date: Thu, 21 Mar 2024 13:45:24 +0400 Subject: [PATCH 08/10] =?UTF-8?q?[=F0=9F=9A=80=20Ready=20to=20be=20merged]?= =?UTF-8?q?=20Added=20support=20for=20Jais=20models=20(#3183)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 1 + docs/source/models/supported_models.rst | 6 +- vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/gpt2.py | 3 +- vllm/model_executor/models/jais.py | 351 ++++++++++++++++++++ vllm/transformers_utils/config.py | 1 + vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/jais.py | 234 +++++++++++++ 8 files changed, 596 insertions(+), 3 deletions(-) create mode 100644 vllm/model_executor/models/jais.py create mode 100644 vllm/transformers_utils/configs/jais.py diff --git a/README.md b/README.md index f57c3f7862ed1..9d3f742225ea8 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi - GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.) - InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.) - InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.) +- Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.) - LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) - Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.) - Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 4019e0bbd90fb..af4eb81646ebe 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -66,7 +66,11 @@ Alongside each architecture, we include some popular models that use it. * - :code:`InternLM2ForCausalLM` - InternLM2 - :code:`internlm/internlm2-7b`, :code:`internlm/internlm2-chat-7b`, etc. - - + - + * - :code:`JAISLMHeadModel` + - Jais + - :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc. + - * - :code:`LlamaForCausalLM` - LLaMA, LLaMA-2, Vicuna, Alpaca, Yi - :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc. diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index bc3b6a582d53d..069830c4d7cb5 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -27,6 +27,7 @@ "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), + "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), # For decapoda-research/llama-* "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 263727cac19ff..e75dda750cb26 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -242,8 +242,7 @@ def sample( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, logits, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py new file mode 100644 index 0000000000000..74c8e7f963026 --- /dev/null +++ b/vllm/model_executor/models/jais.py @@ -0,0 +1,351 @@ +# coding=utf-8 +# Adapted from +# https://huggingface.co/core42/jais-30b-chat-v3/blob/main/modeling_jais.py +# Copyright 2023 The vLLM team. +# Copyright 2023 the Jais authors and HuggingFace Inc. team. All rights +# reserved. +# Copyright 2023 Cerebras Systems. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Jais model compatible with HuggingFace weights.""" + +import math +from typing import List, Optional, Tuple + +import torch +from torch import nn +from vllm.transformers_utils.configs import JAISConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size, + get_tensor_model_parallel_rank, +) +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) +from vllm.sequence import SamplerOutput +from vllm.model_executor.sampling_metadata import SamplingMetadata + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class SwiGLUActivation(nn.Module): + + def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + return x1 * nn.functional.silu(x2) + + +def _get_alibi_slopes(n): + + def get_slopes_power_of_2(n): + start = 2**(-(2**-(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2**math.floor(math.log2(n)) + return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes( + 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) + + +class JAISAttention(nn.Module): + + def __init__( + self, + config: JAISConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.hidden_size = config.hidden_size + total_num_heads = config.num_attention_heads + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + assert total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = total_num_heads // tensor_model_parallel_world_size + self.head_dim = self.hidden_size // total_num_heads + if hasattr(config, "scale_qk_dot_by_d"): + config.mup_scale_qk_dot_by_d = config.scale_qk_dot_by_d + self.attn_scale_power = 1.0 if config.mup_scale_qk_dot_by_d else 0.5 + self.scale = self.head_dim**-self.attn_scale_power + + self.c_attn = QKVParallelLinear( + self.hidden_size, + self.head_dim, + total_num_heads, + bias=True, + linear_method=linear_method, + ) + self.c_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + linear_method=linear_method, + ) + + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = _get_alibi_slopes(total_num_heads) + alibi_slopes = alibi_slopes[head_start:head_end] + self.attn = Attention( + self.num_heads, + self.head_dim, + scale=self.scale, + alibi_slopes=alibi_slopes, + ) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + key_cache, value_cache = kv_cache + attn_output = self.attn(q, k, v, key_cache, value_cache, + input_metadata) + attn_output, _ = self.c_proj(attn_output) + return attn_output + + +class JAISMLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: JAISConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + hidden_size = config.hidden_size + self.swiglu = config.activation_function == "swiglu" + self.c_fc = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + linear_method=linear_method, + ) + self.c_fc2 = (ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + linear_method=linear_method, + ) if self.swiglu else None) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=True, + linear_method=linear_method, + ) + + self.act = SwiGLUActivation() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.swiglu: + hidden_states2, _ = self.c_fc2(hidden_states) + hidden_states, _ = self.c_fc(hidden_states) + hidden_states = (self.act(hidden_states, hidden_states2) + if self.swiglu else self.act(hidden_states)) + hidden_states, _ = self.c_proj(hidden_states) + return hidden_states + + +class JAISBlock(nn.Module): + + def __init__( + self, + config: JAISConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + hidden_size = config.hidden_size + inner_dim = (config.n_inner if config.n_inner is not None else 4 * + hidden_size) + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = JAISAttention(config, linear_method) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = JAISMLP(inner_dim, config, linear_method) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn( + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + # residual connection + hidden_states = attn_output + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + return hidden_states + + +class JAISModel(nn.Module): + + def __init__( + self, + config: JAISConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + assert not config.add_cross_attention + assert not config.scale_attn_by_inverse_layer_idx + assert not config.reorder_and_upcast_attn + self.embed_dim = config.hidden_size + self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) + self.wpe = (nn.Embedding(config.max_position_embeddings, + self.embed_dim) + if config.position_embedding_type != "alibi" else None) + if hasattr(config, "embeddings_scale"): + self.embeddings_scale = config.embeddings_scale + else: + self.embeddings_scale = config.mup_embeddings_scale + self.h = nn.ModuleList([ + JAISBlock(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + inputs_embeds = self.wte(input_ids) + if self.wpe is not None: + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + else: + hidden_states = inputs_embeds + hidden_states *= torch.tensor(float(self.embeddings_scale), + dtype=hidden_states.dtype) + + for i in range(len(self.h)): + layer = self.h[i] + hidden_states = layer(hidden_states, kv_caches[i], input_metadata) + + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class JAISLMHeadModel(nn.Module): + + def __init__( + self, + config: JAISConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.linear_method = linear_method + self.transformer = JAISModel(config, linear_method) + self.lm_head_weight = self.transformer.wte.weight + if hasattr(config, "width_scale"): + self.output_logits_scale = config.width_scale + else: + self.output_logits_scale = (config.mup_output_alpha * + config.mup_width_scale) + self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size, + scale=self.output_logits_scale) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.transformer(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "lm_head.weight" in name: + # GPT-2 ties the weights of the embedding layer and the final + # linear layer. + continue + if ".attn.bias" in name or ".attn.masked_bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + if "relative_pe" in name: + continue + if not name.startswith("transformer."): + name = "transformer." + name + param = params_dict[name] + # The HF's GPT-2 implementation uses Conv1D instead of Linear. + # Because of this, we need to transpose the weights. + # Note(zhuohan): the logic below might break quantized models. + for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: + if conv1d_weight_name not in name: + continue + if not name.endswith(".weight"): + continue + loaded_weight = loaded_weight.t() + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) \ No newline at end of file diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 5e1f0439aec51..081e81768b236 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -10,6 +10,7 @@ "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) "starcoder2": Starcoder2Config, + "jais": JAISConfig, } diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 4966526f15184..150ee2ce97ad5 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -5,10 +5,12 @@ # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.starcoder2 import Starcoder2Config +from vllm.transformers_utils.configs.jais import JAISConfig __all__ = [ "ChatGLMConfig", "MPTConfig", "RWConfig", "Starcoder2Config", + "JAISConfig", ] diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py new file mode 100644 index 0000000000000..94f438716f8bf --- /dev/null +++ b/vllm/transformers_utils/configs/jais.py @@ -0,0 +1,234 @@ +# coding=utf-8 +# Copyright 2023 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# Copyright 2023 Cerebras Systems. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAIS configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class JAISConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a + [`JAISModel`]. It is used to instantiate a JAIS model according to the + specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used + to control the model outputs. Read the documentation from + [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the JAIS model. Defines the number of different + tokens that can be represented by the + `inputs_ids` passed when calling [`JAISModel`]. + n_positions (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used + with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the + Transformer encoder. + n_inner (`int`, *optional*, defaults to None): + Dimensionality of the inner feed-forward layers. `None` will set + it to 4 times n_embd + activation_function (`str`, *optional*, defaults to `"gelu"`): + Activation function, to be selected in the list + `["relu", "silu", "gelu", "tanh", "gelu_new", "swiglu"]`. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in + the embeddings, encoder, and pooler. + embd_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for + initializing all weight matrices. + scale_attn_weights (`bool`, *optional*, defaults to `True`): + Scale attention weights by dividing by sqrt(hidden_size).. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values + attentions (not used by all models). + scale_attn_by_inverse_layer_idx (`bool`, *optional*, + defaults to `False`): + Whether to additionally scale attention weights by + `1 / layer_idx + 1`. + reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): + Whether to scale keys (K) prior to computing attention + (dot-product) + and upcast attention dot-product/softmax to float() when training + with mixed precision. + position_embedding_type (`str`, *optional*, defaults to `"learned"`): + Positional embedding can be either `"alibi"` or `"learned"`. + mup_width_scale (`float`, *optional*, defaults to 1.0): + muP parameter to scale learning rate and initializers. Calculated + as (`d_model,0 / d_model`), where + `d_model` is the model's width and `d_model,0` is the proxy + model's width. + mup_embeddings_scale (`float`, *optional*, defaults to 1.0): + muP parameter to scale token and position embeddings. + mup_output_alpha (`float`, *optional*, defaults to 1.0): + muP parameter to scale output logits + (`output_logits_scale = mup_output_alpha * mup_width_scale`). + mup_scale_qk_dot_by_d (`bool`, *optional*, defaults to `False`): + Scale attention weights by dividing by hidden_size instead of + sqrt(hidden_size). Need to set scale_attn_weights to `True` as + well. + alibi_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for ALiBi + embeddings. Currently only supports linear + scaling strategy. Can specify either the scaling `factor` (must be + a float greater than 1) for fixed scaling + or `train_seq_len` for dynamic scaling on input samples with + sequence length > `train_seq_len`. The expected + formats are `{"type": strategy name, "factor": scaling factor}` or + `{"type": strategy name, + "train_seq_len": training sequence length}`. + architectures (`List`, *optional*, defaults to ['JAISLMHeadModel']): + architecture names for Jais. + + Example: + + ```python + >>> from transformers import JAISConfig, JAISModel + + >>> # Initializing a JAIS configuration + >>> configuration = JAISConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = JAISModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "jais" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "n_embd", + "max_position_embeddings": "n_positions", + "num_attention_heads": "n_head", + "num_hidden_layers": "n_layer", + } + + def __init__( + self, + vocab_size=50257, + n_positions=1024, + n_embd=768, + n_layer=12, + n_head=12, + n_inner=None, + activation_function="gelu_new", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + scale_attn_weights=True, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + scale_attn_by_inverse_layer_idx=False, + reorder_and_upcast_attn=False, + position_embedding_type="learned", + mup_width_scale=1.0, + mup_embeddings_scale=1.0, + mup_output_alpha=1.0, + mup_scale_qk_dot_by_d=False, + alibi_scaling=None, + architectures=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.n_inner = n_inner + self.activation_function = activation_function + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx + self.reorder_and_upcast_attn = reorder_and_upcast_attn + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + self.position_embedding_type = position_embedding_type + self.mup_width_scale = mup_width_scale + self.mup_embeddings_scale = mup_embeddings_scale + self.mup_output_alpha = mup_output_alpha + self.mup_scale_qk_dot_by_d = mup_scale_qk_dot_by_d + + self.alibi_scaling = alibi_scaling + self._alibi_scaling_validation() + if architectures is None: + architectures = ["JAISLMHeadModel"] + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + architectures=architectures, + **kwargs, + ) + + def _alibi_scaling_validation(self): + """ + Validate the `alibi_scaling` configuration. + """ + if self.alibi_scaling is None: + return + + if (not isinstance(self.alibi_scaling, dict) + or len(self.alibi_scaling) != 2): + raise ValueError( + "`alibi_scaling` must be a dictionary with two fields," + "`type` and `factor` or `type` and `train_seq_len`, " + f"got {self.alibi_scaling}") + alibi_scaling_type = self.alibi_scaling.get("type", None) + alibi_scaling_factor = self.alibi_scaling.get("factor", None) + alibi_dynamic_scaling = self.alibi_scaling.get("train_seq_len", None) + if alibi_scaling_type is None or alibi_scaling_type != "linear": + raise ValueError(f"`alibi_scaling`'s type field must be 'linear'," + f"got {alibi_scaling_type}") + if (alibi_scaling_factor is not None + and not isinstance(alibi_scaling_factor, float) + or alibi_scaling_factor <= 1.0): + raise ValueError( + f"`alibi_scaling`'s factor field must be a float > 1.0," + f"got {alibi_scaling_factor}") + if (alibi_dynamic_scaling is not None + and not isinstance(alibi_dynamic_scaling, int) + or alibi_dynamic_scaling <= 1): + raise ValueError( + f"`alibi_scaling`'s `train_seq_len` field must be an" + f"integer > 1, got {alibi_dynamic_scaling}") From 865732342b4e3b8a4ef38f28a2a5bdb87cf3f970 Mon Sep 17 00:00:00 2001 From: Roy Date: Thu, 21 Mar 2024 18:07:48 +0800 Subject: [PATCH 09/10] [Misc][Log] Add log for tokenizer length not equal to vocabulary size (#3500) --- vllm/engine/llm_engine.py | 8 ++++++++ vllm/entrypoints/openai/serving_engine.py | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 2280481cca9cb..b726cdd7a2048 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -169,6 +169,14 @@ def _init_tokenizer(self, **tokenizer_init_kwargs): self.tokenizer: BaseTokenizerGroup = get_tokenizer_group( self.parallel_config.tokenizer_pool_config, **init_kwargs) + if len(self.get_tokenizer()) != self.model_config.get_vocab_size(): + logger.warning( + f"The tokenizer's vocabulary size {len(self.get_tokenizer())}" + f" does not match the model's vocabulary size " + f"{self.model_config.get_vocab_size()}. This might " + f"cause an error in decoding. Please change config.json " + "to match the tokenizer's vocabulary size.") + def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 2db884945c491..976046beec245 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -68,6 +68,14 @@ async def _post_init(self): tokenizer_mode=engine_model_config.tokenizer_mode, trust_remote_code=engine_model_config.trust_remote_code) + if len(self.tokenizer) != engine_model_config.get_vocab_size(): + logger.warning( + f"The tokenizer's vocabulary size {len(self.tokenizer)}" + f" does not match the model's vocabulary size " + f"{engine_model_config.get_vocab_size()}. This might " + f"cause an error in decoding. Please change config.json " + "to match the tokenizer's vocabulary size.") + async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" model_cards = [ From c188ecb080501c5ccb34bbd6542978284c547122 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 21 Mar 2024 07:58:12 -0700 Subject: [PATCH 10/10] [Misc] Bump up transformers to v4.39.0 & Remove StarCoder2Config (#3551) Co-authored-by: Roy Co-authored-by: Roger Meier --- requirements-rocm.txt | 2 +- requirements.txt | 2 +- vllm/model_executor/models/starcoder2.py | 8 +-- vllm/transformers_utils/config.py | 10 ---- vllm/transformers_utils/configs/__init__.py | 2 - vllm/transformers_utils/configs/starcoder2.py | 55 ------------------- 6 files changed, 3 insertions(+), 76 deletions(-) delete mode 100644 vllm/transformers_utils/configs/starcoder2.py diff --git a/requirements-rocm.txt b/requirements-rocm.txt index c30479e40f521..07d94cd94f5fa 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -7,7 +7,7 @@ ray >= 2.9 sentencepiece # Required for LLaMA tokenizer. numpy tokenizers>=0.15.0 -transformers >= 4.38.0 # Required for Gemma. +transformers >= 4.39.0 # Required for StarCoder2. fastapi uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. diff --git a/requirements.txt b/requirements.txt index c9a5bd6619402..e136defad4943 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ ray >= 2.9 sentencepiece # Required for LLaMA tokenizer. numpy torch == 2.1.2 -transformers >= 4.38.0 # Required for Gemma. +transformers >= 4.39.0 # Required for StarCoder2. xformers == 0.0.23.post1 # Required for CUDA 12.1. fastapi uvicorn[standard] diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index e418951a633ab..e5003361bdf2a 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -22,6 +22,7 @@ import torch from torch import nn +from transformers import Starcoder2Config from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -42,13 +43,6 @@ hf_model_weights_iterator) from vllm.sequence import SamplerOutput -try: - from transformers import Starcoder2Config -except ImportError: - # fallback to PretrainedConfig - # NOTE: Please install transformers from source or use transformers>=4.39.0 - from transformers import PretrainedConfig as Starcoder2Config - KVCache = Tuple[torch.Tensor, torch.Tensor] diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 081e81768b236..dc226248910e2 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -9,7 +9,6 @@ "mpt": MPTConfig, "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) - "starcoder2": Starcoder2Config, "jais": JAISConfig, } @@ -18,15 +17,6 @@ def get_config(model: str, trust_remote_code: bool, revision: Optional[str] = None, code_revision: Optional[str] = None) -> PretrainedConfig: - # FIXME(woosuk): This is a temporary fix for StarCoder2. - # Remove this when the model is supported by HuggingFace transformers. - if "bigcode" in model and "starcoder2" in model: - config_class = _CONFIG_REGISTRY["starcoder2"] - config = config_class.from_pretrained(model, - revision=revision, - code_revision=code_revision) - return config - try: config = AutoConfig.from_pretrained( model, diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 150ee2ce97ad5..6fed2fab8c438 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -4,13 +4,11 @@ # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig -from vllm.transformers_utils.configs.starcoder2 import Starcoder2Config from vllm.transformers_utils.configs.jais import JAISConfig __all__ = [ "ChatGLMConfig", "MPTConfig", "RWConfig", - "Starcoder2Config", "JAISConfig", ] diff --git a/vllm/transformers_utils/configs/starcoder2.py b/vllm/transformers_utils/configs/starcoder2.py deleted file mode 100644 index 2879cd0445275..0000000000000 --- a/vllm/transformers_utils/configs/starcoder2.py +++ /dev/null @@ -1,55 +0,0 @@ -from transformers import PretrainedConfig - - -class Starcoder2Config(PretrainedConfig): - model_type = "starcoder2" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=49152, - hidden_size=3072, - intermediate_size=12288, - num_hidden_layers=30, - num_attention_heads=24, - num_key_value_heads=2, - hidden_act="gelu_pytorch_tanh", - max_position_embeddings=4096, - initializer_range=0.018042, - norm_epsilon=1e-5, - use_cache=True, - bos_token_id=50256, - eos_token_id=50256, - rope_theta=10000.0, - sliding_window=None, - attention_dropout=0.0, - residual_dropout=0.0, - embedding_dropout=0.0, - use_bias=True, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.sliding_window = sliding_window - self.use_bias = use_bias - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.norm_epsilon = norm_epsilon - self.use_cache = use_cache - self.rope_theta = rope_theta - self.attention_dropout = attention_dropout - self.residual_dropout = residual_dropout - self.embedding_dropout = embedding_dropout - - super().__init__( - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - **kwargs, - ) - if self.architectures is None: - self.architectures = ['Starcoder2ForCausalLM']