diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 17f4c33670821..6d052d0f7f4a4 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -49,6 +49,9 @@ steps: - label: Samplers Test command: pytest -v -s samplers +- label: LogitsProcessor Test + command: pytest -v -s test_logits_processor.py + - label: Worker Test command: pytest -v -s worker diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 30a8ad03c8ada..38560c251696a 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -13,6 +13,7 @@ import vllm from vllm.config import LoRAConfig from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.model_loader import get_model from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, @@ -85,7 +86,8 @@ def dummy_model() -> nn.Module: ("outact", nn.Sigmoid()), # Special handling for lm_head & sampler ("lm_head", ParallelLMHead(512, 10)), - ("sampler", Sampler(512)) + ("logits_processor", LogitsProcessor(512)), + ("sampler", Sampler()) ])) model.config = MagicMock() return model @@ -110,7 +112,8 @@ def dummy_model_gate_up() -> nn.Module: ("outact", nn.Sigmoid()), # Special handling for lm_head & sampler ("lm_head", ParallelLMHead(512, 10)), - ("sampler", Sampler(512)) + ("logits_processor", LogitsProcessor(512)), + ("sampler", Sampler()) ])) model.config = MagicMock() return model diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 46f054c5b84ef..7dfc3952016f5 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -13,14 +13,14 @@ QKVParallelLinearWithLora, VocabParallelEmbeddingWithLoRA, RowParallelLinearWithLoRA, - SamplerWithLoRA, + LogitsProcessorWithLoRA, LoRAMapping, BaseLayerWithLoRA, ) from vllm.lora.models import (LoRALayerWeights, convert_mapping, PackedLoRALayerWeights) from vllm.config import LoRAConfig -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear, @@ -394,7 +394,7 @@ def create_random_embedding_layer(): @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_lm_head_sampler(dist_init, num_loras, device) -> None: +def test_lm_head_logits_processor(dist_init, num_loras, device) -> None: torch.set_default_device(device) max_loras = 8 @@ -402,28 +402,29 @@ def test_lm_head_sampler(dist_init, num_loras, device) -> None: max_lora_rank=8, lora_dtype=torch.float16) - def create_random_sampler_layer(): + def _pretest(): linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size, 1024, 32000) linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data[:, 32000:] = 0 - sampler = Sampler(32000 + lora_config.lora_extra_vocab_size, 32000) - lora_sampler = SamplerWithLoRA(sampler, 1024, linear.weight.dtype, - linear.weight.device) - lora_sampler.create_lora_weights(max_loras, lora_config) + logits_processor = LogitsProcessor( + 32000 + lora_config.lora_extra_vocab_size, 32000) + lora_logits_processor = LogitsProcessorWithLoRA( + logits_processor, 1024, linear.weight.dtype, linear.weight.device) + lora_logits_processor.create_lora_weights(max_loras, lora_config) - return linear, sampler, lora_sampler + return linear, logits_processor, lora_logits_processor for i in range(10): set_random_seed(i) id_to_index = get_random_id_to_index(num_loras, max_loras) - linear, sampler, lora_sampler = create_random_sampler_layer() + linear, logits_processor, lora_logits_processor = _pretest() # NOTE: all the generated loras share the same embeddings tensor. lora_dict, _ = populate_loras( id_to_index, - layer=lora_sampler, + layer=lora_logits_processor, layer_weights=linear.weight, generate_embeddings_tensor=1024, ) @@ -447,34 +448,37 @@ def create_random_sampler_layer(): 32000, lora_config.lora_extra_vocab_size, ) - lora_sampler.set_mapping(*mapping_info, ) + lora_logits_processor.set_mapping(*mapping_info, ) - lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs), - embedding=linear.weight, - embedding_bias=None) + lora_result = lora_logits_processor._get_logits( + hidden_states=torch.cat(inputs), + embedding=linear.weight, + embedding_bias=None) original_weight = linear.weight.clone() - linear.weight[sampler.org_vocab_size:sampler.org_vocab_size + + linear.weight[logits_processor. + org_vocab_size:logits_processor.org_vocab_size + embeddings_tensor_len] = embeddings_tensor - sampler.org_vocab_size = 32000 + lora_config.lora_extra_vocab_size + logits_processor.org_vocab_size = (32000 + + lora_config.lora_extra_vocab_size) expected_results = [] for input_, lora_id in zip(inputs, prompt_mapping): lora = lora_dict[lora_id] - result = sampler._get_logits(hidden_states=input_, - embedding=linear.weight, - embedding_bias=None) + result = logits_processor._get_logits(hidden_states=input_, + embedding=linear.weight, + embedding_bias=None) result[:, 32000 + embeddings_tensor_len:] = float("-inf") result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling expected_results.append(result) expected_result = torch.cat(expected_results) - sampler.org_vocab_size = 32000 + logits_processor.org_vocab_size = 32000 # Check that resetting the lora weights succeeds for slot_idx in range(max_loras): - lora_sampler.reset_lora(slot_idx) + lora_logits_processor.reset_lora(slot_idx) inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=[0], @@ -488,14 +492,16 @@ def create_random_sampler_layer(): mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, 32000, lora_config.lora_extra_vocab_size) - lora_sampler.set_mapping(*mapping_info, ) - - lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs), - embedding=original_weight, - embedding_bias=None)[:, :32000] - expected_result = sampler._get_logits(hidden_states=torch.cat(inputs), - embedding=original_weight, - embedding_bias=None) + lora_logits_processor.set_mapping(*mapping_info, ) + + lora_result = lora_logits_processor._get_logits( + hidden_states=torch.cat(inputs), + embedding=original_weight, + embedding_bias=None)[:, :32000] + expected_result = logits_processor._get_logits( + hidden_states=torch.cat(inputs), + embedding=original_weight, + embedding_bias=None) rtol, atol = TOLERANCES[lora_result.dtype] assert torch.allclose(lora_result, diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index b0c6e1c09eebc..92aec831d02e2 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -15,17 +15,12 @@ class MockLogitsSampler(Sampler): - def __init__(self, vocab_size: int, fake_logits: torch.Tensor): - super().__init__(vocab_size=vocab_size) + def __init__(self, fake_logits: torch.Tensor): + super().__init__() self.fake_logits = fake_logits def forward(self, *args, **kwargs): - with patch( - "vllm.model_executor.layers.sampler._prune_hidden_states", - lambda x, y: x), patch( - "vllm.model_executor.layers.sampler.Sampler._get_logits", - lambda *args, **kwargs: self.fake_logits): - return super().forward(*args, **kwargs) + return super().forward(*args, **kwargs) def _prepare_test( @@ -36,7 +31,7 @@ def _prepare_test( fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=input_tensor.dtype) - sampler = MockLogitsSampler(32000, fake_logits) + sampler = MockLogitsSampler(fake_logits) model_runner = ModelRunner(None, None, None, None, None) return input_tensor, fake_logits, sampler, model_runner @@ -70,9 +65,7 @@ def _do_sample( sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) - return sampler(embedding=None, - hidden_states=input_tensor, - sampling_metadata=sampling_metadata) + return sampler(logits=input_tensor, sampling_metadata=sampling_metadata) @pytest.mark.parametrize("seed", RANDOM_SEEDS) @@ -85,8 +78,8 @@ def test_sampler_all_greedy(seed: int, device: str): batch_size) sampling_params = SamplingParams(temperature=0) - sampler_output = _do_sample(batch_size, input_tensor, sampler, - model_runner, sampling_params) + sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, + sampling_params) expected = torch.argmax(fake_logits, dim=-1) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: @@ -111,8 +104,8 @@ def test_sampler_all_random(seed: int, device: str): temperature=1.0, n=random.randint(1, 10), ) - sampler_output = _do_sample(batch_size, input_tensor, sampler, - model_runner, sampling_params) + sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, + sampling_params) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: @@ -127,8 +120,7 @@ def test_sampler_all_random_seed(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) + _, fake_logits, sampler, model_runner = _prepare_test(batch_size) for i in range(batch_size): fake_logits[i, i] = 1e2 @@ -138,8 +130,8 @@ def test_sampler_all_random_seed(seed: int, device: str): n=random.randint(1, 10), seed=random.randint(0, 10000), ) - sampler_output = _do_sample(batch_size, input_tensor, sampler, - model_runner, sampling_params) + sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, + sampling_params) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: @@ -154,18 +146,17 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) + _, fake_logits, sampler, model_runner = _prepare_test(batch_size) sampling_params = SamplingParams( temperature=1.0, n=random.randint(1, 10), seed=random.randint(0, 10000), ) - first_sampler_output = _do_sample(batch_size, input_tensor, sampler, + first_sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params) - second_sampler_output = _do_sample(batch_size, input_tensor, sampler, + second_sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params) assert first_sampler_output == second_sampler_output @@ -179,15 +170,14 @@ def test_sampler_all_beam(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, _, sampler, model_runner = _prepare_test(batch_size) + _, fake_logits, sampler, model_runner = _prepare_test(batch_size) sampling_params = SamplingParams( temperature=0, best_of=2, use_beam_search=True, ) - _do_sample(batch_size, input_tensor, sampler, model_runner, - sampling_params) + _do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params) # no assertion here as I am not sure how to determine whether # the outputs are expected - in other words, this just tests # whether there are no exceptions in the sampler @@ -246,8 +236,7 @@ def test_sampler_mixed(seed: int, device: str): def test_sampling(model_runner: ModelRunner): sampling_metadata = model_runner._prepare_sample( seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) - sampler_output = sampler(embedding=None, - hidden_states=input_tensor, + sampler_output = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) for i, (sequence_output, metadata) in enumerate( @@ -294,48 +283,6 @@ def test_sampling(model_runner: ModelRunner): del model_runner -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_sampler_logits_processors(seed: int, device: str): - set_random_seed(seed) - torch.set_default_device(device) - batch_size = random.randint(1, 256) - input_tensor, _, sampler, model_runner = _prepare_test(batch_size) - - # This sample logits processor gives maximum score to the i-th token, - # where i is the length of the input sequence. - # We therefore expect the output token sequence to be [0, 1, 2, ...] - def pick_ith(token_ids, logits): - logits[len(token_ids)] = torch.finfo(logits.dtype).max - return logits - - seq_group_metadata_list = [] - prompt_lens = [] - for i in range(batch_size): - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: SequenceData([1, 2, 3])}, - sampling_params=SamplingParams(temperature=0, - logits_processors=[pick_ith]), - block_tables={0: [1]}, - )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - - sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens) - sampler_output = sampler(embedding=None, - hidden_states=input_tensor, - sampling_metadata=sampling_metadata) - for _, sequence_output in enumerate(sampler_output): - for idx, nth_output in enumerate(sequence_output.samples): - assert nth_output.output_token == idx - - del model_runner - - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_sampler_top_k_top_p(seed: int, device: str): @@ -352,7 +299,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): size=(batch_size, vocab_size), device=input_tensor.device, dtype=input_tensor.dtype) - sampler = MockLogitsSampler(32000, fake_logits) + sampler = MockLogitsSampler(fake_logits) model_runner = ModelRunner(None, None, None, None, None) generation_model = GenerationMixin() @@ -391,9 +338,7 @@ def mock_sample(probs, *args, **kwargs): return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs] with patch("vllm.model_executor.layers.sampler._sample", mock_sample): - sampler(embedding=None, - hidden_states=input_tensor, - sampling_metadata=sampling_metadata) + sampler(logits=fake_logits, sampling_metadata=sampling_metadata) hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone()) hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float) assert torch.allclose(hf_probs, sample_probs, atol=1e-5) diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py new file mode 100644 index 0000000000000..fe321520114f7 --- /dev/null +++ b/tests/test_logits_processor.py @@ -0,0 +1,94 @@ +import random +from typing import Tuple +from unittest.mock import patch + +import pytest +import torch + +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.utils import set_random_seed +from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata +from vllm.worker.model_runner import ModelRunner + + +class MockLogitsProcessor(LogitsProcessor): + + def __init__(self, vocab_size: int, scale: float, + fake_logits: torch.Tensor): + super().__init__(vocab_size=vocab_size, scale=scale) + self.fake_logits = fake_logits.clone() + + def forward(self, *args, **kwargs): + with patch( + "vllm.model_executor.layers.logits_processor._prune_hidden_states", + lambda x, y: x + ), patch( + "vllm.model_executor.layers.logits_processor.LogitsProcessor._get_logits", + lambda *args, **kwargs: self.fake_logits): + return super().forward(*args, **kwargs) + + +def _prepare_test( + batch_size: int +) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor, ModelRunner]: + vocab_size = 32000 + input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) + fake_logits = torch.full((batch_size, vocab_size), + 1e-2, + dtype=input_tensor.dtype) + logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits) + model_runner = ModelRunner(None, None, None, None, None) + return input_tensor, fake_logits, logits_processor, model_runner + + +RANDOM_SEEDS = list(range(128)) +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + + +@pytest.mark.parametrize("seed", RANDOM_SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_logits_processors(seed: int, device: str): + set_random_seed(seed) + torch.set_default_device(device) + batch_size = random.randint(1, 256) + input_tensor, fake_logits, logits_processor, model_runner = _prepare_test( + batch_size) + + # This sample logits processor gives infinite score to the i-th token, + # where i is the length of the input sequence. + # We therefore expect the output token sequence to be [0, 1, 2, ...] + def pick_ith(token_ids, logits): + logits[len(token_ids)] = float("inf") + return logits + + seq_group_metadata_list = [] + prompt_lens = [] + for i in range(batch_size): + seq_group_metadata_list.append( + SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: SequenceData([1, 2, 3])}, + sampling_params=SamplingParams(temperature=0, + logits_processors=[pick_ith]), + block_tables={0: [1]}, + )) + prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + + sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, + prompt_lens, + subquery_lens=prompt_lens) + logits_processor_output = logits_processor( + embedding=None, + hidden_states=input_tensor, + sampling_metadata=sampling_metadata) + + assert torch.isinf(logits_processor_output[:, 0]).all() + + fake_logits *= logits_processor.scale + assert torch.allclose(logits_processor_output[:, 1], fake_logits[:, 1], + 1e-4) + + del model_runner diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 99e6cdeee6364..f6cd1390d4bce 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -10,7 +10,6 @@ from vllm.config import LoRAConfig from vllm.lora.punica import add_lora, add_lora_slice, bgmv -from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, @@ -20,6 +19,7 @@ RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( @@ -783,11 +783,11 @@ def weight(self): return self.base_layer.weight -class SamplerWithLoRA(BaseLayerWithLoRA): +class LogitsProcessorWithLoRA(BaseLayerWithLoRA): def __init__( self, - base_layer: Sampler, + base_layer: LogitsProcessor, hidden_size: int, dtype: torch.dtype, device: torch.device, @@ -806,6 +806,10 @@ def logits_as_hidden_states(self): def vocab_size(self): return self.base_layer.vocab_size + @property + def scale(self): + return self.base_layer.scale + @property def org_vocab_size(self): return self.base_layer.org_vocab_size @@ -968,14 +972,14 @@ def from_layer( return layer -def from_layer_sampler( - layer: Sampler, +def from_layer_logits_processor( + layer: LogitsProcessor, lm_head: ParallelLMHead, max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None, -) -> SamplerWithLoRA: - ret = SamplerWithLoRA(layer, lm_head.embedding_dim, lm_head.weight.dtype, - lm_head.weight.device) +) -> LogitsProcessorWithLoRA: + ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim, + lm_head.weight.dtype, lm_head.weight.device) ret.create_lora_weights(max_loras, lora_config, model_config) return ret diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 6fe07b69b3203..d1bac7617e1d4 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -14,7 +14,7 @@ from vllm.utils import LRUCache, in_wsl from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, from_layer, - from_layer_sampler) + from_layer_logits_processor) from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule @@ -421,11 +421,14 @@ def _create_lora_modules(self): self.model.config)) # (yard1): TODO make this more robust if "lm_head" in module_name: - sampler_module = self.model.get_submodule("sampler") + logits_processor_module = self.model.get_submodule( + "logits_processor") new_module = replace_submodule( - self.model, "sampler", - from_layer_sampler(sampler_module, module, self.lora_slots, - self.lora_config, self.model.config)) + self.model, "logits_processor", + from_layer_logits_processor(logits_processor_module, + module, self.lora_slots, + self.lora_config, + self.model.config)) self.register_module(module_name, new_module) self._register_packed_modules(module_name) new_module.set_mapping(self.base_indices, self.sampler_indices, diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py new file mode 100644 index 0000000000000..baa113c342c28 --- /dev/null +++ b/vllm/model_executor/layers/logits_processor.py @@ -0,0 +1,106 @@ +"""A layer that compute logits from hidden_stats.""" +from typing import Optional + +import torch +import torch.nn as nn + +from vllm.utils import is_neuron + +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_gather) +from vllm.model_executor.sampling_metadata import SamplingMetadata + + +class LogitsProcessor(nn.Module): + """Process logits and apply logits processors from sampling metadata. + + This layer does the following: + 1. Gather logits from model hidden_states. + 2. Scale logits if needed. + 3. Apply logits processors (if any). + """ + + def __init__(self, + vocab_size: int, + org_vocab_size: Optional[int] = None, + scale: Optional[float] = 1.0) -> None: + """ + Args: + scale: A scaling factor to apply to the logits. + """ + super().__init__() + self.scale = scale + self.vocab_size = vocab_size + # Transformers-neuronx generate outputs as logits directly. + self.logits_as_hidden_states = is_neuron() + # original vocabulary size (without LoRA). + self.org_vocab_size = org_vocab_size or vocab_size + + def forward( + self, + embedding: torch.Tensor, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + embedding_bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if self.logits_as_hidden_states: + logits = hidden_states + else: + hidden_states = _prune_hidden_states(hidden_states, + sampling_metadata) + + # Get the logits for the next tokens. + logits = self._get_logits(hidden_states, embedding, embedding_bias) + + if logits is not None: + logits *= self.scale + + # Apply logits processors (if any). + logits = _apply_logits_processors(logits, sampling_metadata) + + return logits + + def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, + embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: + # Get the logits for the next tokens. + logits = torch.matmul(hidden_states, embedding.t()) + if embedding_bias is not None: + logits += embedding_bias + logits = tensor_model_parallel_gather(logits) + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[:, :self.org_vocab_size] + return logits + + +def _prune_hidden_states( + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + return hidden_states.index_select(0, + sampling_metadata.selected_token_indices) + + +def _apply_logits_processors( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + logits_row_idx = 0 + found_logits_processors = False + for seq_ids, sampling_params in sampling_metadata.seq_groups: + logits_processors = sampling_params.logits_processors + if logits_processors: + found_logits_processors = True + for seq_id in seq_ids: + logits_row = logits[logits_row_idx] + token_ids = sampling_metadata.seq_data[seq_id].output_token_ids + for logits_processor in logits_processors: + logits_row = logits_processor(token_ids, logits_row) + logits[logits_row_idx] = logits_row + logits_row_idx += 1 + else: + logits_row_idx += len(seq_ids) + if found_logits_processors: + assert logits_row_idx == logits.shape[0] + return logits diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index ac8336ca0f9ad..63e494586efb5 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -4,8 +4,6 @@ import torch import torch.nn as nn -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_gather) from vllm.model_executor.sampling_metadata import (SamplingMetadata, SamplingTensors) from vllm.sampling_params import SamplingParams, SamplingType @@ -13,7 +11,6 @@ SamplerOutput, SequenceData, SequenceGroupOutput, SequenceOutput) from vllm.model_executor.layers.ops.sample import (sample as sample_triton) -from vllm.utils import is_neuron class Sampler(nn.Module): @@ -31,58 +28,14 @@ class Sampler(nn.Module): parameters (e.g., sampling method, temperature, top-p, top-k, etc.). """ - def __init__(self, - vocab_size: int, - org_vocab_size: Optional[int] = None) -> None: - super().__init__() - self.vocab_size = vocab_size - # Transformers-neuronx generate outputs as logits directly. - self.logits_as_hidden_states = is_neuron() - # original vocabulary size (without LoRA). - self.org_vocab_size = org_vocab_size or vocab_size - - def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, - embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: - # Get the logits for the next tokens. - logits = torch.matmul(hidden_states, embedding.t()) - if embedding_bias is not None: - logits += embedding_bias - logits = tensor_model_parallel_gather(logits) - # Remove paddings in vocab (if any). - if logits is not None: - logits = logits[:, :self.org_vocab_size] - return logits - def forward( self, - embedding: torch.Tensor, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, - embedding_bias: Optional[torch.Tensor] = None, ) -> Optional[SamplerOutput]: - # Get the hidden states that we use for sampling. - if self.logits_as_hidden_states: - logits = hidden_states - else: - hidden_states = _prune_hidden_states(hidden_states, - sampling_metadata) - - # Get the logits for the next tokens. - logits = self._get_logits(hidden_states, embedding, embedding_bias) - - # Only perform sampling in the driver worker. - # Note: `_get_logits` is still distributed across TP workers because - # the `embedding` weight is distributed across TP workers. - # TODO(zhuohan): Change the get_logits part to a separate stage. - if not sampling_metadata.perform_sampling: - return None - assert logits is not None _, vocab_size = logits.shape - # Apply logits processors (if any). - logits = _apply_logits_processors(logits, sampling_metadata) - # Prepare sampling tensors with pinned memory to avoid blocking. (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) = SamplingTensors.from_sampling_metadata( @@ -124,14 +77,6 @@ def forward( prompt_logprobs, sample_logprobs) -def _prune_hidden_states( - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> torch.Tensor: - return hidden_states.index_select(0, - sampling_metadata.selected_token_indices) - - def _get_bin_counts_and_mask( tokens: torch.Tensor, vocab_size: int, @@ -149,30 +94,6 @@ def _get_bin_counts_and_mask( return bin_counts, mask -def _apply_logits_processors( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> torch.Tensor: - logits_row_idx = 0 - found_logits_processors = False - for seq_ids, sampling_params in sampling_metadata.seq_groups: - logits_processors = sampling_params.logits_processors - if logits_processors: - found_logits_processors = True - for seq_id in seq_ids: - logits_row = logits[logits_row_idx] - token_ids = sampling_metadata.seq_data[seq_id].output_token_ids - for logits_processor in logits_processors: - logits_row = logits_processor(token_ids, logits_row) - logits[logits_row_idx] = logits_row - logits_row_idx += 1 - else: - logits_row_idx += len(seq_ids) - if found_logits_processors: - assert logits_row_idx == logits.shape[0] - return logits - - def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, output_tokens_tensor: torch.Tensor, presence_penalties: torch.Tensor, diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index cbf472750e294..968b9ebba87b2 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -34,6 +34,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -295,7 +296,8 @@ def __init__(self, self.linear_method = linear_method self.model = BaiChuanModel(config, position_embedding, linear_method) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -308,13 +310,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 0548b2b140b1b..851c475206661 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -30,6 +30,7 @@ LinearMethodBase, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -273,7 +274,8 @@ def __init__( self.linear_method = linear_method self.transformer = BloomModel(config, linear_method) self.lm_head_weight = self.transformer.word_embeddings.weight - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -286,13 +288,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 1c5dcfacaff2b..15e7de03b61f1 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -17,6 +17,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -332,7 +333,8 @@ def __init__( self.linear_method = linear_method self.transformer = ChatGLMModel(config, linear_method) self.lm_head_weight = self.transformer.output_layer.weight - self.sampler = Sampler(config.padded_vocab_size) + self.logits_processor = LogitsProcessor(config.padded_vocab_size) + self.sampler = Sampler() def forward( self, @@ -345,13 +347,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 13c080cb02774..eff93e706f5dc 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -38,6 +38,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -372,7 +373,8 @@ def __init__( self.linear_method = linear_method self.model = DeepseekModel(config, linear_method) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -385,13 +387,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: Optional[torch.Tensor], + logits: Optional[torch.Tensor], sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 3c148be5b10f4..7626dbe62293f 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -34,6 +34,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -373,7 +374,8 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -390,13 +392,18 @@ def forward( ) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 386a36cf492d6..fd3dbe798cd8e 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -30,6 +30,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -281,7 +282,8 @@ def __init__( self.config = config self.linear_method = linear_method self.model = GemmaModel(config, linear_method) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -295,13 +297,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.model.embed_tokens.weight, + hidden_states, sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.model.embed_tokens.weight, - hidden_states, sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 3f7b21e5a4133..263727cac19ff 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -30,6 +30,7 @@ LinearMethodBase, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -216,7 +217,8 @@ def __init__( self.linear_method = linear_method self.transformer = GPT2Model(config, linear_method) self.lm_head_weight = self.transformer.wte.weight - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -229,12 +231,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, hidden_states, + next_tokens = self.sampler(self.lm_head_weight, logits, sampling_metadata) return next_tokens diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 5c30d47d93e36..65caabae60daa 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -31,6 +31,7 @@ LinearMethodBase, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -237,7 +238,8 @@ def __init__( self.linear_method = linear_method self.transformer = GPTBigCodeModel(config, linear_method) self.lm_head_weight = self.transformer.wte.weight - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -250,13 +252,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 93dce7b67a7a5..c956a12f3e46e 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -30,6 +30,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -224,7 +225,8 @@ def __init__( config.n_embd, bias=True, ) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -237,13 +239,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata, self.lm_head.bias) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata, self.lm_head.bias) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 98107350e60b9..db2173936e7d9 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -30,6 +30,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -238,7 +239,8 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -251,13 +253,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.embed_out.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.embed_out.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 7b2215ef4bda5..93026fc01f0f0 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -14,6 +14,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -250,7 +251,8 @@ def __init__( self.linear_method = linear_method self.model = InternLM2Model(config, linear_method) self.output = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -263,13 +265,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.output.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.output.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 4c163dfdab537..757b75129845c 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -37,6 +37,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) @@ -325,7 +326,11 @@ def __init__( # compatibility if not lora_config else lora_config.lora_vocab_padding_size, ) - self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.sampler = Sampler() def forward( self, @@ -338,13 +343,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index d47834e519697..68a3a298444ae 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -37,6 +37,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) @@ -369,7 +370,9 @@ def __init__( # compatibility if not lora_config else lora_config.lora_vocab_padding_size, ) - self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -382,13 +385,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: Optional[torch.Tensor], + logits: Optional[torch.Tensor], sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 25c7f1978c0dc..b4dfc439d50e9 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -39,6 +39,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -344,7 +345,8 @@ def __init__( self.linear_method = linear_method self.model = MixtralModel(config, linear_method) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -357,13 +359,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: Optional[torch.Tensor], + logits: Optional[torch.Tensor], sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 16ecac3d0529a..7a2568817858c 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -13,6 +13,7 @@ LinearMethodBase, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -259,7 +260,8 @@ def __init__( self.transformer = MPTModel(config, linear_method) self.lm_head_weight = self.transformer.wte.weight - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -272,13 +274,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/neuron/llama.py b/vllm/model_executor/models/neuron/llama.py index e2856da99d9b1..32c43c4944fac 100644 --- a/vllm/model_executor/models/neuron/llama.py +++ b/vllm/model_executor/models/neuron/llama.py @@ -7,6 +7,7 @@ from transformers import LlamaConfig from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput @@ -25,7 +26,8 @@ def __init__( self.config = config self.linear_method = linear_method self.model = None - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -45,13 +47,18 @@ def forward( start_ids=seq_ids.flatten()) return logits + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.model.chkpt_model.lm_head, + hidden_states, sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.model.chkpt_model.lm_head, - hidden_states, sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/neuron/mistral.py b/vllm/model_executor/models/neuron/mistral.py index a302cce30abab..24fc0fa0aacab 100755 --- a/vllm/model_executor/models/neuron/mistral.py +++ b/vllm/model_executor/models/neuron/mistral.py @@ -6,6 +6,7 @@ from transformers import MistralConfig from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput @@ -26,7 +27,8 @@ def __init__( self.linear_method = linear_method self.model = None self.lm_head = None - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -48,13 +50,18 @@ def forward( start_ids=seq_ids) return logits + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.model.chkpt_model.lm_head, + hidden_states, sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.model.chkpt_model.lm_head, - hidden_states, sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 2b0a420e82faf..19f2be6da8ed3 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -51,6 +51,7 @@ RowParallelLinear, ) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -336,7 +337,8 @@ def __init__(self, self.lm_head_weight = (self.model.transformer.wte.weight if config.weight_tying else self.model.transformer.ff_out.weight) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -353,13 +355,18 @@ def forward( ) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights( diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 782f43ce265bd..a12f63b58f52b 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -31,6 +31,7 @@ QKVParallelLinear, ReplicatedLinear, RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -292,7 +293,8 @@ def __init__( self.linear_method = linear_method self.model = OPTModel(config, linear_method) self.lm_head_weight = self.model.decoder.embed_tokens.weight - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -305,13 +307,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 6039b1cdc3534..86428e320e0f7 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -18,6 +18,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -256,7 +257,8 @@ def __init__( self.linear_method = linear_method self.model = OrionModel(config, linear_method) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -269,13 +271,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 039dc7a9b7675..ef70c823dc905 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -49,6 +49,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -240,7 +241,8 @@ def __init__(self, self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, bias=True) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -254,14 +256,18 @@ def forward( return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata, self.lm_head.bias) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - head = self.lm_head - next_tokens = self.sampler(head.weight, hidden_states, - sampling_metadata, head.bias) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index d4d5a4e8bb9a5..61ac2c6c605c6 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -19,6 +19,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -230,7 +231,8 @@ def __init__( self.linear_method = linear_method self.transformer = QWenModel(config, linear_method) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -243,13 +245,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 12e0feddcb7f1..6698f01b7c701 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -37,6 +37,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -300,11 +301,15 @@ def __init__( self.linear_method = linear_method self.model = Qwen2Model(config, linear_method) - if not config.tie_word_embeddings: + if config.tie_word_embeddings: + self.lm_head_weight = self.model.embed_tokens.weight + else: self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.lm_head_weight = self.lm_head.weight - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -317,17 +322,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - if self.config.tie_word_embeddings: - lm_head_weight = self.model.embed_tokens.weight - else: - lm_head_weight = self.lm_head.weight - next_tokens = self.sampler(lm_head_weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index c66f327beee7a..7624ca89ee670 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -33,6 +33,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -238,7 +239,8 @@ def __init__( self.linear_method = linear_method self.model = StableLMEpochModel(config, linear_method) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -251,13 +253,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index cfbb1bdb7909e..e418951a633ab 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -32,6 +32,7 @@ LinearMethodBase, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) @@ -254,7 +255,9 @@ def __init__(self, padding_size=DEFAULT_VOCAB_PADDING_SIZE, ) self.lm_head_weight = self.lm_head.weight - self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -267,13 +270,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: Optional[torch.Tensor], + logits: Optional[torch.Tensor], sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index cfccbbb20adc5..347b9380f1113 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -613,9 +613,16 @@ def execute_model( input_metadata=input_metadata, ) + # Compute the logits. + logits = self.model.compute_logits(hidden_states, sampling_metadata) + + # Only perform sampling in the driver worker. + if not sampling_metadata.perform_sampling: + return None + # Sample the next token. output = self.model.sample( - hidden_states=hidden_states, + logits=logits, sampling_metadata=sampling_metadata, ) return output