From 7a38dbe67444fc1224713b68084a04c72b2f6fc9 Mon Sep 17 00:00:00 2001 From: NJordan72 Date: Tue, 24 Dec 2024 16:18:50 -0500 Subject: [PATCH] fix: allow trainer builder to use custom jinja chat template (#2219) * fix: allow trainer builder to use custom jinja chat template * chore: use get_chat_template_from_config Co-authored-by: Chirag Jain * fix: swap imports --------- Co-authored-by: Chirag Jain --- src/axolotl/core/trainer_builder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index fffddac81..e81740399 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -68,7 +68,7 @@ ) from axolotl.utils.callbacks.lisa import lisa_callback_factory from axolotl.utils.callbacks.profiler import PytorchProfilerCallback -from axolotl.utils.chat_templates import get_chat_template +from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, @@ -1834,8 +1834,8 @@ def build(self, total_num_steps): training_arguments_kwargs["model_type"] = self.cfg.model_config_type training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset) if self.cfg.chat_template: - training_arguments_kwargs["chat_template"] = get_chat_template( - self.cfg.chat_template, + training_arguments_kwargs["chat_template"] = get_chat_template_from_config( + cfg=self.cfg, tokenizer=self.tokenizer, )