Skip to content

Commit

Permalink
Fix target_modules type in config.from_pretrained
Browse files Browse the repository at this point in the history
Fixes huggingface#1045, supersedes huggingface#1041

Description

When loading a config from a file, we currently set the loaded
attributes on the config directly. However, this sidesteps the
__post_init__ call, which is required to convert the target_modules to a
set. This PR fixes this by avoiding to set attributes on the config
class directly, instead of going through __init__.

Other changes

While working on this, I did a slight refactor of the config tests.

1. All config classes are included now (some where missing before).
2. Use parameterized instead of looping through the classes.
3. Added a unit test for the aforementioned bug.

Notes

This fix is based on my comment here:

huggingface#1041 (review)

Normally, I'd just wait for Sourab to reply, but since he's off for a
few days, I created a separate PR.

Another way we could achieve this is to override __setattr__ on the
config class which explicitly converts target_modules to a set. This
would cover the case where a user does something like:

config = ...
config.target_modules = ["a", "b", "c"]

Then we don't need to rely on __post_init__. However, I would propose to
save the heavy guns for when absolutely necessary.
  • Loading branch information
BenjaminBossan committed Oct 23, 2023
1 parent aaa7e9f commit 4f3cef2
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 82 deletions.
8 changes: 2 additions & 6 deletions src/peft/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, subfolder: Optional
else:
config_cls = cls

config = config_cls(**class_kwargs)

for key, value in loaded_attributes.items():
if hasattr(config, key):
setattr(config, key, value)

kwargs = {**class_kwargs, **loaded_attributes}
config = config_cls(**kwargs)
return config

@classmethod
Expand Down
168 changes: 92 additions & 76 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@
import warnings

import pytest
from parameterized import parameterized

from peft import (
AdaLoraConfig,
AdaptionPromptConfig,
IA3Config,
LoHaConfig,
LoraConfig,
MultitaskPromptTuningConfig,
PeftConfig,
PrefixTuningConfig,
PromptEncoder,
Expand All @@ -35,20 +39,22 @@

PEFT_MODELS_TO_TEST = [("lewtun/tiny-random-OPTForCausalLM-delta", "v1")]


class PeftConfigTestMixin:
all_config_classes = (
LoraConfig,
PromptEncoderConfig,
PrefixTuningConfig,
PromptTuningConfig,
AdaptionPromptConfig,
IA3Config,
)
ALL_CONFIG_CLASSES = (
AdaptionPromptConfig,
AdaLoraConfig,
IA3Config,
LoHaConfig,
LoraConfig,
MultitaskPromptTuningConfig,
PrefixTuningConfig,
PromptEncoderConfig,
PromptTuningConfig,
)


class PeftConfigTester(unittest.TestCase, PeftConfigTestMixin):
def test_methods(self):
class PeftConfigTester(unittest.TestCase):
@parameterized.expand(ALL_CONFIG_CLASSES)
def test_methods(self, config_class):
r"""
Test if all configs have the expected methods. Here we test
- to_dict
Expand All @@ -57,109 +63,107 @@ def test_methods(self):
- from_json_file
"""
# test if all configs have the expected methods
for config_class in self.all_config_classes:
config = config_class()
self.assertTrue(hasattr(config, "to_dict"))
self.assertTrue(hasattr(config, "save_pretrained"))
self.assertTrue(hasattr(config, "from_pretrained"))
self.assertTrue(hasattr(config, "from_json_file"))

def test_task_type(self):
for config_class in self.all_config_classes:
# assert this will not fail
_ = config_class(task_type="test")

def test_from_pretrained(self):
config = config_class()
self.assertTrue(hasattr(config, "to_dict"))
self.assertTrue(hasattr(config, "save_pretrained"))
self.assertTrue(hasattr(config, "from_pretrained"))
self.assertTrue(hasattr(config, "from_json_file"))

@parameterized.expand(ALL_CONFIG_CLASSES)
def test_task_type(self, config_class):
config_class(task_type="test")

@parameterized.expand(ALL_CONFIG_CLASSES)
def test_from_pretrained(self, config_class):
r"""
Test if the config is correctly loaded using:
- from_pretrained
"""
for config_class in self.all_config_classes:
for model_name, revision in PEFT_MODELS_TO_TEST:
# Test we can load config from delta
_ = config_class.from_pretrained(model_name, revision=revision)
for model_name, revision in PEFT_MODELS_TO_TEST:
# Test we can load config from delta
config_class.from_pretrained(model_name, revision=revision)

def test_save_pretrained(self):
@parameterized.expand(ALL_CONFIG_CLASSES)
def test_save_pretrained(self, config_class):
r"""
Test if the config is correctly saved and loaded using
- save_pretrained
"""
for config_class in self.all_config_classes:
config = config_class()
with tempfile.TemporaryDirectory() as tmp_dirname:
config.save_pretrained(tmp_dirname)
config = config_class()
with tempfile.TemporaryDirectory() as tmp_dirname:
config.save_pretrained(tmp_dirname)

config_from_pretrained = config_class.from_pretrained(tmp_dirname)
self.assertEqual(config.to_dict(), config_from_pretrained.to_dict())
config_from_pretrained = config_class.from_pretrained(tmp_dirname)
self.assertEqual(config.to_dict(), config_from_pretrained.to_dict())

def test_from_json_file(self):
for config_class in self.all_config_classes:
config = config_class()
with tempfile.TemporaryDirectory() as tmp_dirname:
config.save_pretrained(tmp_dirname)
@parameterized.expand(ALL_CONFIG_CLASSES)
def test_from_json_file(self, config_class):
config = config_class()
with tempfile.TemporaryDirectory() as tmp_dirname:
config.save_pretrained(tmp_dirname)

config_from_json = config_class.from_json_file(os.path.join(tmp_dirname, "adapter_config.json"))
self.assertEqual(config.to_dict(), config_from_json)
config_from_json = config_class.from_json_file(os.path.join(tmp_dirname, "adapter_config.json"))
self.assertEqual(config.to_dict(), config_from_json)

def test_to_dict(self):
@parameterized.expand(ALL_CONFIG_CLASSES)
def test_to_dict(self, config_class):
r"""
Test if the config can be correctly converted to a dict using:
- to_dict
"""
for config_class in self.all_config_classes:
config = config_class()
self.assertTrue(isinstance(config.to_dict(), dict))
config = config_class()
self.assertTrue(isinstance(config.to_dict(), dict))

def test_from_pretrained_cache_dir(self):
@parameterized.expand(ALL_CONFIG_CLASSES)
def test_from_pretrained_cache_dir(self, config_class):
r"""
Test if the config is correctly loaded with extra kwargs
"""
with tempfile.TemporaryDirectory() as tmp_dirname:
for config_class in self.all_config_classes:
for model_name, revision in PEFT_MODELS_TO_TEST:
# Test we can load config from delta
_ = config_class.from_pretrained(model_name, revision=revision, cache_dir=tmp_dirname)
for model_name, revision in PEFT_MODELS_TO_TEST:
# Test we can load config from delta
config_class.from_pretrained(model_name, revision=revision, cache_dir=tmp_dirname)

def test_from_pretrained_cache_dir_remote(self):
r"""
Test if the config is correctly loaded with a checkpoint from the hub
"""
with tempfile.TemporaryDirectory() as tmp_dirname:
_ = PeftConfig.from_pretrained("ybelkada/test-st-lora", cache_dir=tmp_dirname)
PeftConfig.from_pretrained("ybelkada/test-st-lora", cache_dir=tmp_dirname)
self.assertTrue("models--ybelkada--test-st-lora" in os.listdir(tmp_dirname))

def test_set_attributes(self):
@parameterized.expand(ALL_CONFIG_CLASSES)
def test_set_attributes(self, config_class):
# manually set attributes and check if they are correctly written
for config_class in self.all_config_classes:
config = config_class(peft_type="test")
config = config_class(peft_type="test")

# save pretrained
with tempfile.TemporaryDirectory() as tmp_dirname:
config.save_pretrained(tmp_dirname)
# save pretrained
with tempfile.TemporaryDirectory() as tmp_dirname:
config.save_pretrained(tmp_dirname)

config_from_pretrained = config_class.from_pretrained(tmp_dirname)
self.assertEqual(config.to_dict(), config_from_pretrained.to_dict())
config_from_pretrained = config_class.from_pretrained(tmp_dirname)
self.assertEqual(config.to_dict(), config_from_pretrained.to_dict())

def test_config_copy(self):
@parameterized.expand(ALL_CONFIG_CLASSES)
def test_config_copy(self, config_class):
# see https://github.com/huggingface/peft/issues/424
for config_class in self.all_config_classes:
config = config_class()
copied = copy.copy(config)
self.assertEqual(config.to_dict(), copied.to_dict())
config = config_class()
copied = copy.copy(config)
self.assertEqual(config.to_dict(), copied.to_dict())

def test_config_deepcopy(self):
@parameterized.expand(ALL_CONFIG_CLASSES)
def test_config_deepcopy(self, config_class):
# see https://github.com/huggingface/peft/issues/424
for config_class in self.all_config_classes:
config = config_class()
copied = copy.deepcopy(config)
self.assertEqual(config.to_dict(), copied.to_dict())
config = config_class()
copied = copy.deepcopy(config)
self.assertEqual(config.to_dict(), copied.to_dict())

def test_config_pickle_roundtrip(self):
@parameterized.expand(ALL_CONFIG_CLASSES)
def test_config_pickle_roundtrip(self, config_class):
# see https://github.com/huggingface/peft/issues/424
for config_class in self.all_config_classes:
config = config_class()
copied = pickle.loads(pickle.dumps(config))
self.assertEqual(config.to_dict(), copied.to_dict())
config = config_class()
copied = pickle.loads(pickle.dumps(config))
self.assertEqual(config.to_dict(), copied.to_dict())

def test_prompt_encoder_warning_num_layers(self):
# This test checks that if a prompt encoder config is created with an argument that is ignored, there should be
Expand All @@ -182,3 +186,15 @@ def test_prompt_encoder_warning_num_layers(self):
PromptEncoder(config)
expected_msg = "for MLP, the argument `encoder_num_layers` is ignored. Exactly 2 MLP layers are used."
assert str(record.list[0].message) == expected_msg

@parameterized.expand([LoHaConfig, LoraConfig, IA3Config])
def test_save_pretrained_with_target_modules(self, config_class):
# See #1041, #1045
config = config_class(target_modules=["a", "list"])
with tempfile.TemporaryDirectory() as tmp_dirname:
config.save_pretrained(tmp_dirname)

config_from_pretrained = config_class.from_pretrained(tmp_dirname)
self.assertEqual(config.to_dict(), config_from_pretrained.to_dict())
# explicit test that target_modules should be converted to set
self.assertTrue(isinstance(config_from_pretrained.target_modules, set))

0 comments on commit 4f3cef2

Please sign in to comment.