Skip to content

Commit

Permalink
hide logprobs calculation by condition, add dataclass for raw logprob…
Browse files Browse the repository at this point in the history
…s info, fix logprobs_random
  • Loading branch information
Valery Chernov committed Jan 8, 2024
1 parent 4c67afb commit 294815b
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 68 deletions.
8 changes: 5 additions & 3 deletions serve/mlc_serve/api/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,20 @@ class ChatCompletionRequest(BaseModel):
class TopLogprobs(BaseModel):
"""An OpenAI API compatible schema for logprobs output."""

token: str
# token is string in OpenAI, but for unification int type is added
token: Union[str, int]
logprob: float
bytes: Optional[List] = None


class LogprobsContent(BaseModel):
"""An OpenAI API compatible schema for logprobs output."""

token: str
# token is string in OpenAI, but for unification int type is added
token: Union[str, int]
logprob: float
bytes: Optional[List] = None
top_logprobs: List[TopLogprobs]
top_logprobs: List[TopLogprobs] # It can be empty


class Logprobs(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions serve/mlc_serve/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@
PROMPT_SEQEUNCE_INDEX,
get_prompt_sequence_id,
LOGPROBS_TYPE,
RawLogprobsInfo,
)
from .sampling_params import SamplingParams, SamplingType, LOGPROB_TOP_K_MAX
9 changes: 9 additions & 0 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from typing import List, Callable, Any, Optional, Dict, Tuple
import inspect
import numpy as np

from .sampling_params import SamplingParams, SamplingType

Expand All @@ -14,6 +15,14 @@
# ((token, logprob), [(top1_token, top1_logprob), ...])


@dataclass
class RawLogprobsInfo:
current_token: np.array
current_logprob: np.array
top_tokens: Optional[np.array]
top_logprobs: Optional[np.array]


# TODO(@sunggg): consider transition to something like Pydantic
@dataclass
class MLCServeEngineConfig:
Expand Down
4 changes: 2 additions & 2 deletions serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from .base import (
ChatMessage,
LOGPROBS_TYPE,
MLCServeEngineConfig,
RawLogprobsInfo,
RequestId,
RequestState,
SequenceId,
Expand Down Expand Up @@ -51,7 +51,7 @@ class TextGenerationResult:
# making this a list of token ids to leave room for speculative decoding
generated_tokens: List[int]
error: Optional[str]
logprob_info: Optional[LOGPROBS_TYPE] = None
logprob_info: Optional[RawLogprobsInfo] = None


class KVCache(Protocol):
Expand Down
141 changes: 78 additions & 63 deletions serve/mlc_serve/model/paged_cache_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from pathlib import Path
from collections import defaultdict
from typing import List, Union, Optional, Tuple
from dataclasses import dataclass

import structlog
import numpy as np
Expand All @@ -20,9 +19,9 @@
MLCServeEngineConfig,
SamplingParams,
LOGPROB_TOP_K_MAX,
LOGPROBS_TYPE,
SequenceId,
PROMPT_SEQEUNCE_INDEX,
RawLogprobsInfo,
get_prompt_sequence_id,
)
from ..engine.model_module import (
Expand All @@ -35,6 +34,49 @@
LOG = structlog.stdlib.get_logger(__name__)


def fetch_raw_logprob_infos(
logits,
res_tokens,
sampling_params,
) -> List[Optional[RawLogprobsInfo]]:
logprob_infos: List[Optional[RawLogprobsInfo]] = []
num_seq = logits.shape[0]
for index in range(num_seq):
if (
sampling_params[index].logprobs is None or
not sampling_params[index].logprobs
):
# Logprob sampling
logprobs = torch.log_softmax(logits[index], dim=-1)
res_logprob = logprobs[res_tokens[index]].cpu().numpy()

top_logprobs_num = sampling_params[index].top_logprobs
if (
top_logprobs_num is None or
top_logprobs_num == 0
):
top_logprobs = None
top_tokens = None
else:
assert top_logprobs_num <= LOGPROB_TOP_K_MAX, "Invalid input top_logprobs"
top_logprobs, top_tokens = torch.topk(
logprobs, k=top_logprobs_num, dim=-1, largest=True, sorted=True
)
top_tokens=top_tokens.cpu().numpy(),
top_logprobs=top_logprobs.cpu().numpy()

# Set to raw logprob info
logprob_infos.append(RawLogprobsInfo(
current_token=res_tokens[index].cpu().numpy(),
current_logprob=res_logprob,
top_tokens=top_tokens,
top_logprobs=top_logprobs,
))
else:
logprob_infos.append(None)
return logprob_infos


def _apply_top_p_top_k(logits, top_ps, top_ks):
p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device)
k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
Expand Down Expand Up @@ -63,7 +105,7 @@ def sample(
sampling_params: List[SamplingParams],
vocab_size: int,
check_safety=False,
) -> Optional[Tuple[np.ndarray, Optional[Tuple[Tuple, Tuple]]]]:
) -> Optional[Tuple[np.ndarray, List[Optional[RawLogprobsInfo]]]]:
def _is_safe_to_sample(prob_like):
return (
torch.sum(torch.isnan(prob_like) | torch.isinf(prob_like) | (prob_like < 0))
Expand All @@ -82,21 +124,18 @@ def _is_safe_to_sample(prob_like):
logits_greedy = logits[mask_greedy]

if logits_greedy.shape[0] > 0:
# Greedy sampling
logprobs = torch.log_softmax(logits_greedy, dim=-1)
res_greedy_logprob, res_greedy = torch.max(logprobs, dim=-1)

top_greedy_logprob, top_greedy = torch.topk(
logprobs, k=LOGPROB_TOP_K_MAX, dim=-1, largest=True, sorted=True
res_greedy = torch.argmax(logits_greedy, -1)

logprob_infos_greedy = fetch_raw_logprob_infos(
logits_greedy,
res_greedy,
sampling_params[mask_greedy],
)
# Convert to numpy
res_greedy_logprob = res_greedy_logprob.cpu().numpy()

res_greedy = res_greedy.cpu().numpy()
top_greedy_logprob = top_greedy_logprob.cpu().numpy()
top_greedy = top_greedy.cpu().numpy()
# Case when there's only greedy sampling
if logits_greedy.shape[0] == num_seq:
return res_greedy, ((res_greedy, res_greedy_logprob), (top_greedy, top_greedy_logprob))
return res_greedy, logprob_infos_greedy

temperatures = []
top_ps = []
Expand Down Expand Up @@ -131,7 +170,6 @@ def _is_safe_to_sample(prob_like):

if param.logit_bias:
logits[i][param.logit_bias_index] += torch.Tensor(param.logit_bias_value).type_as(logits).to(device=logits.device)


logits_random = logits[mask_random]

Expand All @@ -140,43 +178,40 @@ def _is_safe_to_sample(prob_like):
logits_random.div_(t.unsqueeze(dim=1))

if do_top_p or do_top_k:
# TODO(vvchernov): looks like there is misprinting. Should logits_random be returned?
# If no, where are logits used below?
logits = _apply_top_p_top_k(logits_random, top_ps, top_ks)

probs = torch.softmax(logits_random, dim=-1)
logprobs = torch.log_softmax(logits_greedy, dim=-1)
top_random_logprob, top_random = torch.topk(
logprobs, k=LOGPROB_TOP_K_MAX, dim=-1, largest=True, sorted=True
)
top_random_logprob = top_random_logprob.cpu().numpy()
top_random = top_random.cpu().numpy()

if check_safety and not _is_safe_to_sample(probs):
return None

res_random = torch.multinomial(probs, 1, True).cpu().numpy()[:, 0]
res_random_logprobs = torch.gather(logprobs, dim=-1, index=torch.tensor(res_random, dtype=torch.int64, device=logits.device)).cpu().numpy()
res_random = torch.multinomial(probs, 1, True)[:, 0]

logprob_infos_random = fetch_raw_logprob_infos(
logits_random,
res_random,
sampling_params[mask_random],
)

res_random = res_random.cpu().numpy()
# Case when there's only random sampling
if logits_random.shape[0] == num_seq:
return res_random, ((res_random, res_random_logprobs), (top_random, top_random_logprob))
return res_random, logprob_infos_random

res = np.empty((num_seq,), dtype=np.int32)
res_logprobs = np.empty((num_seq,), dtype=np.float32)
top = np.empty((num_seq, LOGPROB_TOP_K_MAX), dtype=np.int32)
top_logprobs = np.empty((num_seq, LOGPROB_TOP_K_MAX), dtype=np.float32)

res[mask_random] = res_random
res_logprobs[mask_random] = res_random_logprobs
top[mask_random] = top_random
top_logprobs[mask_random] = top_random_logprob

logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * num_seq
logprob_infos[mask_random] = logprob_infos_random

if logits_greedy.shape[0] > 0:
res[mask_greedy] = res_greedy
res_logprobs[mask_greedy] = res_greedy_logprob
top[mask_greedy] = top_greedy
top_logprobs[mask_greedy] = top_greedy_logprob

return res, ((res, res_logprobs), (top, top_logprobs))
logprob_infos[mask_greedy] = logprob_infos_greedy

return res, logprob_infos


def load_disco_module(artifact_path, lib_path, num_shards):
Expand Down Expand Up @@ -228,26 +263,6 @@ def get_tvm_model(config, dev):
return load_disco_module(config.model_artifact_path, lib_path, config.num_shards)


def fetch_logprobs(
logprob_info: Optional[Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]],
index: int,
sampling_param: SamplingParams,
) -> Optional[LOGPROBS_TYPE]: # np.ndarray inside
"""Fetch the logprob information with index"""
if (
sampling_param.logprobs is None or
not sampling_param.logprobs or
logprob_info is None
):
return None
(res, res_logprobs), (top, top_logprobs) = logprob_info
return (res[index],res_logprobs[index]), \
list(zip(
top[index][:sampling_param.top_logprobs],
top_logprobs[index][:sampling_param.top_logprobs]
))


def _prepare_inputs(
sequence_ids,
all_token_ids,
Expand Down Expand Up @@ -528,7 +543,7 @@ def generate(
cache.pending_copy_from_to = []

try:
next_tokens, logprob_info = sample(logits, sampling_params, self.vocab_size)
next_tokens, logprob_infos = sample(logits, sampling_params, self.vocab_size)
assert next_tokens is not None
outputs = []
for i, (sequence_id, new_token) in enumerate(
Expand All @@ -544,7 +559,7 @@ def generate(
sequence_id=SequenceId(sequence_id.request_id, seq_id),
generated_tokens=[new_token],
error=None,
logprob_info=fetch_logprobs(logprob_info, seq_id, sampling_params[seq_id]),
logprob_info=logprob_infos[i],
)
)
else:
Expand All @@ -553,7 +568,7 @@ def generate(
sequence_id=sequence_id,
generated_tokens=[new_token],
error=None,
logprob_info=fetch_logprobs(logprob_info, i, sampling_params[i]),
logprob_info=logprob_infos[i],
)
)

Expand All @@ -569,7 +584,7 @@ def generate(
for i, (sequence_id, logits_per_token, sampling_param) in enumerate(
zip(sequence_ids, torch.from_dlpack(logits), sampling_params)
):
maybe_new_token, logprob_info = sample(
maybe_new_token, logprob_infos = sample(
torch.unsqueeze(logits_per_token, 0),
[sampling_param],
self.vocab_size,
Expand All @@ -590,7 +605,7 @@ def generate(
),
generated_tokens=[new_token], # type: ignore
error=None,
logprob_info=fetch_logprobs(logprob_info, i, sampling_param)
logprob_info=logprob_infos[i]
)
)
else:
Expand All @@ -599,7 +614,7 @@ def generate(
sequence_id=sequence_id,
generated_tokens=[new_token], # type: ignore
error=None,
logprob_info=fetch_logprobs(logprob_info, i, sampling_param)
logprob_info=logprob_infos[i]
)
)
else:
Expand All @@ -612,7 +627,7 @@ def generate(
),
generated_tokens=[],
error=err_msg,
logprob_info=fetch_logprobs(logprob_info, i, sampling_param)
logprob_info=logprob_infos[i]
)
)
else:
Expand All @@ -621,7 +636,7 @@ def generate(
sequence_id=sequence_id,
generated_tokens=[],
error=err_msg,
logprob_info=fetch_logprobs(logprob_info, i, sampling_param)
logprob_info=logprob_infos[i]
)
)

Expand Down

0 comments on commit 294815b

Please sign in to comment.