-
Notifications
You must be signed in to change notification settings - Fork 418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fake DPO / KTO #599
Fake DPO / KTO #599
Changes from all commits
729021f
c1c03d8
82057fe
0591878
4849292
b6d46b4
c4182df
d6ae214
581729c
f329b5e
cfcdc56
de88f05
35715e8
e009706
cb3ccd1
6043f1e
965f576
535444e
2a2a79f
565dc46
c980a9b
ab1d6c3
c5d1cbd
1165467
235bc76
5c44e9d
33a4ccd
241b05b
33708cf
976523d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1 @@ | ||
Defines the loss function H2O LLM Studio utilizes during model training. The loss function is a differentiable function measuring the prediction error. The model utilizes gradients of the loss function to update the model weights during training. | ||
|
||
- **TokenAveragedCrossEntropy** | ||
- H2O LLM Studio utilizes Cross Entropy as the loss function and averages over all tokens (excluding padding tokens) of one batch. | ||
- **SampleAveragedCrossEntropy** | ||
- H2O LLM Studio utilizes Cross Entropy as the loss function and averages over samples of one batch. | ||
Defines the loss function H2O LLM Studio utilizes during model training. The loss function is a differentiable function measuring the prediction error. The model utilizes gradients of the loss function to update the model weights during training. The options depend on the selected Problem Type. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
The column in the dataset containing the user prompt for the rejected answer. By default this can be set to None to take the same prompt as for the accepted answer and should only be changed if the accepted and rejected answers exhibit different prompts, such as when using KTOPairLoss. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
""" | ||
|
||
import logging | ||
from typing import Any, KeysView, Tuple | ||
from typing import Any, KeysView | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
@@ -33,7 +33,7 @@ def forward( | |
policy_rejected_logps: torch.FloatTensor, | ||
reference_chosen_logps: torch.FloatTensor, | ||
reference_rejected_logps: torch.FloatTensor, | ||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: | ||
): | ||
pi_logratios = policy_chosen_logps - policy_rejected_logps | ||
ref_logratios = reference_chosen_logps - reference_rejected_logps | ||
|
||
|
@@ -67,6 +67,53 @@ def get_losses(self, logits): | |
return losses | ||
|
||
|
||
class KTOPairLoss(nn.Module): | ||
""" | ||
Implements original paired KTO implementation | ||
Adopted from https://github.com/ContextualAI/HALOs | ||
and https://github.com/huggingface/trl | ||
""" | ||
|
||
def __init__(self, cfg: Any): | ||
super().__init__() | ||
self.cfg = cfg | ||
|
||
def forward( | ||
self, | ||
policy_chosen_logps: torch.FloatTensor, | ||
policy_rejected_logps: torch.FloatTensor, | ||
reference_chosen_logps: torch.FloatTensor, | ||
reference_rejected_logps: torch.FloatTensor, | ||
): | ||
chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0) | ||
rejected_KL = ( | ||
(policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0) | ||
) | ||
|
||
chosen_logratios = policy_chosen_logps - reference_chosen_logps | ||
rejected_logratios = policy_rejected_logps - reference_rejected_logps | ||
losses = torch.cat( | ||
( | ||
1 | ||
- F.sigmoid(self.cfg.training.beta * (chosen_logratios - rejected_KL)), | ||
1 | ||
- F.sigmoid(self.cfg.training.beta * (chosen_KL - rejected_logratios)), | ||
), | ||
0, | ||
) | ||
|
||
chosen_rewards = ( | ||
self.cfg.training.beta | ||
* (policy_chosen_logps - reference_chosen_logps).detach() | ||
).float() | ||
rejected_rewards = ( | ||
self.cfg.training.beta | ||
* (policy_rejected_logps - reference_rejected_logps).detach() | ||
).float() | ||
|
||
return losses.mean(), chosen_rewards.mean(), rejected_rewards.mean() | ||
|
||
|
||
class HingeLoss(DPOLoss): | ||
def get_losses(self, logits): | ||
losses = torch.relu(1 - self.cfg.training.beta * logits) | ||
|
@@ -95,6 +142,7 @@ class Losses: | |
"DPOLoss": DPOLoss, | ||
"HingeLoss": HingeLoss, | ||
"IPOLoss": IPOLoss, | ||
"KTOPairLoss": KTOPairLoss, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess Mid-term it may sense to add |
||
} | ||
|
||
@classmethod | ||
|
@@ -113,4 +161,9 @@ def get(cls, name: str) -> Any: | |
|
||
|
||
# see https://github.com/huggingface/trl/commit/29d439a2043edf4455b05cae5a1e2ade69d22794 | ||
LOSS_REDUCTION = {"DPOLoss": False, "HingeLoss": True, "IPOLoss": True} | ||
LOSS_REDUCTION = { | ||
"DPOLoss": False, | ||
"KTOPairLoss": False, | ||
"HingeLoss": True, | ||
"IPOLoss": True, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was a side fix I did in this PR, some models cant be merged on cpu in float16.