Skip to content

Commit

Permalink
[Trainer] Correct behavior of _load_best_model for PEFT models (#…
Browse files Browse the repository at this point in the history
…24103)

* v1

* some refactor

- add ST format as well

* fix

* add `ADAPTER_WEIGHTS_NAME` & `ADAPTER_SAFE_WEIGHTS_NAME`
  • Loading branch information
younesbelkada authored and sgugger committed Jun 8, 2023
1 parent 17db177 commit 53e1f5c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
28 changes: 20 additions & 8 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@
)
from .training_args import OptimizerNames, ParallelMode, TrainingArguments
from .utils import (
ADAPTER_SAFE_WEIGHTS_NAME,
ADAPTER_WEIGHTS_NAME,
CONFIG_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
Expand Down Expand Up @@ -2177,11 +2179,20 @@ def _load_best_model(self):
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME)
best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME)
best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)

model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if os.path.exists(best_model_path) or os.path.exists(best_safe_model_path):
if (
os.path.exists(best_model_path)
or os.path.exists(best_safe_model_path)
or os.path.exists(best_adapter_model_path)
or os.path.exists(best_safe_adapter_model_path)
):
if self.is_deepspeed_enabled:
deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint)
else:
has_been_loaded = True
if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
# If the 'user_content.pt' file exists, load with the new smp api.
Expand All @@ -2207,10 +2218,10 @@ def _load_best_model(self):
self.accelerator, model, self.state.best_model_checkpoint
)
else:
if hasattr(model, "base_model") and getattr(model.base_model, "is_8bit_serializable", False):
# If train base_8_bit_models using PEFT & LoRA, assume that adapter have been saved properly.
if is_peft_available() and isinstance(model, PeftModel):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
if os.path.exists(os.path.join(self.state.best_model_checkpoint, "adapter_model.bin")):
if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
# Load_adapter has no return value present, modify it when appropriate.
from torch.nn.modules.module import _IncompatibleKeys
Expand All @@ -2219,12 +2230,13 @@ def _load_best_model(self):
else:
logger.warning(
"The intermediate checkpoints of PEFT may not be saved correctly, "
"using `TrainerCallback` to save adapter_model.bin in corresponding folders, "
f"using `TrainerCallback` to save {ADAPTER_WEIGHTS_NAME} in corresponding folders, "
"here are some examples https://github.com/huggingface/peft/issues/96"
)
has_been_loaded = False
else:
# We can't do pure 8bit training using transformers.
logger.warning("Could not loading a quantized checkpoint.")
logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")
has_been_loaded = False
else:
# We load the model state dict on the CPU to avoid an OOM error.
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
Expand All @@ -2236,7 +2248,7 @@ def _load_best_model(self):
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# which takes *args instead of **kwargs
load_result = model.load_state_dict(state_dict, False)
if not is_sagemaker_mp_enabled():
if not is_sagemaker_mp_enabled() and has_been_loaded:
self._issue_warnings_after_load(load_result)
elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
load_result = load_sharded_checkpoint(
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@

WEIGHTS_NAME = "pytorch_model.bin"
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
ADAPTER_WEIGHTS_NAME = "adapter_model.bin"
ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors"
TF2_WEIGHTS_NAME = "tf_model.h5"
TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
TF_WEIGHTS_NAME = "model.ckpt"
Expand Down

0 comments on commit 53e1f5c

Please sign in to comment.