-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support per-request seed #2514
Support per-request seed #2514
Conversation
Haven't looked into the PR yet, but we can guarantee the request ID is unique |
One issue I have encountered when investigating something similar is that due to not enough precision when doing forward pass, different batch sizes can actually lead to different tokens being sampled despite the seed being set. Something to keep in mind here! |
Thanks @Yard1! Yes I'm aware of that and it's even the case for greedy. But this should hopefully allow for "mostly stable" results. float16 is much better than bfloat16, quantized case is probably worse. Even OpenAI docs for the param say something along the lines of "best effort" :) I am just getting back to this PR now so will do some tests. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a test to ensure that the sampling is deterministic for the same seed?
Yes! |
Personally, I really like this feature. Looking forward to the merge! |
@WoosukKwon I saw you flagged for the imminent release ... I am working on the extra tests right now, should push them in the next hour or so. |
@njhill Oh just dropped it from the release tracker. Sorry if I pushed it too tight. For the next release, I think we will just focus on bug fixes. Let's ship this in v0.3.2. |
vllm/sequence.py
Outdated
@@ -359,6 +362,7 @@ class SequenceGroupMetadata: | |||
sampling_params: The sampling parameters used to generate the outputs. | |||
block_tables: The block tables. (Seq id -> list of physical block | |||
numbers) | |||
state: A dict for holding internal state tied to this sequence group. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should make this a dataclass with defined fields
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Yard1 I'd intentionally kept this opaque to maintain separation of concerns since the torch generator is lower level and managed only by the model runner / sampler. If we define a dataclass here then the member type would be torch.Generator
and as far as I can see torch
is decoupled from the engine layer. In this case the model runner just needs access to somewhere tied to the lifecycle of the sequence group that it can stash the corresponding Generator.
But happy to change it if you're sure about this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's mainly about ensuring we avoid errors caused by eg. typos in key names. Type hinting/checking is secondary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Yard1 have now changed it to use a dataclass
If the SamplingParams object passed to LLMEngine.add_request() is mutated after it returns, it could affect the async sampling process for that request. Suggested by @Yard1 vllm-project#2514 (comment)
If the SamplingParams object passed to LLMEngine.add_request() is mutated after it returns, it could affect the async sampling process for that request. Suggested by @Yard1 #2514 (comment)
vllm_outputs_seed_1_2 = vllm_model.generate(example_prompts, | ||
sampling_params_seed_1) | ||
vllm_outputs_seed_2_2 = vllm_model.generate(example_prompts, | ||
sampling_params_seed_2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we also shuffle the prompts here? it would also be great if we could test multiple different seeds in one batch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Yard1 I've updated the test to do a mix of different/same seeds in a single batch, also updated the other mixed batch sampler test to shuffle and compare the seeded requests.
757fb4a
to
e7cf1a1
Compare
Revert enforcement of best_of == 1 when using seed
Rather than SamplingParams object
Per @Yard1's review comment
I'm going to merge this considering the newest commit just removes comment. Here's the passing CI for commit 1a774fd https://buildkite.com/vllm/ci/builds/1465 |
If the SamplingParams object passed to LLMEngine.add_request() is mutated after it returns, it could affect the async sampling process for that request. Suggested by @Yard1 vllm-project#2514 (comment)
I encountered an issue where, with the same prompt and parameters of top_k=-1, top_p=1, temperature=0, and using continuous batching, there is a certain probability that the response values will differ when the number of concurrent requests is greater than 1. However, when testing with offline inference and a batch size of 2, the response values are always the same. It seems that continuous batching may affect the results of greedy sampling. |
@tdeng521 yes, this is expected due to the precision-related differences when floating point ops are accumulated differently, including different matmul implementations used for different batch sizes, etc. You should see it to a lesser degree if you try with float32. |
If the SamplingParams object passed to LLMEngine.add_request() is mutated after it returns, it could affect the async sampling process for that request. Suggested by @Yard1 vllm-project#2514 (comment)
If the SamplingParams object passed to LLMEngine.add_request() is mutated after it returns, it could affect the async sampling process for that request. Suggested by @Yard1 vllm-project/vllm#2514 (comment)
If the SamplingParams object passed to LLMEngine.add_request() is mutated after it returns, it could affect the async sampling process for that request. Suggested by @Yard1 vllm-project/vllm#2514 (comment)
If the SamplingParams object passed to LLMEngine.add_request() is mutated after it returns, it could affect the async sampling process for that request. Suggested by @Yard1 vllm-project#2514 (comment)
@WoosukKwon @zhuohan123 @simon-mo please let me know if this looks reasonable!
Question: Can we rely on(now n/a)request_id
to be unique? If not, this may also require assigning a guaranteed-unique internal id.Resolves #1211, #1595