diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 7d47d0e496..63f9272acc 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -59,6 +59,22 @@ LOG = logging.getLogger("axolotl.core.trainer_builder") +def _sanitize_kwargs_for_tagging(tag_names, kwargs=None): + if isinstance(tag_names, str): + tag_names = [tag_names] + + if kwargs is not None: + if "tags" not in kwargs: + kwargs["tags"] = tag_names + elif "tags" in kwargs and isinstance(kwargs["tags"], list): + kwargs["tags"].extend(tag_names) + elif "tags" in kwargs and isinstance(kwargs["tags"], str): + tag_names.append(kwargs["tags"]) + kwargs["tags"] = tag_names + + return kwargs + + @dataclass class AxolotlTrainingArguments(TrainingArguments): """ @@ -349,30 +365,13 @@ def compute_loss(self, model, inputs, return_outputs=False): # return (loss, outputs) if return_outputs else loss return super().compute_loss(model, inputs, return_outputs=return_outputs) - def _sanitize_kwargs_for_tagging(self, tag_names, kwargs=None): - if isinstance(tag_names, str): - tag_names = [tag_names] - - if kwargs is not None: - if "tags" not in kwargs: - kwargs["tags"] = tag_names - elif "tags" in kwargs and isinstance(kwargs["tags"], list): - kwargs["tags"].extend(tag_names) - elif "tags" in kwargs and isinstance(kwargs["tags"], str): - tag_names.append(kwargs["tags"]) - kwargs["tags"] = tag_names - - return kwargs - @wraps(Trainer.push_to_hub) def push_to_hub(self, *args, **kwargs) -> str: """ Overwrite the `push_to_hub` method in order to force-add the tags when pushing the model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. """ - kwargs = self._sanitize_kwargs_for_tagging( - tag_names=self.tag_names, kwargs=kwargs - ) + kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) return super().push_to_hub(*args, **kwargs) @@ -471,6 +470,24 @@ def create_scheduler( return self.lr_scheduler +class AxolotlDPOTrainer(DPOTrainer): + """ + Extend the base DPOTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "dpo"] + + @wraps(DPOTrainer.push_to_hub) + def push_to_hub(self, *args, **kwargs) -> str: + """ + Overwrite the `push_to_hub` method in order to force-add the tags when pushing the + model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. + """ + kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) + + return super().push_to_hub(*args, **kwargs) + + class TrainerBuilderBase(abc.ABC): """ Base class for trainer builder @@ -1076,7 +1093,7 @@ def build(self, total_num_steps): dpo_trainer_kwargs[ "precompute_ref_log_probs" ] = self.cfg.precompute_ref_log_probs - dpo_trainer = DPOTrainer( + dpo_trainer = AxolotlDPOTrainer( self.model, self.model_ref, args=training_args,