diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 23ac0952ed..249398f850 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -456,7 +456,7 @@ def create_optimizer(self): if self.args.loraplus_lr_ratio is not None: loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) loraplus_lr_embedding = getattr( - self.args, "loraplus_lr_embedding", None + self.args, "loraplus_lr_embedding", 1e-6 ) self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init opt_model, diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 2217855083..4e07c9260a 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -298,6 +298,13 @@ def validate_qlora(self): raise ValueError("Require cfg.load_in_4bit to be True for qlora") return self + @field_validator("loraplus_lr_embedding") + @classmethod + def convert_loraplus_lr_embedding(cls, loraplus_lr_embedding): + if loraplus_lr_embedding and isinstance(loraplus_lr_embedding, str): + loraplus_lr_embedding = float(loraplus_lr_embedding) + return loraplus_lr_embedding + class ReLoRAConfig(BaseModel): """ReLoRA configuration subset"""