-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable Logprobs in MLC Batch Serving #82
Enable Logprobs in MLC Batch Serving #82
Conversation
56f5b41
to
0994bd8
Compare
OpenAI API doesn't specify how logprobs should be calculated and returned. I think it's better to wait until their new logprob API is released. Besides, logprob-related logic in vllm look very complicated and they are scattered across their codebase. Does this change implement the same logic? |
I also find vllm's implementation quite complicated so I only referred to their example as linked in the PR descrption. Given OpenAI's api is not revealed, I don't have strong opinion on whether to integrate it right now. It's more of a use case based decision so I would like to see how other folks think 👀. |
I think we can merge tentative logprob support now as long as
|
Although OpenAI spec does not reveal logprob yet, in my understanding, we need this for @vvchernov for his accuracy testing work. Let's get his feedback and incorporate with his PR #69. |
@zxybazh, seems like you might need to rebase. See the conflicts
|
3be4501
to
5957ae8
Compare
Hello @sunggg and @zxybazh! Thank you that called me! Good job! I think I could spend on more time for it due to I'm not so familiar with this code. I've added some suggestions to fix. But now I want to ask and discuss some things related to logprobs and why they are needed to us to clarify understanding from both sides. |
But before it I've marked in your example that the first top token is an empty string(""). I've met the same when test logprobs on mlc-llm side and avoid it by removing all processing related to system prompt. I'm not sure that it is correct behavior. Where do you get the example? Is it a result of measurements after the feature was added? |
Logprobs is a powerful tool of LLM and it is used in many scenarios:
Due to this I'm not sure that "logprob will not be frequently asked" in the closest future. Brief resume about loglikelyhood approach: Request consists of context and continuation strings. The context is question or part of sentence. Usually there are four continuations (answers on the question or the context continuation): one of them is assumed as correct. We force model to continue the context along the continuation tokens (not generate its own continuation) and sum of their logprobs (it is always negative). After that the maximum from four sums is found and if it is sum of the correct continuation the answer of model is assumed correct.
@masahi @zxybazh @sunggg any thoughts? cc @binarybana |
d62d756
to
e232862
Compare
Thanks @vvchernov for the detailed response! I agree Logprobs could be very useful. From implementation perspective, I think it possible to implement less intrusive logprob generation as sometimes they may not be required. The challeging part to me is when there're different logprob requests (some without logprob, some with different logprob topK numbers) in the same batch, it would be a bit hard to handle without losing performance. In that case we may need to separate the requests and build two topologies as you said. For your first comment on the "empty string", I'm not sure yet how this was generated but it is likely containing some escape characters which the current decoder cannot directly decode which made it look like an empty string. The example is generated by running a
Logprobs generated here are for each token during decode step, the definition of logprob could vary as OpenAI doesn't have a clear rule for that right now. We can try generate logprob during prefill step if that's what we need.
There're many sampling parameters to control when model generation should stop, the most frequently used one is length, a.k.a., max token number. For now the input is a prompt just like our regular use case in completion.
Right now I'm producing logprobs for decode steps only, would you please share an example where the new one is not included in the top5 of logprobs? If that's the case, we can still output the logprob of selected token (the specified token). |
Hello @zxybazh! Thank you for quick response!
My general point was not that Logprobs are valuable tool (of course, it is), but that may be it can become basic need of a client. Therefore I called @binarybana to discuss it with us and share with his point of view. In both cases the selection of with/without logprobs request and using two topologies for processing looks like solution. |
A little bit more details of my implementation and test of loglikelihood calculation on mlc-llm side. I compare results between original HF llama2-7b and llama-7b-chat-hf-q0f16 from mlc-llm on specified samples. The enough big gap was observed constantly. There were two reasons: 1. system prompt. Moreover when I set empty prompt it was still some gap. It still added some tokens before (like [INST], [/INST]). It disappeared when I commented SystemPromptProcessing method (I used mlc-chat pipeline for initial tests). 2. I observed that tokenizer.decode(context + continuation) != tokenizer.decode(context) + tokenizer.decode(continuation). I always observed incorrect first token for continuation which is empty string in encoded state. As result I cut it and the gap become minimal. |
For me just now it is what we need only. Decode step is not needed for loglikelihood calculation. |
Hi, guys. Thank you for the fruitful discussion!
What are the differences among these options? My general take is why not following the OpenAI spec given that it is a standard. OpenAI is working on bringing
If we only need it for prefill, I think performance penalty might be marginal. @vvchernov, it seems like prefill suffices your current need, but do you think you may need Also, I believe we can do more efficient discussion if we have performance data. So for the following three options below, can we clarify what we would need eventually so that @zxybazh can benchmark the performance implication?
|
Hello @sunggg!
|
Hello @masahi @zxybazh @sunggg! After discussion outside and analyzing openai-style implementation I've prepare resume of possible options I see to implement loglikelihood approach support over here.
What is the problem with current PR? If the context string is used as input, in each decode step (1) the logprob of the corresponding continuation token should be output, it can be not from top5 logprobs (2) the next token should be the corresponding continuation token, not one that the LLM "thinks" it should be. And (3) the process should be stopped when the continuation tokens are ended. I saw how did lm-evaluation-harness use deprecated OpenAI API for loglikelihood task evaluation.
response = oa_completion(
engine=self.engine,
prompt=inps,
echo=True,
max_tokens=0,
temperature=0.0,
logprobs=10,
) One more note: the completion is used, thus it resets cache each request. logprobs = response["logprobs"]["token_logprobs"]
continuation_logprobs = sum(logprobs[ctxlen:]) So the main part of this text: prefill step should be used for loglikelihood approach. The weak place here is all logprobs should be output. I see two options to do this: (1) two topologies of the same model are used. One return the last set of logits only without performance penalty (current state). Another one return all logits which are processed on CPU as needed. But Jason do not prefer this scenario (2) one topology is used. We use conditions and key to calculate all logits if needed. I see slight performance penalty due to logits projection before output (e.g. for llama2) and if I do not mistake, GPU calculates both branch of condition. One more thing: if we plan to use speculative decoding scenario we need to do the same, namely to get set of last logprobs for given continuation. @zxybazh my plan now is start from your branch and add to it some fixes and support of all logits on prefill step. I will pull my changes to your branch or we will decide how to move it here. P.S. Jason said me that @tqchen knows how to calculate perplexity without performance penalty. Tianqi if you have other understanding that I explain here, please, share it with us. cc @binarybana |
The new openai API for logprob is out |
Hello @masahi @zxybazh @sunggg! I've upstream OpenAI API in this PR. The latter updates the current branch. I suggest you to review it, merge to branch of @zxybazh and merge this PR. cc @binarybana |
Thanks @vvchernov for the quick update after OpenAI's new api went out. I've merged his PR into mine and will do a rebase to sync with current branch. |
|
||
top_greedy_logprob, top_greedy = torch.topk(logprobs, k=5, dim=-1, largest=True, sorted=True) | ||
# Convert to numpy | ||
res_greedy_logprob = res_greedy_logprob.cpu().numpy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like it is performance bottleneck place. Could we continue to use torch.tensor or do it async. Another scenario suggested by Masa is to calculate logprobs if need
b64db01
to
9b053e8
Compare
I think description due to logprob response format was changed |
Great write up, @zxybazh! To me, there is no obvious reason not to try |
Fix performance reduction
Hello guys! Last benchmark measurements show very small performance reduction for case when logprobs not requested (-0.25%) and it is in deviation range from test to test. The performance of request with logprobs is strongly slower. I think to not mix request with/without logprobs the logical next step would be to separate them (e.g. create LogprobRequest), simultaneously it finally resolves issue for the case when logprobs not requested |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok merging now, but please follow-up with my latest comments.
for info in logprob_infos: | ||
if info is not None: | ||
check = True | ||
break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just return logprob_infos
here. No need for check
variable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually you don't need this function at all. You can fuse get_raw_logprob_infos
and check_logprob_infos
by returning None from the former after doing the same check done by the latter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First of all output logprob_infos as None instead of list of Nones allowed me to decrease performance reduction from 2% to 0.25% for case when logprobs not requested. It was last mile in this task. Therefore I need such check. A little bit more details, looks like pydantic and dataclass are slower enough (of course, it is still of order of 100 ns) for init and filling when standard classes like dict and it starts to make sense for very intensive token generation by 7b models.
Unfortunately I can not simply embed check_logprob_infos
to get_raw_logprob_infos
(of course, it was good idea) due to the latter goes not along full list but part of it corresponded to greedy or random indices. Thus it does not work for mixed case when there are random and greedy requests processed together as was done for benchmark_throughput by default.
) -> Optional[RawLogprobsInfos]: | ||
if logprob_infos is None or logprob_infos[i] is None: | ||
return None | ||
return [logprob_infos[i]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for this to be a method of this class. Move it to model_common.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, but I also can not do it. Earlier it was done so, but led to performance reduction. In common case we have two options logprob_infos
is 1. None
, case when logprobs not requested, such trick allows to avoid performance reduction for this case. 2. list of None
or logprob_info
.
When we create TextGenerationResult
we need only one element from the list. It means we should check that it is a list (not None) and extract the element, it can not be transfer to model_common.py
side.
My doubts are we potentially can expect one element of RawLogprobsInfo
in TextGenerationResult
instead of the list with one element, it will slightly simplified the code. But I've seen that generated_tokens
is a list with some comments about speculative decoding. And I've decided that logprobs info should correspond to each token in the list (i.e. should be the list)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see what you mean. This function doesn't touch self
at all, but its implementation is valid only when used by this class. So it cannot be moved to model_common.py
which is supposed to be a collection of general utilities.
But soon I'm adding a PyTorch-based implementation of the model, in a file pt_model.py
. I don't want to repeat this class there. Assuming that get_logprob_infos
is called only by the TVM or by the PT model, I think it is ok to put it inside model_common.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logits[ind], | ||
token_ids[ind], | ||
top_logprobs, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to double check if this is correct. It seems like i
and ind
are switched to me.
logprob_infos[ind] = get_raw_logprob_info(
logits[i],
token_ids[i],
top_logprobs,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've double checked. It is correct, moreover it would failed with index out of bounds on benchmark test if it was not correct
This PR enables logprobs option in mlc server following vllm's example and openai's api.
Example query:
Example response:
Ready for review CC @sunggg @masahi