Skip to content

Commit

Permalink
adapted to vllm-project#3233 and bug fix for gpt2
Browse files Browse the repository at this point in the history
  • Loading branch information
Lalit Pradhan committed Mar 21, 2024
1 parent 85cc0ce commit 33a3a8c
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 106 deletions.
1 change: 0 additions & 1 deletion tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
"stabilityai/stablelm-3b-4e1t",
"allenai/OLMo-1B",
"bigcode/starcoder2-3b",
"core42/jais-13b",
]


Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,7 @@ def sample(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, logits,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self,
Expand Down
116 changes: 13 additions & 103 deletions vllm/model_executor/models/jais.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,8 @@
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.sampler import (
Sampler,
_prune_hidden_states,
_apply_logits_processors,
_apply_penalties,
_apply_top_k_top_p,
_apply_min_p,
_sample,
_get_logprobs,
_build_sampler_output,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, )
from vllm.model_executor.parallel_utils.parallel_state import (
Expand Down Expand Up @@ -85,90 +76,6 @@ def get_slopes_power_of_2(n):
2 * closest_power_of_2)[0::2][:n - closest_power_of_2])


class JAISSampler(Sampler):

def __init__(self,
vocab_size: int,
org_vocab_size: Optional[int] = None) -> None:
super().__init__(vocab_size, org_vocab_size)

def forward(
self,
embedding: torch.Tensor,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
output_logits_scale: float,
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[SamplerOutput]:
# Get the hidden states that we use for sampling.
if self.logits_as_hidden_states:
logits = hidden_states
else:
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)

# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, embedding, embedding_bias)
if logits is not None:
logits *= torch.tensor(float(output_logits_scale),
dtype=logits.dtype)

# Only perform sampling in the driver worker.
# Note: `_get_logits` is still distributed across TP workers because
# the `embedding` weight is distributed across TP workers.
# TODO(zhuohan): Change the get_logits part to a separate stage.
if not sampling_metadata.perform_sampling:
return None

assert logits is not None
_, vocab_size = logits.shape

# Apply logits processors (if any).
logits = _apply_logits_processors(logits, sampling_metadata)

# Prepare sampling tensors with pinned memory to avoid blocking.
(sampling_tensors, do_penalties, do_top_p_top_k,
do_min_p) = (SamplingTensors.from_sampling_metadata(
sampling_metadata, vocab_size, logits.device, logits.dtype))

# Apply presence and frequency penalties.
if do_penalties:
logits = _apply_penalties(
logits,
sampling_tensors.prompt_tokens,
sampling_tensors.output_tokens,
sampling_tensors.presence_penalties,
sampling_tensors.frequency_penalties,
sampling_tensors.repetition_penalties,
)

# Apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))

if do_top_p_top_k:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)

if do_min_p:
logits = _apply_min_p(logits, sampling_tensors.min_ps)

# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities.
# Use log_softmax to ensure numerical stability.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)

# Sample the next tokens.
sample_results = _sample(probs, logprobs, sampling_metadata)
# Get the logprobs query results.
prompt_logprobs, sample_logprobs = _get_logprobs(
logprobs, sampling_metadata, sample_results)
return _build_sampler_output(sample_results, sampling_metadata,
prompt_logprobs, sample_logprobs)


class JAISAttention(nn.Module):

def __init__(
Expand Down Expand Up @@ -381,7 +288,9 @@ def __init__(
else:
self.output_logits_scale = (config.mup_output_alpha *
config.mup_width_scale)
self.sampler = JAISSampler(config.vocab_size)
self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size,
scale=self.output_logits_scale)
self.sampler = Sampler()

def forward(
self,
Expand All @@ -394,17 +303,18 @@ def forward(
input_metadata)
return hidden_states

def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits

def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(
self.lm_head_weight,
hidden_states,
sampling_metadata,
self.output_logits_scale,
)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(
Expand Down

0 comments on commit 33a3a8c

Please sign in to comment.