From b3fa8cb8d7254d494c00f932b15ebbcf9b78fc1a Mon Sep 17 00:00:00 2001 From: roy Date: Tue, 12 Mar 2024 21:29:04 +0800 Subject: [PATCH 01/19] refactor --- vllm/model_executor/layers/logit_processor.py | 110 ++++++++++++++++++ vllm/model_executor/layers/sampler.py | 82 +------------ vllm/model_executor/models/llama.py | 15 ++- 3 files changed, 122 insertions(+), 85 deletions(-) create mode 100644 vllm/model_executor/layers/logit_processor.py diff --git a/vllm/model_executor/layers/logit_processor.py b/vllm/model_executor/layers/logit_processor.py new file mode 100644 index 0000000000000..32327b0dbeb88 --- /dev/null +++ b/vllm/model_executor/layers/logit_processor.py @@ -0,0 +1,110 @@ +"""A layer that compute logits from hidden_stats.""" +from typing import Optional + +import torch +import torch.nn as nn + +from vllm.utils import is_neuron + +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_gather) +from vllm.model_executor.sampling_metadata import SamplingMetadata + + +class LogitProcessor(nn.Module): + """Process logits and apply logits processors from sampling metadata. + + This layer does the following: + 1. Gather logits from model hidden_states. + 2. Scale logits if needed. + 3. Apply logits processors (if any). + """ + + def __init__(self, + vocab_size: int, + org_vocab_size: Optional[int] = None, + scale: Optional[float] = 1.0) -> None: + """ + Args: + scale: A scaling factor to apply to the logits. + """ + super().__init__() + self.scale = scale + # Transformers-neuronx generate outputs as logits directly. + self.logits_as_hidden_states = is_neuron() + # original vocabulary size (without LoRA). + self.org_vocab_size = org_vocab_size or vocab_size + + def forward( + self, + embedding: torch.Tensor, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + embedding_bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if self.logits_as_hidden_states: + logits = hidden_states + else: + hidden_states = _prune_hidden_states(hidden_states, + sampling_metadata) + + # Get the logits for the next tokens. + logits = self._get_logits(hidden_states, embedding, embedding_bias) + + logits *= self.scale + + # Only perform sampling in the driver worker. + # Note: `_get_logits` is still distributed across TP workers because + # the `embedding` weight is distributed across TP workers. + if not sampling_metadata.perform_sampling: + return None + + # Apply logits processors (if any). + logits = _apply_logits_processors(logits, sampling_metadata) + + return logits + + def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, + embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: + # Get the logits for the next tokens. + logits = torch.matmul(hidden_states, embedding.t()) + if embedding_bias is not None: + logits += embedding_bias + logits = tensor_model_parallel_gather(logits) + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[:, :self.org_vocab_size] + return logits + + +def _prune_hidden_states( + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + return hidden_states.index_select(0, + sampling_metadata.selected_token_indices) + + +def _apply_logits_processors( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + logits_row_idx = 0 + found_logits_processors = False + for seq_ids, sampling_params in sampling_metadata.seq_groups: + logits_processors = sampling_params.logits_processors + if logits_processors: + found_logits_processors = True + for seq_id in seq_ids: + logits_row = logits[logits_row_idx] + token_ids = sampling_metadata.seq_data[seq_id].output_token_ids + for logits_processor in logits_processors: + logits_row = logits_processor(token_ids, logits_row) + logits[logits_row_idx] = logits_row + logits_row_idx += 1 + else: + logits_row_idx += len(seq_ids) + if found_logits_processors: + assert logits_row_idx == logits.shape[0] + return logits diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 4377b845df628..8b63c82511fa4 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -4,15 +4,12 @@ import torch import torch.nn as nn -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_gather) from vllm.model_executor.sampling_metadata import (SamplingMetadata, SamplingTensors) from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs, SamplerOutput, SequenceData, SequenceGroupOutput, SequenceOutput) -from vllm.utils import is_neuron class Sampler(nn.Module): @@ -30,58 +27,14 @@ class Sampler(nn.Module): parameters (e.g., sampling method, temperature, top-p, top-k, etc.). """ - def __init__(self, - vocab_size: int, - org_vocab_size: Optional[int] = None) -> None: - super().__init__() - self.vocab_size = vocab_size - # Transformers-neuronx generate outputs as logits directly. - self.logits_as_hidden_states = is_neuron() - # original vocabulary size (without LoRA). - self.org_vocab_size = org_vocab_size or vocab_size - - def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, - embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: - # Get the logits for the next tokens. - logits = torch.matmul(hidden_states, embedding.t()) - if embedding_bias is not None: - logits += embedding_bias - logits = tensor_model_parallel_gather(logits) - # Remove paddings in vocab (if any). - if logits is not None: - logits = logits[:, :self.org_vocab_size] - return logits - def forward( self, - embedding: torch.Tensor, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, - embedding_bias: Optional[torch.Tensor] = None, ) -> Optional[SamplerOutput]: - # Get the hidden states that we use for sampling. - if self.logits_as_hidden_states: - logits = hidden_states - else: - hidden_states = _prune_hidden_states(hidden_states, - sampling_metadata) - - # Get the logits for the next tokens. - logits = self._get_logits(hidden_states, embedding, embedding_bias) - - # Only perform sampling in the driver worker. - # Note: `_get_logits` is still distributed across TP workers because - # the `embedding` weight is distributed across TP workers. - # TODO(zhuohan): Change the get_logits part to a separate stage. - if not sampling_metadata.perform_sampling: - return None - assert logits is not None _, vocab_size = logits.shape - # Apply logits processors (if any). - logits = _apply_logits_processors(logits, sampling_metadata) - # Prepare sampling tensors with pinned memory to avoid blocking. (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) = SamplingTensors.from_sampling_metadata( @@ -122,15 +75,6 @@ def forward( prompt_logprobs, sample_logprobs) -def _prune_hidden_states( - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> torch.Tensor: - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - return hidden_states.index_select(0, - sampling_metadata.selected_token_indices) - - def _get_bin_counts_and_mask( tokens: torch.Tensor, vocab_size: int, @@ -148,30 +92,6 @@ def _get_bin_counts_and_mask( return bin_counts, mask -def _apply_logits_processors( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> torch.Tensor: - logits_row_idx = 0 - found_logits_processors = False - for seq_ids, sampling_params in sampling_metadata.seq_groups: - logits_processors = sampling_params.logits_processors - if logits_processors: - found_logits_processors = True - for seq_id in seq_ids: - logits_row = logits[logits_row_idx] - token_ids = sampling_metadata.seq_data[seq_id].output_token_ids - for logits_processor in logits_processors: - logits_row = logits_processor(token_ids, logits_row) - logits[logits_row_idx] = logits_row - logits_row_idx += 1 - else: - logits_row_idx += len(seq_ids) - if found_logits_processors: - assert logits_row_idx == logits.shape[0] - return logits - - def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, output_tokens_tensor: torch.Tensor, presence_penalties: torch.Tensor, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 4c163dfdab537..48aacb42f87f4 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -37,7 +37,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import (Sampler, LogitProcessor) from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) from vllm.model_executor.parallel_utils.parallel_state import ( @@ -325,7 +325,11 @@ def __init__( # compatibility if not lora_config else lora_config.lora_vocab_padding_size, ) - self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logit_processor = LogitProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.sampler = Sampler() def forward( self, @@ -343,8 +347,11 @@ def sample( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + + logits = self.logit_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, From 2e2b65ca80ff33cfd3949400138a4f3e7fe17077 Mon Sep 17 00:00:00 2001 From: roy Date: Tue, 12 Mar 2024 21:44:11 +0800 Subject: [PATCH 02/19] add test --- tests/samplers/test_logit_processor.py | 88 ++++++++++++++++++++++++++ tests/samplers/test_sampler.py | 55 ++-------------- vllm/model_executor/models/llama.py | 3 +- 3 files changed, 94 insertions(+), 52 deletions(-) create mode 100644 tests/samplers/test_logit_processor.py diff --git a/tests/samplers/test_logit_processor.py b/tests/samplers/test_logit_processor.py new file mode 100644 index 0000000000000..ddef21b2c0c1f --- /dev/null +++ b/tests/samplers/test_logit_processor.py @@ -0,0 +1,88 @@ +import random +from typing import Tuple +from unittest.mock import patch + +import pytest +import torch + +from vllm.model_executor.layers.logit_processor import LogitProcessor +from vllm.model_executor.utils import set_random_seed +from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata +from vllm.worker.model_runner import ModelRunner + + +class MockLogitsProcessor(LogitProcessor): + + def __init__(self, vocab_size: int, fake_logits: torch.Tensor): + super().__init__(vocab_size=vocab_size) + self.fake_logits = fake_logits + + def forward(self, *args, **kwargs): + with patch( + "vllm.model_executor.layers.logit_processor._prune_hidden_states", + lambda x, y: x + ), patch( + "vllm.model_executor.layers.logit_processor.LogitProcessor._get_logits", + lambda *args, **kwargs: self.fake_logits): + return super().forward(*args, **kwargs) + + +def _prepare_test( + batch_size: int +) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor, ModelRunner]: + vocab_size = 32000 + input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) + fake_logits = torch.full((batch_size, vocab_size), + 1e-2, + dtype=input_tensor.dtype) + logit_processor = MockLogitsProcessor(32000, fake_logits) + model_runner = ModelRunner(None, None, None, None, None) + return input_tensor, fake_logits, logit_processor, model_runner + + +RANDOM_SEEDS = list(range(128)) +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + + +@pytest.mark.parametrize("seed", RANDOM_SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_logits_processors(seed: int, device: str): + set_random_seed(seed) + torch.set_default_device(device) + batch_size = random.randint(1, 256) + input_tensor, fake_logits, logit_processor, model_runner = _prepare_test( + batch_size) + + # This sample logits processor gives infinite score to the i-th token, + # where i is the length of the input sequence. + # We therefore expect the output token sequence to be [0, 1, 2, ...] + def pick_ith(token_ids, logits): + logits[len(token_ids)] = float("inf") + return logits + + seq_group_metadata_list = [] + prompt_lens = [] + for i in range(batch_size): + seq_group_metadata_list.append( + SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: SequenceData([1, 2, 3])}, + sampling_params=SamplingParams(temperature=0, + logits_processors=[pick_ith]), + block_tables={0: [1]}, + )) + prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + + sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, + prompt_lens, + subquery_lens=prompt_lens) + logit_processor_output = logit_processor( + embedding=None, + hidden_states=input_tensor, + sampling_metadata=sampling_metadata) + assert logit_processor_output == fake_logits + + del model_runner diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 1bc8703d1a8e0..66afe62754dda 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -15,17 +15,12 @@ class MockLogitsSampler(Sampler): - def __init__(self, vocab_size: int, fake_logits: torch.Tensor): - super().__init__(vocab_size=vocab_size) + def __init__(self, fake_logits: torch.Tensor): + super().__init__() self.fake_logits = fake_logits def forward(self, *args, **kwargs): - with patch( - "vllm.model_executor.layers.sampler._prune_hidden_states", - lambda x, y: x), patch( - "vllm.model_executor.layers.sampler.Sampler._get_logits", - lambda *args, **kwargs: self.fake_logits): - return super().forward(*args, **kwargs) + return super().forward(*args, **kwargs) def _prepare_test( @@ -36,7 +31,7 @@ def _prepare_test( fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=input_tensor.dtype) - sampler = MockLogitsSampler(32000, fake_logits) + sampler = MockLogitsSampler(fake_logits) model_runner = ModelRunner(None, None, None, None, None) return input_tensor, fake_logits, sampler, model_runner @@ -294,48 +289,6 @@ def test_sampling(model_runner: ModelRunner): del model_runner -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_sampler_logits_processors(seed: int, device: str): - set_random_seed(seed) - torch.set_default_device(device) - batch_size = random.randint(1, 256) - input_tensor, _, sampler, model_runner = _prepare_test(batch_size) - - # This sample logits processor gives infinite score to the i-th token, - # where i is the length of the input sequence. - # We therefore expect the output token sequence to be [0, 1, 2, ...] - def pick_ith(token_ids, logits): - logits[len(token_ids)] = float("inf") - return logits - - seq_group_metadata_list = [] - prompt_lens = [] - for i in range(batch_size): - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: SequenceData([1, 2, 3])}, - sampling_params=SamplingParams(temperature=0, - logits_processors=[pick_ith]), - block_tables={0: [1]}, - )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - - sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens) - sampler_output = sampler(embedding=None, - hidden_states=input_tensor, - sampling_metadata=sampling_metadata) - for _, sequence_output in enumerate(sampler_output): - for idx, nth_output in enumerate(sequence_output.samples): - assert nth_output.output_token == idx - - del model_runner - - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_sampler_top_k_top_p(seed: int, device: str): diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 48aacb42f87f4..ee502c1064000 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -37,7 +37,8 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import (Sampler, LogitProcessor) +from vllm.model_executor.layers.logit_processor import LogitProcessor +from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) from vllm.model_executor.parallel_utils.parallel_state import ( From 4d4c31f2ad566363af5a2455a205a0a360151320 Mon Sep 17 00:00:00 2001 From: roy Date: Wed, 13 Mar 2024 19:57:35 +0800 Subject: [PATCH 03/19] fix --- tests/samplers/test_sampler.py | 40 ++++++++----------- tests/{samplers => }/test_logit_processor.py | 16 +++++--- vllm/model_executor/layers/logit_processor.py | 11 ++--- vllm/model_executor/models/llama.py | 22 +++++----- vllm/worker/model_runner.py | 8 +++- 5 files changed, 45 insertions(+), 52 deletions(-) rename tests/{samplers => }/test_logit_processor.py (86%) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 66afe62754dda..723a0f403c021 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -65,9 +65,7 @@ def _do_sample( sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) - return sampler(embedding=None, - hidden_states=input_tensor, - sampling_metadata=sampling_metadata) + return sampler(logits=input_tensor, sampling_metadata=sampling_metadata) @pytest.mark.parametrize("seed", RANDOM_SEEDS) @@ -80,8 +78,8 @@ def test_sampler_all_greedy(seed: int, device: str): batch_size) sampling_params = SamplingParams(temperature=0) - sampler_output = _do_sample(batch_size, input_tensor, sampler, - model_runner, sampling_params) + sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, + sampling_params) expected = torch.argmax(fake_logits, dim=-1) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: @@ -106,8 +104,8 @@ def test_sampler_all_random(seed: int, device: str): temperature=1.0, n=random.randint(1, 10), ) - sampler_output = _do_sample(batch_size, input_tensor, sampler, - model_runner, sampling_params) + sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, + sampling_params) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: @@ -122,8 +120,7 @@ def test_sampler_all_random_seed(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) + _, fake_logits, sampler, model_runner = _prepare_test(batch_size) for i in range(batch_size): fake_logits[i, i] = 1e2 @@ -133,8 +130,8 @@ def test_sampler_all_random_seed(seed: int, device: str): n=random.randint(1, 10), seed=random.randint(0, 10000), ) - sampler_output = _do_sample(batch_size, input_tensor, sampler, - model_runner, sampling_params) + sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, + sampling_params) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: @@ -149,18 +146,17 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) + _, fake_logits, sampler, model_runner = _prepare_test(batch_size) sampling_params = SamplingParams( temperature=1.0, n=random.randint(1, 10), seed=random.randint(0, 10000), ) - first_sampler_output = _do_sample(batch_size, input_tensor, sampler, + first_sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params) - second_sampler_output = _do_sample(batch_size, input_tensor, sampler, + second_sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params) assert first_sampler_output == second_sampler_output @@ -174,15 +170,14 @@ def test_sampler_all_beam(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, _, sampler, model_runner = _prepare_test(batch_size) + _, fake_logits, sampler, model_runner = _prepare_test(batch_size) sampling_params = SamplingParams( temperature=0, best_of=2, use_beam_search=True, ) - _do_sample(batch_size, input_tensor, sampler, model_runner, - sampling_params) + _do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params) # no assertion here as I am not sure how to determine whether # the outputs are expected - in other words, this just tests # whether there are no exceptions in the sampler @@ -241,8 +236,7 @@ def test_sampler_mixed(seed: int, device: str): def test_sampling(model_runner: ModelRunner): sampling_metadata = model_runner._prepare_sample( seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) - sampler_output = sampler(embedding=None, - hidden_states=input_tensor, + sampler_output = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) for i, (sequence_output, metadata) in enumerate( @@ -305,7 +299,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): size=(batch_size, vocab_size), device=input_tensor.device, dtype=input_tensor.dtype) - sampler = MockLogitsSampler(32000, fake_logits) + sampler = MockLogitsSampler(fake_logits) model_runner = ModelRunner(None, None, None, None, None) generation_model = GenerationMixin() @@ -344,9 +338,7 @@ def mock_sample(probs, logprobs, sampling_metadata): return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs] with patch("vllm.model_executor.layers.sampler._sample", mock_sample): - sampler(embedding=None, - hidden_states=input_tensor, - sampling_metadata=sampling_metadata) + sampler(logits=fake_logits, sampling_metadata=sampling_metadata) hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone()) hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float) assert torch.allclose(hf_probs, sample_probs, atol=1e-5) diff --git a/tests/samplers/test_logit_processor.py b/tests/test_logit_processor.py similarity index 86% rename from tests/samplers/test_logit_processor.py rename to tests/test_logit_processor.py index ddef21b2c0c1f..d1959af0dc7c3 100644 --- a/tests/samplers/test_logit_processor.py +++ b/tests/test_logit_processor.py @@ -13,9 +13,10 @@ class MockLogitsProcessor(LogitProcessor): - def __init__(self, vocab_size: int, fake_logits: torch.Tensor): - super().__init__(vocab_size=vocab_size) - self.fake_logits = fake_logits + def __init__(self, vocab_size: int, scale: float, + fake_logits: torch.Tensor): + super().__init__(vocab_size=vocab_size, scale=scale) + self.fake_logits = fake_logits.clone() def forward(self, *args, **kwargs): with patch( @@ -35,7 +36,7 @@ def _prepare_test( fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=input_tensor.dtype) - logit_processor = MockLogitsProcessor(32000, fake_logits) + logit_processor = MockLogitsProcessor(32000, 0.5, fake_logits) model_runner = ModelRunner(None, None, None, None, None) return input_tensor, fake_logits, logit_processor, model_runner @@ -83,6 +84,11 @@ def pick_ith(token_ids, logits): embedding=None, hidden_states=input_tensor, sampling_metadata=sampling_metadata) - assert logit_processor_output == fake_logits + + assert torch.isinf(logit_processor_output[:, 0]).all() + + fake_logits *= logit_processor.scale + assert torch.allclose(logit_processor_output[:, 1], fake_logits[:, 1], + 1e-4) del model_runner diff --git a/vllm/model_executor/layers/logit_processor.py b/vllm/model_executor/layers/logit_processor.py index 32327b0dbeb88..69ff5a99a4373 100644 --- a/vllm/model_executor/layers/logit_processor.py +++ b/vllm/model_executor/layers/logit_processor.py @@ -51,16 +51,11 @@ def forward( # Get the logits for the next tokens. logits = self._get_logits(hidden_states, embedding, embedding_bias) + if logits is not None: logits *= self.scale - # Only perform sampling in the driver worker. - # Note: `_get_logits` is still distributed across TP workers because - # the `embedding` weight is distributed across TP workers. - if not sampling_metadata.perform_sampling: - return None - - # Apply logits processors (if any). - logits = _apply_logits_processors(logits, sampling_metadata) + # Apply logits processors (if any). + logits = _apply_logits_processors(logits, sampling_metadata) return logits diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index ee502c1064000..41a16c6b6ba70 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -332,26 +332,22 @@ def __init__( config.vocab_size, logit_scale) self.sampler = Sampler() - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - ) -> torch.Tensor: + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + kv_caches: List[KVCache], input_metadata: InputMetadata, + sampling_metadata: SamplingMetadata) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, input_metadata) - return hidden_states + + logits = self.logit_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + + return logits def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - - logits = self.logit_processor(self.lm_head.weight, hidden_states, - sampling_metadata) - next_tokens = self.sampler(logits, sampling_metadata) return next_tokens diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 7eac576e3f0fe..129b346800b1c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -587,16 +587,20 @@ def execute_model( model_executable = self.graph_runners[graph_batch_size] else: model_executable = self.model - hidden_states = model_executable( + logits = model_executable( input_ids=input_tokens, positions=input_positions, kv_caches=kv_caches, input_metadata=input_metadata, + sampling_metadata=sampling_metadata, ) + if not sampling_metadata.perform_sampling: + return None + # Sample the next token. output = self.model.sample( - hidden_states=hidden_states, + logits=logits, sampling_metadata=sampling_metadata, ) return output From ab108139ea9ee9e0cb6b6ac33b1c14878a8ca385 Mon Sep 17 00:00:00 2001 From: roy Date: Wed, 13 Mar 2024 20:14:28 +0800 Subject: [PATCH 04/19] fix test --- vllm/model_executor/models/opt.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 782f43ce265bd..10e3c9a3efcc1 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -31,6 +31,7 @@ QKVParallelLinear, ReplicatedLinear, RowParallelLinear) +from vllm.model_executor.layers.logit_processor import LogitProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -292,26 +293,28 @@ def __init__( self.linear_method = linear_method self.model = OPTModel(config, linear_method) self.lm_head_weight = self.model.decoder.embed_tokens.weight - self.sampler = Sampler(config.vocab_size) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - ) -> torch.Tensor: + logit_scale = getattr(config, "logit_scale", 1.0) + self.logit_processor = LogitProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.sampler = Sampler() + + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + kv_caches: List[KVCache], input_metadata: InputMetadata, + sampling_metadata: SamplingMetadata) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, input_metadata) - return hidden_states + + logits = self.logit_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + + return logits def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, From 97be522f74baabd67060603300d9750bc0cc92e5 Mon Sep 17 00:00:00 2001 From: roy Date: Wed, 13 Mar 2024 20:20:22 +0800 Subject: [PATCH 05/19] fix opt --- vllm/model_executor/models/opt.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 10e3c9a3efcc1..21f1988010667 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -293,9 +293,7 @@ def __init__( self.linear_method = linear_method self.model = OPTModel(config, linear_method) self.lm_head_weight = self.model.decoder.embed_tokens.weight - logit_scale = getattr(config, "logit_scale", 1.0) - self.logit_processor = LogitProcessor(self.unpadded_vocab_size, - config.vocab_size, logit_scale) + self.logit_processor = LogitProcessor(config.vocab_size) self.sampler = Sampler() def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, From b3ddea341f3cf059812aee79bc6a6bfc4e8f447d Mon Sep 17 00:00:00 2001 From: roy Date: Wed, 13 Mar 2024 20:42:26 +0800 Subject: [PATCH 06/19] fix cuda graph --- vllm/model_executor/models/llama.py | 13 ++++++++++--- vllm/model_executor/models/opt.py | 13 ++++++++++--- vllm/worker/model_runner.py | 6 ++++-- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 41a16c6b6ba70..86c73cc7af35b 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -332,12 +332,19 @@ def __init__( config.vocab_size, logit_scale) self.sampler = Sampler() - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], input_metadata: InputMetadata, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, input_metadata) + return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logit_processor(self.lm_head.weight, hidden_states, sampling_metadata) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 21f1988010667..3a8b7a2a9ebb3 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -296,12 +296,19 @@ def __init__( self.logit_processor = LogitProcessor(config.vocab_size) self.sampler = Sampler() - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], input_metadata: InputMetadata, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, input_metadata) + return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logit_processor(self.lm_head_weight, hidden_states, sampling_metadata) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 129b346800b1c..0de58d1763361 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -587,14 +587,16 @@ def execute_model( model_executable = self.graph_runners[graph_batch_size] else: model_executable = self.model - logits = model_executable( + hidden_states = model_executable( input_ids=input_tokens, positions=input_positions, kv_caches=kv_caches, input_metadata=input_metadata, - sampling_metadata=sampling_metadata, ) + # Compute the logits. + logits = self.model.compute_logits(hidden_states, sampling_metadata) + if not sampling_metadata.perform_sampling: return None From 1ebe550a9d72694aea16566a0a6cd2f6b430108d Mon Sep 17 00:00:00 2001 From: roy Date: Wed, 13 Mar 2024 21:17:03 +0800 Subject: [PATCH 07/19] rename --- .buildkite/test-pipeline.yaml | 3 +++ ..._processor.py => test_logits_processor.py} | 20 +++++++++---------- ...logit_processor.py => logits_processor.py} | 2 +- vllm/model_executor/models/llama.py | 11 +++++----- vllm/model_executor/models/opt.py | 9 ++++----- 5 files changed, 23 insertions(+), 22 deletions(-) rename tests/{test_logit_processor.py => test_logits_processor.py} (80%) rename vllm/model_executor/layers/{logit_processor.py => logits_processor.py} (99%) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 42a1eacb6de57..9c8fdb9482efd 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -49,6 +49,9 @@ steps: - label: Samplers Test command: pytest -v -s samplers --forked +- label: LogitsProcessor Test + command: pytest -v -s test_logits_processor.py + - label: Worker Test command: pytest -v -s worker diff --git a/tests/test_logit_processor.py b/tests/test_logits_processor.py similarity index 80% rename from tests/test_logit_processor.py rename to tests/test_logits_processor.py index d1959af0dc7c3..68e0310625d08 100644 --- a/tests/test_logit_processor.py +++ b/tests/test_logits_processor.py @@ -5,7 +5,7 @@ import pytest import torch -from vllm.model_executor.layers.logit_processor import LogitProcessor +from vllm.model_executor.layers.logits_processor import LogitProcessor from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.worker.model_runner import ModelRunner @@ -20,10 +20,10 @@ def __init__(self, vocab_size: int, scale: float, def forward(self, *args, **kwargs): with patch( - "vllm.model_executor.layers.logit_processor._prune_hidden_states", + "vllm.model_executor.layers.logits_processor._prune_hidden_states", lambda x, y: x ), patch( - "vllm.model_executor.layers.logit_processor.LogitProcessor._get_logits", + "vllm.model_executor.layers.logits_processor.LogitsProcessor._get_logits", lambda *args, **kwargs: self.fake_logits): return super().forward(*args, **kwargs) @@ -36,9 +36,9 @@ def _prepare_test( fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=input_tensor.dtype) - logit_processor = MockLogitsProcessor(32000, 0.5, fake_logits) + logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits) model_runner = ModelRunner(None, None, None, None, None) - return input_tensor, fake_logits, logit_processor, model_runner + return input_tensor, fake_logits, logits_processor, model_runner RANDOM_SEEDS = list(range(128)) @@ -53,7 +53,7 @@ def test_logits_processors(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, logit_processor, model_runner = _prepare_test( + input_tensor, fake_logits, logits_processor, model_runner = _prepare_test( batch_size) # This sample logits processor gives infinite score to the i-th token, @@ -80,15 +80,15 @@ def pick_ith(token_ids, logits): sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) - logit_processor_output = logit_processor( + logits_processor_output = logits_processor( embedding=None, hidden_states=input_tensor, sampling_metadata=sampling_metadata) - assert torch.isinf(logit_processor_output[:, 0]).all() + assert torch.isinf(logits_processor_output[:, 0]).all() - fake_logits *= logit_processor.scale - assert torch.allclose(logit_processor_output[:, 1], fake_logits[:, 1], + fake_logits *= logits_processor.scale + assert torch.allclose(logits_processor_output[:, 1], fake_logits[:, 1], 1e-4) del model_runner diff --git a/vllm/model_executor/layers/logit_processor.py b/vllm/model_executor/layers/logits_processor.py similarity index 99% rename from vllm/model_executor/layers/logit_processor.py rename to vllm/model_executor/layers/logits_processor.py index 69ff5a99a4373..5aba4ade8cc58 100644 --- a/vllm/model_executor/layers/logit_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -11,7 +11,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata -class LogitProcessor(nn.Module): +class LogitsProcessor(nn.Module): """Process logits and apply logits processors from sampling metadata. This layer does the following: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 86c73cc7af35b..757b75129845c 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -37,7 +37,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.logit_processor import LogitProcessor +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) @@ -328,8 +328,8 @@ def __init__( ) logit_scale = getattr(config, "logit_scale", 1.0) - self.logit_processor = LogitProcessor(self.unpadded_vocab_size, - config.vocab_size, logit_scale) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) self.sampler = Sampler() def forward( @@ -345,9 +345,8 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logit_processor(self.lm_head.weight, hidden_states, - sampling_metadata) - + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) return logits def sample( diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 3a8b7a2a9ebb3..a12f63b58f52b 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -31,7 +31,7 @@ QKVParallelLinear, ReplicatedLinear, RowParallelLinear) -from vllm.model_executor.layers.logit_processor import LogitProcessor +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -293,7 +293,7 @@ def __init__( self.linear_method = linear_method self.model = OPTModel(config, linear_method) self.lm_head_weight = self.model.decoder.embed_tokens.weight - self.logit_processor = LogitProcessor(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() def forward( @@ -309,9 +309,8 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logit_processor(self.lm_head_weight, hidden_states, - sampling_metadata) - + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) return logits def sample( From 8b642acdc81414bbd78d6a15e03c187074b2920e Mon Sep 17 00:00:00 2001 From: roy Date: Wed, 13 Mar 2024 21:23:36 +0800 Subject: [PATCH 08/19] fix test --- tests/test_logits_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index 68e0310625d08..fe321520114f7 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -5,13 +5,13 @@ import pytest import torch -from vllm.model_executor.layers.logits_processor import LogitProcessor +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.worker.model_runner import ModelRunner -class MockLogitsProcessor(LogitProcessor): +class MockLogitsProcessor(LogitsProcessor): def __init__(self, vocab_size: int, scale: float, fake_logits: torch.Tensor): From 4892fb948f190fb72ff2ae3a1896a1791c9c7b29 Mon Sep 17 00:00:00 2001 From: roy Date: Thu, 14 Mar 2024 21:05:05 +0800 Subject: [PATCH 09/19] fix lora --- vllm/lora/layers.py | 16 ++++++++-------- vllm/lora/models.py | 13 ++++++++----- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 99e6cdeee6364..135cd1bca85e4 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -10,7 +10,7 @@ from vllm.config import LoRAConfig from vllm.lora.punica import add_lora, add_lora_slice, bgmv -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, @@ -783,11 +783,11 @@ def weight(self): return self.base_layer.weight -class SamplerWithLoRA(BaseLayerWithLoRA): +class LogitsProcessorWithLoRA(BaseLayerWithLoRA): def __init__( self, - base_layer: Sampler, + base_layer: LogitsProcessor, hidden_size: int, dtype: torch.dtype, device: torch.device, @@ -968,14 +968,14 @@ def from_layer( return layer -def from_layer_sampler( - layer: Sampler, +def from_layer_logits_processor( + layer: LogitsProcessor, lm_head: ParallelLMHead, max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None, -) -> SamplerWithLoRA: - ret = SamplerWithLoRA(layer, lm_head.embedding_dim, lm_head.weight.dtype, - lm_head.weight.device) +) -> LogitsProcessorWithLoRA: + ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim, + lm_head.weight.dtype, lm_head.weight.device) ret.create_lora_weights(max_loras, lora_config, model_config) return ret diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 238da256b7cdc..5c9e6c6923053 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -14,7 +14,7 @@ from vllm.utils import LRUCache, in_wsl from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, from_layer, - from_layer_sampler) + from_layer_logits_processor) from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule @@ -421,11 +421,14 @@ def _create_lora_modules(self): self.model.config)) # (yard1): TODO make this more robust if "lm_head" in module_name: - sampler_module = self.model.get_submodule("sampler") + logits_processor_module = self.model.get_submodule( + "logits_processor") new_module = replace_submodule( - self.model, "sampler", - from_layer_sampler(sampler_module, module, self.lora_slots, - self.lora_config, self.model.config)) + self.model, "logits_processor", + from_layer_logits_processor(logits_processor_module, + module, self.lora_slots, + self.lora_config, + self.model.config)) self.register_module(module_name, new_module) self._register_packed_modules(module_name) new_module.set_mapping(self.base_indices, self.sampler_indices, From d75c42c07e4f59942409ca6e5abbd72bfceefe4f Mon Sep 17 00:00:00 2001 From: roy Date: Thu, 14 Mar 2024 22:12:46 +0800 Subject: [PATCH 10/19] fix other models --- vllm/model_executor/models/baichuan.py | 15 +++++++++++---- vllm/model_executor/models/bloom.py | 15 +++++++++++---- vllm/model_executor/models/chatglm.py | 15 +++++++++++---- vllm/model_executor/models/deepseek.py | 15 +++++++++++---- vllm/model_executor/models/falcon.py | 15 +++++++++++---- vllm/model_executor/models/gemma.py | 15 +++++++++++---- vllm/model_executor/models/gpt2.py | 14 +++++++++++--- vllm/model_executor/models/gpt_bigcode.py | 15 +++++++++++---- vllm/model_executor/models/gpt_j.py | 15 +++++++++++---- vllm/model_executor/models/gpt_neox.py | 15 +++++++++++---- vllm/model_executor/models/internlm2.py | 15 +++++++++++---- vllm/model_executor/models/mixtral.py | 16 ++++++++++++---- vllm/model_executor/models/mixtral_quant.py | 15 +++++++++++---- vllm/model_executor/models/mpt.py | 15 +++++++++++---- vllm/model_executor/models/olmo.py | 15 +++++++++++---- vllm/model_executor/models/orion.py | 15 +++++++++++---- vllm/model_executor/models/phi.py | 16 +++++++++++----- vllm/model_executor/models/qwen.py | 15 +++++++++++---- vllm/model_executor/models/qwen2.py | 15 +++++++++++---- vllm/model_executor/models/stablelm.py | 15 +++++++++++---- vllm/model_executor/models/starcoder2.py | 16 ++++++++++++---- 21 files changed, 233 insertions(+), 84 deletions(-) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index cbf472750e294..968b9ebba87b2 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -34,6 +34,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -295,7 +296,8 @@ def __init__(self, self.linear_method = linear_method self.model = BaiChuanModel(config, position_embedding, linear_method) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -308,13 +310,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 0548b2b140b1b..851c475206661 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -30,6 +30,7 @@ LinearMethodBase, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -273,7 +274,8 @@ def __init__( self.linear_method = linear_method self.transformer = BloomModel(config, linear_method) self.lm_head_weight = self.transformer.word_embeddings.weight - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -286,13 +288,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 1c5dcfacaff2b..15e7de03b61f1 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -17,6 +17,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -332,7 +333,8 @@ def __init__( self.linear_method = linear_method self.transformer = ChatGLMModel(config, linear_method) self.lm_head_weight = self.transformer.output_layer.weight - self.sampler = Sampler(config.padded_vocab_size) + self.logits_processor = LogitsProcessor(config.padded_vocab_size) + self.sampler = Sampler() def forward( self, @@ -345,13 +347,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 13c080cb02774..eff93e706f5dc 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -38,6 +38,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -372,7 +373,8 @@ def __init__( self.linear_method = linear_method self.model = DeepseekModel(config, linear_method) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -385,13 +387,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: Optional[torch.Tensor], + logits: Optional[torch.Tensor], sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 3c148be5b10f4..7626dbe62293f 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -34,6 +34,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -373,7 +374,8 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -390,13 +392,18 @@ def forward( ) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 386a36cf492d6..fd3dbe798cd8e 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -30,6 +30,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -281,7 +282,8 @@ def __init__( self.config = config self.linear_method = linear_method self.model = GemmaModel(config, linear_method) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -295,13 +297,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.model.embed_tokens.weight, + hidden_states, sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.model.embed_tokens.weight, - hidden_states, sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 3f7b21e5a4133..263727cac19ff 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -30,6 +30,7 @@ LinearMethodBase, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -216,7 +217,8 @@ def __init__( self.linear_method = linear_method self.transformer = GPT2Model(config, linear_method) self.lm_head_weight = self.transformer.wte.weight - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -229,12 +231,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, hidden_states, + next_tokens = self.sampler(self.lm_head_weight, logits, sampling_metadata) return next_tokens diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 5c30d47d93e36..65caabae60daa 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -31,6 +31,7 @@ LinearMethodBase, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -237,7 +238,8 @@ def __init__( self.linear_method = linear_method self.transformer = GPTBigCodeModel(config, linear_method) self.lm_head_weight = self.transformer.wte.weight - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -250,13 +252,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 93dce7b67a7a5..c956a12f3e46e 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -30,6 +30,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -224,7 +225,8 @@ def __init__( config.n_embd, bias=True, ) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -237,13 +239,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata, self.lm_head.bias) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata, self.lm_head.bias) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 98107350e60b9..db2173936e7d9 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -30,6 +30,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -238,7 +239,8 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -251,13 +253,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.embed_out.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.embed_out.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 7b2215ef4bda5..93026fc01f0f0 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -14,6 +14,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -250,7 +251,8 @@ def __init__( self.linear_method = linear_method self.model = InternLM2Model(config, linear_method) self.output = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -263,13 +265,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.output.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.output.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index d47834e519697..68a3a298444ae 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -37,6 +37,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) @@ -369,7 +370,9 @@ def __init__( # compatibility if not lora_config else lora_config.lora_vocab_padding_size, ) - self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -382,13 +385,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: Optional[torch.Tensor], + logits: Optional[torch.Tensor], sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 25c7f1978c0dc..b4dfc439d50e9 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -39,6 +39,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -344,7 +345,8 @@ def __init__( self.linear_method = linear_method self.model = MixtralModel(config, linear_method) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -357,13 +359,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: Optional[torch.Tensor], + logits: Optional[torch.Tensor], sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 16ecac3d0529a..7a2568817858c 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -13,6 +13,7 @@ LinearMethodBase, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -259,7 +260,8 @@ def __init__( self.transformer = MPTModel(config, linear_method) self.lm_head_weight = self.transformer.wte.weight - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -272,13 +274,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 2b0a420e82faf..19f2be6da8ed3 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -51,6 +51,7 @@ RowParallelLinear, ) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -336,7 +337,8 @@ def __init__(self, self.lm_head_weight = (self.model.transformer.wte.weight if config.weight_tying else self.model.transformer.ff_out.weight) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -353,13 +355,18 @@ def forward( ) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights( diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 6039b1cdc3534..86428e320e0f7 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -18,6 +18,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -256,7 +257,8 @@ def __init__( self.linear_method = linear_method self.model = OrionModel(config, linear_method) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -269,13 +271,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 039dc7a9b7675..ef70c823dc905 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -49,6 +49,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -240,7 +241,8 @@ def __init__(self, self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, bias=True) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -254,14 +256,18 @@ def forward( return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata, self.lm_head.bias) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - head = self.lm_head - next_tokens = self.sampler(head.weight, hidden_states, - sampling_metadata, head.bias) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index d4d5a4e8bb9a5..61ac2c6c605c6 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -19,6 +19,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -230,7 +231,8 @@ def __init__( self.linear_method = linear_method self.transformer = QWenModel(config, linear_method) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -243,13 +245,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 3e4f843e649b4..aa60744d14843 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -37,6 +37,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -300,7 +301,8 @@ def __init__( self.linear_method = linear_method self.model = Qwen2Model(config, linear_method) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -313,13 +315,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index c66f327beee7a..7624ca89ee670 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -33,6 +33,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -238,7 +239,8 @@ def __init__( self.linear_method = linear_method self.model = StableLMEpochModel(config, linear_method) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -251,13 +253,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index cfbb1bdb7909e..e418951a633ab 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -32,6 +32,7 @@ LinearMethodBase, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) @@ -254,7 +255,9 @@ def __init__(self, padding_size=DEFAULT_VOCAB_PADDING_SIZE, ) self.lm_head_weight = self.lm_head.weight - self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -267,13 +270,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: Optional[torch.Tensor], + logits: Optional[torch.Tensor], sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, From 1b29f924b6dcfb5e916529d814ba12b1fb44eb8f Mon Sep 17 00:00:00 2001 From: roy Date: Thu, 14 Mar 2024 22:21:46 +0800 Subject: [PATCH 11/19] fix lora test --- tests/lora/test_layers.py | 37 +++++++++++++++++++------------------ vllm/lora/layers.py | 2 +- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 46f054c5b84ef..b59587ff7a224 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -13,7 +13,7 @@ QKVParallelLinearWithLora, VocabParallelEmbeddingWithLoRA, RowParallelLinearWithLoRA, - SamplerWithLoRA, + LogitsProcessorWithLoRA, LoRAMapping, BaseLayerWithLoRA, ) @@ -21,6 +21,7 @@ PackedLoRALayerWeights) from vllm.config import LoRAConfig from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear, @@ -402,28 +403,28 @@ def test_lm_head_sampler(dist_init, num_loras, device) -> None: max_lora_rank=8, lora_dtype=torch.float16) - def create_random_sampler_layer(): + def create_random_lora_logits_processor_layer(): linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size, 1024, 32000) linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data[:, 32000:] = 0 - sampler = Sampler(32000 + lora_config.lora_extra_vocab_size, 32000) - lora_sampler = SamplerWithLoRA(sampler, 1024, linear.weight.dtype, + logits_processor = LogitsProcessor(32000 + lora_config.lora_extra_vocab_size, 32000) + lora_logits_processor = LogitsProcessorWithLoRA(logits_processor, 1024, linear.weight.dtype, linear.weight.device) - lora_sampler.create_lora_weights(max_loras, lora_config) + lora_logits_processor.create_lora_weights(max_loras, lora_config) - return linear, sampler, lora_sampler + return linear, logits_processor, lora_logits_processor for i in range(10): set_random_seed(i) id_to_index = get_random_id_to_index(num_loras, max_loras) - linear, sampler, lora_sampler = create_random_sampler_layer() + linear, logits_processor, lora_logits_processor = create_random_lora_logits_processor_layer() # NOTE: all the generated loras share the same embeddings tensor. lora_dict, _ = populate_loras( id_to_index, - layer=lora_sampler, + layer=lora_logits_processor, layer_weights=linear.weight, generate_embeddings_tensor=1024, ) @@ -447,34 +448,34 @@ def create_random_sampler_layer(): 32000, lora_config.lora_extra_vocab_size, ) - lora_sampler.set_mapping(*mapping_info, ) + lora_logits_processor.set_mapping(*mapping_info, ) - lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs), + lora_result = lora_logits_processor._get_logits(hidden_states=torch.cat(inputs), embedding=linear.weight, embedding_bias=None) original_weight = linear.weight.clone() - linear.weight[sampler.org_vocab_size:sampler.org_vocab_size + + linear.weight[logits_processor.org_vocab_size:logits_processor.org_vocab_size + embeddings_tensor_len] = embeddings_tensor - sampler.org_vocab_size = 32000 + lora_config.lora_extra_vocab_size + logits_processor.org_vocab_size = 32000 + lora_config.lora_extra_vocab_size expected_results = [] for input_, lora_id in zip(inputs, prompt_mapping): lora = lora_dict[lora_id] - result = sampler._get_logits(hidden_states=input_, + result = logits_processor._get_logits(hidden_states=input_, embedding=linear.weight, embedding_bias=None) result[:, 32000 + embeddings_tensor_len:] = float("-inf") result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling expected_results.append(result) expected_result = torch.cat(expected_results) - sampler.org_vocab_size = 32000 + logits_processor.org_vocab_size = 32000 # Check that resetting the lora weights succeeds for slot_idx in range(max_loras): - lora_sampler.reset_lora(slot_idx) + lora_logits_processor.reset_lora(slot_idx) inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=[0], @@ -488,12 +489,12 @@ def create_random_sampler_layer(): mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, 32000, lora_config.lora_extra_vocab_size) - lora_sampler.set_mapping(*mapping_info, ) + lora_logits_processor.set_mapping(*mapping_info, ) - lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs), + lora_result = lora_logits_processor._get_logits(hidden_states=torch.cat(inputs), embedding=original_weight, embedding_bias=None)[:, :32000] - expected_result = sampler._get_logits(hidden_states=torch.cat(inputs), + expected_result = logits_processor._get_logits(hidden_states=torch.cat(inputs), embedding=original_weight, embedding_bias=None) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 135cd1bca85e4..a886851b5be98 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -10,7 +10,6 @@ from vllm.config import LoRAConfig from vllm.lora.punica import add_lora, add_lora_slice, bgmv -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, @@ -20,6 +19,7 @@ RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( From 0f13076202c4450f209fe456b4437936e10c8dc1 Mon Sep 17 00:00:00 2001 From: roy Date: Thu, 14 Mar 2024 22:27:15 +0800 Subject: [PATCH 12/19] format --- tests/lora/test_layers.py | 45 ++++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index b59587ff7a224..7dfc3952016f5 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -20,7 +20,6 @@ from vllm.lora.models import (LoRALayerWeights, convert_mapping, PackedLoRALayerWeights) from vllm.config import LoRAConfig -from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, @@ -395,7 +394,7 @@ def create_random_embedding_layer(): @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_lm_head_sampler(dist_init, num_loras, device) -> None: +def test_lm_head_logits_processor(dist_init, num_loras, device) -> None: torch.set_default_device(device) max_loras = 8 @@ -403,14 +402,15 @@ def test_lm_head_sampler(dist_init, num_loras, device) -> None: max_lora_rank=8, lora_dtype=torch.float16) - def create_random_lora_logits_processor_layer(): + def _pretest(): linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size, 1024, 32000) linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data[:, 32000:] = 0 - logits_processor = LogitsProcessor(32000 + lora_config.lora_extra_vocab_size, 32000) - lora_logits_processor = LogitsProcessorWithLoRA(logits_processor, 1024, linear.weight.dtype, - linear.weight.device) + logits_processor = LogitsProcessor( + 32000 + lora_config.lora_extra_vocab_size, 32000) + lora_logits_processor = LogitsProcessorWithLoRA( + logits_processor, 1024, linear.weight.dtype, linear.weight.device) lora_logits_processor.create_lora_weights(max_loras, lora_config) return linear, logits_processor, lora_logits_processor @@ -419,7 +419,7 @@ def create_random_lora_logits_processor_layer(): set_random_seed(i) id_to_index = get_random_id_to_index(num_loras, max_loras) - linear, logits_processor, lora_logits_processor = create_random_lora_logits_processor_layer() + linear, logits_processor, lora_logits_processor = _pretest() # NOTE: all the generated loras share the same embeddings tensor. lora_dict, _ = populate_loras( @@ -450,22 +450,25 @@ def create_random_lora_logits_processor_layer(): ) lora_logits_processor.set_mapping(*mapping_info, ) - lora_result = lora_logits_processor._get_logits(hidden_states=torch.cat(inputs), - embedding=linear.weight, - embedding_bias=None) + lora_result = lora_logits_processor._get_logits( + hidden_states=torch.cat(inputs), + embedding=linear.weight, + embedding_bias=None) original_weight = linear.weight.clone() - linear.weight[logits_processor.org_vocab_size:logits_processor.org_vocab_size + + linear.weight[logits_processor. + org_vocab_size:logits_processor.org_vocab_size + embeddings_tensor_len] = embeddings_tensor - logits_processor.org_vocab_size = 32000 + lora_config.lora_extra_vocab_size + logits_processor.org_vocab_size = (32000 + + lora_config.lora_extra_vocab_size) expected_results = [] for input_, lora_id in zip(inputs, prompt_mapping): lora = lora_dict[lora_id] result = logits_processor._get_logits(hidden_states=input_, - embedding=linear.weight, - embedding_bias=None) + embedding=linear.weight, + embedding_bias=None) result[:, 32000 + embeddings_tensor_len:] = float("-inf") result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling expected_results.append(result) @@ -491,12 +494,14 @@ def create_random_lora_logits_processor_layer(): lora_config.lora_extra_vocab_size) lora_logits_processor.set_mapping(*mapping_info, ) - lora_result = lora_logits_processor._get_logits(hidden_states=torch.cat(inputs), - embedding=original_weight, - embedding_bias=None)[:, :32000] - expected_result = logits_processor._get_logits(hidden_states=torch.cat(inputs), - embedding=original_weight, - embedding_bias=None) + lora_result = lora_logits_processor._get_logits( + hidden_states=torch.cat(inputs), + embedding=original_weight, + embedding_bias=None)[:, :32000] + expected_result = logits_processor._get_logits( + hidden_states=torch.cat(inputs), + embedding=original_weight, + embedding_bias=None) rtol, atol = TOLERANCES[lora_result.dtype] assert torch.allclose(lora_result, From 8b66abd4d0a61d7450bdc91c50bc9188df70b12f Mon Sep 17 00:00:00 2001 From: roy Date: Thu, 14 Mar 2024 23:19:24 +0800 Subject: [PATCH 13/19] fix test --- vllm/model_executor/layers/logits_processor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 5aba4ade8cc58..baa113c342c28 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -30,6 +30,7 @@ def __init__(self, """ super().__init__() self.scale = scale + self.vocab_size = vocab_size # Transformers-neuronx generate outputs as logits directly. self.logits_as_hidden_states = is_neuron() # original vocabulary size (without LoRA). From 4c5703dbe439c3c687b73c1e6593a7dce2791cda Mon Sep 17 00:00:00 2001 From: roy Date: Thu, 14 Mar 2024 23:31:17 +0800 Subject: [PATCH 14/19] fix layers --- vllm/lora/layers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index a886851b5be98..a00c2811b716a 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -805,6 +805,10 @@ def logits_as_hidden_states(self): @property def vocab_size(self): return self.base_layer.vocab_size + + @property + def scale(self): + return self.base_layer.scale @property def org_vocab_size(self): From 6bdcbf410fc09db810af70e48c2ed29baf6f8144 Mon Sep 17 00:00:00 2001 From: roy Date: Thu, 14 Mar 2024 23:37:47 +0800 Subject: [PATCH 15/19] format --- vllm/lora/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index a00c2811b716a..f6cd1390d4bce 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -805,7 +805,7 @@ def logits_as_hidden_states(self): @property def vocab_size(self): return self.base_layer.vocab_size - + @property def scale(self): return self.base_layer.scale From 02603675bb733b3214393d06e25a324283da7949 Mon Sep 17 00:00:00 2001 From: roy Date: Fri, 15 Mar 2024 00:56:40 +0800 Subject: [PATCH 16/19] fix test and neuron --- tests/lora/conftest.py | 4 ++-- vllm/model_executor/models/neuron/llama.py | 15 +++++++++++---- vllm/model_executor/models/neuron/mistral.py | 15 +++++++++++---- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 30a8ad03c8ada..da939d331487f 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -85,7 +85,7 @@ def dummy_model() -> nn.Module: ("outact", nn.Sigmoid()), # Special handling for lm_head & sampler ("lm_head", ParallelLMHead(512, 10)), - ("sampler", Sampler(512)) + ("sampler", Sampler()) ])) model.config = MagicMock() return model @@ -110,7 +110,7 @@ def dummy_model_gate_up() -> nn.Module: ("outact", nn.Sigmoid()), # Special handling for lm_head & sampler ("lm_head", ParallelLMHead(512, 10)), - ("sampler", Sampler(512)) + ("sampler", Sampler()) ])) model.config = MagicMock() return model diff --git a/vllm/model_executor/models/neuron/llama.py b/vllm/model_executor/models/neuron/llama.py index e2856da99d9b1..32c43c4944fac 100644 --- a/vllm/model_executor/models/neuron/llama.py +++ b/vllm/model_executor/models/neuron/llama.py @@ -7,6 +7,7 @@ from transformers import LlamaConfig from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput @@ -25,7 +26,8 @@ def __init__( self.config = config self.linear_method = linear_method self.model = None - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -45,13 +47,18 @@ def forward( start_ids=seq_ids.flatten()) return logits + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.model.chkpt_model.lm_head, + hidden_states, sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.model.chkpt_model.lm_head, - hidden_states, sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/neuron/mistral.py b/vllm/model_executor/models/neuron/mistral.py index a302cce30abab..24fc0fa0aacab 100755 --- a/vllm/model_executor/models/neuron/mistral.py +++ b/vllm/model_executor/models/neuron/mistral.py @@ -6,6 +6,7 @@ from transformers import MistralConfig from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput @@ -26,7 +27,8 @@ def __init__( self.linear_method = linear_method self.model = None self.lm_head = None - self.sampler = Sampler(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() def forward( self, @@ -48,13 +50,18 @@ def forward( start_ids=seq_ids) return logits + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.model.chkpt_model.lm_head, + hidden_states, sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.model.chkpt_model.lm_head, - hidden_states, sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, From 6d95bc04e15645fdded427839e06750e4e450542 Mon Sep 17 00:00:00 2001 From: roy Date: Fri, 15 Mar 2024 07:51:23 +0800 Subject: [PATCH 17/19] fix test --- tests/lora/conftest.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index da939d331487f..38560c251696a 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -13,6 +13,7 @@ import vllm from vllm.config import LoRAConfig from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.model_loader import get_model from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, @@ -85,6 +86,7 @@ def dummy_model() -> nn.Module: ("outact", nn.Sigmoid()), # Special handling for lm_head & sampler ("lm_head", ParallelLMHead(512, 10)), + ("logits_processor", LogitsProcessor(512)), ("sampler", Sampler()) ])) model.config = MagicMock() @@ -110,6 +112,7 @@ def dummy_model_gate_up() -> nn.Module: ("outact", nn.Sigmoid()), # Special handling for lm_head & sampler ("lm_head", ParallelLMHead(512, 10)), + ("logits_processor", LogitsProcessor(512)), ("sampler", Sampler()) ])) model.config = MagicMock() From 0179bd150f218f3e3adb7aa512ccc085531977ea Mon Sep 17 00:00:00 2001 From: roy Date: Tue, 19 Mar 2024 20:46:23 +0800 Subject: [PATCH 18/19] fix --- vllm/model_executor/models/qwen2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 20fb6a6afdaca..6698f01b7c701 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -308,7 +308,6 @@ def __init__( config.hidden_size) self.lm_head_weight = self.lm_head.weight - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() From 744c0c5e7d8c2dd2f79982b41fb9c990218b0fa3 Mon Sep 17 00:00:00 2001 From: roy Date: Thu, 21 Mar 2024 06:47:48 +0800 Subject: [PATCH 19/19] fix comment --- vllm/worker/model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index cadb377bdb0a4..510ff122606a6 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -579,6 +579,7 @@ def execute_model( # Compute the logits. logits = self.model.compute_logits(hidden_states, sampling_metadata) + # Only perform sampling in the driver worker. if not sampling_metadata.perform_sampling: return None