From 2f64ea90792c8fc5ec5fe270f2bdea747e082246 Mon Sep 17 00:00:00 2001 From: bursteratom Date: Sun, 24 Nov 2024 14:29:42 -0500 Subject: [PATCH 1/6] rebased add_ds_model_card --- src/axolotl/core/trainer_builder.py | 67 +++++++++++++++++++++++------ src/axolotl/train.py | 24 +++++++++-- 2 files changed, 73 insertions(+), 18 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 75219a274..885a3402e 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -107,6 +107,24 @@ def _sanitize_kwargs_for_tagging(tag_names, kwargs=None): return kwargs +def _sanitize_kwargs_for_ds_tagging(dataset_tag_names, kwargs=None): + if isinstance(dataset_tag_names, type(None)): + return kwargs + if isinstance(dataset_tag_names, str): + dataset_tag_names = [dataset_tag_names] + + if kwargs is not None: + if "datasets" not in kwargs: + kwargs["datasets"] = dataset_tag_names + elif "datasets" in kwargs and isinstance(kwargs["datasets"], list): + kwargs["datasets"].extend(dataset_tag_names) + elif "datasets" in kwargs and isinstance(kwargs["datasets"], str): + dataset_tag_names.append(kwargs["datasets"]) + kwargs["datasets"] = dataset_tag_names + + return kwargs + + @dataclass class AxolotlTrainingMixins: """ @@ -410,10 +428,12 @@ def __init__( *_args, bench_data_collator=None, eval_data_collator=None, + dataset_tag_names=None, **kwargs, ): self.bench_data_collator = bench_data_collator self.eval_data_collator = eval_data_collator + self.dataset_tag_names = dataset_tag_names super().__init__(*_args, **kwargs) self.train_data_collator = self.data_collator self._stored_metrics = defaultdict(lambda: defaultdict(list)) @@ -871,6 +891,9 @@ def push_to_hub(self, *args, **kwargs) -> str: Overwrite the `push_to_hub` method in order to force-add the tags when pushing the model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. """ + kwargs = _sanitize_kwargs_for_ds_tagging( + dataset_tag_names=self.dataset_tag_names, kwargs=kwargs + ) kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) return super().push_to_hub(*args, **kwargs) @@ -994,8 +1017,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): tag_names = ["axolotl", "dpo"] - def __init__(self, *args, **kwargs): + def __init__(self, *args, dataset_tag_names=None, **kwargs): super().__init__(*args, **kwargs) + self.dataset_tag_names = dataset_tag_names self.optimizer = None def create_optimizer(self): @@ -1034,6 +1058,9 @@ def push_to_hub(self, *args, **kwargs) -> str: Overwrite the `push_to_hub` method in order to force-add the tags when pushing the model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. """ + kwargs = _sanitize_kwargs_for_ds_tagging( + dataset_tag_names=self.dataset_tag_names, kwargs=kwargs + ) kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) return super().push_to_hub(*args, **kwargs) @@ -1212,17 +1239,11 @@ def get_post_trainer_create_callbacks(self, trainer): Callbacks added after the trainer is created, usually b/c these need access to the trainer """ callbacks = [] - if self.cfg.plugins: - plugin_manager = PluginManager.get_instance() - callbacks.extend( - [ - cb - for cb in plugin_manager.add_callbacks_post_trainer( - self.cfg, trainer - ) - if cb - ] - ) + + plugin_manager = PluginManager.get_instance() + callbacks.extend( + plugin_manager.add_callbacks_post_trainer(cfg=self.cfg, trainer=trainer) + ) return callbacks def hook_pre_create_training_args(self, training_arguments_kwargs): @@ -1269,7 +1290,7 @@ def get_callbacks(self): return callbacks def get_post_trainer_create_callbacks(self, trainer): - callbacks = [] + callbacks = super().get_post_trainer_create_callbacks(trainer=trainer) if self.cfg.use_wandb and self.cfg.eval_table_size > 0: LogPredictionCallback = log_prediction_callback_factory( trainer, self.tokenizer, "wandb" @@ -1307,7 +1328,17 @@ def get_post_trainer_create_callbacks(self, trainer): if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers: callbacks.append(lisa_callback_factory(trainer)) - callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer)) + if self.cfg.plugins: + plugin_manager = PluginManager.get_instance() + callbacks.extend( + [ + cb + for cb in plugin_manager.add_callbacks_post_trainer( + self.cfg, trainer + ) + if cb + ] + ) return callbacks def _get_trainer_cls(self): @@ -1755,6 +1786,10 @@ def build(self, total_num_steps): else: trainer_kwargs["tokenizer"] = self.tokenizer + if (trainer_cls is not AxolotlRewardTrainer) and self.cfg.datasets is not None: + trainer_kwargs["dataset_tag_names"] = [ + d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir() + ] trainer = trainer_cls( model=self.model, train_dataset=self.train_dataset, @@ -2028,6 +2063,10 @@ def build(self, total_num_steps): else: dpo_trainer_kwargs["tokenizer"] = self.tokenizer + if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer): + dpo_trainer_kwargs["dataset_tag_names"] = [ + d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir() + ] dpo_trainer = trainer_cls( *trainer_cls_args, args=training_args, diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 5fde4d384..7ea0367f7 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -259,11 +259,27 @@ def terminate_handler(_, __, model_weakref): model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) if not cfg.hub_model_id: + from huggingface_hub import HfApi + from huggingface_hub.utils import RepositoryNotFoundError + try: - trainer.create_model_card( - model_name=cfg.output_dir.lstrip("./").encode("utf-8").decode("utf-8") - ) - except (AttributeError, UnicodeDecodeError): + # Check to make sure the base model is from HuggingFace not a local directory + hf_api = HfApi() + hf_api.model_info(cfg.base_model) + + model_card_kwarg = {"model_name": cfg.output_dir.lstrip("./")} + if cfg.datasets is not None: + if cfg.rl is not None or cfg.reward_model: + model_card_kwarg["dataset_name"] = [ + d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir() + ] + else: + model_card_kwarg["dataset_tags"] = [ + d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir() + ] + + trainer.create_model_card(**model_card_kwarg) + except (AttributeError, UnicodeDecodeError, RepositoryNotFoundError): pass elif cfg.hub_model_id: # defensively push to the hub to ensure the model card is updated From e981e54cef9d64380535b78e82912a9999fe883b Mon Sep 17 00:00:00 2001 From: bursteratom Date: Sun, 24 Nov 2024 16:00:46 -0500 Subject: [PATCH 2/6] manual rebasing --- src/axolotl/core/trainer_builder.py | 30 +++++++++++++---------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 885a3402e..db6c180a1 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1239,11 +1239,17 @@ def get_post_trainer_create_callbacks(self, trainer): Callbacks added after the trainer is created, usually b/c these need access to the trainer """ callbacks = [] - - plugin_manager = PluginManager.get_instance() - callbacks.extend( - plugin_manager.add_callbacks_post_trainer(cfg=self.cfg, trainer=trainer) - ) + if self.cfg.plugins: + plugin_manager = PluginManager.get_instance() + callbacks.extend( + [ + cb + for cb in plugin_manager.add_callbacks_post_trainer( + self.cfg, trainer + ) + if cb + ] + ) return callbacks def hook_pre_create_training_args(self, training_arguments_kwargs): @@ -1290,7 +1296,7 @@ def get_callbacks(self): return callbacks def get_post_trainer_create_callbacks(self, trainer): - callbacks = super().get_post_trainer_create_callbacks(trainer=trainer) + callbacks = [] if self.cfg.use_wandb and self.cfg.eval_table_size > 0: LogPredictionCallback = log_prediction_callback_factory( trainer, self.tokenizer, "wandb" @@ -1328,17 +1334,7 @@ def get_post_trainer_create_callbacks(self, trainer): if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers: callbacks.append(lisa_callback_factory(trainer)) - if self.cfg.plugins: - plugin_manager = PluginManager.get_instance() - callbacks.extend( - [ - cb - for cb in plugin_manager.add_callbacks_post_trainer( - self.cfg, trainer - ) - if cb - ] - ) + callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer)) return callbacks def _get_trainer_cls(self): From d10b063832904ac6d5f942aa5f5b94f7c5badf22 Mon Sep 17 00:00:00 2001 From: bursteratom Date: Mon, 25 Nov 2024 09:49:11 -0500 Subject: [PATCH 3/6] fix redundancy --- src/axolotl/core/trainer_builder.py | 2 -- src/axolotl/train.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index db6c180a1..3f27ded36 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -108,8 +108,6 @@ def _sanitize_kwargs_for_tagging(tag_names, kwargs=None): def _sanitize_kwargs_for_ds_tagging(dataset_tag_names, kwargs=None): - if isinstance(dataset_tag_names, type(None)): - return kwargs if isinstance(dataset_tag_names, str): dataset_tag_names = [dataset_tag_names] diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 7ea0367f7..304486856 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -267,7 +267,7 @@ def terminate_handler(_, __, model_weakref): hf_api = HfApi() hf_api.model_info(cfg.base_model) - model_card_kwarg = {"model_name": cfg.output_dir.lstrip("./")} + model_card_kwarg = {"model_name": cfg.output_dir.lstrip("./").encode("utf-8").decode("utf-8")} if cfg.datasets is not None: if cfg.rl is not None or cfg.reward_model: model_card_kwarg["dataset_name"] = [ From 4dde7e1a330e123463ce3cf8b7e31c22e2fa753f Mon Sep 17 00:00:00 2001 From: bursteratom Date: Mon, 25 Nov 2024 10:03:30 -0500 Subject: [PATCH 4/6] lint --- src/axolotl/train.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 304486856..39af9f45c 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -267,7 +267,11 @@ def terminate_handler(_, __, model_weakref): hf_api = HfApi() hf_api.model_info(cfg.base_model) - model_card_kwarg = {"model_name": cfg.output_dir.lstrip("./").encode("utf-8").decode("utf-8")} + model_card_kwarg = { + "model_name": cfg.output_dir.lstrip("./") + .encode("utf-8") + .decode("utf-8") + } if cfg.datasets is not None: if cfg.rl is not None or cfg.reward_model: model_card_kwarg["dataset_name"] = [ From c7d6fdb93a33529929ff8da26b074999b860c1bd Mon Sep 17 00:00:00 2001 From: bursteratom Date: Mon, 25 Nov 2024 10:33:20 -0500 Subject: [PATCH 5/6] include case when ds_tag is none --- src/axolotl/core/trainer_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 3f27ded36..6848763b4 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -111,7 +111,7 @@ def _sanitize_kwargs_for_ds_tagging(dataset_tag_names, kwargs=None): if isinstance(dataset_tag_names, str): dataset_tag_names = [dataset_tag_names] - if kwargs is not None: + if (dataset_tag_names is not None) and (kwargs is not None): if "datasets" not in kwargs: kwargs["datasets"] = dataset_tag_names elif "datasets" in kwargs and isinstance(kwargs["datasets"], list): From a8b73ea14f5332e35316ec529f97a8365a65d77c Mon Sep 17 00:00:00 2001 From: bursteratom Date: Tue, 26 Nov 2024 10:49:40 -0500 Subject: [PATCH 6/6] conform to kwargs in create_model_card --- src/axolotl/core/trainer_builder.py | 40 ++++++++++++++--------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 6848763b4..57febd291 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -107,18 +107,18 @@ def _sanitize_kwargs_for_tagging(tag_names, kwargs=None): return kwargs -def _sanitize_kwargs_for_ds_tagging(dataset_tag_names, kwargs=None): - if isinstance(dataset_tag_names, str): - dataset_tag_names = [dataset_tag_names] - - if (dataset_tag_names is not None) and (kwargs is not None): - if "datasets" not in kwargs: - kwargs["datasets"] = dataset_tag_names - elif "datasets" in kwargs and isinstance(kwargs["datasets"], list): - kwargs["datasets"].extend(dataset_tag_names) - elif "datasets" in kwargs and isinstance(kwargs["datasets"], str): - dataset_tag_names.append(kwargs["datasets"]) - kwargs["datasets"] = dataset_tag_names +def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None): + if isinstance(dataset_tags, str): + dataset_tags = [dataset_tags] + + if (dataset_tags is not None) and (kwargs is not None): + if "dataset_tags" not in kwargs: + kwargs["dataset_tags"] = dataset_tags + elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list): + kwargs["dataset_tags"].extend(dataset_tags) + elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str): + dataset_tags.append(kwargs["dataset_tags"]) + kwargs["dataset_tags"] = dataset_tags return kwargs @@ -426,12 +426,12 @@ def __init__( *_args, bench_data_collator=None, eval_data_collator=None, - dataset_tag_names=None, + dataset_tags=None, **kwargs, ): self.bench_data_collator = bench_data_collator self.eval_data_collator = eval_data_collator - self.dataset_tag_names = dataset_tag_names + self.dataset_tags = dataset_tags super().__init__(*_args, **kwargs) self.train_data_collator = self.data_collator self._stored_metrics = defaultdict(lambda: defaultdict(list)) @@ -890,7 +890,7 @@ def push_to_hub(self, *args, **kwargs) -> str: model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. """ kwargs = _sanitize_kwargs_for_ds_tagging( - dataset_tag_names=self.dataset_tag_names, kwargs=kwargs + dataset_tags=self.dataset_tags, kwargs=kwargs ) kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) @@ -1015,9 +1015,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): tag_names = ["axolotl", "dpo"] - def __init__(self, *args, dataset_tag_names=None, **kwargs): + def __init__(self, *args, dataset_tags=None, **kwargs): super().__init__(*args, **kwargs) - self.dataset_tag_names = dataset_tag_names + self.dataset_tags = dataset_tags self.optimizer = None def create_optimizer(self): @@ -1057,7 +1057,7 @@ def push_to_hub(self, *args, **kwargs) -> str: model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. """ kwargs = _sanitize_kwargs_for_ds_tagging( - dataset_tag_names=self.dataset_tag_names, kwargs=kwargs + dataset_tags=self.dataset_tags, kwargs=kwargs ) kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) @@ -1781,7 +1781,7 @@ def build(self, total_num_steps): trainer_kwargs["tokenizer"] = self.tokenizer if (trainer_cls is not AxolotlRewardTrainer) and self.cfg.datasets is not None: - trainer_kwargs["dataset_tag_names"] = [ + trainer_kwargs["dataset_tags"] = [ d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir() ] trainer = trainer_cls( @@ -2058,7 +2058,7 @@ def build(self, total_num_steps): dpo_trainer_kwargs["tokenizer"] = self.tokenizer if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer): - dpo_trainer_kwargs["dataset_tag_names"] = [ + dpo_trainer_kwargs["dataset_tags"] = [ d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir() ] dpo_trainer = trainer_cls(