Skip to content

Commit

Permalink
FIX Warning abt config.json when the base model is local. (#1668)
Browse files Browse the repository at this point in the history
Fix incorrect warning when loading local model.
  • Loading branch information
elementary-particle authored May 21, 2024
1 parent 691bc22 commit bc6a999
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/peft/utils/save_and_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,25 +185,26 @@ def get_peft_model_state_dict(

# For some models e.g. diffusers the text config file is stored in a subfolder
# we need to make sure we can download that config.
has_remote_config = False
has_base_config = False

# ensure that this check is not performed in HF offline mode, see #1452
if model_id is not None:
exists = check_file_exists_on_hf_hub(model_id, "config.json")
local_config_exists = os.path.exists(os.path.join(model_id, "config.json"))
exists = local_config_exists or check_file_exists_on_hf_hub(model_id, "config.json")
if exists is None:
# check failed, could not determine if it exists or not
warnings.warn(
f"Could not find a config file in {model_id} - will assume that the vocabulary was not modified."
)
has_remote_config = False
has_base_config = False
else:
has_remote_config = exists
has_base_config = exists

# check if the vocab size of the base model is different from the vocab size of the finetuned model
if (
vocab_size
and model_id
and has_remote_config
and has_base_config
and (vocab_size != model.config.__class__.from_pretrained(model_id).vocab_size)
):
warnings.warn(
Expand Down
22 changes: 22 additions & 0 deletions tests/test_hub_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,28 @@ def test_subfolder(self):
assert isinstance(model, PeftModel)


class TestLocalModel:
def test_local_model_saving_no_warning(self, recwarn, tmp_path):
# When the model is saved, the library checks for vocab changes by
# examining `config.json` in the model path.
# However, previously, those checks only covered huggingface hub models.
# This test makes sure that the local `config.json` is checked as well.
# If `save_pretrained` could not find the file, it will issue a warning.
model_id = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_id)
local_dir = tmp_path / model_id
model.save_pretrained(local_dir)
del model

base_model = AutoModelForCausalLM.from_pretrained(local_dir)
peft_config = LoraConfig()
peft_model = get_peft_model(base_model, peft_config)
peft_model.save_pretrained(local_dir)

for warning in recwarn.list:
assert "Could not find a config file" not in warning.message.args[0]


class TestBaseModelRevision:
def test_save_and_load_base_model_revision(self, tmp_path):
r"""
Expand Down

0 comments on commit bc6a999

Please sign in to comment.