-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fixing issue 46 #136
fixing issue 46 #136
Conversation
Signed-off-by: gkoren <[email protected]>
Signed-off-by: gkoren <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sincere apologies for the late review!
I really appreciate you tackling this issue, which is not trivial. Unfortunately it's more complex than that, for two reasons:
- We need to handle tensor parallelism, which means that modifying the logits probably needs to be done within
DistributedLogprob
. It's likely be going to be a bit tricky to implement but it should be doable, at the expense of a few more steps to handle top_k / top_p. Note that it may be more efficient (and less memory intensive) to gather only the top_k logits from each rank. - We need to also modify the logits used in the loss here
I also think we should add flags to control where exactly these transformations are applied. I'm actually not sure it's a good idea to apply it to compute the KL penalty term because:
- If we apply it to the reference policy, it may lead to infinite KL due to top_p / top_k (when we sample a token that has zero probability under the reference policy)
- If we don't apply it to the reference policy, then we may start with a high KL penalty from the start, which could cause some issues.
I would thus suggest to add some fine-grained control on where we apply this transformation, with the following default values:
model:
ppo:
transform_logits_from_sampling_params:
loss: True
kl_penalty_actor: False
kl_penalty_ref: ${.kl_penalty_actor}
This way we will be able to easily experiment with various configurations to see what actually works best in practice.
# apply the sampling params to the logits - focusing only on the generated tokens. | ||
context_length = context_lengths.min().item() | ||
resp_logits = logits[:, context_length - 1 :].contiguous() | ||
if not samparams.get("use_greedy", False): # if use_greedy is True, use the logits as is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor suggestion: move this up two lines and write it
if samparams.get("use_greedy", False):
return logits
which will avoid a couple of useless ops & extra indent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also: I think we should add a check to skip scaling if temp == 1, and top_p / top_k if they are equal to 1.0 (or 0.0) / 1. This way we don't mess with logits for no good reason.
replaced by new PR #186 |
What does this PR do ?
Apply sampling params to the logprobs of the response tokens (see issue #46)
The application of sampling params is done by default.
to be consistent with the response generation process (done in text_generation_utils.py )
the following parameters were taken into account:
note that:
use_greedy
is set to True (default), the generation doesnt change the logits, thus the original logits are used to compute the log prob, ignoring the other sampling params (top_p, temperature and top_k)compute_logprob
is True).Additional Information
sampling_params
#46