diff --git a/examples/gemma2/reward-model.yaml b/examples/gemma2/reward-model.yaml index b492c6f93..ef9dfce07 100644 --- a/examples/gemma2/reward-model.yaml +++ b/examples/gemma2/reward-model.yaml @@ -1,6 +1,8 @@ base_model: google/gemma-2-2b # optionally might have model_type or tokenizer_type model_type: AutoModelForSequenceClassification +model_cfg: + num_labels: 1 tokenizer_type: AutoTokenizer # Automatically upload checkpoint and final model to HF # hub_model_id: username/custom_model_name diff --git a/examples/qwen2/dpo.yaml b/examples/qwen2/dpo.yaml index e924be195..c3d1632d8 100644 --- a/examples/qwen2/dpo.yaml +++ b/examples/qwen2/dpo.yaml @@ -24,7 +24,7 @@ datasets: dataset_prepared_path: val_set_size: 0.0 -output_dir: ./outputs/dpo-out +output_dir: ./outputs/out sequence_len: 2048 sample_packing: false diff --git a/examples/qwen2/reward-model.yaml b/examples/qwen2/reward-model.yaml new file mode 100644 index 000000000..820f70656 --- /dev/null +++ b/examples/qwen2/reward-model.yaml @@ -0,0 +1,68 @@ +base_model: Qwen/Qwen2.5-0.5B +# optionally might have model_type or tokenizer_type +model_type: AutoModelForSequenceClassification +model_cfg: + num_labels: 1 +tokenizer_type: AutoTokenizer +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +reward_model: true +chat_template: qwen_25 +datasets: + - path: argilla/distilabel-intel-orca-dpo-pairs + type: bradley_terry.chat_template +val_set_size: 0.0 +output_dir: ./outputs/out +remove_unused_columns: false + +sequence_len: 2048 +sample_packing: false +eval_sample_packing: false +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index d07b10ce3..917f682d1 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -411,7 +411,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): # load the config from the yaml file with open(config, encoding="utf-8") as file: - cfg: DictDefault = DictDefault(yaml.safe_load(file)) + cfg = DictDefault(yaml.safe_load(file)) # if there are any options passed in the cli, if it is something that seems valid from the yaml, # then overwrite the value cfg_keys = cfg.keys() diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 0781c6798..512fced0c 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -395,6 +395,7 @@ class ModelInputConfig(BaseModel): trust_remote_code: Optional[bool] = None model_kwargs: Optional[Dict[str, Any]] = None + model_cfg: Optional[Dict[str, Any]] = None @field_validator("trust_remote_code") @classmethod diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 523fd76fe..26d88452c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -50,8 +50,8 @@ from axolotl.common.architectures import MOE_ARCH_BLOCK from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.monkeypatch.multipack import ( - SUPPORTED_MULTIPACK_MODEL_TYPES, patch_for_multipack, + SUPPORTED_MULTIPACK_MODEL_TYPES, ) from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.utils.bench import log_gpu_memory_usage @@ -138,7 +138,9 @@ def load_model_config(cfg): config_kwargs = {} if cfg.revision_of_model: config_kwargs["revision"] = cfg.revision_of_model - + if cfg.model_cfg: + for k, v in cfg.model_cfg.items(): + config_kwargs[k] = v try: model_config = AutoConfig.from_pretrained( model_config_name, @@ -641,9 +643,9 @@ def set_quantization_config(self) -> None: ) else: if self.cfg.gptq_disable_exllama is not None: - self.model_config.quantization_config[ - "disable_exllama" - ] = self.cfg.gptq_disable_exllama + self.model_config.quantization_config["disable_exllama"] = ( + self.cfg.gptq_disable_exllama + ) self.model_kwargs["quantization_config"] = GPTQConfig( **self.model_config.quantization_config ) @@ -1290,7 +1292,7 @@ def setup_quantized_peft_meta_for_training(model: nn.Module): def load_lora(model, cfg, inference=False, config_only=False): # type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]] - from peft import LoraConfig, get_peft_model + from peft import get_peft_model, LoraConfig lora_target_modules = cfg.lora_target_modules or []