From e1915f5625b2330555c3f61816dd003fb939ae13 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 2 Oct 2024 21:02:48 -0400 Subject: [PATCH] Multimodal Vision Llama - rudimentary support (#1940) --------- Co-authored-by: Sunny Co-authored-by: sunny --- docs/input_output.qmd | 2 +- docs/multimodal.qmd | 28 +++ examples/llama-3-vision/lora-11b.yaml | 63 +++++ src/axolotl/cli/__init__.py | 7 +- src/axolotl/core/trainer_builder.py | 20 +- src/axolotl/monkeypatch/attention/mllama.py | 229 ++++++++++++++++++ src/axolotl/monkeypatch/multipack.py | 1 + .../monkeypatch/stablelm_attn_hijack_flash.py | 1 + src/axolotl/prompt_strategies/__init__.py | 4 +- .../prompt_strategies/chat_template.py | 60 ++++- src/axolotl/train.py | 10 +- src/axolotl/utils/chat_templates.py | 46 ++-- src/axolotl/utils/collators/__init__.py | 10 + .../{collators.py => collators/batching.py} | 35 +-- src/axolotl/utils/collators/core.py | 4 + src/axolotl/utils/collators/mamba.py | 38 +++ src/axolotl/utils/collators/mm_chat.py | 77 ++++++ src/axolotl/utils/config/__init__.py | 25 +- .../config/models/input/v0_4_1/__init__.py | 29 ++- src/axolotl/utils/data/sft.py | 48 +++- src/axolotl/utils/models.py | 98 ++++++-- src/axolotl/utils/trainer.py | 16 +- .../prompt_strategies/test_chat_templates.py | 21 +- .../test_chat_templates_advanced.py | 46 +++- 24 files changed, 799 insertions(+), 119 deletions(-) create mode 100644 docs/multimodal.qmd create mode 100644 examples/llama-3-vision/lora-11b.yaml create mode 100644 src/axolotl/monkeypatch/attention/mllama.py create mode 100644 src/axolotl/utils/collators/__init__.py rename src/axolotl/utils/{collators.py => collators/batching.py} (90%) create mode 100644 src/axolotl/utils/collators/core.py create mode 100644 src/axolotl/utils/collators/mamba.py create mode 100644 src/axolotl/utils/collators/mm_chat.py diff --git a/docs/input_output.qmd b/docs/input_output.qmd index 7715dd250d..6559578d18 100644 --- a/docs/input_output.qmd +++ b/docs/input_output.qmd @@ -205,7 +205,7 @@ ds = load_from_disk(f'last_run_prepared/{directory[0]}/') hi there!. goodbye farewell ``` -We can check that the right tokens are ingored by comparing the labels +We can check that the right tokens are ignored by comparing the labels to each token: ```python diff --git a/docs/multimodal.qmd b/docs/multimodal.qmd new file mode 100644 index 0000000000..2381566adb --- /dev/null +++ b/docs/multimodal.qmd @@ -0,0 +1,28 @@ +# MultiModal / Vision Language Models (BETA) + +### Supported Models + +- Mllama, i.e. llama with vision models + +### Usage + +Currently multimodal support is limited and doesn't have full feature parity. To finetune a multimodal Llama w/ LoRA, +you'll need to use the following in YAML in combination with the rest of the required hyperparams. + +```yaml +base_model: alpindale/Llama-3.2-11B-Vision-Instruct +processor_type: AutoProcessor +skip_prepare_dataset: true + +chat_template: llama3_2_vision +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] + field_messages: messages +remove_unused_columns: false +sample_packing: false + +# only finetune the Language model, leave the vision model and vision tower frozen +lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' +``` diff --git a/examples/llama-3-vision/lora-11b.yaml b/examples/llama-3-vision/lora-11b.yaml new file mode 100644 index 0000000000..b2e4946418 --- /dev/null +++ b/examples/llama-3-vision/lora-11b.yaml @@ -0,0 +1,63 @@ +base_model: alpindale/Llama-3.2-11B-Vision-Instruct +processor_type: AutoProcessor +strict: false + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +chat_template: llama3_2_vision +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] + field_messages: messages +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: lora +lora_model_dir: + +sequence_len: 8192 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +local_rank: +logging_steps: 1 +flash_attention: true +eager_attention: + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 13c5b4ab58..a1d84b6a16 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -40,7 +40,7 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process from axolotl.utils.mlflow_ import setup_mlflow_env_vars -from axolotl.utils.models import load_tokenizer +from axolotl.utils.models import load_processor, load_tokenizer from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env from axolotl.utils.wandb_ import setup_wandb_env_vars @@ -430,9 +430,12 @@ def load_datasets( cli_args: TrainerCliArgs, ) -> TrainDatasetMeta: tokenizer = load_tokenizer(cfg) + processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset( - cfg, tokenizer + cfg, + tokenizer, + processor=processor, ) if cli_args.debug or cfg.debug: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 249398f850..4893e63dc2 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -61,12 +61,14 @@ log_prediction_callback_factory, ) from axolotl.utils.callbacks.lisa import lisa_callback_factory +from axolotl.utils.chat_templates import chat_templates from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, MambaDataCollator, V2BatchSamplerDataCollatorForSeq2Seq, ) +from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator from axolotl.utils.models import ensure_dtype from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.schedulers import ( @@ -250,6 +252,10 @@ class AxolotlTrainingMixins: "help": "workaround to pass an alternate lr scheduler to the HF trainer" }, ) + chat_template: Optional[str] = field( + default=None, + metadata={"help": "Chat template converting chat messages to text"}, + ) @dataclass @@ -1043,10 +1049,11 @@ class TrainerBuilderBase(abc.ABC): _model_ref = None _peft_config = None - def __init__(self, cfg, model, tokenizer): + def __init__(self, cfg, model, tokenizer, processor=None): self.cfg = cfg self.model = model self.tokenizer = tokenizer + self.processor = processor # in case the model supports tagging, add the axolotl tag. # This makes sure the tag is correctly pushed even if a user calls @@ -1515,6 +1522,10 @@ def build(self, total_num_steps): ) training_arguments_kwargs["model_type"] = self.cfg.model_config_type training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset) + if self.cfg.chat_template: + training_arguments_kwargs["chat_template"] = chat_templates( + self.cfg.chat_template + ) if self.cfg.rl == "orpo": training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha @@ -1661,7 +1672,12 @@ def build_collator( else: collator = BatchSamplerDataCollatorForSeq2Seq else: - collator = DataCollatorForSeq2Seq + if self.cfg.processor_type and self.processor: + collator = MultiModalChatDataCollator + kwargs["processor"] = self.processor + kwargs["chat_template"] = training_args.chat_template + else: + collator = DataCollatorForSeq2Seq return collator( self.tokenizer, diff --git a/src/axolotl/monkeypatch/attention/mllama.py b/src/axolotl/monkeypatch/attention/mllama.py new file mode 100644 index 0000000000..0b18b716d5 --- /dev/null +++ b/src/axolotl/monkeypatch/attention/mllama.py @@ -0,0 +1,229 @@ +""" +Monkeypatch for Vision Llama for FA2 support +""" +# pylint: disable=duplicate-code + +from typing import Optional, Tuple + +import torch +from flash_attn.flash_attn_interface import flash_attn_func +from transformers.cache_utils import Cache +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.models.mllama.configuration_mllama import MllamaTextConfig +from transformers.models.mllama.modeling_mllama import ( + MllamaTextCrossAttention, + MllamaTextSelfAttention, + apply_rotary_pos_emb, + repeat_kv, +) +from transformers.utils import is_flash_attn_greater_or_equal_2_10 + + +class MllamaTextCrossFlashAttention2(MllamaTextCrossAttention): + """ + Mllama flash cross-attention module. This module inherits from `MllamaTextCrossAttention` and + implements the forward pass using Flash Attention for improved performance. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Check if flash attention version is greater or equal to 2.1 + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + attention_mask: Optional[ # pylint: disable=unused-argument + torch.Tensor + ] = None, + output_attentions: bool = False, + use_cache: bool = False, # pylint: disable=unused-argument + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + query_states = self.q_norm(query_states) + + if cross_attention_states is not None: + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + key_states = key_states.view( + bsz, -1, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, -1, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + key_states = self.k_norm(key_states) + if past_key_value is not None: + key_states, value_states = past_key_value.update( + key_states, + value_states, + self.layer_idx, + {"cache_position": cache_position}, + ) + elif cache_position[0] != 0: + key_states, value_states = ( + past_key_value.key_cache[self.layer_idx], + past_key_value.value_cache[self.layer_idx], + ) + else: + raise ValueError( + "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" + ) + + # Transpose to get the expected layout for flash attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # Apply Flash Attention + dropout_rate = self.dropout if self.training else 0.0 + output = flash_attn_func( + query_states, + key_states, + value_states, + dropout_p=dropout_rate, + softmax_scale=None, + causal=False, + return_attn_probs=output_attentions, + ) + + attn_output = output.contiguous().view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class MllamaTextSelfFlashAttention2(MllamaTextSelfAttention): + """ + Mllama flash self-attention module. This module inherits from `MllamaTextSelfAttention` and + implements the forward pass using Flash Attention for improved performance. + """ + + def __init__(self, config: MllamaTextConfig, layer_idx: int, *args, **kwargs): + super().__init__(config, layer_idx, *args, **kwargs) + + # Check if flash attention version is greater or equal to 2.1 + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, # pylint: disable=unused-argument + past_key_value=None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, # pylint: disable=unused-argument + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x num_heads x head_dim + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # Transpose to get the expected layout for flash attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.dropout if self.training else 0.0 + + # Handle potential silent casting to float32 + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = ( + self.config._pre_quantization_dtype # pylint: disable=protected-access + ) + else: + target_dtype = self.q_proj.weight.dtype + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=True, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def patch_mllama(): + from transformers.models.mllama.modeling_mllama import ( + MLLAMA_TEXT_ATTENTION_CLASSES, + MLLAMA_TEXT_CROSS_ATTENTION_CLASSES, + MLLAMA_VISION_ATTENTION_CLASSES, + MllamaPreTrainedModel, + ) + + MllamaPreTrainedModel._supports_flash_attn_2 = ( # pylint: disable=protected-access + True + ) + MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2 + MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[ + "flash_attention_2" + ] = MllamaTextCrossFlashAttention2 + # fallback to SDPA + MLLAMA_VISION_ATTENTION_CLASSES[ + "flash_attention_2" + ] = MLLAMA_VISION_ATTENTION_CLASSES["sdpa"] diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 44fc4cb473..85101cd3c4 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -10,6 +10,7 @@ from axolotl.monkeypatch.utils import get_unpad_data SUPPORTED_MULTIPACK_MODEL_TYPES = [ + "mllama_text_model", "llama", "mistral", "mixtral", diff --git a/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py b/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py index 0269f90157..67e9337e36 100644 --- a/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py @@ -16,6 +16,7 @@ # This code is based off the following work: # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py +# pylint: disable=duplicate-code """ PyTorch StableLM Epoch model. """ import importlib import math diff --git a/src/axolotl/prompt_strategies/__init__.py b/src/axolotl/prompt_strategies/__init__.py index f5699a0871..66cd5deeb9 100644 --- a/src/axolotl/prompt_strategies/__init__.py +++ b/src/axolotl/prompt_strategies/__init__.py @@ -9,7 +9,7 @@ LOG = logging.getLogger("axolotl.prompt_strategies") -def load(strategy, tokenizer, cfg, ds_cfg): +def load(strategy, tokenizer, cfg, ds_cfg, processor=None): try: load_fn = "load" if strategy.split(".")[-1].startswith("load_"): @@ -24,6 +24,8 @@ def load(strategy, tokenizer, cfg, ds_cfg): sig = inspect.signature(func) if "ds_cfg" in sig.parameters: load_kwargs["ds_cfg"] = ds_cfg + if "processor" in sig.parameters: + load_kwargs["processor"] = processor return func(tokenizer, cfg, **load_kwargs) except ModuleNotFoundError: return None diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 88e748895d..48d52dae11 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -5,6 +5,8 @@ import logging from typing import Any, Dict, List, Optional +from transformers import ProcessorMixin + from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompters import IGNORE_TOKEN_ID, Prompter from axolotl.utils.chat_templates import chat_templates @@ -20,6 +22,7 @@ class ChatTemplatePrompter(Prompter): def __init__( self, tokenizer, + processor=None, chat_template=None, max_length=2048, message_field_role: str = "from", @@ -44,11 +47,12 @@ def __init__( self.message_field_training = message_field_training self.message_field_training_detail = message_field_training_detail self.tokenizer = tokenizer + self.processor: ProcessorMixin = processor self.chat_template = chat_template self.max_length = max_length self.drop_system_message = drop_system_message - def build_prompt(self, conversation, add_generation_prompt=False): + def build_prompt(self, conversation, add_generation_prompt=False, images=None): turns = [ { "role": self.roles[t[self.message_field_role]], @@ -61,6 +65,28 @@ def build_prompt(self, conversation, add_generation_prompt=False): if self.drop_system_message and turns[0]["role"] == "system": turns = turns[1:] + if self.processor: + text = self.processor.apply_chat_template( + turns, + chat_template=self.chat_template, + tokenize=False, + add_generation_prompt=add_generation_prompt, + ) + batch = self.processor( + text=text, + images=images, + return_tensors="pt", + truncation=True, + max_length=self.max_length, + ) + # workaround since processor works in batches instead of single examples + for k, val in batch.items(): + if k in ["pixel_values"]: + batch[k] = val.tolist() + else: + batch[k] = val.squeeze().tolist() + return batch + return self.tokenizer.apply_chat_template( turns, truncation=True, @@ -191,6 +217,7 @@ def __init__( super().__init__(prompter, tokenizer, train_on_inputs, sequence_len) self.roles_to_train = roles_to_train if roles_to_train is not None else [] self.train_on_eos = train_on_eos + self.images = "images" @property def messages(self): @@ -209,10 +236,21 @@ def tokenize_prompt(self, prompt): and not self.prompter.message_field_training_detail ): turns = self.get_conversation_thread(prompt) + images = self.get_images(prompt) prompt_ids = self.prompter.build_prompt( - turns[:-1], add_generation_prompt=True + turns[:-1], + add_generation_prompt=True, + images=images, ) - input_ids = self.prompter.build_prompt(turns) + tokenized_res = self.prompter.build_prompt(turns, images=images) + tokenized_prompt = {} + if isinstance(tokenized_res, list): + input_ids = prompt_ids + tokenized_res[len(prompt_ids) :] + tokenized_prompt["input_ids"] = input_ids + tokenized_prompt["attention_mask"] = [1] * len(input_ids) + else: + input_ids = tokenized_res["input_ids"] + tokenized_prompt = tokenized_res if not self.train_on_inputs: user_prompt_len = len(prompt_ids) @@ -220,17 +258,9 @@ def tokenize_prompt(self, prompt): else: labels = input_ids - tokenized_prompt = { - "input_ids": input_ids, - "labels": labels, - "attention_mask": [1] * len(input_ids), - } + tokenized_prompt["labels"] = labels return tokenized_prompt - LOG.info(self.roles_to_train) - LOG.info(self.train_on_eos) - LOG.info(self.prompter.message_field_training) - LOG.info(self.prompter.message_field_training_detail) turns = prompt[self.messages] input_ids = self.prompter.build_prompt(turns) @@ -368,8 +398,11 @@ def find_turn(self, conversation_ids, turn, turn_content): def get_conversation_thread(self, prompt): return prompt[self.messages] + def get_images(self, prompt): + return prompt.get(self.images, None) + -def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): +def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None): ds_cfg = ds_cfg or {} prompter_params = { @@ -386,6 +419,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): "drop_system_message": ds_cfg.get("drop_system_message", False), # we need to add one for detecting sequences with exceeding the `sequence_len` limit. "max_length": cfg.sequence_len + 1, + "processor": processor, } strategy_params = { diff --git a/src/axolotl/train.py b/src/axolotl/train.py index b21b0b269c..855dbc2d3b 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -24,7 +24,7 @@ from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault from axolotl.utils.freeze import freeze_layers_except -from axolotl.utils.models import load_model, load_tokenizer +from axolotl.utils.models import load_model, load_processor, load_tokenizer from axolotl.utils.trainer import setup_trainer try: @@ -69,6 +69,9 @@ def train( main_process_only=True, ) tokenizer = load_tokenizer(cfg) + processor = None + if cfg.is_multimodal: + processor = load_processor(cfg, tokenizer) train_dataset = dataset_meta.train_dataset eval_dataset = dataset_meta.eval_dataset @@ -96,7 +99,9 @@ def train( LOG.debug(msg) # we wait unitl the last possible moment to setup Accelerator Accelerator() - model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference) + model, peft_config = load_model( + cfg, tokenizer, processor=processor, inference=cli_args.inference + ) model.generation_config.do_sample = True model_ref = None @@ -122,6 +127,7 @@ def train( eval_dataset, (model, model_ref, peft_config), tokenizer, + processor, total_num_steps, ) diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 7a96f5c1e1..7468ae8b15 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -3,6 +3,20 @@ These templates are used for formatting messages in a conversation. """ +CHAT_TEMPLATES = { + "alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}", + "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral. + "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", + "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", + "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", + "llama3_2_vision": '{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now("%d %b %Y") %}\n {%- else %}\n {%- set date_string = "26 Jul 2024" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0][\'role\'] == \'system\' %}\n {%- set system_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = "" %}\n{%- endif %}\n\n{#- Find out if there are any images #}\n{% set image_ns = namespace(has_images=false) %} \n{%- for message in messages %}\n {%- for content in message[\'content\'] %}\n {%- if content[\'type\'] == \'image\' %}\n {%- set image_ns.has_images = true %}\n {%- endif %}\n {%- endfor %}\n{%- endfor %}\n\n{#- Error out if there are images and system message #}\n{%- if image_ns.has_images and not system_message == "" %}\n {{- raise_exception("Prompting with images is incompatible with system messages.") }}\n{%- endif %}\n\n{#- System message if there are no images #}\n{%- if not image_ns.has_images %}\n {{- "<|start_header_id|>system<|end_header_id|>\\n\\n" }}\n {%- if tools is not none %}\n {{- "Environment: ipython\\n" }}\n {%- endif %}\n {{- "Cutting Knowledge Date: December 2023\\n" }}\n {{- "Today Date: " + date_string + "\\n\\n" }}\n {%- if tools is not none and not tools_in_user_message %}\n {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- "<|eot_id|>" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception("Cannot put tools in the first user message when there\'s no first user message!") }}\n{%- endif %}\n {{- \'<|start_header_id|>user<|end_header_id|>\\n\\n\' -}}\n {{- "Given the following functions, please respond with a JSON for a function call " }}\n {{- "with its proper arguments that best answers the given prompt.\\n\\n" }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {{- first_user_message + "<|eot_id|>"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == \'ipython\' or message.role == \'tool\' or \'tool_calls\' in message) %}\n {{- \'<|start_header_id|>\' + message[\'role\'] + \'<|end_header_id|>\\n\\n\' }}\n {%- if message[\'content\'] is string %}\n {{- message[\'content\'] }}\n {%- else %}\n {%- for content in message[\'content\'] %}\n {%- if content[\'type\'] == \'image\' %}\n {{- \'<|image|>\' }}\n {%- elif content[\'type\'] == \'text\' %}\n {{- content[\'text\'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- \'<|eot_id|>\' }}\n {%- elif \'tool_calls\' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception("This model only supports single tool-calls at once!") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' -}}\n {{- \'{"name": "\' + tool_call.name + \'", \' }}\n {{- \'"parameters": \' }}\n {{- tool_call.arguments | tojson }}\n {{- "}" }}\n {{- "<|eot_id|>" }}\n {%- elif message.role == "tool" or message.role == "ipython" %}\n {{- "<|start_header_id|>ipython<|end_header_id|>\\n\\n" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- "<|eot_id|>" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' }}\n{%- endif %}\n', + "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", + "phi_35": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}", + "deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|Assistant|>' }}{% endif %}", + "jamba": '{# Variables #}\n{% set ns = namespace(message_count=0, is_last_checked_defined=False) %}\n{##}\n{% set bom_str = bom_str or "<|bom|>" %}\n{% set eom_str = eom_str or "<|eom|>" %}\n{% set default_system_message = "" %}\n{##}\n{% set documents_prefix = "" %}\n{% set documents_suffix = "" %}\n{% set tool_definitions_prefix = "" %}\n{% set tool_definitions_suffix = "" %}\n{% set active_modes_prefix = "" %}\n{% set active_modes_suffix = "" %}\n{##}\n{% set tool_calls_prefix = "" %}\n{% set tool_calls_suffix = "" %}\n{% set citations_prefix = "" %}\n{% set citations_suffix = "" %}\n{##}\n{% if add_generation_prompt is not defined %}\n {% set add_generation_prompt = True %}\n{% endif %}\n{% set role_to_predict = role_to_predict or "assistant" %}\n{% if messages|length > 0 and messages[0].role == "system" %}\n {% set system_message = messages[0].content %}\n {% set loop_messages = messages[1:] %}\n{% else %}\n {% set system_message = default_system_message %}\n {% set loop_messages = messages %}\n{% endif %}\n{##}\n{##}\n{# Macros #}\n{% macro handle_tool_definitions(tools) %}\n {{- tool_definitions_prefix -}}\n {{- "\\n# Tools" -}}\n {{- "\\n\\n## Functions" -}}\n {% for tool in tools %}\n {% set _ = is_param_set(tool, field="type") %}\n {% set is_tool_type_set = ns.is_last_checked_defined %}\n {% if is_tool_type_set %}\n {% if tool.type == "function" %}\n {% set tool = tool.function %}\n {% else %}\n {{ raise_exception("Currently, the only supported tool type is `function`") }}\n {% endif %}\n {% endif %}\n {{- "\\n\\n" + (tool|tojson(indent=2)) -}}\n {% endfor %}\n {{- "\\n" + tool_definitions_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_first_system_message(system_message, tools) %}\n {{- bom_str + handle_role("system") -}}\n {% set _ = is_param_set(system_message) %}\n {% set is_system_message_set = ns.is_last_checked_defined %}\n {% if is_system_message_set %}\n {{- system_message -}}\n {% endif %}\n {% set _ = is_param_set(tools, is_list=True) %}\n {% set is_tools_set = ns.is_last_checked_defined %}\n {% if is_tools_set %}\n {% if system_message %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- handle_tool_definitions(tools) -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_tool_calls(tool_calls) %}\n {{- tool_calls_prefix + "[\\n" -}}\n {% for tool_call in tool_calls %}\n {% set _ = is_param_set(tool_call, field="function") %}\n {% set is_tool_call_function_set = ns.is_last_checked_defined %}\n {% if is_tool_call_function_set %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {% set arguments = tool_call.arguments %}\n {% if arguments is not string %}\n {%- set arguments = arguments|tojson -%}\n {%- endif %}\n {{ "{\\"name\\": \\"" + tool_call.name + "\\", \\"arguments\\": " + arguments + "}" -}}\n {% if not loop.last %}\n {{- "," }}\n {% endif %}\n {% endfor %}\n {{- "\\n]" + tool_calls_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_documents(documents) %}\n {{- documents_prefix -}}\n {{- "\\n# Documents" -}}\n {{- "\\n\\nYou can use the following documents for reference:" -}}\n {% for doc in documents %}\n {{- "\\n\\n## Document ID: " + loop.index0|string -}}\n {% set _ = is_param_set(doc, field="title") %}\n {% set is_doc_title_set = ns.is_last_checked_defined %}\n {% if is_doc_title_set %}\n {{- "\\nTitle: " + doc.title -}}\n {% endif %}\n {% for key, value in doc.items() %}\n {% if key not in ["title", "text"] %}\n {{- "\\n" + key|title + ": " + value|string -}}\n {% endif %}\n {% endfor %}\n {{- "\\nText: " + doc.text -}}\n {% endfor %}\n {{- "\\n" + documents_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_knobs(knobs) %}\n {{- active_modes_prefix -}}\n {{- "\\n# Active Modes" -}}\n {{ "\\n\\nThe following modes configure the format or style of your responses. You should adhere to all currently" -}}\n {{ " active modes simultaneously." -}}\n {% if knobs.citation_mode == "fast" %}\n {{- "\\n\\n## Citation Mode" -}}\n {{- "\\n\\nProvide a list of references only for the documents you base your response on. Format your response" -}}\n {{ " with the original answer followed by a citation section. Use this template:" -}}\n {{ " `{answer}" + citations_prefix + "DOCUMENT_IDS" + citations_suffix + "`, where DOCUMENT_IDS are the relevant document numbers" -}}\n {{ " (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents." -}}\n {% endif %}\n {% if knobs.response_format == "json_object" %}\n {{- "\\n\\n## JSON Mode" -}}\n {{ "\\n\\nProvide your response in JSON format. Adhere strictly to any schema given by the user." -}}\n {{ " If an appropriate JSON format exists, use it without modification." -}}\n {% endif %}\n {{- "\\n" + active_modes_suffix -}}\n{% endmacro %}\n{##}\n{% macro get_last_user_index(messages) %}\n {% set ns.last_user_index = 0 %}\n {% for message in messages %}\n {% if message.role == \'user\' %}\n {% set ns.last_user_index = loop.index0 %}\n {% endif %}\n {% endfor %}\n {{- ns.last_user_index -}}\n{% endmacro %}\n{##}\n{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %}\n {{- bom_str + handle_role("system") -}}\n {% set macros_to_call = [] %}\n {% set params_for_macros = [] %}\n {% if use_documents %}\n {% set macros_to_call = macros_to_call + [handle_documents] %}\n {% set params_for_macros = params_for_macros + [[documents]] %}\n {% endif %}\n {% if use_knobs %}\n {% set macros_to_call = macros_to_call + [handle_knobs] %}\n {% set params_for_macros = params_for_macros + [[knobs]] %}\n {% endif %}\n {% for i in range(macros_to_call|length) %}\n {% if i > 0 %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- macros_to_call[i](*params_for_macros[i]) -}}\n {% endfor %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_role(role, add_space=True) %}\n {{- "<|" + role + "|>" -}}\n {% if add_space %}\n {{- " " -}}\n {% endif %}\n{% endmacro %}\n{##}\n{% macro is_param_set(param, field=none, is_list=False) %}\n {% if field is not none %}\n {% if field in param %}\n {% set param = param[field] %}\n {% else %}\n {% set param = none %}\n {% endif %}\n {% endif %}\n {% set is_defined = param is defined and param is not none %}\n {% if is_list %}\n {% set ns.is_last_checked_defined = is_defined and param|length > 0 %}\n {% else %}\n {% set ns.is_last_checked_defined = is_defined %}\n {% endif %}\n{% endmacro %}\n{##}\n{##}\n{# Template #}\n{{- "<|startoftext|>" -}}\n{% set _ = is_param_set(system_message) %}\n{% set is_system_message_set = ns.is_last_checked_defined %}\n{% set _ = is_param_set(tools, is_list=True) %}\n{% set is_tools_set = ns.is_last_checked_defined %}\n{% set has_system_message = (is_system_message_set or is_tools_set) %}\n{% if has_system_message %}\n {{- handle_first_system_message(system_message, tools) -}}\n{% endif %}\n{% set last_user_index = get_last_user_index(loop_messages)|int %}\n{% for message in loop_messages %}\n {% if loop.index0 == last_user_index %}\n {% set _ = is_param_set(documents, is_list=True) %}\n {% set use_documents = ns.is_last_checked_defined %}\n {% set _ = is_param_set(knobs) %}\n {% set use_knobs = ns.is_last_checked_defined and knobs.is_set %}\n {% set add_last_system_message = use_documents or use_knobs %}\n {% if add_last_system_message %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}}\n {% endif %}\n {% endif %}\n {% set role = message.role %}\n {% set _ = is_param_set(message, field="name") %}\n {% set is_message_name_set = ns.is_last_checked_defined %}\n {% if is_message_name_set %}\n {% set message_prefix = handle_role(role) + "(" + message.name + ")" %}\n {% else %}\n {% set message_prefix = handle_role(role) %}\n {% endif %}\n {% set content = (message.content or "") %}\n {% if content is not string %}\n {% set content = content|tojson %}\n {% endif %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + message_prefix + content -}}\n {% set _ = is_param_set(message, field="tool_calls", is_list=True) %}\n {% set is_tool_calls_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_tool_calls_set %}\n {{- handle_tool_calls(message.tool_calls) -}}\n {% endif %}\n {% set _ = is_param_set(message, field="citations", is_list=True) %}\n {% set is_citations_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_citations_set %}\n {{- citations_prefix + message.citations|map(attribute="document_id")|list|string + citations_suffix -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endfor %}\n{% if add_generation_prompt %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + handle_role(role_to_predict, add_space=False) -}}\n {% set _ = is_param_set(generation_preamble) %}\n {% set is_generation_preamble_set = ns.is_last_checked_defined %}\n {% if is_generation_preamble_set and generation_preamble.strip() != "" %}\n {{- " " + generation_preamble -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% else %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n{% endif %}\n', +} + def chat_templates(user_choice: str): """ @@ -18,20 +32,22 @@ def chat_templates(user_choice: str): ValueError: If the user_choice is not found in the templates. """ - templates = { - "alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}", - "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral. - "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", - "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", - "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", - "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", - "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", - "phi_35": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}", - "deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|Assistant|>' }}{% endif %}", - "jamba": '{# Variables #}\n{% set ns = namespace(message_count=0, is_last_checked_defined=False) %}\n{##}\n{% set bom_str = bom_str or "<|bom|>" %}\n{% set eom_str = eom_str or "<|eom|>" %}\n{% set default_system_message = "" %}\n{##}\n{% set documents_prefix = "" %}\n{% set documents_suffix = "" %}\n{% set tool_definitions_prefix = "" %}\n{% set tool_definitions_suffix = "" %}\n{% set active_modes_prefix = "" %}\n{% set active_modes_suffix = "" %}\n{##}\n{% set tool_calls_prefix = "" %}\n{% set tool_calls_suffix = "" %}\n{% set citations_prefix = "" %}\n{% set citations_suffix = "" %}\n{##}\n{% if add_generation_prompt is not defined %}\n {% set add_generation_prompt = True %}\n{% endif %}\n{% set role_to_predict = role_to_predict or "assistant" %}\n{% if messages|length > 0 and messages[0].role == "system" %}\n {% set system_message = messages[0].content %}\n {% set loop_messages = messages[1:] %}\n{% else %}\n {% set system_message = default_system_message %}\n {% set loop_messages = messages %}\n{% endif %}\n{##}\n{##}\n{# Macros #}\n{% macro handle_tool_definitions(tools) %}\n {{- tool_definitions_prefix -}}\n {{- "\\n# Tools" -}}\n {{- "\\n\\n## Functions" -}}\n {% for tool in tools %}\n {% set _ = is_param_set(tool, field="type") %}\n {% set is_tool_type_set = ns.is_last_checked_defined %}\n {% if is_tool_type_set %}\n {% if tool.type == "function" %}\n {% set tool = tool.function %}\n {% else %}\n {{ raise_exception("Currently, the only supported tool type is `function`") }}\n {% endif %}\n {% endif %}\n {{- "\\n\\n" + (tool|tojson(indent=2)) -}}\n {% endfor %}\n {{- "\\n" + tool_definitions_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_first_system_message(system_message, tools) %}\n {{- bom_str + handle_role("system") -}}\n {% set _ = is_param_set(system_message) %}\n {% set is_system_message_set = ns.is_last_checked_defined %}\n {% if is_system_message_set %}\n {{- system_message -}}\n {% endif %}\n {% set _ = is_param_set(tools, is_list=True) %}\n {% set is_tools_set = ns.is_last_checked_defined %}\n {% if is_tools_set %}\n {% if system_message %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- handle_tool_definitions(tools) -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_tool_calls(tool_calls) %}\n {{- tool_calls_prefix + "[\\n" -}}\n {% for tool_call in tool_calls %}\n {% set _ = is_param_set(tool_call, field="function") %}\n {% set is_tool_call_function_set = ns.is_last_checked_defined %}\n {% if is_tool_call_function_set %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {% set arguments = tool_call.arguments %}\n {% if arguments is not string %}\n {%- set arguments = arguments|tojson -%}\n {%- endif %}\n {{ "{\\"name\\": \\"" + tool_call.name + "\\", \\"arguments\\": " + arguments + "}" -}}\n {% if not loop.last %}\n {{- "," }}\n {% endif %}\n {% endfor %}\n {{- "\\n]" + tool_calls_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_documents(documents) %}\n {{- documents_prefix -}}\n {{- "\\n# Documents" -}}\n {{- "\\n\\nYou can use the following documents for reference:" -}}\n {% for doc in documents %}\n {{- "\\n\\n## Document ID: " + loop.index0|string -}}\n {% set _ = is_param_set(doc, field="title") %}\n {% set is_doc_title_set = ns.is_last_checked_defined %}\n {% if is_doc_title_set %}\n {{- "\\nTitle: " + doc.title -}}\n {% endif %}\n {% for key, value in doc.items() %}\n {% if key not in ["title", "text"] %}\n {{- "\\n" + key|title + ": " + value|string -}}\n {% endif %}\n {% endfor %}\n {{- "\\nText: " + doc.text -}}\n {% endfor %}\n {{- "\\n" + documents_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_knobs(knobs) %}\n {{- active_modes_prefix -}}\n {{- "\\n# Active Modes" -}}\n {{ "\\n\\nThe following modes configure the format or style of your responses. You should adhere to all currently" -}}\n {{ " active modes simultaneously." -}}\n {% if knobs.citation_mode == "fast" %}\n {{- "\\n\\n## Citation Mode" -}}\n {{- "\\n\\nProvide a list of references only for the documents you base your response on. Format your response" -}}\n {{ " with the original answer followed by a citation section. Use this template:" -}}\n {{ " `{answer}" + citations_prefix + "DOCUMENT_IDS" + citations_suffix + "`, where DOCUMENT_IDS are the relevant document numbers" -}}\n {{ " (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents." -}}\n {% endif %}\n {% if knobs.response_format == "json_object" %}\n {{- "\\n\\n## JSON Mode" -}}\n {{ "\\n\\nProvide your response in JSON format. Adhere strictly to any schema given by the user." -}}\n {{ " If an appropriate JSON format exists, use it without modification." -}}\n {% endif %}\n {{- "\\n" + active_modes_suffix -}}\n{% endmacro %}\n{##}\n{% macro get_last_user_index(messages) %}\n {% set ns.last_user_index = 0 %}\n {% for message in messages %}\n {% if message.role == \'user\' %}\n {% set ns.last_user_index = loop.index0 %}\n {% endif %}\n {% endfor %}\n {{- ns.last_user_index -}}\n{% endmacro %}\n{##}\n{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %}\n {{- bom_str + handle_role("system") -}}\n {% set macros_to_call = [] %}\n {% set params_for_macros = [] %}\n {% if use_documents %}\n {% set macros_to_call = macros_to_call + [handle_documents] %}\n {% set params_for_macros = params_for_macros + [[documents]] %}\n {% endif %}\n {% if use_knobs %}\n {% set macros_to_call = macros_to_call + [handle_knobs] %}\n {% set params_for_macros = params_for_macros + [[knobs]] %}\n {% endif %}\n {% for i in range(macros_to_call|length) %}\n {% if i > 0 %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- macros_to_call[i](*params_for_macros[i]) -}}\n {% endfor %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_role(role, add_space=True) %}\n {{- "<|" + role + "|>" -}}\n {% if add_space %}\n {{- " " -}}\n {% endif %}\n{% endmacro %}\n{##}\n{% macro is_param_set(param, field=none, is_list=False) %}\n {% if field is not none %}\n {% if field in param %}\n {% set param = param[field] %}\n {% else %}\n {% set param = none %}\n {% endif %}\n {% endif %}\n {% set is_defined = param is defined and param is not none %}\n {% if is_list %}\n {% set ns.is_last_checked_defined = is_defined and param|length > 0 %}\n {% else %}\n {% set ns.is_last_checked_defined = is_defined %}\n {% endif %}\n{% endmacro %}\n{##}\n{##}\n{# Template #}\n{{- "<|startoftext|>" -}}\n{% set _ = is_param_set(system_message) %}\n{% set is_system_message_set = ns.is_last_checked_defined %}\n{% set _ = is_param_set(tools, is_list=True) %}\n{% set is_tools_set = ns.is_last_checked_defined %}\n{% set has_system_message = (is_system_message_set or is_tools_set) %}\n{% if has_system_message %}\n {{- handle_first_system_message(system_message, tools) -}}\n{% endif %}\n{% set last_user_index = get_last_user_index(loop_messages)|int %}\n{% for message in loop_messages %}\n {% if loop.index0 == last_user_index %}\n {% set _ = is_param_set(documents, is_list=True) %}\n {% set use_documents = ns.is_last_checked_defined %}\n {% set _ = is_param_set(knobs) %}\n {% set use_knobs = ns.is_last_checked_defined and knobs.is_set %}\n {% set add_last_system_message = use_documents or use_knobs %}\n {% if add_last_system_message %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}}\n {% endif %}\n {% endif %}\n {% set role = message.role %}\n {% set _ = is_param_set(message, field="name") %}\n {% set is_message_name_set = ns.is_last_checked_defined %}\n {% if is_message_name_set %}\n {% set message_prefix = handle_role(role) + "(" + message.name + ")" %}\n {% else %}\n {% set message_prefix = handle_role(role) %}\n {% endif %}\n {% set content = (message.content or "") %}\n {% if content is not string %}\n {% set content = content|tojson %}\n {% endif %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + message_prefix + content -}}\n {% set _ = is_param_set(message, field="tool_calls", is_list=True) %}\n {% set is_tool_calls_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_tool_calls_set %}\n {{- handle_tool_calls(message.tool_calls) -}}\n {% endif %}\n {% set _ = is_param_set(message, field="citations", is_list=True) %}\n {% set is_citations_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_citations_set %}\n {{- citations_prefix + message.citations|map(attribute="document_id")|list|string + citations_suffix -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endfor %}\n{% if add_generation_prompt %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + handle_role(role_to_predict, add_space=False) -}}\n {% set _ = is_param_set(generation_preamble) %}\n {% set is_generation_preamble_set = ns.is_last_checked_defined %}\n {% if is_generation_preamble_set and generation_preamble.strip() != "" %}\n {{- " " + generation_preamble -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% else %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n{% endif %}\n', - } - - if user_choice in templates: - return templates[user_choice] + if user_choice in CHAT_TEMPLATES: + return CHAT_TEMPLATES[user_choice] raise ValueError(f"Template '{user_choice}' not found.") + + +def register_chat_template(template_name: str, chat_template: str): + """ + Registers chat templates. + + Args: + template_name (str): The name of the template. + chat_template (str): The template string. + """ + + if template_name in CHAT_TEMPLATES: + raise ValueError(f"Template '{template_name}' already exists.") + + CHAT_TEMPLATES[template_name] = chat_template diff --git a/src/axolotl/utils/collators/__init__.py b/src/axolotl/utils/collators/__init__.py new file mode 100644 index 0000000000..93502b67d7 --- /dev/null +++ b/src/axolotl/utils/collators/__init__.py @@ -0,0 +1,10 @@ +""" +shared axolotl collators for multipack, mamba, multimodal +""" +from .batching import ( # noqa: F401 + BatchSamplerDataCollatorForSeq2Seq, + DataCollatorForSeq2Seq, + PretrainingBatchSamplerDataCollatorForSeq2Seq, + V2BatchSamplerDataCollatorForSeq2Seq, +) +from .mamba import MambaDataCollator # noqa: F401 diff --git a/src/axolotl/utils/collators.py b/src/axolotl/utils/collators/batching.py similarity index 90% rename from src/axolotl/utils/collators.py rename to src/axolotl/utils/collators/batching.py index 26c7fa9f3c..7cf771421c 100644 --- a/src/axolotl/utils/collators.py +++ b/src/axolotl/utils/collators/batching.py @@ -1,17 +1,14 @@ """ DataCollator for axolotl to pad labels and position_ids for packed sequences """ + from dataclasses import dataclass -from typing import Any, Dict, Optional, Sequence, Union +from typing import Any, Optional, Union import numpy as np -import torch -import transformers from transformers import PreTrainedTokenizerBase from transformers.utils import PaddingStrategy -IGNORE_INDEX = -100 - @dataclass class DataCollatorForSeq2Seq: @@ -183,34 +180,6 @@ def __call__(self, features, return_tensors=None): return super().__call__(out_features, return_tensors=return_tensors) -@dataclass -class MambaDataCollator: - """ - Collator for State Space Models (Mamba) - """ - - tokenizer: transformers.PreTrainedTokenizer - - def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: - input_ids, labels = tuple( - [torch.LongTensor(instance[key]) for instance in instances] - for key in ("input_ids", "labels") - ) - input_ids = torch.nn.utils.rnn.pad_sequence( - input_ids, - batch_first=True, - padding_value=self.tokenizer.pad_token_id, - ) - labels = torch.nn.utils.rnn.pad_sequence( - labels, batch_first=True, padding_value=IGNORE_INDEX - ) - - return { - "input_ids": input_ids, - "labels": labels, - } - - @dataclass class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): """ diff --git a/src/axolotl/utils/collators/core.py b/src/axolotl/utils/collators/core.py new file mode 100644 index 0000000000..0eae0c3bda --- /dev/null +++ b/src/axolotl/utils/collators/core.py @@ -0,0 +1,4 @@ +""" +basic shared collator constants +""" +IGNORE_INDEX = -100 diff --git a/src/axolotl/utils/collators/mamba.py b/src/axolotl/utils/collators/mamba.py new file mode 100644 index 0000000000..0c4a22fcc0 --- /dev/null +++ b/src/axolotl/utils/collators/mamba.py @@ -0,0 +1,38 @@ +""" +collators for Mamba +""" +from dataclasses import dataclass +from typing import Dict, Sequence + +import torch +import transformers + +from axolotl.utils.collators.core import IGNORE_INDEX + + +@dataclass +class MambaDataCollator: + """ + Collator for State Space Models (Mamba) + """ + + tokenizer: transformers.PreTrainedTokenizer + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + input_ids, labels = tuple( + [torch.LongTensor(instance[key]) for instance in instances] + for key in ("input_ids", "labels") + ) + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id, + ) + labels = torch.nn.utils.rnn.pad_sequence( + labels, batch_first=True, padding_value=IGNORE_INDEX + ) + + return { + "input_ids": input_ids, + "labels": labels, + } diff --git a/src/axolotl/utils/collators/mm_chat.py b/src/axolotl/utils/collators/mm_chat.py new file mode 100644 index 0000000000..f49e97f37f --- /dev/null +++ b/src/axolotl/utils/collators/mm_chat.py @@ -0,0 +1,77 @@ +""" +Collators for multi-modal chat messages and packing +""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union + +from transformers import PreTrainedTokenizerBase, ProcessorMixin +from transformers.data.data_collator import DataCollatorMixin +from transformers.utils import PaddingStrategy + + +@dataclass +class MultiModalChatDataCollator(DataCollatorMixin): + """ + Collator for multi-modal chat messages + """ + + tokenizer: PreTrainedTokenizerBase + processor: ProcessorMixin + return_tensors: str = "pt" + chat_template: Optional[str] = None + packing: bool = False + max_images: int = -1 + padding: Union[bool, str, PaddingStrategy] = True + pad_to_multiple_of: Optional[int] = None + + def __post_init__(self): + if self.packing: + raise ValueError("Packing is currently not supported.") + + def torch_call( + self, examples: List[Union[List[int], Any, Dict[str, Any]]] + ) -> Dict[str, Any]: + # Handle dict or lists with proper padding and conversion to tensor. + + return self.__class__.process_rows( + examples, self.processor, self.chat_template, self.max_images + ) + + @staticmethod + def process_rows(examples, processor, chat_template, max_images, length_only=False): + # HINT: use `_torch_collate_batch` to stack and pad tensors + # see also DataCollatorWithFlattening and DefaultDataCollator + + # *** This is COPIED from the trl example sft_vlm.py code *** + # use this as a starting point + + # Get the texts and images, and apply the chat template + texts = [ + processor.apply_chat_template( + example["messages"], chat_template=chat_template, tokenize=False + ) + for example in examples + ] + images = [example["images"] for example in examples] + + if max_images > 0: + images = [img_batch[:max_images] for img_batch in images] + + # Tokenize the texts and process the images + batch = processor(text=texts, images=images, return_tensors="pt", padding=True) + + # The labels are the input_ids, and we mask the padding tokens in the loss computation + labels = batch["input_ids"].clone() + labels[labels == processor.tokenizer.pad_token_id] = -100 # + # Ignore the image token index in the loss computation (model specific) + image_token_id = processor.tokenizer.convert_tokens_to_ids( + processor.image_token + ) + labels[labels == image_token_id] = -100 + batch["labels"] = labels + + if length_only: + return { + "length": [len(sample["input_ids"]) for sample in batch["input_ids"]] + } + return batch diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 82436e8d79..f732db06fc 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -121,15 +121,36 @@ def normalize_config(cfg): cfg.base_model_config = cfg.base_model model_config = load_model_config(cfg) - cfg.model_config_type = model_config.model_type cfg.tokenizer_config = ( cfg.tokenizer_config or cfg.base_model_config or cfg.base_model ) + cfg.is_multimodal = ( + hasattr(model_config, "model_type") + and model_config.model_type in ["llava", "mllama"] + or any( + multimodal_name in cfg.base_model.lower() + for multimodal_name in [ + "pixtral", + ] + ) + or cfg.is_multimodal + ) + if cfg.is_multimodal: + cfg.processor_config = ( + cfg.processor_config or cfg.base_model_config or cfg.base_model + ) + model_config = model_config.text_config + + cfg.model_config_type = model_config.model_type + # figure out if the model is llama cfg.is_llama_derived_model = ( - (hasattr(model_config, "model_type") and model_config.model_type == "llama") + ( + hasattr(model_config, "model_type") + and model_config.model_type == ["llama", "mllama_text_model"] + ) or cfg.is_llama_derived_model or "llama" in cfg.base_model.lower() or (cfg.type_of_model and "llama" in cfg.type_of_model.lower()) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 4e07c9260a..fced5e639d 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -188,6 +188,7 @@ class ChatTemplate(str, Enum): gemma = "gemma" # pylint: disable=invalid-name cohere = "cohere" # pylint: disable=invalid-name llama3 = "llama3" # pylint: disable=invalid-name + llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name phi_3 = "phi_3" # pylint: disable=invalid-name phi_35 = "phi_35" # pylint: disable=invalid-name deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name @@ -228,11 +229,12 @@ class LoraConfig(BaseModel): lora_r: Optional[int] = None lora_alpha: Optional[int] = None lora_fan_in_fan_out: Optional[bool] = None - lora_target_modules: Optional[List[str]] = None + lora_target_modules: Optional[Union[str, List[str]]] = None lora_target_linear: Optional[bool] = None lora_modules_to_save: Optional[List[str]] = None lora_dropout: Optional[float] = 0.0 peft_layers_to_transform: Optional[List[int]] = None + peft_layers_pattern: Optional[List[str]] = None peft: Optional[PeftConfig] = None peft_use_dora: Optional[bool] = None peft_use_rslora: Optional[bool] = None @@ -328,6 +330,9 @@ class ModelInputConfig(BaseModel): tokenizer_type: Optional[str] = Field( default=None, metadata={"help": "transformers tokenizer class"} ) + processor_type: Optional[str] = Field( + default=None, metadata={"help": "transformers processor class"} + ) trust_remote_code: Optional[bool] = None model_kwargs: Optional[Dict[str, Any]] = None @@ -530,6 +535,7 @@ class Config: dataset_prepared_path: Optional[str] = None dataset_shard_num: Optional[int] = None dataset_shard_idx: Optional[int] = None + skip_prepare_dataset: Optional[bool] = False pretraining_dataset: Optional[ # type: ignore conlist(Union[PretrainingDataset, SFTDataset], min_length=1) @@ -997,6 +1003,18 @@ def check_eval_packing(cls, data): return data + @model_validator(mode="before") + @classmethod + def check_mm_prepare(cls, data): + if data.get("skip_prepare_dataset"): + if data.get("remove_unused_columns") is None: + LOG.info( + "setting `remove_unused_columns: false` for skip_prepare_dataset" + ) + data["remove_unused_columns"] = False + + return data + @model_validator(mode="before") @classmethod def check_warmup(cls, data): @@ -1052,6 +1070,15 @@ def check_frozen(cls, data): return data + @model_validator(mode="before") + @classmethod + def check_peft_layers_pattern(cls, data): + if data.get("peft_layers_pattern") and not data.get("peft_layers_to_transform"): + raise ValueError( + "peft_layers_pattern requires peft_layers_to_transform to be set" + ) + return data + @model_validator(mode="after") def check_fft_possible_bad_config(self): if ( diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 1b6df1cded..7d6922cbf2 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -51,20 +51,31 @@ LOG = logging.getLogger("axolotl") -def prepare_dataset(cfg, tokenizer): +def prepare_dataset(cfg, tokenizer, processor=None): prompters = [] if not cfg.pretraining_dataset: with zero_first(is_local_main_process()): if cfg.test_datasets: train_dataset, _, prompters = load_prepare_datasets( - tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train" + tokenizer, + cfg, + DEFAULT_DATASET_PREPARED_PATH, + split="train", + processor=processor, ) _, eval_dataset, _ = load_prepare_datasets( - tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="test" + tokenizer, + cfg, + DEFAULT_DATASET_PREPARED_PATH, + split="test", + processor=processor, ) else: train_dataset, eval_dataset, prompters = load_prepare_datasets( - tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH + tokenizer, + cfg, + DEFAULT_DATASET_PREPARED_PATH, + processor=processor, ) else: path = cfg.pretraining_dataset @@ -123,6 +134,7 @@ def load_tokenized_prepared_datasets( cfg, default_dataset_prepared_path, split="train", + processor=None, ) -> Tuple[DatasetDict, List[Prompter]]: cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets tokenizer_name = cfg.tokenizer_config @@ -180,6 +192,7 @@ def load_tokenized_prepared_datasets( cfg.dataset_prepared_path and any(prepared_ds_path.glob("*")) and not cfg.is_preprocess + and not cfg.skip_prepare_dataset ): LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") dataset = load_from_disk(str(prepared_ds_path)) @@ -423,12 +436,16 @@ def for_d_in_datasets(dataset_configs): dataset=ds, d_base_type=d_base_type, d_prompt_style=d_prompt_style, + processor=processor, ) datasets.append(dataset_wrapper) prompters.append(dataset_prompter) - LOG.info("merging datasets") - dataset = concatenate_datasets(datasets) + if len(datasets) == 1: + dataset = datasets[0] + else: + LOG.info("merging datasets") + dataset = concatenate_datasets(datasets) if len(datasets) > 1: if cfg.shuffle_merged_datasets: @@ -437,9 +454,10 @@ def for_d_in_datasets(dataset_configs): else: LOG.debug("NOT shuffling merged datasets") - dataset, _ = process_datasets_for_packing(cfg, dataset, None) + if not cfg.skip_prepare_dataset: + dataset, _ = process_datasets_for_packing(cfg, dataset, None) - if cfg.local_rank == 0: + if cfg.local_rank == 0 and not cfg.skip_prepare_dataset: LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") dataset.save_to_disk(str(prepared_ds_path)) if cfg.push_dataset_to_hub: @@ -478,9 +496,14 @@ def load_prepare_datasets( cfg, default_dataset_prepared_path, split="train", + processor=None, ) -> Tuple[Dataset, Dataset, List[Prompter]]: dataset, prompters = load_tokenized_prepared_datasets( - tokenizer, cfg, default_dataset_prepared_path, split=split + tokenizer, + cfg, + default_dataset_prepared_path, + split=split, + processor=processor, ) if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None: @@ -546,6 +569,7 @@ def get_dataset_wrapper( d_base_type, dataset, d_prompt_style=None, + processor=None, ): dataset_wrapper = None dataset_prompter = None @@ -578,7 +602,11 @@ def get_dataset_wrapper( dataset, **ds_kwargs, ) - elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset): + elif cfg.skip_prepare_dataset: + dataset_wrapper = dataset + elif ds_strategy := load( + config_dataset.type, tokenizer, cfg, config_dataset, processor=processor + ): dataset_prompter = UnsupportedPrompter() dataset_wrapper = TokenizedPromptDataset( ds_strategy, diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index e183301991..c18af9760f 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -28,12 +28,17 @@ AddedToken, AutoConfig, AutoModelForCausalLM, + AutoModelForVision2Seq, + AutoProcessor, AutoTokenizer, AwqConfig, BitsAndBytesConfig, GPTQConfig, + LlavaForConditionalGeneration, + MllamaForConditionalGeneration, PreTrainedModel, PreTrainedTokenizerBase, + ProcessorMixin, ) from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled @@ -80,6 +85,9 @@ def get_module_class_from_name(module, name): def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]): + if cfg.is_multimodal: + model_config = model_config.text_config + quant_config_exists = ( hasattr(model_config, "quantization_config") and model_config.quantization_config @@ -299,11 +307,31 @@ def load_tokenizer(cfg): return tokenizer +def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): + processor_kwargs: Dict[str, Any] = {} # do we actually need this? + + processor_cls = AutoProcessor + if cfg.processor_type: + processor_cls = getattr(transformers, cfg.processor_type) + + processor = processor_cls.from_pretrained( + cfg.processor_config, + trust_remote_code=cfg.trust_remote_code or False, + tokenizer=tokenizer, + **processor_kwargs, + ) + + return processor + + def load_model( cfg: DictDefault, tokenizer: PreTrainedTokenizerBase, + *, + processor: ProcessorMixin = None, # pylint: disable=unused-argument inference: bool = False, reference_model: bool = False, + **kwargs, # pylint: disable=unused-argument ) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: """ Load a model for a given configuration and tokenizer. @@ -319,12 +347,23 @@ def load_model( plugin_manager = PluginManager.get_instance() plugin_manager.pre_model_load(cfg) + if cfg.is_multimodal: + text_model_config = model_config.text_config + else: + text_model_config = model_config + # TODO refactor as a kwarg load_in_8bit = cfg.load_in_8bit if cfg.gradient_checkpointing == "unsloth": transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper + if hasattr(model_config, "model_type") and model_config.model_type == "mllama": + if cfg.flash_attention: + from axolotl.monkeypatch.attention.mllama import patch_mllama + + patch_mllama() + if hasattr(model_config, "model_type") and model_config.model_type == "btlm": if cfg.flash_attention: from axolotl.monkeypatch.btlm_attn_hijack_flash import ( @@ -461,6 +500,19 @@ def load_model( max_memory = cfg.max_memory device_map = cfg.device_map + AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name + if cfg.is_multimodal: + if model_config.model_type == "llava": + AutoModelLoader = ( # pylint: disable=invalid-name + LlavaForConditionalGeneration + ) + elif model_config.model_type == "mllama": + AutoModelLoader = ( # pylint: disable=invalid-name + MllamaForConditionalGeneration + ) + else: + AutoModelLoader = AutoModelForVision2Seq # pylint: disable=invalid-name + if cfg.gpu_memory_limit: gpu_memory_limit = ( str(cfg.gpu_memory_limit) + "GiB" @@ -478,7 +530,7 @@ def load_model( from accelerate import infer_auto_device_map with init_empty_weights(): - model_canvas = AutoModelForCausalLM.from_config( + model_canvas = AutoModelLoader.from_config( model_config, trust_remote_code=cfg.trust_remote_code or False ) model_canvas.tie_weights() @@ -633,6 +685,8 @@ def load_model( quantization_config = ( quantization_config or model_kwargs["quantization_config"] ) + if cfg.is_multimodal: + model_config.text_config = text_model_config model = load_sharded_model_quant( base_model, model_config, @@ -651,7 +705,9 @@ def load_model( if "device_map" in model_kwargs: del model_kwargs["device_map"] - model = AutoModelForCausalLM.from_pretrained( + if cfg.is_multimodal: + model_config.text_config = text_model_config + model = AutoModelLoader.from_pretrained( base_model, config=model_config, **model_kwargs, @@ -690,13 +746,17 @@ def load_model( and not cfg.trust_remote_code ): if cfg.gptq: - model = AutoModelForCausalLM.from_pretrained( + if cfg.is_multimodal: + model_config.text_config = text_model_config + model = AutoModelLoader.from_pretrained( base_model, config=model_config, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) else: + if cfg.is_multimodal: + model_config.text_config = text_model_config model = getattr(transformers, model_type).from_pretrained( base_model, config=model_config, @@ -707,21 +767,23 @@ def load_model( # Shouldn't be a problem most of the time. will obviously error if the model doesn't support this # when training starts if ( - hasattr(model_config, "max_seq_len") - and model_config.max_seq_len + hasattr(text_model_config, "max_seq_len") + and text_model_config.max_seq_len and cfg.sequence_len > model_config.max_seq_len ): - model_config.max_seq_len = cfg.sequence_len + text_model_config.max_seq_len = cfg.sequence_len LOG.warning(f"increasing context length to {cfg.sequence_len}") elif ( - hasattr(model_config, "max_sequence_length") - and model_config.max_sequence_length - and cfg.sequence_len > model_config.max_sequence_length + hasattr(text_model_config, "max_sequence_length") + and text_model_config.max_sequence_length + and cfg.sequence_len > text_model_config.max_sequence_length ): - model_config.max_sequence_length = cfg.sequence_len + text_model_config.max_sequence_length = cfg.sequence_len LOG.warning(f"increasing context length to {cfg.sequence_len}") if cfg.gptq: - model = AutoModelForCausalLM.from_pretrained( + if cfg.is_multimodal: + model_config.text_config = text_model_config + model = AutoModelLoader.from_pretrained( base_model, config=model_config, trust_remote_code=cfg.trust_remote_code or False, @@ -734,7 +796,9 @@ def load_model( if "device_map" in model_kwargs: del model_kwargs["device_map"] - model = AutoModelForCausalLM.from_pretrained( + if cfg.is_multimodal: + model_config.text_config = text_model_config + model = AutoModelLoader.from_pretrained( base_model, config=model_config, trust_remote_code=cfg.trust_remote_code or False, @@ -1016,12 +1080,17 @@ def load_lora(model, cfg, inference=False, config_only=False): from peft import LoraConfig, get_peft_model - lora_target_modules = list(cfg.lora_target_modules or []) + lora_target_modules = cfg.lora_target_modules or [] if cfg.lora_target_linear: linear_names = find_all_linear_names(model) LOG.info(f"found linear modules: {repr(sorted(linear_names))}") - lora_target_modules = list(set(lora_target_modules + linear_names)) + lora_target_modules_as_list = ( + lora_target_modules + if isinstance(lora_target_modules, list) + else [lora_target_modules] + ) + lora_target_modules = list(set(lora_target_modules_as_list + linear_names)) lora_config_kwargs = {} loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits @@ -1040,6 +1109,7 @@ def load_lora(model, cfg, inference=False, config_only=False): lora_alpha=cfg.lora_alpha, target_modules=lora_target_modules, layers_to_transform=cfg.peft_layers_to_transform, + layers_pattern=cfg.peft_layers_pattern, lora_dropout=cfg.lora_dropout, fan_in_fan_out=cfg.lora_fan_in_fan_out, modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None, diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 89ae4e6970..17276dd8ed 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -306,7 +306,7 @@ def process_pretraining_datasets_for_packing( def calculate_total_num_steps(cfg, train_dataset, update=True): - if not cfg.total_num_tokens: + if not cfg.total_num_tokens and not cfg.skip_prepare_dataset: total_num_tokens = np.sum( train_dataset.data.column("input_ids") .to_pandas() @@ -319,7 +319,11 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): skip_estimates = cfg.model_config_type == "mamba" - if not skip_estimates and not cfg.total_supervised_tokens: + if ( + not skip_estimates + and not cfg.total_supervised_tokens + and not cfg.skip_prepare_dataset + ): total_supervised_tokens = ( train_dataset.data.column("labels") .to_pandas() @@ -478,13 +482,15 @@ def prepare_opinionated_env(cfg): os.environ["TOKENIZERS_PARALLELISM"] = "false" -def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): +def setup_trainer( + cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps +): if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]: - trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer) + trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor) trainer_builder.model_ref = model[1] trainer_builder.peft_config = model[2] else: - trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer) + trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer, processor) trainer_builder.train_dataset = train_dataset trainer_builder.eval_dataset = eval_dataset diff --git a/tests/prompt_strategies/test_chat_templates.py b/tests/prompt_strategies/test_chat_templates.py index 28210b7ae8..20533504ce 100644 --- a/tests/prompt_strategies/test_chat_templates.py +++ b/tests/prompt_strategies/test_chat_templates.py @@ -73,7 +73,7 @@ def test_llama3(self, llama3_tokenizer, assistant_dataset): strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, - chat_templates("llama3"), + chat_template=chat_templates("llama3"), message_field_role="role", message_field_content="content", roles={ @@ -113,7 +113,7 @@ def test_phi35(self, phi35_tokenizer, assistant_dataset): strategy = ChatTemplateStrategy( ChatTemplatePrompter( phi35_tokenizer, - chat_templates("phi_35"), + chat_template=chat_templates("phi_35"), message_field_role="role", message_field_content="content", roles={ @@ -171,7 +171,7 @@ def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset): strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, - chat_templates("llama3"), + chat_template=chat_templates("llama3"), message_field_role="role", message_field_content="content", message_field_training="training", @@ -227,8 +227,11 @@ class TestSharegptChatTemplateLlama3: def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset): LOG.info("Testing ShareGPT style datasets with llama-3 assistant prompts") + # pylint: disable=duplicate-code strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, train_on_eos="none", @@ -277,8 +280,11 @@ def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset): def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset): LOG.info("Testing ShareGPT style datasets with llama-3 human prompts") + # pylint: disable=duplicate-code strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, train_on_eos="none", @@ -327,8 +333,11 @@ def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset): def test_llama3_system_human(self, llama3_tokenizer, basic_dataset): LOG.info("Testing ShareGPT style datasets with llama-3 system/human prompts") + # pylint: disable=duplicate-code strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, train_on_eos="none", diff --git a/tests/prompt_strategies/test_chat_templates_advanced.py b/tests/prompt_strategies/test_chat_templates_advanced.py index f18fb39423..50429e3a26 100644 --- a/tests/prompt_strategies/test_chat_templates_advanced.py +++ b/tests/prompt_strategies/test_chat_templates_advanced.py @@ -34,7 +34,9 @@ def find_sublist(full_list, sub_list): def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_inputs=True") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=True, sequence_len=512, @@ -77,7 +79,9 @@ def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset): def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_inputs=False") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -118,7 +122,9 @@ def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset): def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset): LOG.info("Testing roles_to_train with assistant only") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -144,7 +150,9 @@ def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset): def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset): LOG.info("Testing roles_to_train with all roles") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=True, sequence_len=512, @@ -175,7 +183,9 @@ def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset): def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with empty roles_to_train") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -194,7 +204,9 @@ def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset): def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_eos='all'") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -219,7 +231,9 @@ def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset): def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_eos='turn'") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -267,7 +281,9 @@ def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset): def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_eos='last'") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -298,7 +314,9 @@ def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset): def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_eos='none'") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -324,7 +342,9 @@ def test_drop_system_message(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with drop_system_message=True") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_templates("llama3"), drop_system_message=True + llama3_tokenizer, + chat_template=chat_templates("llama3"), + drop_system_message=True, ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -350,7 +370,9 @@ def test_custom_roles(self, llama3_tokenizer): } strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_templates("llama3"), roles=custom_roles + llama3_tokenizer, + chat_template=chat_templates("llama3"), + roles=custom_roles, ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -402,7 +424,7 @@ def test_message_field_training(self, llama3_tokenizer): strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, - chat_templates("llama3"), + chat_template=chat_templates("llama3"), message_field_training="train", message_field_training_detail="train_detail", ),