Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Push logprob generation to LLMEngine #3065

Merged
merged 9 commits into from
Mar 4, 2024

Conversation

Yard1
Copy link
Collaborator

@Yard1 Yard1 commented Feb 27, 2024

This PR moves the logprob detokenization logic away from the OpenAI server to the LLMEngine, allowing for consistent output between the two. It also is a first step towards making the OpenAI server more lightweight by pushing down some of its responsibilities.

It also ensures the logprob tokens are detokenized with the previous tokens in mind (same as generated tokens), which will make them more accurate.

Copy link
Collaborator

@esmeetu esmeetu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

vllm/sequence.py Outdated
and self.logprobs == other.logprobs)
equal = (self.parent_seq_id == other.parent_seq_id
and self.output_token == other.output_token)
log_probs_equal = ((len(other.logprobs) == len(self.logprobs))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it better to move this to Logprobs's __eq__?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good call!

@AguirreNicolas
Copy link
Contributor

I pointed out an issue w.r.t to LogProbs in a previous PR:

Please refers to the source of the ChatCompletionResponseChoice (class Choice source) in OpenAI.

There is is a new class called ChoiceLogprobs (source) that is not the same as LogProbs.

IMO, first it should be implemented both ChatCompletionTokenLogprob and TopLogprob (source)

API reference about the structure: https://platform.openai.com/docs/api-reference/chat/object

Maybe this PR is a good point to merge/attend this into vLLM ?

@Yard1
Copy link
Collaborator Author

Yard1 commented Feb 28, 2024

@AguirreNicolas IIUC, considering that change would need to be mainly implemented in OpenAI server, I think it should be independent of this PR.

@Yard1 Yard1 requested a review from esmeetu February 29, 2024 01:54
@Yard1
Copy link
Collaborator Author

Yard1 commented Feb 29, 2024

@esmeetu I had to add some extra logic, ptal again

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Yard1! I'd realized that something like this was needed while making changes to use a threadpool for tokenization (per #2879 (comment)). I'll wait until this is merged before opening the PR for that.

Comment on lines +106 to +108
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the error need to be the first response? This would also delay the first responses until after the first token is generated (which could include any time queuing I think)?

Copy link
Collaborator Author

@Yard1 Yard1 Mar 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember the openai client package was not able to handle errors unless they were the first thing that came out of the endpoint. I think the current version may be more robust, though. Will see if the test can still pass with the previous layout.

I think the slight delay in response is fine, it will not affect e2e time

for token_id, sample_logprob in logprobs.items():
if (sample_logprob.decoded_token is None and token_id != -1):
all_input_ids_with_logprob = all_input_ids[:-1] + [token_id]
_, new_text, prefix_offset, read_offset = detokenize_incrementally(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without knowing the OpenAI behaviour, IMHO it would be more appropriate here to use convert_ids_to_tokens and include the explicit/atomic token strings. Otherwise the text may not line up with the token.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the OpenAI behavior is exactly that - the logprob token text depends on the previous tokens and is not constant.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Yard1 ok, thanks! I'll take a closer look at this.

Copy link
Collaborator

@esmeetu esmeetu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Yard1 LGTM! Could you merge the latest branch and pass the CI?

@@ -30,6 +30,7 @@ class EngineArgs:
max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256
max_paddings: int = 256
max_log_probs: int = 5
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is max_logprobs better than this?
And we can add comments for why default value is 5 (from OpenAI API Reference?).

@Yard1 Yard1 enabled auto-merge (squash) March 4, 2024 19:07
@Yard1 Yard1 merged commit 22de452 into vllm-project:main Mar 4, 2024
22 checks passed
@njhill
Copy link
Member

njhill commented Mar 7, 2024

@Yard1 I realized that this is an API-breaking change for anyone consuming logprobs via the engine API (it actually broke our integration). I'm not sure what the project stance on this is w.r.t. semantic versioning but at minimum I guess it should be highlighted in the 0.3.4 release notes.

@pcmoritz
Copy link
Collaborator

pcmoritz commented Mar 8, 2024

@Yard1 This also breaks https://github.com/EleutherAI/lm-evaluation-harness -- we should either fix the harness or roll back the API change :)

  File "/home/ray/anaconda3/bin/lm_eval", line 8, in <module>
    sys.exit(cli_evaluate())
  File "/home/ray/default/lm-evaluation-harness/lm_eval/__main__.py", line 318, in cli_evaluate
    results = evaluator.simple_evaluate(
  File "/home/ray/default/lm-evaluation-harness/lm_eval/utils.py", line 288, in _wrapper
    return fn(*args, **kwargs)
  File "/home/ray/default/lm-evaluation-harness/lm_eval/evaluator.py", line 230, in simple_evaluate
    results = evaluate(
  File "/home/ray/default/lm-evaluation-harness/lm_eval/utils.py", line 288, in _wrapper
    return fn(*args, **kwargs)
  File "/home/ray/default/lm-evaluation-harness/lm_eval/evaluator.py", line 368, in evaluate
    resps = getattr(lm, reqtype)(cloned_reqs)
  File "/home/ray/default/lm-evaluation-harness/lm_eval/api/model.py", line 321, in loglikelihood
    return self._loglikelihood_tokens(new_reqs)
  File "/home/ray/default/lm-evaluation-harness/lm_eval/models/vllm_causallms.py", line 379, in _loglikelihood_tokens
    answer = self._parse_logprobs(
  File "/home/ray/default/lm-evaluation-harness/lm_eval/models/vllm_causallms.py", line 416, in _parse_logprobs
    continuation_logprobs = sum(
TypeError: unsupported operand type(s) for +: 'int' and 'Logprob'

@Yard1
Copy link
Collaborator Author

Yard1 commented Mar 8, 2024

We should fix the harness.

@pcmoritz
Copy link
Collaborator

pcmoritz commented Mar 8, 2024

Sounds good to me -- can you make a PR for it?

Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants