Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ds model card, rebased #2101

Merged
merged 6 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,22 @@ def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
return kwargs


def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
if isinstance(dataset_tags, str):
dataset_tags = [dataset_tags]

if (dataset_tags is not None) and (kwargs is not None):
if "dataset_tags" not in kwargs:
kwargs["dataset_tags"] = dataset_tags
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
kwargs["dataset_tags"].extend(dataset_tags)
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
dataset_tags.append(kwargs["dataset_tags"])
kwargs["dataset_tags"] = dataset_tags

return kwargs


@dataclass
class AxolotlTrainingMixins:
"""
Expand Down Expand Up @@ -410,10 +426,12 @@ def __init__(
*_args,
bench_data_collator=None,
eval_data_collator=None,
dataset_tags=None,
**kwargs,
):
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Expand Down Expand Up @@ -871,6 +889,9 @@ def push_to_hub(self, *args, **kwargs) -> str:
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = _sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)

return super().push_to_hub(*args, **kwargs)
Expand Down Expand Up @@ -994,8 +1015,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):

tag_names = ["axolotl", "dpo"]

def __init__(self, *args, **kwargs):
def __init__(self, *args, dataset_tags=None, **kwargs):
super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags
self.optimizer = None

def create_optimizer(self):
Expand Down Expand Up @@ -1034,6 +1056,9 @@ def push_to_hub(self, *args, **kwargs) -> str:
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = _sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)

return super().push_to_hub(*args, **kwargs)
Expand Down Expand Up @@ -1755,6 +1780,10 @@ def build(self, total_num_steps):
else:
trainer_kwargs["tokenizer"] = self.tokenizer

if (trainer_cls is not AxolotlRewardTrainer) and self.cfg.datasets is not None:
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
bursteratom marked this conversation as resolved.
Show resolved Hide resolved
]
trainer = trainer_cls(
model=self.model,
train_dataset=self.train_dataset,
Expand Down Expand Up @@ -2028,6 +2057,10 @@ def build(self, total_num_steps):
else:
dpo_trainer_kwargs["tokenizer"] = self.tokenizer

if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
dpo_trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
dpo_trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
Expand Down
28 changes: 24 additions & 4 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,31 @@ def terminate_handler(_, __, model_weakref):
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)

if not cfg.hub_model_id:
from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError

try:
trainer.create_model_card(
model_name=cfg.output_dir.lstrip("./").encode("utf-8").decode("utf-8")
)
except (AttributeError, UnicodeDecodeError):
# Check to make sure the base model is from HuggingFace not a local directory
hf_api = HfApi()
hf_api.model_info(cfg.base_model)

model_card_kwarg = {
"model_name": cfg.output_dir.lstrip("./")
.encode("utf-8")
.decode("utf-8")
}
if cfg.datasets is not None:
if cfg.rl is not None or cfg.reward_model:
model_card_kwarg["dataset_name"] = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
else:
model_card_kwarg["dataset_tags"] = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
bursteratom marked this conversation as resolved.
Show resolved Hide resolved

trainer.create_model_card(**model_card_kwarg)
except (AttributeError, UnicodeDecodeError, RepositoryNotFoundError):
pass
elif cfg.hub_model_id:
# defensively push to the hub to ensure the model card is updated
Expand Down