Skip to content

Commit

Permalink
updates after review
Browse files Browse the repository at this point in the history
  • Loading branch information
Valery Chernov committed Jan 5, 2024
1 parent f11b7f8 commit 4c67afb
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 18 deletions.
20 changes: 11 additions & 9 deletions serve/mlc_serve/api/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,19 +241,21 @@ async def collect_result_stream(
finish_reasons[seq.index] = seq.finish_reason.value # type: ignore

choices = []
for index, (chunks, finish_reason) in enumerate(zip(sequences, finish_reasons)):
content = []
if logprob_infos[index] != []:
for logprob_info in logprob_infos[index]:
for index, (logprob_info_seq, chunks, finish_reason) in enumerate(zip(logprob_infos, sequences, finish_reasons)):
logprobs_content = []
if logprob_info_seq != []:
for logprob_info in logprob_info_seq:
cur_token_logprob_info = logprob_info[0]
top_logprobs_info = logprob_info[1]
top_logprobs = [TopLogprobs(
token=str(token),
logprob=float(logprob),
# TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object
bytes=None,
) for token, logprob in logprob_info[1]]
content.append(LogprobsContent(
token=str(logprob_info[0][0]),
logprob=float(logprob_info[0][1]),
) for token, logprob in top_logprobs_info]
logprobs_content.append(LogprobsContent(
token=str(cur_token_logprob_info[0]),
logprob=float(cur_token_logprob_info[1]),
# TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object
bytes=None,
top_logprobs=top_logprobs,
Expand All @@ -262,7 +264,7 @@ async def collect_result_stream(
index=index,
message=ChatMessage(role="assistant", content="".join(chunks)),
finish_reason=finish_reason,
logprobs=Logprobs(content=content),
logprobs=Logprobs(content=logprobs_content),
)
choices.append(choice)

Expand Down
2 changes: 1 addition & 1 deletion serve/mlc_serve/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@
get_prompt_sequence_id,
LOGPROBS_TYPE,
)
from .sampling_params import SamplingParams, SamplingType, TOP_LOGPROBS_NUMBER
from .sampling_params import SamplingParams, SamplingType, LOGPROB_TOP_K_MAX
6 changes: 3 additions & 3 deletions serve/mlc_serve/engine/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Dict, Optional

_SAMPLING_EPS = 1e-5
TOP_LOGPROBS_NUMBER = 5
LOGPROB_TOP_K_MAX = 5


class SamplingType(IntEnum):
Expand Down Expand Up @@ -105,9 +105,9 @@ def _verify_args(self) -> None:
f"logit bias must be in [-100, 100], got {bias} for token {token}."
)
if self.logprobs is not None and self.logprobs:
if (self.top_logprobs < 1 or self.top_logprobs > TOP_LOGPROBS_NUMBER):
if (self.top_logprobs < 1 or self.top_logprobs > LOGPROB_TOP_K_MAX):
raise ValueError(
f"top_logprobs must be between 1 and {TOP_LOGPROBS_NUMBER}, got {self.top_logprobs}."
f"top_logprobs must be between 1 and {LOGPROB_TOP_K_MAX}, got {self.top_logprobs}."
)

def _verify_greedy_sampling(self) -> None:
Expand Down
10 changes: 5 additions & 5 deletions serve/mlc_serve/model/paged_cache_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
SamplingType,
MLCServeEngineConfig,
SamplingParams,
TOP_LOGPROBS_NUMBER,
LOGPROB_TOP_K_MAX,
LOGPROBS_TYPE,
SequenceId,
PROMPT_SEQEUNCE_INDEX,
Expand Down Expand Up @@ -87,7 +87,7 @@ def _is_safe_to_sample(prob_like):
res_greedy_logprob, res_greedy = torch.max(logprobs, dim=-1)

top_greedy_logprob, top_greedy = torch.topk(
logprobs, k=TOP_LOGPROBS_NUMBER, dim=-1, largest=True, sorted=True
logprobs, k=LOGPROB_TOP_K_MAX, dim=-1, largest=True, sorted=True
)
# Convert to numpy
res_greedy_logprob = res_greedy_logprob.cpu().numpy()
Expand Down Expand Up @@ -145,7 +145,7 @@ def _is_safe_to_sample(prob_like):
probs = torch.softmax(logits_random, dim=-1)
logprobs = torch.log_softmax(logits_greedy, dim=-1)
top_random_logprob, top_random = torch.topk(
logprobs, k=TOP_LOGPROBS_NUMBER, dim=-1, largest=True, sorted=True
logprobs, k=LOGPROB_TOP_K_MAX, dim=-1, largest=True, sorted=True
)
top_random_logprob = top_random_logprob.cpu().numpy()
top_random = top_random.cpu().numpy()
Expand All @@ -161,8 +161,8 @@ def _is_safe_to_sample(prob_like):

res = np.empty((num_seq,), dtype=np.int32)
res_logprobs = np.empty((num_seq,), dtype=np.float32)
top = np.empty((num_seq, TOP_LOGPROBS_NUMBER), dtype=np.int32)
top_logprobs = np.empty((num_seq, TOP_LOGPROBS_NUMBER), dtype=np.float32)
top = np.empty((num_seq, LOGPROB_TOP_K_MAX), dtype=np.int32)
top_logprobs = np.empty((num_seq, LOGPROB_TOP_K_MAX), dtype=np.float32)

res[mask_random] = res_random
res_logprobs[mask_random] = res_random_logprobs
Expand Down

0 comments on commit 4c67afb

Please sign in to comment.