Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fake DPO / KTO #599

Merged
merged 30 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
729021f
h2o-wave = "==1.0.0" and minor upgrades
pascal-pfeiffer Jan 27, 2024
c1c03d8
Update requirements.txt
pascal-pfeiffer Jan 27, 2024
82057fe
SQLAlchemy
pascal-pfeiffer Jan 27, 2024
0591878
Merge remote-tracking branch 'origin/pp/upgrade_python_deps' into pp/…
pascal-pfeiffer Jan 27, 2024
4849292
Update requirements.txt
pascal-pfeiffer Jan 27, 2024
b6d46b4
replaced deprecated use_auth_token with token
pascal-pfeiffer Jan 27, 2024
c4182df
upd changed-filesv35 to changed-filesv41
pascal-pfeiffer Jan 27, 2024
d6ae214
mapped_column
pascal-pfeiffer Jan 27, 2024
581729c
use of __wave_submission_name__
pascal-pfeiffer Jan 29, 2024
f329b5e
more __wave_submission_name__
pascal-pfeiffer Jan 29, 2024
cfcdc56
Merge branch 'main' into pp/upgrade_python_deps
pascal-pfeiffer Jan 29, 2024
de88f05
c
psinger Jan 29, 2024
35715e8
m
psinger Jan 29, 2024
e009706
c
psinger Jan 29, 2024
cb3ccd1
m
psinger Jan 29, 2024
6043f1e
c
psinger Jan 31, 2024
965f576
c
psinger Jan 31, 2024
535444e
lock
psinger Feb 1, 2024
2a2a79f
format
psinger Feb 1, 2024
565dc46
tooltips
psinger Feb 1, 2024
c980a9b
m
psinger Feb 1, 2024
ab1d6c3
pipfile
psinger Feb 1, 2024
c5d1cbd
format
psinger Feb 1, 2024
1165467
Update requirements.txt
psinger Feb 1, 2024
235bc76
fix
psinger Feb 2, 2024
5c44e9d
Merge branch 'psi/dpofakepairs' of github.com:h2oai/h2o-llmstudio int…
psinger Feb 2, 2024
33a4ccd
Merge branch 'main' into psi/dpofakepairs
psinger Feb 5, 2024
241b05b
c
psinger Feb 5, 2024
33708cf
c
psinger Feb 5, 2024
976523d
README
psinger Feb 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ python_version = "3.10"
[packages]
torch = {index = "pytorch", version = "==2.1.2+cu118"}
tqdm = ">=4.65.0, <5.0.0"
transformers = "==4.36.1"
transformers = "==4.37.1"
numpy = ">=1.23.2, <2.0.0"
pandas = ">=2.1.0, <3.0.0"
scikit-learn = ">=1.0.2, <2.0.0"
Expand Down Expand Up @@ -46,7 +46,7 @@ Jinja2 = ">=3.1.3, <4.0.0"
tenacity = ">=8.2.2, <9.0.0"
h2o-wave = "==0.26.3"
tiktoken = "==0.5.1"
hf-transfer = "==0.1.3"
hf-transfer = "==0.1.5"
peft = "==0.5.0"
azure-storage-file-datalake = ">=12.12.0"
deepspeed = "==0.11.1"
Expand Down
216 changes: 140 additions & 76 deletions Pipfile.lock

Large diffs are not rendered by default.

6 changes: 1 addition & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,12 @@ Using CLI for fine-tuning LLMs:

## What's New

- [PR 592](https://github.com/h2oai/h2o-llmstudio/pull/599) Added `KTOPairLoss` for DPO modeling allowing to train models with simple preference data. Data currently needs to be manually prepared by randomly matching positive and negative examples as pairs.
- [PR 592](https://github.com/h2oai/h2o-llmstudio/pull/592) Starting to deprecate RLHF in favor of DPO/IPO optimization. Training is disabled, but old experiments are still viewable. RLHF will be fully removed in a future release.
- [PR 530](https://github.com/h2oai/h2o-llmstudio/pull/530) Introduced a new problem type for DPO/IPO optimization. This optimization technique can be used as an alternative to RLHF.
- [PR 288](https://github.com/h2oai/h2o-llmstudio/pull/288) Introduced Deepspeed for sharded training allowing to train larger models on machines with multiple GPUs. Requires NVLink. This feature replaces FSDP and offers more flexibility. Deepspeed requires a system installation of cudatoolkit and we recommend using version 11.8. See [Recommended Install](#recommended-install).
- [PR 449](https://github.com/h2oai/h2o-llmstudio/pull/449) New problem type for Causal Classification Modeling allows to train binary and multiclass models using LLMs.
- [PR 364](https://github.com/h2oai/h2o-llmstudio/pull/364) User secrets are now handled more securely and flexible. Support for handling secrets using the 'keyring' library was added. User settings are tried to be migrated automatically.
- [PR 328](https://github.com/h2oai/h2o-llmstudio/pull/328) RLHF is now a separate problem type. Note that starting a new RLHF experiment from an old experiment that used RLHF is no longer supported. To continue from a previous experiment, please start a new experiment and enter the settings from the previous experiment manually.
- [PR 308](https://github.com/h2oai/h2o-llmstudio/pull/308) Sequence to sequence models have been added as a new problem type.
- [PR 152](https://github.com/h2oai/h2o-llmstudio/pull/152) Add RLHF functionality for fine-tuning LLMs.
- [PR 132](https://github.com/h2oai/h2o-llmstudio/pull/131) Add 4bit training that allows training of larger LLM backbones with less GPU memory. See [here](https://huggingface.co/blog/4bit-transformers-bitsandbytes) for a comprehensive summary of this method.
- [PR 40](https://github.com/h2oai/h2o-llmstudio/pull/40) Added functionality for supporting nested conversations in data. A new `parent_id_column` can be selected for datasets to support tree-like structures in your conversational data. Additional `augmentation` settings have been added for this feature.

Please note that due to current rapid development we cannot guarantee full backwards compatibility of new functionality. We thus recommend to pin the version of the framework to the one you used for your experiments. For resetting, please delete/backup your `data` and `output` folders.

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1 @@
Defines the loss function H2O LLM Studio utilizes during model training. The loss function is a differentiable function measuring the prediction error. The model utilizes gradients of the loss function to update the model weights during training.

- **TokenAveragedCrossEntropy**
- H2O LLM Studio utilizes Cross Entropy as the loss function and averages over all tokens (excluding padding tokens) of one batch.
- **SampleAveragedCrossEntropy**
- H2O LLM Studio utilizes Cross Entropy as the loss function and averages over samples of one batch.
Defines the loss function H2O LLM Studio utilizes during model training. The loss function is a differentiable function measuring the prediction error. The model utilizes gradients of the loss function to update the model weights during training. The options depend on the selected Problem Type.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The column in the dataset containing the user prompt for the rejected answer. By default this can be set to None to take the same prompt as for the accepted answer and should only be changed if the accepted and rejected answers exhibit different prompts, such as when using KTOPairLoss.
19 changes: 13 additions & 6 deletions llm_studio/app_utils/sections/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,19 +190,16 @@ def load_cfg_model_tokenizer(
cfg.architecture.backbone_dtype = "float16"
cfg.architecture.pretrained = True

# if "cpu" in device:
# cfg.architecture.backbone_dtype = "float32"

with torch.device(cfg.environment._device):
model = cfg.architecture.model_class(cfg)
cfg.architecture.pretrained_weights = os.path.join(
experiment_path, "checkpoint.pth"
)
load_checkpoint(cfg, model, strict=False)

if merge and cfg.training.lora:
# merges the LoRa layers into the base model.
# This is needed if one wants to use the base model as a standalone model.
logger.info("Merging LORA layers with base model.")
model.backbone = model.backbone.merge_and_unload()

if device == "cpu_shard":
max_memory = get_balanced_memory(
model,
Expand All @@ -213,6 +210,16 @@ def load_cfg_model_tokenizer(
device_map=device_map,
)

if merge and cfg.training.lora:
# merges the LoRa layers into the base model.
# This is needed if one wants to use the base model as a standalone model.
logger.info("Merging LORA layers with base model.")
if device == "cpu":
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a side fix I did in this PR, some models cant be merged on cpu in float16.

model = model.to(torch.float32)
model.backbone = model.backbone.merge_and_unload()
if device == "cpu":
model = model.to(torch.float16)

model = model.eval()
model.backbone.use_cache = True

Expand Down
7 changes: 7 additions & 0 deletions llm_studio/python_configs/text_dpo_modeling_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,25 @@ class ConfigDPODataset(ConfigNLPCausalLMDataset):
limit_chained_samples: bool = True
mask_prompt_labels: bool = True

rejected_prompt_column: str = "None"
answer_column: str = "chosen_response"
rejected_answer_column: str = "rejected_response"

def __post_init__(self):
super().__post_init__()
self._possible_values["rejected_prompt_column"] = possible_values.Columns(
prefer_with=lambda column: column
in ("rejected_input", "rejected_prompt", "rejected_instruction"),
add_none=True,
)
self._possible_values["rejected_answer_column"] = possible_values.Columns(
prefer_with=lambda column: column
in ("rejected_answer", "rejected_response")
)

self._visibility["limit_chained_samples"] = -1
self._visibility["mask_prompt_labels"] = -1
self._order.insert("rejected_prompt_column", after="prompt_column")
self._order.insert("rejected_answer_column", after="answer_column")


Expand Down
15 changes: 12 additions & 3 deletions llm_studio/src/datasets/text_dpo_modeling_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,21 @@ def __init__(self, df: pd.DataFrame, cfg: Any, mode: str = "train"):
cfg.dataset.limit_chained_samples
), "Need to enable limit_chained_samples for dpo training"
super().__init__(df=df, cfg=cfg, mode=mode)

with PatchedAttribute(
cfg.dataset, "answer_column", cfg.dataset.rejected_answer_column
):
self.conversation_chain_handler_rejected = ConversationChainHandler(
self.df, cfg
)
if cfg.dataset.rejected_prompt_column != "None":
with PatchedAttribute(
cfg.dataset, "prompt_column", cfg.dataset.rejected_prompt_column
):
self.conversation_chain_handler_rejected = ConversationChainHandler(
self.df, cfg
)
else:
self.conversation_chain_handler_rejected = ConversationChainHandler(
self.df, cfg
)

def __getitem__(self, idx: int) -> Dict:
"""Reads a single text observation."""
Expand Down
59 changes: 56 additions & 3 deletions llm_studio/src/losses/text_dpo_modeling_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

import logging
from typing import Any, KeysView, Tuple
from typing import Any, KeysView

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -33,7 +33,7 @@ def forward(
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
):
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps

Expand Down Expand Up @@ -67,6 +67,53 @@ def get_losses(self, logits):
return losses


class KTOPairLoss(nn.Module):
"""
Implements original paired KTO implementation
Adopted from https://github.com/ContextualAI/HALOs
and https://github.com/huggingface/trl
"""

def __init__(self, cfg: Any):
super().__init__()
self.cfg = cfg

def forward(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
):
chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0)
rejected_KL = (
(policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0)
)

chosen_logratios = policy_chosen_logps - reference_chosen_logps
rejected_logratios = policy_rejected_logps - reference_rejected_logps
losses = torch.cat(
(
1
- F.sigmoid(self.cfg.training.beta * (chosen_logratios - rejected_KL)),
1
- F.sigmoid(self.cfg.training.beta * (chosen_KL - rejected_logratios)),
),
0,
)

chosen_rewards = (
self.cfg.training.beta
* (policy_chosen_logps - reference_chosen_logps).detach()
).float()
rejected_rewards = (
self.cfg.training.beta
* (policy_rejected_logps - reference_rejected_logps).detach()
).float()

return losses.mean(), chosen_rewards.mean(), rejected_rewards.mean()


class HingeLoss(DPOLoss):
def get_losses(self, logits):
losses = torch.relu(1 - self.cfg.training.beta * logits)
Expand Down Expand Up @@ -95,6 +142,7 @@ class Losses:
"DPOLoss": DPOLoss,
"HingeLoss": HingeLoss,
"IPOLoss": IPOLoss,
"KTOPairLoss": KTOPairLoss,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess KTOPairLoss needs to be added to LOSS_REDUCTION dict.

Mid-term it may sense to add get_batch_logps function directly to the loss calculation instead of using it in the model (and pass output dict with logits + labels to the loss functions). But not high priority atm.

}

@classmethod
Expand All @@ -113,4 +161,9 @@ def get(cls, name: str) -> Any:


# see https://github.com/huggingface/trl/commit/29d439a2043edf4455b05cae5a1e2ade69d22794
LOSS_REDUCTION = {"DPOLoss": False, "HingeLoss": True, "IPOLoss": True}
LOSS_REDUCTION = {
"DPOLoss": False,
"KTOPairLoss": False,
"HingeLoss": True,
"IPOLoss": True,
}
24 changes: 12 additions & 12 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ bitsandbytes==0.41.1
bleach==6.1.0; python_version >= '3.8'
blessed==1.20.0; python_version >= '2.7'
bokeh==3.3.4; python_version >= '3.9'
boto3==1.34.30; python_version >= '3.8'
botocore==1.34.30; python_version >= '3.8'
boto3==1.34.32; python_version >= '3.8'
botocore==1.34.32; python_version >= '3.8'
bravado==11.0.3; python_version not in '3.0, 3.1, 3.2, 3.3, 3.4' and python_full_version != '3.5.0'
bravado-core==6.1.1; python_version >= '3.7'
certifi==2023.11.17; python_version >= '3.6'
Expand All @@ -27,7 +27,7 @@ colorama==0.4.6; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.
contourpy==1.2.0; python_version >= '3.9'
coolname==2.2.0
cramjam==2.8.1; python_version >= '3.7'
cryptography==42.0.1; python_version >= '3.7'
cryptography==42.0.2; python_version >= '3.7'
datasets==2.15.0; python_full_version >= '3.8.0'
deepspeed==0.11.1
dill==0.3.7; python_version >= '3.7'
Expand All @@ -39,14 +39,14 @@ filelock==3.13.1; python_version >= '3.8'
fqdn==1.5.1
frozenlist==1.4.1; python_version >= '3.8'
fsspec[http]==2023.10.0; python_version >= '3.8'
future==0.18.3; python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3'
future==0.18.3; python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2'
gitdb==4.0.11; python_version >= '3.7'
gitpython==3.1.41; python_version >= '3.7'
gputil==1.4.0
greenlet==3.0.3; platform_machine == 'aarch64' or (platform_machine == 'ppc64le' or (platform_machine == 'x86_64' or (platform_machine == 'amd64' or (platform_machine == 'AMD64' or (platform_machine == 'win32' or platform_machine == 'WIN32')))))
h11==0.14.0; python_version >= '3.7'
h2o-wave==0.26.3; python_full_version >= '3.7.1'
hf-transfer==0.1.3; python_version >= '3.7'
hf-transfer==0.1.5; python_version >= '3.7'
hjson==3.1.0
httpcore==1.0.2; python_version >= '3.8'
httpx==0.26.0; python_version >= '3.8'
Expand All @@ -65,7 +65,7 @@ jsonpointer==2.4
jsonref==1.1.0; python_version >= '3.7'
jsonschema[format-nongpl]==4.21.1; python_version >= '3.8'
jsonschema-specifications==2023.12.1; python_version >= '3.8'
kaggle==1.6.3
kaggle==1.6.4
keyring==24.2.0; python_version >= '3.8'
markupsafe==2.1.4; python_version >= '3.7'
monotonic==1.6
Expand Down Expand Up @@ -93,8 +93,8 @@ pyarrow-hotfix==0.6; python_version >= '3.5'
pycparser==2.21
pydantic==1.10.14; python_version >= '3.7'
pyjwt==2.8.0; python_version >= '3.7'
python-dateutil==2.8.2; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
python-slugify==8.0.2; python_version >= '3.7'
python-dateutil==2.8.2; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2'
python-slugify==8.0.3; python_version >= '3.7'
pytz==2023.4
pyyaml==6.0.1; python_version >= '3.6'
readchar==4.0.5; python_version >= '3.7'
Expand All @@ -114,8 +114,8 @@ scipy==1.12.0; python_version >= '3.9'
secretstorage==3.3.3; sys_platform == 'linux'
sentencepiece==0.1.99
setuptools==69.0.3; python_version >= '3.8'
simplejson==3.19.2; python_version >= '2.5' and python_version not in '3.0, 3.1, 3.2, 3.3'
six==1.16.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
simplejson==3.19.2; python_version >= '2.5' and python_version not in '3.0, 3.1, 3.2'
six==1.16.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2'
smmap==5.0.1; python_version >= '3.7'
sniffio==1.3.0; python_version >= '3.7'
soupsieve==2.5; python_version >= '3.8'
Expand All @@ -130,11 +130,11 @@ text-unidecode==1.3
threadpoolctl==3.2.0; python_version >= '3.8'
tiktoken==0.5.1; python_version >= '3.8'
tokenizers==0.15.1; python_version >= '3.7'
toml==0.10.2; python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3'
toml==0.10.2; python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2'
torch==2.1.2+cu118
tornado==6.4; python_version >= '3.8'
tqdm==4.66.1; python_version >= '3.7'
transformers==4.36.1; python_full_version >= '3.8.0'
transformers==4.37.1; python_full_version >= '3.8.0'
triton==2.1.0; platform_system == 'Linux' and platform_machine == 'x86_64'
types-python-dateutil==2.8.19.20240106; python_version >= '3.8'
typing-extensions==4.9.0; python_version >= '3.8'
Expand Down
Loading