Skip to content

Commit

Permalink
adding model_cfg to set num_labels
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi committed Jan 7, 2025
1 parent 3915abe commit 2ca689c
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 8 deletions.
2 changes: 2 additions & 0 deletions examples/gemma2/reward-model.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
base_model: google/gemma-2-2b
# optionally might have model_type or tokenizer_type
model_type: AutoModelForSequenceClassification
model_cfg:
num_labels: 1
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
Expand Down
2 changes: 1 addition & 1 deletion examples/qwen2/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ datasets:

dataset_prepared_path:
val_set_size: 0.0
output_dir: ./outputs/dpo-out
output_dir: ./outputs/out

sequence_len: 2048
sample_packing: false
Expand Down
68 changes: 68 additions & 0 deletions examples/qwen2/reward-model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
base_model: Qwen/Qwen2.5-0.5B
# optionally might have model_type or tokenizer_type
model_type: AutoModelForSequenceClassification
model_cfg:
num_labels: 1
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

load_in_8bit: false
load_in_4bit: false
strict: false

reward_model: true
chat_template: qwen_25
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.0
output_dir: ./outputs/out
remove_unused_columns: false

sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:


gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true

gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_ratio: 0.1
evals_per_epoch:
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
2 changes: 1 addition & 1 deletion src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):

# load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
cfg = DictDefault(yaml.safe_load(file))
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ class ModelInputConfig(BaseModel):
trust_remote_code: Optional[bool] = None

model_kwargs: Optional[Dict[str, Any]] = None
model_cfg: Optional[Dict[str, Any]] = None

@field_validator("trust_remote_code")
@classmethod
Expand Down
14 changes: 8 additions & 6 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@
from axolotl.common.architectures import MOE_ARCH_BLOCK
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
patch_for_multipack,
SUPPORTED_MULTIPACK_MODEL_TYPES,
)
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
Expand Down Expand Up @@ -138,7 +138,9 @@ def load_model_config(cfg):
config_kwargs = {}
if cfg.revision_of_model:
config_kwargs["revision"] = cfg.revision_of_model

if cfg.model_cfg:
for k, v in cfg.model_cfg.items():
config_kwargs[k] = v
try:
model_config = AutoConfig.from_pretrained(
model_config_name,
Expand Down Expand Up @@ -641,9 +643,9 @@ def set_quantization_config(self) -> None:
)
else:
if self.cfg.gptq_disable_exllama is not None:
self.model_config.quantization_config[
"disable_exllama"
] = self.cfg.gptq_disable_exllama
self.model_config.quantization_config["disable_exllama"] = (
self.cfg.gptq_disable_exllama
)
self.model_kwargs["quantization_config"] = GPTQConfig(
**self.model_config.quantization_config
)
Expand Down Expand Up @@ -1290,7 +1292,7 @@ def setup_quantized_peft_meta_for_training(model: nn.Module):
def load_lora(model, cfg, inference=False, config_only=False):
# type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]

from peft import LoraConfig, get_peft_model
from peft import get_peft_model, LoraConfig

lora_target_modules = cfg.lora_target_modules or []

Expand Down

0 comments on commit 2ca689c

Please sign in to comment.