From 33a3a8c39b7d0b4d2de66e55cedc4a04837156d7 Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Thu, 21 Mar 2024 07:27:24 +0000 Subject: [PATCH] adapted to #3233 and bug fix for gpt2 --- tests/models/test_models.py | 1 - vllm/model_executor/models/gpt2.py | 3 +- vllm/model_executor/models/jais.py | 116 ++++------------------------- 3 files changed, 14 insertions(+), 106 deletions(-) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 5488149227dff..fb567e837d281 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -20,7 +20,6 @@ "stabilityai/stablelm-3b-4e1t", "allenai/OLMo-1B", "bigcode/starcoder2-3b", - "core42/jais-13b", ] diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 263727cac19ff..e75dda750cb26 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -242,8 +242,7 @@ def sample( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, logits, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 261f570284ec0..471322a0ea144 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -34,17 +34,8 @@ QKVParallelLinear, RowParallelLinear, ) -from vllm.model_executor.layers.sampler import ( - Sampler, - _prune_hidden_states, - _apply_logits_processors, - _apply_penalties, - _apply_top_k_top_p, - _apply_min_p, - _sample, - _get_logprobs, - _build_sampler_output, -) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from vllm.model_executor.parallel_utils.parallel_state import ( @@ -85,90 +76,6 @@ def get_slopes_power_of_2(n): 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) -class JAISSampler(Sampler): - - def __init__(self, - vocab_size: int, - org_vocab_size: Optional[int] = None) -> None: - super().__init__(vocab_size, org_vocab_size) - - def forward( - self, - embedding: torch.Tensor, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - output_logits_scale: float, - embedding_bias: Optional[torch.Tensor] = None, - ) -> Optional[SamplerOutput]: - # Get the hidden states that we use for sampling. - if self.logits_as_hidden_states: - logits = hidden_states - else: - hidden_states = _prune_hidden_states(hidden_states, - sampling_metadata) - - # Get the logits for the next tokens. - logits = self._get_logits(hidden_states, embedding, embedding_bias) - if logits is not None: - logits *= torch.tensor(float(output_logits_scale), - dtype=logits.dtype) - - # Only perform sampling in the driver worker. - # Note: `_get_logits` is still distributed across TP workers because - # the `embedding` weight is distributed across TP workers. - # TODO(zhuohan): Change the get_logits part to a separate stage. - if not sampling_metadata.perform_sampling: - return None - - assert logits is not None - _, vocab_size = logits.shape - - # Apply logits processors (if any). - logits = _apply_logits_processors(logits, sampling_metadata) - - # Prepare sampling tensors with pinned memory to avoid blocking. - (sampling_tensors, do_penalties, do_top_p_top_k, - do_min_p) = (SamplingTensors.from_sampling_metadata( - sampling_metadata, vocab_size, logits.device, logits.dtype)) - - # Apply presence and frequency penalties. - if do_penalties: - logits = _apply_penalties( - logits, - sampling_tensors.prompt_tokens, - sampling_tensors.output_tokens, - sampling_tensors.presence_penalties, - sampling_tensors.frequency_penalties, - sampling_tensors.repetition_penalties, - ) - - # Apply temperature scaling. - # Use in-place division to avoid creating a new tensor. - logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1)) - - if do_top_p_top_k: - logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, - sampling_tensors.top_ks) - - if do_min_p: - logits = _apply_min_p(logits, sampling_tensors.min_ps) - - # We use float32 for probabilities and log probabilities. - # Compute the probabilities. - probs = torch.softmax(logits, dim=-1, dtype=torch.float) - # Compute the log probabilities. - # Use log_softmax to ensure numerical stability. - logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) - - # Sample the next tokens. - sample_results = _sample(probs, logprobs, sampling_metadata) - # Get the logprobs query results. - prompt_logprobs, sample_logprobs = _get_logprobs( - logprobs, sampling_metadata, sample_results) - return _build_sampler_output(sample_results, sampling_metadata, - prompt_logprobs, sample_logprobs) - - class JAISAttention(nn.Module): def __init__( @@ -381,7 +288,9 @@ def __init__( else: self.output_logits_scale = (config.mup_output_alpha * config.mup_width_scale) - self.sampler = JAISSampler(config.vocab_size) + self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size, + scale=self.output_logits_scale) + self.sampler = Sampler() def forward( self, @@ -394,17 +303,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler( - self.lm_head_weight, - hidden_states, - sampling_metadata, - self.output_logits_scale, - ) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(