-
-
Notifications
You must be signed in to change notification settings - Fork 927
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DPO cleanup #1126
DPO cleanup #1126
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome PR! I left a comment in case you see fit. Also, maybe it could be tackled in a different PR, but the preprocess
command could also be updated to allow checking rl
datasets:
+ if parsed_cfg.rl:
+ _ = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
+ else:
+ _ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
- _ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
d5f97c3
to
c0a1553
Compare
|
||
def load(strategy, cfg): | ||
try: | ||
load_fn = strategy.split(".")[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is most likely not correct. The strategy
includes underscores, not .
, such as intel_apply_chatml
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def load(strategy, cfg):
try:
load_fn = strategy.split("_")[-1]
#strategy = ".".join(strategy.split("_")[:-1])
LOG.info(load_fn)
LOG.info(strategy)
mod = importlib.import_module(f".{load_fn}", "axolotl.prompt_strategies.dpo")
func = getattr(mod, strategy)
load_kwargs = {}
return func(cfg, **load_kwargs)
except Exception as e: # pylint: disable=broad-exception-caught
LOG.warning(e)
return None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works for me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the intention is the setting is something like
type: chatml.argilla
in which case it will load the argilla function from the axolotl.prompt_strategies.dpo.chatml
module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @winglian 👋🏻 thanks. That makes sense. I will test it later today 👍🏻
Co-authored-by: Agus <[email protected]>
Co-authored-by: Agus <[email protected]>
* cleanup dpo to be a little more extensible, add zephyr/nectar strategy * fix eos slash * support for eval split * fix kwargs * handle empty evals * don't load peft model for dpo * ensure dpo traning args gets bf16 for peft if applicable * fix duplicate kwargs for bf16 * make sure to respect the configured lr scheduler * supprt trainer callback to push config to wandb * set dataloader preload args * ensure that we are loading the lora when merging * Update src/axolotl/utils/data.py Co-authored-by: Agus <[email protected]> * support local datasets for dpo Co-authored-by: Agus <[email protected]> * chore: lint * dpo/kto/ipo smoke tests w lora, simplify dpo dataset type names * add split to dpo tests * fix rebase/merging error * handle edge case w logging * use accelerator for dpo datasets so it doesn't break the logger * missing args * validate checkpoint is an adapter for now * log warning when dataset strategy is not loadable --------- Co-authored-by: Agus <[email protected]>
Description
This PR cleans up some hardcoding, improves the integration with trl's DPOTrainer and adds support for dpo prompt_strategies.