From f0ca96200d40138e3e2b541ce9789e1119131a99 Mon Sep 17 00:00:00 2001 From: calpt Date: Mon, 30 Dec 2024 00:12:37 +0100 Subject: [PATCH 1/5] Fix default Lora/ (IA)^3 scaling in forward (#770) Resolves issue described in #760. **IMPORTANT**: this fix restores weights compatibility with adapter-transformers. Compatibility to previous adapters versions is kept via a compat patch. ## Details The current implementation of LoRA/ (IA)^3 in `adapters ` versions < 1.1.0 does not correctly implement adapter states scaling via the LoRA `alpha` attribute, effectively ignoring `alpha` and always applying a scaling of 1.0. This PR restores the correct original behavior as found in adapter-transformers/ original LoRA implementation. As this change breaks all adapters pre-trained using `adapters` versions 0.1.0 - 1.0.1, a backward compatibility patch is introduced that automatically sets `alpha = r` for LoRAs for adapters that were trained using affected versions. This ensures all previous adapters continue to behave exactly as trained (ie give the exact same output using newer versions). --------- Co-authored-by: TimoImhof <62378375+TimoImhof@users.noreply.github.com> --- setup.py | 4 +++- src/adapters/__init__.py | 2 +- src/adapters/loading.py | 20 ++++++++++++++++++++ src/adapters/methods/lora.py | 2 ++ 4 files changed, 26 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index e7389c8be6..1666ae3d0a 100644 --- a/setup.py +++ b/setup.py @@ -34,6 +34,7 @@ "isort>=5.5.4", "Jinja2==2.11.3", "nltk", + "packaging", "parameterized", "pillow", "protobuf", @@ -136,11 +137,12 @@ def deps_list(*pkgs): # when modifying the following list, make sure to update src/transformers/dependency_versions_check.py install_requires = [ deps["transformers"], + deps["packaging"], ] setup( name="adapters", - version="1.0.1", + version="1.1.0.dev0", author="The AdapterHub team and community contributors", author_email="calpt@mail.de", description="A Unified Library for Parameter-Efficient and Modular Transfer Learning", diff --git a/src/adapters/__init__.py b/src/adapters/__init__.py index a917828e72..88549c6969 100644 --- a/src/adapters/__init__.py +++ b/src/adapters/__init__.py @@ -16,7 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.0.1" +__version__ = "1.1.0.dev0" from typing import TYPE_CHECKING diff --git a/src/adapters/loading.py b/src/adapters/loading.py index b1918b0a0f..69747e04cb 100644 --- a/src/adapters/loading.py +++ b/src/adapters/loading.py @@ -6,6 +6,7 @@ from typing import Callable, Mapping, Optional, Sequence, Tuple import torch +from packaging.version import Version try: @@ -368,6 +369,23 @@ def _rename_legacy_weights(self, k): k = k.replace(old, new) return k + def _fix_backward_compat(self, config): + # Fix error in previous versions for LoRA/ (IA)^3 + ADAPTER_PREFIX = "adapters." + MIN_VERSION = Version("1.1.0") + + version = config.get("version", "") + if version.startswith(ADAPTER_PREFIX) and Version(version[len(ADAPTER_PREFIX) :]) < MIN_VERSION: + if ( + config["config"].get("architecture", None) == "lora" + and config["config"]["r"] != config["config"]["alpha"] + ): + logger.warning( + "Loading a LoRA trained using a faulty scaling implementation of a previous library version. Editing the configuration to make sure the adapter works as trained." + "See https://github.com/adapter-hub/adapters/pull/770 for more." + ) + config["config"]["alpha"] = config["config"]["r"] + # This method is used to remove unnecessary invertible adapters from task adapters using the old format. # In the old format, task adapters e.g. using seq_bn config specify inv. adapters but don't use them. # As inv. adapters would be incorrectly used in the new implementation, @@ -560,6 +578,8 @@ def load( # The conversion to a set and then back to a list removes all duplicates leave_out = list(set(leave_out + config["config"]["leave_out"])) config["config"]["leave_out"] = leave_out + # Fix issues + self._fix_backward_compat(config) adapter_name = load_as or config["name"] # If the adapter is not part of the model, add it diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py index d56a11a91d..8f3bc29401 100644 --- a/src/adapters/methods/lora.py +++ b/src/adapters/methods/lora.py @@ -100,6 +100,7 @@ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tens hidden_states = hidden_states * gate else: gate = None + hidden_states = hidden_states * self.scaling return hidden_states, gate @@ -171,6 +172,7 @@ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tens hidden_states = hidden_states * gate else: gate = None + hidden_states = hidden_states * self.scaling return hidden_states, gate From d6054cb5d8230a39e834518560c7f6b26b9acf89 Mon Sep 17 00:00:00 2001 From: Julian Fong <44014224+julian-fong@users.noreply.github.com> Date: Mon, 6 Jan 2025 12:04:32 -0500 Subject: [PATCH 2/5] [BUG] Fix `AdapterPlus` config (#775) This pr fixes the configuration parameters set in the `AdapterPlusConfig` edit: This pr also incorporates some updates as described inside the comments in #764 1) Added some more information regarding training configurations inside the `AdapterPlusConfig` and its corresponding notebook 2) Added more info regarding layer norms inside the documentation --- docs/methods.md | 5 +++++ notebooks/ViT_AdapterPlus_FineTuning.ipynb | 15 +++++++++++++-- src/adapters/configuration/adapter_config.py | 11 ++++++++++- 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/docs/methods.md b/docs/methods.md index 302b1973d3..95226e3578 100644 --- a/docs/methods.md +++ b/docs/methods.md @@ -59,6 +59,11 @@ _Papers:_ * [Adapters Strike Back](https://arxiv.org/pdf/2406.06820) (Steitz and Roth., 2024) * [AdapterHub: A Framework for Adapting Transformers](https://arxiv.org/pdf/2007.07779.pdf) (Pfeiffer et al., 2020) +```{eval-rst} +.. note:: + The two parameters ``original_ln_before`` and ``original_ln_after`` inside bottleneck adapters control both the addition of the residual input and the application of the pretrained layer norm. If the original model does not apply a layer norm function at a specific position of the forward function (e.g after the FFN layer), the two bottleneck parameters of the adapter set at that same position will only control the application of the residual input. +``` + ## Language Adapters - Invertible Adapters _Configuration class_: [`SeqBnInvConfig`](adapters.SeqBnInvConfig), [`DoubleSeqBnInvConfig`](adapters.DoubleSeqBnInvConfig) diff --git a/notebooks/ViT_AdapterPlus_FineTuning.ipynb b/notebooks/ViT_AdapterPlus_FineTuning.ipynb index 1cf549ea75..6833a6b0e1 100644 --- a/notebooks/ViT_AdapterPlus_FineTuning.ipynb +++ b/notebooks/ViT_AdapterPlus_FineTuning.ipynb @@ -205,7 +205,18 @@ "source": [ "### Loading the `ViT` model and the `AdapterPlusConfig`\n", "\n", - "Here we load the `vit-base-patch16-224-in21k` model similar to the one used in the `AdapterConfig` paper. We will load the model using the `adapters` `AutoAdapterModel` and add the corresponding `AdapterPlusConfig`. To read more about the config, you can check out the docs page [here](https://docs.adapterhub.ml/methods#bottleneck-adapters) under `AdapterPlusConfig`" + "Here we load the `vit-base-patch16-224-in21k` model similar to the one used in the `AdapterConfig` paper. We will load the model using the `adapters` `AutoAdapterModel` and add the corresponding `AdapterPlusConfig`. To read more about the config, you can check out the docs page [here](https://docs.adapterhub.ml/methods#bottleneck-adapters) under `AdapterPlusConfig`.\n", + "\n", + "#### Important Note\n", + "\n", + "Please note that some configurations of the adapters parameters `original_ln_after`, `original_ln_before`, and \n", + "`residual_before_ln` may result in performance issues when training. \n", + "\n", + "In the general case:\n", + "\n", + "1) At least one of `original_ln_before` or `original_ln_after` should be set to `True` in order to ensure that the original residual\n", + " connection from pre-training is preserved. \n", + "2) If `original_ln_after` is set to `False`, `residual_before_ln` must also be set to `False` to ensure convergence during training." ] }, { @@ -218,7 +229,7 @@ "from adapters import AdapterPlusConfig\n", "\n", "model = ViTAdapterModel.from_pretrained(model_name_or_path)\n", - "config = AdapterPlusConfig(original_ln_after=True)\n", + "config = AdapterPlusConfig()\n", "\n", "model.add_adapter(\"adapterplus_config\", config)\n", "model.add_image_classification_head(\"adapterplus_config\", num_labels=num_classes)\n", diff --git a/src/adapters/configuration/adapter_config.py b/src/adapters/configuration/adapter_config.py index b5249cb9f5..9e1cf052ac 100644 --- a/src/adapters/configuration/adapter_config.py +++ b/src/adapters/configuration/adapter_config.py @@ -374,10 +374,19 @@ class ParBnConfig(BnConfig): class AdapterPlusConfig(BnConfig): """ The AdapterPlus config architecture proposed by Jan-Martin O, Steitz and Stefan Roth. See https://arxiv.org/pdf/2406.06820 + + Please note that some configurations of the adapters parameters `original_ln_after`, `original_ln_before`, and + `residual_before_ln` may result in performance issues when training. + + In the general case: + 1) At least one of `original_ln_before` or `original_ln_after` should be set to True in order to ensure that the original residual + connection from pre-training is preserved. + 2) If `original_ln_after` is set to `False`, `residual_before_ln` must also be set to `False` to ensure convergence during training. """ original_ln_after: bool = False - residual_before_ln: bool = True + original_ln_before: bool = True + residual_before_ln: bool = False stochastic_depth: float = 0.1 init_weights: str = "houlsby" scaling: Union[float, str] = "channel" From 7c2357f8d49b6dedab9ab83143b6cbbff5d92301 Mon Sep 17 00:00:00 2001 From: calpt Date: Wed, 8 Jan 2025 11:20:04 +0100 Subject: [PATCH 3/5] Upgrade Transformers to v4.47.x (#776) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Leon Engländer --- .github/workflows/adapter_docs_build.yml | 2 +- .github/workflows/tests_torch.yml | 16 ++--- hf_transformers | 2 +- setup.py | 2 +- .../models/deberta/modeling_deberta.py | 63 +++++++++++-------- .../models/deberta_v2/modeling_deberta_v2.py | 41 ++++++++---- 6 files changed, 79 insertions(+), 47 deletions(-) diff --git a/.github/workflows/adapter_docs_build.yml b/.github/workflows/adapter_docs_build.yml index 187f57d82c..35fab0de49 100644 --- a/.github/workflows/adapter_docs_build.yml +++ b/.github/workflows/adapter_docs_build.yml @@ -18,7 +18,7 @@ jobs: fetch-depth: 0 - uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: "3.10" - name: Install run: | pip install setuptools==57.4.0 diff --git a/.github/workflows/tests_torch.yml b/.github/workflows/tests_torch.yml index fd5930ebb6..cb8c61be1b 100644 --- a/.github/workflows/tests_torch.yml +++ b/.github/workflows/tests_torch.yml @@ -32,8 +32,8 @@ jobs: submodules: true - uses: actions/setup-python@v2 with: - python-version: 3.8 - - uses: actions/cache@v2 + python-version: "3.10" + - uses: actions/cache@v4 with: path: ~/.cache/pip key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }} @@ -53,8 +53,8 @@ jobs: submodules: true - uses: actions/setup-python@v2 with: - python-version: 3.8 - - uses: actions/cache@v2 + python-version: "3.10" + - uses: actions/cache@v4 with: path: ~/.cache/pip key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }} @@ -76,8 +76,8 @@ jobs: submodules: true - uses: actions/setup-python@v2 with: - python-version: 3.8 - - uses: actions/cache@v2 + python-version: "3.10" + - uses: actions/cache@v4 with: path: ~/.cache/pip key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }} @@ -99,8 +99,8 @@ jobs: submodules: true - uses: actions/setup-python@v2 with: - python-version: 3.8 - - uses: actions/cache@v2 + python-version: "3.10" + - uses: actions/cache@v4 with: path: ~/.cache/pip key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }} diff --git a/hf_transformers b/hf_transformers index 052e652d6d..241c04d368 160000 --- a/hf_transformers +++ b/hf_transformers @@ -1 +1 @@ -Subproject commit 052e652d6d53c2b26ffde87e039b723949a53493 +Subproject commit 241c04d36867259cdf11dbb4e9d9a60f9cb65ebc diff --git a/setup.py b/setup.py index 1666ae3d0a..d7a15ef921 100644 --- a/setup.py +++ b/setup.py @@ -61,7 +61,7 @@ "timeout-decorator", "torch", "torchvision", - "transformers~=4.46.3", + "transformers~=4.47.1", ] diff --git a/src/adapters/models/deberta/modeling_deberta.py b/src/adapters/models/deberta/modeling_deberta.py index 4380b5e038..77c6117b19 100644 --- a/src/adapters/models/deberta/modeling_deberta.py +++ b/src/adapters/models/deberta/modeling_deberta.py @@ -16,12 +16,13 @@ import torch import torch.utils.checkpoint +from torch import nn from transformers.models.deberta.modeling_deberta import ( DebertaOutput, DebertaSelfOutput, DisentangledSelfAttention, - XSoftmax, + scaled_size_sqrt, ) from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel @@ -95,71 +96,83 @@ def forward( """ + # >>> START AH Changes <<< attention_mask = prefix_attention_mask(attention_mask, dim=3, prefix_value=1) # type: ignore attention_mask = prefix_attention_mask(attention_mask, dim=2, prefix_value=1) # type: ignore + # >>> END AH Changes <<< if query_states is None: qp = self.in_proj(hidden_states) # .split(self.all_head_size, dim=-1) query_layer, key_layer, value_layer = self.transpose_for_scores(qp).chunk(3, dim=-1) else: - - def linear(w, b, x): - if b is not None: - return torch.matmul(x, w.t()) + b.t() - else: - return torch.matmul(x, w.t()) # + b.t() - ws = self.in_proj.weight.chunk(self.num_attention_heads * 3, dim=0) qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)] - qkvb = [None] * 3 - - q = linear(qkvw[0], qkvb[0], query_states.to(dtype=qkvw[0].dtype)) - k, v = [linear(qkvw[i], qkvb[i], hidden_states.to(dtype=qkvw[i].dtype)) for i in range(1, 3)] + q = torch.matmul(qkvw[0], query_states.t().to(dtype=qkvw[0].dtype)) + k = torch.matmul(qkvw[1], hidden_states.t().to(dtype=qkvw[1].dtype)) + v = torch.matmul(qkvw[2], hidden_states.t().to(dtype=qkvw[2].dtype)) query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]] + # >>> START AH Changes <<< query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer) (attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask) + # >>> END AH Changes <<< query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :]) value_layer = value_layer + self.transpose_for_scores(self.v_bias[None, None, :]) + # >>> START AH Changes <<< orig_key_layer = key_layer # save this for relative attention key_layer, value_layer, attention_mask = self.prefix_tuning( key_layer, value_layer, hidden_states, attention_mask, False ) (query_layer, orig_key_layer) = adjust_tensors_for_parallel(key_layer, query_layer, orig_key_layer) + # >>> END AH Changes <<< - rel_att = None + rel_att: int = 0 # Take the dot product between "query" and "key" to get the raw attention scores. scale_factor = 1 + len(self.pos_att_type) - scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) + scale = scaled_size_sqrt(query_layer, scale_factor) query_layer = query_layer / scale.to(dtype=query_layer.dtype) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - if self.relative_attention: + + if self.relative_attention and rel_embeddings is not None and relative_pos is not None: rel_embeddings = self.pos_dropout(rel_embeddings) + # >>> START AH Changes <<< rel_att = self.disentangled_att_bias( query_layer, orig_key_layer, relative_pos, rel_embeddings, scale_factor ) + # >>> END AH Changes <<< if rel_att is not None: - rel_att_padded = torch.zeros_like(attention_scores) - rel_att_padded[:, :, :, -rel_att.size(-1) :] = rel_att - attention_scores = attention_scores + rel_att_padded + # >>> START AH Changes <<< + # rel_att is set to 0 by default, i.e. rel_att is always not None (don't know why HuggingFace does this). + # Hence, we must check whether rel_att is a tensor and if so, pad it with zeros to be able to add it to attention_scores. + if isinstance(rel_att, torch.Tensor): + rel_att_padded = torch.zeros_like(attention_scores) + rel_att_padded[:, :, :, -rel_att.size(-1) :] = rel_att + attention_scores = attention_scores + rel_att_padded + else: + attention_scores = attention_scores + rel_att + # >>> END AH Changes <<< # bxhxlxd - if self.talking_head: + if self.head_logits_proj is not None: attention_scores = self.head_logits_proj(attention_scores.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) + attention_mask = attention_mask.bool() + attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min) + # bsz x height x length x dimension + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + attention_probs.masked_fill(attention_mask, 0) + attention_probs = self.dropout(attention_probs) - if self.talking_head: + if self.head_weights_proj is not None: attention_probs = self.head_weights_proj(attention_probs.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (-1,) context_layer = context_layer.view(new_context_layer_shape) - if output_attentions: - return (context_layer, attention_probs) - else: - return context_layer + if not output_attentions: + return (context_layer, None) + return (context_layer, attention_probs) diff --git a/src/adapters/models/deberta_v2/modeling_deberta_v2.py b/src/adapters/models/deberta_v2/modeling_deberta_v2.py index bc41ae82af..2b673c491f 100644 --- a/src/adapters/models/deberta_v2/modeling_deberta_v2.py +++ b/src/adapters/models/deberta_v2/modeling_deberta_v2.py @@ -16,12 +16,13 @@ import torch import torch.utils.checkpoint +from torch import nn from transformers.models.deberta_v2.modeling_deberta_v2 import ( DebertaV2Output, DebertaV2SelfOutput, DisentangledSelfAttention, - XSoftmax, + scaled_size_sqrt, ) from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel @@ -90,11 +91,15 @@ def forward( The embedding of relative distances. It's a tensor of shape [\\(2 \\times \\text{max_relative_positions}\\), *hidden_size*]. """ + # >>> START AH Changes <<< attention_mask = prefix_attention_mask(attention_mask, dim=3, prefix_value=1) # type: ignore attention_mask = prefix_attention_mask(attention_mask, dim=2, prefix_value=1) # type: ignore + # >>> END AH Changes <<< if query_states is None: query_states = hidden_states + + # >>> START AH Changes <<< query_layer = self.transpose_for_scores_extended(self.query_proj(query_states), self.num_attention_heads) key_layer = self.transpose_for_scores_extended(self.key_proj(hidden_states), self.num_attention_heads) value_layer = self.transpose_for_scores_extended(self.value_proj(hidden_states), self.num_attention_heads) @@ -112,6 +117,7 @@ def forward( key_layer = key_layer.contiguous().view(-1, key_layer.size(2), key_layer.size(-1)) value_layer = value_layer.contiguous().view(-1, value_layer.size(2), value_layer.size(-1)) orig_key_layer = orig_key_layer.contiguous().view(-1, orig_key_layer.size(2), orig_key_layer.size(-1)) + # >>> END AH Changes <<< rel_att = None # Take the dot product between "query" and "key" to get the raw attention scores. @@ -120,25 +126,39 @@ def forward( scale_factor += 1 if "p2c" in self.pos_att_type: scale_factor += 1 - scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) - attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale.to(dtype=query_layer.dtype) + scale = scaled_size_sqrt(query_layer, scale_factor) + attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2) / scale.to(dtype=query_layer.dtype)) if self.relative_attention: rel_embeddings = self.pos_dropout(rel_embeddings) + # >>> START AH Changes <<< rel_att = self.disentangled_attention_bias( query_layer, orig_key_layer, relative_pos, rel_embeddings, scale_factor ) + # >>> END AH Changes <<< if rel_att is not None: - rel_att_padded = torch.zeros_like(attention_scores) - rel_att_padded[:, :, -rel_att.size(2) :] = rel_att - attention_scores = attention_scores + rel_att_padded + # >>> START AH Changes <<< + # rel_att is set to 0 by default, i.e. rel_att is always not None (don't know why HuggingFace does this). + # Hence, we must check whether rel_att is a tensor and if so, pad it with zeros to be able to add it to attention_scores. + if isinstance(rel_att, torch.Tensor): + rel_att_padded = torch.zeros_like(attention_scores) + rel_att_padded[:, :, -rel_att.size(2) :] = rel_att + attention_scores = attention_scores + rel_att_padded + else: + attention_scores = attention_scores + rel_att + # >>> END AH Changes <<< + attention_scores = attention_scores attention_scores = attention_scores.view( -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1) ) + attention_mask = attention_mask.bool() + attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min) # bsz x height x length x dimension - attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + attention_probs.masked_fill(attention_mask, 0) + attention_probs = self.dropout(attention_probs) context_layer = torch.bmm( attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer @@ -150,7 +170,6 @@ def forward( ) new_context_layer_shape = context_layer.size()[:-2] + (-1,) context_layer = context_layer.view(new_context_layer_shape) - if output_attentions: - return (context_layer, attention_probs) - else: - return context_layer + if not output_attentions: + return (context_layer, None) + return (context_layer, attention_probs) From 9edc20d37e7a14e5513266b7da8ab2c1c8a58069 Mon Sep 17 00:00:00 2001 From: calpt Date: Wed, 8 Jan 2025 18:06:06 +0100 Subject: [PATCH 4/5] Allow saving, loading and pushing adapter compositions together (#771) Closes #441; closes #747. This PR introduces a set of new methods for saving, loading and pushing entire adapter compositions with one command: - `save_adapter_setup()` - `load_adapter_setup()` - `push_adapter_setup_to_hub()` They require two main params: - `adapter_setup`: the adapter composition to be saved. Identical to what can be specified for `active_adapters` - `head_setup`: for models with heads, the head setup to save along with the adapters. Identical to what can be specified for `active_head` Docs [here](https://github.com/adapter-hub/adapters/blob/04e69957a2bfc8093e2593186f7ebb2e71f88ec9/docs/loading.md#saving-and-loading-adapter-compositions) ### Example ```python model = AutoAdapterModel.from_pretrained("roberta-base") # create a complex setup model.add_adapter("a", config=SeqBnConfig()) model.add_adapter("b", config=SeqBnConfig()) model.add_adapter("c", config=SeqBnConfig()) model.add_adapter_fusion(["a", "b"]) model.add_classification_head("head_a") model.add_classification_head("head_b") adapter_setup = Stack(Fuse("a", "b"), "c") head_setup = BatchSplit("head_a", "head_b", batch_sizes=[1, 1]) model.set_active_adapters(adapter_setup) model.active_head = head_setup # save model.save_adapter_setup("checkpoint", adapter_setup, head_setup=head_setup) # push model.push_adapter_setup_to_hub("calpt/random_adapter_setup_test", adapter_setup, head_setup=head_setup) # re-load # model2 = AutoAdapterModel.from_pretrained("roberta-base") # model2.load_adapter_setup("checkpoint", set_active=True) ``` --------- Co-authored-by: Timo Imhof --- docs/adapter_composition.md | 2 + docs/loading.md | 36 ++++ docs/quickstart.md | 2 +- src/adapters/composition.py | 35 ++++ src/adapters/hub_mixin.py | 95 ++++++++- src/adapters/model_mixin.py | 278 ++++++++++++++++++++++++++- src/adapters/utils.py | 11 +- tests/methods/test_adapter_common.py | 43 +++++ tests/test_clip.py | 3 + 9 files changed, 497 insertions(+), 8 deletions(-) diff --git a/docs/adapter_composition.md b/docs/adapter_composition.md index 0e35f7b21b..b55dccef1c 100644 --- a/docs/adapter_composition.md +++ b/docs/adapter_composition.md @@ -125,6 +125,8 @@ model.active_adapters = ac.Fuse("d", "e", "f") To learn how training an _AdapterFusion_ layer works, check out [this Colab notebook](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/03_Adapter_Fusion.ipynb) from the `adapters` repo. +To save and upload the full composition setup with adapters and fusion layer in one line of code, check out the docs on [saving and loading adapter compositions](loading.md#saving-and-loading-adapter-compositions). + ### Retrieving AdapterFusion attentions Finally, it is possible to retrieve the attention scores computed by each fusion layer in a forward pass of the model. diff --git a/docs/loading.md b/docs/loading.md index 8af81820d9..a1a37ed6d8 100644 --- a/docs/loading.md +++ b/docs/loading.md @@ -94,3 +94,39 @@ We will go through the different arguments and their meaning one by one: To load the adapter using a custom name, we can use the `load_as` parameter. - Finally, `set_active` will directly activate the loaded adapter for usage in each model forward pass. Otherwise, you have to manually activate the adapter via `set_active_adapters()`. + +## Saving and loading adapter compositions + +In addition to saving and loading individual adapters, you can also save, load and share entire [compositions of adapters](adapter_composition.md) with a single line of code. +_Adapters_ provides three methods for this purpose that work very similar to those for single adapters: + +- [`save_adapter_setup()`](adapters.ModelWithHeadsAdaptersMixin.save_adapter_setup) to save an adapter composition along with prediction heads to the local file system. +- [`load_adapter_setup()`](adapters.ModelWithHeadsAdaptersMixin.load_adapter_setup) to load a saved adapter composition from the local file system or the Model Hub. +- [`push_adapter_setup_to_hub()`](adapters.hub_mixin.PushAdapterToHubMixin.push_adapter_setup_to_hub) to upload an adapter setup along with prediction heads to the Model Hub. See our [Hugging Face Model Hub guide](huggingface_hub.md) for more. + +As an example, this is how you would save and load an AdapterFusion setup of three adapters with a prediction head: + +```python +# Create an AdapterFusion +model = AutoAdapterModel.from_pretrained("bert-base-uncased") +model.load_adapter("sentiment/sst-2@ukp", config=SeqBnConfig(), with_head=False) +model.load_adapter("nli/multinli@ukp", config=SeqBnConfig(), with_head=False) +model.load_adapter("sts/qqp@ukp", config=SeqBnConfig(), with_head=False) +model.add_adapter_fusion(["sst-2", "mnli", "qqp"]) +model.add_classification_head("clf_head") +adapter_setup = Fuse("sst-2", "mnli", "qqp") +head_setup = "clf_head" +model.set_active_adapters(adapter_setup) +model.active_head = head_setup + +# Train AdapterFusion ... + +# Save +model.save_adapter_setup("checkpoint", adapter_setup, head_setup=head_setup) + +# Push to Hub +model.push_adapter_setup_to_hub("/fusion_setup", adapter_setup, head_setup=head_setup) + +# Re-load +# model.load_adapter_setup("checkpoint", set_active=True) +``` diff --git a/docs/quickstart.md b/docs/quickstart.md index 9cefe33cc1..6e8b7fd49f 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -105,7 +105,7 @@ model = AutoAdapterModel.from_pretrained(example_path) model.load_adapter(example_path) ``` -Similar to how the weights of the full model are saved, the `save_adapter()` will create a file for saving the adapter weights and a file for saving the adapter configuration in the specified directory. +Similar to how the weights of the full model are saved, [`save_adapter()`](adapters.ModelWithHeadsAdaptersMixin.save_adapter) will create a file for saving the adapter weights and a file for saving the adapter configuration in the specified directory. Finally, if we have finished working with adapters, we can restore the base Transformer to its original form by deactivating and deleting the adapter: diff --git a/src/adapters/composition.py b/src/adapters/composition.py index 48a6bc8acf..a44b9c5aac 100644 --- a/src/adapters/composition.py +++ b/src/adapters/composition.py @@ -1,4 +1,5 @@ import itertools +import sys import warnings from collections.abc import Sequence from typing import List, Optional, Set, Tuple, Union @@ -45,6 +46,31 @@ def parallel_channels(self): def flatten(self) -> Set[str]: return set(itertools.chain(*[[b] if isinstance(b, str) else b.flatten() for b in self.children])) + def _get_save_kwargs(self): + return None + + def to_dict(self): + save_dict = { + "type": self.__class__.__name__, + "children": [ + c.to_dict() if isinstance(c, AdapterCompositionBlock) else {"type": "single", "children": [c]} + for c in self.children + ], + } + if kwargs := self._get_save_kwargs(): + save_dict["kwargs"] = kwargs + return save_dict + + @classmethod + def from_dict(cls, data): + children = [] + for child in data["children"]: + if child["type"] == "single": + children.append(child["children"][0]) + else: + children.append(cls.from_dict(child)) + return getattr(sys.modules[__name__], data["type"])(*children, **data.get("kwargs", {})) + class Parallel(AdapterCompositionBlock): def __init__(self, *parallel_adapters: List[str]): @@ -80,12 +106,18 @@ def __init__(self, *split_adapters: List[Union[AdapterCompositionBlock, str]], s super().__init__(*split_adapters) self.splits = splits if isinstance(splits, list) else [splits] * len(split_adapters) + def _get_save_kwargs(self): + return {"splits": self.splits} + class BatchSplit(AdapterCompositionBlock): def __init__(self, *split_adapters: List[Union[AdapterCompositionBlock, str]], batch_sizes: Union[List[int], int]): super().__init__(*split_adapters) self.batch_sizes = batch_sizes if isinstance(batch_sizes, list) else [batch_sizes] * len(split_adapters) + def _get_save_kwargs(self): + return {"batch_sizes": self.batch_sizes} + class Average(AdapterCompositionBlock): def __init__( @@ -105,6 +137,9 @@ def __init__( else: self.weights = [1 / len(average_adapters)] * len(average_adapters) + def _get_save_kwargs(self): + return {"weights": self.weights} + # Mapping each composition block type to the allowed nested types ALLOWED_NESTINGS = { diff --git a/src/adapters/hub_mixin.py b/src/adapters/hub_mixin.py index c23c92eb7e..61942426d9 100644 --- a/src/adapters/hub_mixin.py +++ b/src/adapters/hub_mixin.py @@ -4,6 +4,8 @@ from transformers.utils.generic import working_or_temp_dir +from .composition import AdapterCompositionBlock + logger = logging.getLogger(__name__) @@ -35,7 +37,7 @@ from adapters import AutoAdapterModel model = AutoAdapterModel.from_pretrained("{model_name}") -adapter_name = model.load_adapter("{adapter_repo_name}", set_active=True) +adapter_name = model.{load_fn}("{adapter_repo_name}", set_active=True) ``` ## Architecture & Training @@ -66,6 +68,7 @@ def _save_adapter_card( language: Optional[str] = None, license: Optional[str] = None, metrics: Optional[List[str]] = None, + load_fn: str = "load_adapter", **kwargs, ): # Key remains "adapter-transformers", see: https://github.com/huggingface/huggingface.js/pull/459 @@ -103,6 +106,7 @@ def _save_adapter_card( model_name=self.model_name, dataset_name=dataset_name, head_info=head_info, + load_fn=load_fn, adapter_repo_name=adapter_repo_name, architecture_training=kwargs.pop("architecture_training", DEFAULT_TEXT), results=kwargs.pop("results", DEFAULT_TEXT), @@ -133,8 +137,6 @@ def push_adapter_to_hub( Args: repo_id (str): The name of the repository on the model hub to upload to. adapter_name (str): The name of the adapter to be uploaded. - organization (str, optional): Organization in which to push the adapter - (you must be a member of this organization). Defaults to None. datasets_tag (str, optional): Dataset identifier from https://huggingface.co/datasets. Defaults to None. local_path (str, optional): Local path used as clone directory of the adapter repository. @@ -156,6 +158,8 @@ def push_adapter_to_hub( Branch to push the uploaded files to. commit_description (`str`, *optional*): The description of the commit that will be created + adapter_card_kwargs (Optional[dict], optional): Additional arguments to pass to the adapter card text generation. + Currently includes: tags, language, license, metrics, architecture_training, results, citation. Returns: str: The url of the adapter repository on the model hub. @@ -190,3 +194,88 @@ def push_adapter_to_hub( revision=revision, commit_description=commit_description, ) + + def push_adapter_setup_to_hub( + self, + repo_id: str, + adapter_setup: Union[str, list, AdapterCompositionBlock], + head_setup: Optional[Union[bool, str, list, AdapterCompositionBlock]] = None, + datasets_tag: Optional[str] = None, + local_path: Optional[str] = None, + commit_message: Optional[str] = None, + private: Optional[bool] = None, + token: Optional[Union[bool, str]] = None, + overwrite_adapter_card: bool = False, + create_pr: bool = False, + revision: str = None, + commit_description: str = None, + adapter_card_kwargs: Optional[dict] = None, + ): + """Upload an adapter setup to HuggingFace's Model Hub. + + Args: + repo_id (str): The name of the repository on the model hub to upload to. + adapter_setup (Union[str, list, AdapterCompositionBlock]): The adapter setup to be uploaded. Usually an adapter composition block. + head_setup (Optional[Union[bool, str, list, AdapterCompositionBlock]], optional): The head setup to be uploaded. + datasets_tag (str, optional): Dataset identifier from https://huggingface.co/datasets. Defaults to + None. + local_path (str, optional): Local path used as clone directory of the adapter repository. + If not specified, will create a temporary directory. Defaults to None. + commit_message (:obj:`str`, `optional`): + Message to commit while pushing. Will default to :obj:`"add config"`, :obj:`"add tokenizer"` or + :obj:`"add model"` depending on the type of the class. + private (:obj:`bool`, `optional`): + Whether or not the repository created should be private (requires a paying subscription). + token (`bool` or `str`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). Will default to `True` if `repo_url` + is not specified. + overwrite_adapter_card (bool, optional): Overwrite an existing adapter card with a newly generated one. + If set to `False`, will only generate an adapter card, if none exists. Defaults to False. + create_pr (bool, optional): + Whether or not to create a PR with the uploaded files or directly commit. + revision (`str`, *optional*): + Branch to push the uploaded files to. + commit_description (`str`, *optional*): + The description of the commit that will be created + adapter_card_kwargs (Optional[dict], optional): Additional arguments to pass to the adapter card text generation. + Currently includes: tags, language, license, metrics, architecture_training, results, citation. + + Returns: + str: The url of the adapter repository on the model hub. + """ + use_temp_dir = not os.path.isdir(local_path) if local_path else True + + # Create repo or get retrieve an existing repo + repo_id = self._create_repo(repo_id, private=private, token=token) + + # Commit and push + logger.info('Pushing adapter setup "%s" to model hub at %s ...', adapter_setup, repo_id) + with working_or_temp_dir(working_dir=local_path, use_temp_dir=use_temp_dir) as work_dir: + files_timestamps = self._get_files_timestamps(work_dir) + # Save adapter and optionally create model card + if head_setup is not None: + save_kwargs = {"head_setup": head_setup} + else: + save_kwargs = {} + self.save_adapter_setup(work_dir, adapter_setup, **save_kwargs) + if overwrite_adapter_card or not os.path.exists(os.path.join(work_dir, "README.md")): + adapter_card_kwargs = adapter_card_kwargs or {} + self._save_adapter_card( + work_dir, + str(adapter_setup), + repo_id, + datasets_tag=datasets_tag, + load_fn="load_adapter_setup", + **adapter_card_kwargs, + ) + return self._upload_modified_files( + work_dir, + repo_id, + files_timestamps, + commit_message=commit_message, + token=token, + create_pr=create_pr, + revision=revision, + commit_description=commit_description, + ) diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 659a6cfcff..3154af5ac8 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -1,4 +1,5 @@ import inspect +import json import logging import os from abc import ABC, abstractmethod @@ -15,6 +16,7 @@ from transformers.modeling_outputs import ModelOutput from transformers.utils import is_accelerate_available +from . import __version__ from .composition import AdapterCompositionBlock, Fuse, Stack, parse_composition from .configuration import ADAPTER_CONFIG_MAP, AdapterConfig, AdapterFusionConfig, BnConfig from .context import AdapterSetup, ForwardContext @@ -27,7 +29,15 @@ from .methods.prefix_tuning import PrefixTuningLayer, PrefixTuningPool from .methods.prompt_tuning import PromptTuningLayer from .methods.reft import init_reft -from .utils import EMBEDDING_FILE, TOKENIZER_PATH, get_adapter_config_hash, inherit_doc, patch_forward +from .utils import ( + EMBEDDING_FILE, + SETUP_CONFIG_NAME, + TOKENIZER_PATH, + get_adapter_config_hash, + inherit_doc, + patch_forward, + resolve_adapter_path, +) from .wrappers.configuration import SUBMODEL_NAMES, init_adapters_config @@ -802,7 +812,7 @@ def load_adapter( adapter_name_or_path (str): can be either: - the identifier of a pre-trained task adapter to be loaded from Adapter Hub - - a path to a directory containing adapter weights saved using `model.saved_adapter()` + - a path to a directory containing adapter weights saved using `model.save_adapter()` - a URL pointing to a zip folder containing a saved adapter module config (dict or str, optional): Deprecated. version (str, optional): The version of the adapter to be loaded. @@ -881,6 +891,161 @@ def load_adapter_fusion( ) return load_name + def _save_adapter_setup_config( + self, + save_directory: str, + adapter_setup: AdapterCompositionBlock, + head_setup: Optional[Union[bool, str, list, AdapterCompositionBlock]] = None, + ): + setup_config = { + "adapter_setup": adapter_setup.to_dict(), + "head_setup": head_setup.to_dict() if isinstance(head_setup, AdapterCompositionBlock) else head_setup, + "version": "adapters." + __version__, + } + with open(join(save_directory, SETUP_CONFIG_NAME), "w") as f: + json.dump(setup_config, f, indent=2) + + def _load_adapter_setup_config( + self, load_directory: str + ) -> Tuple[AdapterCompositionBlock, Optional[AdapterCompositionBlock]]: + with open(join(load_directory, SETUP_CONFIG_NAME), "r") as f: + setup_config = json.load(f) + adapter_setup = AdapterCompositionBlock.from_dict(setup_config["adapter_setup"]) + head_setup = setup_config["head_setup"] + if isinstance(head_setup, dict): + head_setup = AdapterCompositionBlock.from_dict(head_setup) + return adapter_setup, head_setup + + def _save_adapter_setup_weights( + self, + save_directory: str, + adapter_setup: AdapterCompositionBlock, + meta_dict: dict = None, + custom_weights_loaders: Optional[List[WeightsLoader]] = None, + use_safetensors: bool = False, + ): + # Save single adapters + for adapter_name in adapter_setup.flatten(): + save_path = join(save_directory, adapter_name) + self.save_adapter(save_path, adapter_name, meta_dict=meta_dict, use_safetensors=use_safetensors) + # Save adapter fusions + fusions = [] + if isinstance(adapter_setup, Fuse): + fusions.append(adapter_setup) + for child_setup in adapter_setup.children: + if isinstance(child_setup, Fuse): + fusions.append(child_setup) + for fusion in fusions: + save_path = join(save_directory, fusion.name) + self.save_adapter_fusion(save_path, fusion, meta_dict=meta_dict, use_safetensors=use_safetensors) + # Save additional custom weights + if custom_weights_loaders: + for weights_loader in custom_weights_loaders: + weights_loader.save(save_directory, adapter_name) + + def _load_adapter_setup_weights( + self, + load_directory: str, + adapter_setup: AdapterCompositionBlock, + custom_weights_loaders: Optional[List[WeightsLoader]] = None, + set_active: bool = False, + use_safetensors: bool = False, + ): + # Load single adapters + for adapter_name in adapter_setup.flatten(): + save_path = join(load_directory, adapter_name) + self.load_adapter(save_path, use_safetensors=use_safetensors) + # Load adapter fusions + fusions = [] + if isinstance(adapter_setup, Fuse): + fusions.append(adapter_setup) + for child_setup in adapter_setup.children: + if isinstance(child_setup, Fuse): + fusions.append(child_setup) + for fusion in fusions: + save_path = join(load_directory, fusion.name) + self.load_adapter_fusion(save_path, use_safetensors=use_safetensors) + # Load additional custom weights + if custom_weights_loaders: + for weights_loader in custom_weights_loaders: + weights_loader.load(load_directory) + + if set_active: + self.set_active_adapters(adapter_setup) + + def save_adapter_setup( + self, + save_directory: str, + adapter_setup: Union[str, list, AdapterCompositionBlock], + meta_dict: dict = None, + custom_weights_loaders: Optional[List[WeightsLoader]] = None, + use_safetensors: bool = False, + ): + """Saves an adapter setup to a directory so that it can be shared or reloaded using `load_adapter_setup()`. + + Args: + save_directory (str): Path to a directory where the adapter setup should be saved. + adapter_setup (Union[str, list, AdapterCompositionBlock]): The adapter setup to be saved. Usually an adapter composition block. + use_safetensors (bool, optional): If True, weights are saved via `safetensors`. Otherwise, the regular torch save method is used. + """ + os.makedirs(save_directory, exist_ok=True) + adapter_setup = parse_composition(adapter_setup, model_type=self.config.model_type) + + self._save_adapter_setup_config(save_directory, adapter_setup) + self._save_adapter_setup_weights( + save_directory, + adapter_setup, + meta_dict=meta_dict, + custom_weights_loaders=custom_weights_loaders, + use_safetensors=use_safetensors, + ) + + def load_adapter_setup( + self, + adapter_setup_name_or_path: str, + version: str = None, + custom_weights_loaders: Optional[List[WeightsLoader]] = None, + set_active: bool = False, + use_safetensors: bool = False, + **kwargs, + ) -> Tuple[AdapterCompositionBlock, Any]: + """Loads an adapter setup from the local file system or a remote location. + + Args: + adapter_setup_name_or_path (str): can be either: + + - the identifier of a repository on the HuggingFace Model Hub. + - a path to a directory containing adapter weights saved using `model.save_adapter_setup()` + - a URL pointing to a zip folder containing a saved adapter module + version (str, optional): The version of the adapter to be loaded. + set_active (bool, optional): + Set the loaded adapter setup to be the active one. By default (False), the adapter setup is loaded but not + activated. + use_safetensors (bool, optional): If True, weights are loaded via `safetensors` if safetensors checkpoint is available. Otherwise, the regular torch save method is used. + + Returns: + Tuple[AdapterCompositionBlock, Any]: The loaded adapter setup and the head setup if available. + """ + resolved_folder = resolve_adapter_path( + adapter_setup_name_or_path, + version=version, + do_exists_check=False, + **kwargs, + ) + adapter_setup, head_setup = self._load_adapter_setup_config(resolved_folder) + self._load_adapter_setup_weights( + resolved_folder, + adapter_setup, + custom_weights_loaders=custom_weights_loaders, + set_active=set_active, + use_safetensors=use_safetensors, + ) + + if head_setup: + logger.warning("Loaded adapter setup contains a head setup that is not supported by the current model.") + + return adapter_setup, head_setup + def save_all_adapters( self, save_directory: str, @@ -1857,6 +2022,115 @@ def load_adapter_fusion( **kwargs, ) + def save_adapter_setup( + self, + save_directory: str, + adapter_setup: Union[str, list, AdapterCompositionBlock], + head_setup: Optional[Union[bool, str, list, AdapterCompositionBlock]] = None, + meta_dict: dict = None, + custom_weights_loaders: Optional[List[WeightsLoader]] = None, + use_safetensors: bool = False, + ): + """Saves an adapter setup to a directory so that it can be shared or reloaded using `load_adapter_setup()`. + + Args: + save_directory (str): Path to a directory where the adapter setup should be saved. + adapter_setup (Union[str, list, AdapterCompositionBlock]): The adapter setup to be saved. Usually an adapter composition block. + head_setup (Optional[Union[bool, str, list, AdapterCompositionBlock]], optional): The head setup to be saved. Can be either: + + - True: save the default head for models without flex heads. + - str: save a single head with the given name. + - list: save a list of heads. + - AdapterCompositionBlock: save a custom head setup. + - None (default): do not save any heads. + use_safetensors (bool, optional): If True, weights are saved via `safetensors`. Otherwise, the regular torch save method is used. + """ + os.makedirs(save_directory, exist_ok=True) + adapter_setup = parse_composition(adapter_setup, model_type=self.config.model_type) + + self._save_adapter_setup_config(save_directory, adapter_setup, head_setup) + self._save_adapter_setup_weights( + save_directory, + adapter_setup, + meta_dict=meta_dict, + custom_weights_loaders=custom_weights_loaders, + use_safetensors=use_safetensors, + ) + + if head_setup is True: + self.save_head(save_directory, use_safetensors=use_safetensors) + elif head_setup: + heads_to_save = [] + if isinstance(head_setup, AdapterCompositionBlock): + heads_to_save = head_setup.flatten() + elif isinstance(head_setup, list): + heads_to_save = head_setup + elif isinstance(head_setup, str): + heads_to_save = [head_setup] + for head_name in heads_to_save: + save_path = join(save_directory, head_name) + self.save_head(save_path, head_name, use_safetensors=use_safetensors) + + def load_adapter_setup( + self, + adapter_setup_name_or_path: str, + version: str = None, + custom_weights_loaders: Optional[List[WeightsLoader]] = None, + set_active: bool = False, + use_safetensors: bool = False, + **kwargs, + ) -> str: + """Loads an adapter setup from the local file system or a remote location. + + Args: + adapter_setup_name_or_path (str): can be either: + + - the identifier of a repository on the HuggingFace Model Hub. + - a path to a directory containing adapter weights saved using `model.save_adapter_setup()` + - a URL pointing to a zip folder containing a saved adapter module + version (str, optional): The version of the adapter to be loaded. + set_active (bool, optional): + Set the loaded adapter setup to be the active one. By default (False), the adapter setup is loaded but not + activated. + use_safetensors (bool, optional): If True, weights are loaded via `safetensors` if safetensors checkpoint is available. Otherwise, the regular torch save method is used. + + Returns: + Tuple[AdapterCompositionBlock, Any]: The loaded adapter setup and the head setup if available. + """ + resolved_folder = resolve_adapter_path( + adapter_setup_name_or_path, + version=version, + do_exists_check=False, + **kwargs, + ) + adapter_setup, head_setup = self._load_adapter_setup_config(resolved_folder) + self._load_adapter_setup_weights( + resolved_folder, + adapter_setup, + custom_weights_loaders=custom_weights_loaders, + set_active=set_active, + use_safetensors=use_safetensors, + ) + + if head_setup is True: + self.load_head(resolved_folder, use_safetensors=use_safetensors) + elif head_setup: + heads_to_load = [] + if isinstance(head_setup, AdapterCompositionBlock): + heads_to_load = head_setup.flatten() + elif isinstance(head_setup, list): + heads_to_load = head_setup + elif isinstance(head_setup, str): + heads_to_load = [head_setup] + for head_name in heads_to_load: + save_path = join(resolved_folder, head_name) + self.load_head(save_path, head_name, use_safetensors=use_safetensors) + + if set_active: + self.active_head = head_setup + + return adapter_setup, head_setup + def save_all_heads(self, save_directory: str, use_safetensors: bool = False): """Saves all prediction heads of this model to subfolders of the given location. diff --git a/src/adapters/utils.py b/src/adapters/utils.py index 1103d9fffb..7c0540850a 100644 --- a/src/adapters/utils.py +++ b/src/adapters/utils.py @@ -53,6 +53,7 @@ SAFE_ADAPTERFUSION_WEIGHTS_NAME = "model_adapter_fusion.safetensors" EMBEDDING_FILE = "embedding.pt" TOKENIZER_PATH = "tokenizer" +SETUP_CONFIG_NAME = "adapter_setup.json" ADAPTER_HUB_URL = "https://raw.githubusercontent.com/Adapter-Hub/Hub/master/dist/v2/" ADAPTER_HUB_INDEX_FILE = ADAPTER_HUB_URL + "index/{}.json" @@ -671,6 +672,7 @@ def resolve_adapter_path( model_name: str = None, adapter_config: Union[dict, str] = None, version: str = None, + do_exists_check: bool = True, **kwargs, ) -> str: """ @@ -701,8 +703,13 @@ def resolve_adapter_path( # path to a local folder saved using save() elif isdir(adapter_name_or_path): if ( - isfile(join(adapter_name_or_path, WEIGHTS_NAME)) or isfile(join(adapter_name_or_path, SAFE_WEIGHTS_NAME)) - ) and isfile(join(adapter_name_or_path, CONFIG_NAME)): + not do_exists_check + or ( + isfile(join(adapter_name_or_path, WEIGHTS_NAME)) + or isfile(join(adapter_name_or_path, SAFE_WEIGHTS_NAME)) + ) + and isfile(join(adapter_name_or_path, CONFIG_NAME)) + ): return adapter_name_or_path else: raise EnvironmentError( diff --git a/tests/methods/test_adapter_common.py b/tests/methods/test_adapter_common.py index 1ea6cd6f37..717d3af98e 100644 --- a/tests/methods/test_adapter_common.py +++ b/tests/methods/test_adapter_common.py @@ -1,4 +1,5 @@ import copy +import os import tempfile import torch @@ -17,8 +18,10 @@ MAMConfig, SeqBnConfig, SeqBnInvConfig, + Stack, ) from adapters.heads.language_modeling import CausalLMHead +from adapters.utils import SETUP_CONFIG_NAME from transformers import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, CLIPConfig from transformers.testing_utils import require_torch, torch_device @@ -475,3 +478,43 @@ def test_batch_split_training(self): base_with_change |= not torch.equal(v1, v2) self.assertTrue(adapters_with_change) self.assertFalse(base_with_change) + + def test_load_adapter_setup(self): + if self.config_class not in ADAPTER_MODEL_MAPPING: + self.skipTest("Does not support flex heads.") + model1, model2 = create_twin_models(self.model_class, self.config) + + # Create a complex setup + model1.add_adapter("a", config=SeqBnConfig()) + model1.add_adapter("b", config=SeqBnConfig()) + model1.add_adapter("c", config=SeqBnConfig()) + model1.add_adapter_fusion(["a", "b"]) + self.add_head(model1, "head_a") + self.add_head(model1, "head_b") + adapter_setup = Stack(Fuse("a", "b"), "c") + head_setup = BatchSplit("head_a", "head_b", batch_sizes=[2, 1]) + model1.set_active_adapters(adapter_setup) + model1.active_head = head_setup + + with tempfile.TemporaryDirectory() as temp_dir: + model1.save_adapter_setup(temp_dir, adapter_setup, head_setup=head_setup) + + self.assertTrue(os.path.exists(os.path.join(temp_dir, SETUP_CONFIG_NAME))) + + # also tests that set_active works + model2.load_adapter_setup(temp_dir, set_active=True) + + # check if adapter was correctly loaded + for name in ["a", "b", "c"]: + self.assertTrue(name in model2.adapters_config) + self.assertEqual(adapter_setup, model2.active_adapters) + + # check equal output + input_data = self.get_input_samples(config=model1.config) + model1.to(torch_device) + model2.to(torch_device) + output1 = model1(**input_data) + output2 = model2(**input_data) + self.assertEqual(len(output1), len(output2)) + self.assertTrue(torch.allclose(output1[0][0], output2[0][0], atol=1e-4)) + self.assertTrue(torch.allclose(output1[1][0], output2[1][0], atol=1e-4)) diff --git a/tests/test_clip.py b/tests/test_clip.py index 30be353f74..ead9c7d561 100644 --- a/tests/test_clip.py +++ b/tests/test_clip.py @@ -225,3 +225,6 @@ class CLIPAdapterTest( def test_adapter_fusion_save_with_head(self): # This test is not applicable to CLIP self.skipTest("Not applicable to CLIP.") + + def test_load_adapter_setup(self): + self.skipTest("Not applicable to CLIP.") From 303c34bdd91f37e656fca5f60624d86b991def3c Mon Sep 17 00:00:00 2001 From: calpt Date: Wed, 8 Jan 2025 18:29:24 +0100 Subject: [PATCH 5/5] Custom names for AdapterFusion layers (#774) Resolves #316. This PR implements the option to specify a custom name for an added AdapterFusion layer. The name can be specified when adding a fusion layer like this: ```python model.add_adapter_fusion(["adapter1", "adapter2"], name="custom_name_fusion") ``` Afterwards, to address the custom-name fusion, specify the name in the `Fuse` block. E.g. for activation: ```python model.set_active_adapters(Fuse("adapter1", "adapter2", name="custom_name_fusion")) ``` Some fusion-specific methods can either take the named `Fuse` block or directly the fusion name: ```python # saving model.save_adapter_fusion("./checkpoint_dir", Fuse("adapter1", "adapter2", name="custom_name_fusion")) # or: # model.save_adapter_fusion("./checkpoint_dir", "custom_name_fusion") # deleting model.delete_adapter_fusion(Fuse("adapter1", "adapter2", name="custom_name_fusion")) # or: # model.delete_adapter_fusion("custom_name_fusion") ``` --------- Co-authored-by: Timo Imhof --- src/adapters/composition.py | 8 +- .../configuration/model_adapters_config.py | 30 +++++-- src/adapters/loading.py | 12 ++- src/adapters/methods/bottleneck.py | 8 +- src/adapters/model_mixin.py | 27 +++--- tests/test_adapter_fusion_common.py | 83 +++++++++++++++++++ 6 files changed, 140 insertions(+), 28 deletions(-) diff --git a/src/adapters/composition.py b/src/adapters/composition.py index a44b9c5aac..6c17fb8ebd 100644 --- a/src/adapters/composition.py +++ b/src/adapters/composition.py @@ -92,13 +92,17 @@ def __init__(self, *stack_layers: List[Union[AdapterCompositionBlock, str]]): class Fuse(AdapterCompositionBlock): - def __init__(self, *fuse_stacks: List[Union[AdapterCompositionBlock, str]]): + def __init__(self, *fuse_stacks: List[Union[AdapterCompositionBlock, str]], name: Optional[str] = None): super().__init__(*fuse_stacks) + self._name = name # TODO-V2 pull this up to all block classes? @property def name(self): - return ",".join([c if isinstance(c, str) else c.last() for c in self.children]) + if self._name: + return self._name + else: + return ",".join([c if isinstance(c, str) else c.last() for c in self.children]) class Split(AdapterCompositionBlock): diff --git a/src/adapters/configuration/model_adapters_config.py b/src/adapters/configuration/model_adapters_config.py index 3ae7dcf56c..f742028b67 100644 --- a/src/adapters/configuration/model_adapters_config.py +++ b/src/adapters/configuration/model_adapters_config.py @@ -1,7 +1,7 @@ import copy import logging from collections.abc import Collection, Mapping -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union from .. import __version__ from ..composition import AdapterCompositionBlock @@ -27,6 +27,7 @@ def __init__(self, **kwargs): self.fusions: Mapping[str, str] = kwargs.pop("fusions", {}) self.fusion_config_map = kwargs.pop("fusion_config_map", {}) + self.fusion_name_map = kwargs.pop("fusion_name_map", {}) # TODO-V2 Save this with config? self.active_setup: Optional[AdapterCompositionBlock] = None @@ -131,7 +132,7 @@ def add(self, adapter_name: str, config: Optional[Union[str, dict]] = None): self.adapters[adapter_name] = config_name logger.info(f"Adding adapter '{adapter_name}'.") - def get_fusion(self, fusion_name: Union[str, List[str]]) -> Optional[dict]: + def get_fusion(self, fusion_name: Union[str, List[str]]) -> Tuple[Optional[dict], Optional[list]]: """ Gets the config dictionary for a given AdapterFusion. @@ -140,6 +141,7 @@ def get_fusion(self, fusion_name: Union[str, List[str]]) -> Optional[dict]: Returns: Optional[dict]: The AdapterFusion configuration. + Optional[list]: The names of the adapters to fuse. """ if isinstance(fusion_name, list): fusion_name = ",".join(fusion_name) @@ -149,20 +151,31 @@ def get_fusion(self, fusion_name: Union[str, List[str]]) -> Optional[dict]: config = self.fusion_config_map.get(config_name, None) else: config = ADAPTERFUSION_CONFIG_MAP.get(config_name, None) + + if fusion_name in self.fusion_name_map: + adapter_names = self.fusion_name_map[fusion_name] + else: + adapter_names = fusion_name.split(",") + + return config, adapter_names else: - config = None - return config + return None, None - def add_fusion(self, fusion_name: Union[str, List[str]], config: Optional[Union[str, dict]] = None): + def add_fusion( + self, adapter_names: List[str], config: Optional[Union[str, dict]] = None, fusion_name: Optional[str] = None + ): """ Adds a new AdapterFusion. Args: - fusion_name (Union[str, List[str]]): The name of the AdapterFusion or the adapters to fuse. + adapter_names (List[str]): The names of the adapters to fuse. config (Optional[Union[str, dict]], optional): AdapterFusion config. Defaults to None. + fusion_name (Optional[str], optional): The name of the AdapterFusion. If not specified, will default to comma-separated adapter names. """ - if isinstance(fusion_name, list): - fusion_name = ",".join(fusion_name) + if fusion_name is None: + fusion_name = ",".join(adapter_names) + else: + self.fusion_name_map[fusion_name] = adapter_names if fusion_name in self.fusions: raise ValueError(f"An AdapterFusion with the name '{fusion_name}' has already been added.") if config is None: @@ -218,6 +231,7 @@ def to_dict(self): output_dict["fusion_config_map"][k] = v.to_dict() else: output_dict["fusion_config_map"][k] = copy.deepcopy(v) + output_dict["fusion_name_map"] = copy.deepcopy(self.fusion_name_map) return output_dict def __eq__(self, other): diff --git a/src/adapters/loading.py b/src/adapters/loading.py index 69747e04cb..55ba1db45b 100644 --- a/src/adapters/loading.py +++ b/src/adapters/loading.py @@ -639,7 +639,7 @@ def save_to_state_dict(self, name: str): if name not in self.model.adapters_config.fusions: raise ValueError(f"No AdapterFusion with name '{name}' available.") - adapter_fusion_config = self.model.adapters_config.get_fusion(name) + adapter_fusion_config, _ = self.model.adapters_config.get_fusion(name) config_dict = build_full_config( adapter_fusion_config, @@ -676,13 +676,14 @@ def save(self, save_directory: str, name: str, meta_dict=None): else: assert isdir(save_directory), "Saving path should be a directory where the head can be saved." - adapter_fusion_config = self.model.adapters_config.get_fusion(name) + adapter_fusion_config, adapter_names = self.model.adapters_config.get_fusion(name) # Save the adapter fusion configuration config_dict = build_full_config( adapter_fusion_config, self.model.config, name=name, + adapter_names=adapter_names, model_name=self.model.model_name, model_class=self.model.__class__.__name__, ) @@ -746,9 +747,14 @@ def load(self, save_directory, load_as=None, loading_info=None, **kwargs): config = self.weights_helper.load_weights_config(save_directory) adapter_fusion_name = load_as or config["name"] + adapter_names = config.get("adapter_names", adapter_fusion_name) if adapter_fusion_name not in self.model.adapters_config.fusions: self.model.add_adapter_fusion( - adapter_fusion_name, config["config"], overwrite_ok=True, set_active=kwargs.pop("set_active", True) + adapter_names, + config["config"], + name=adapter_fusion_name, + overwrite_ok=True, + set_active=kwargs.pop("set_active", True), ) else: logger.warning("Overwriting existing adapter fusion module '{}'".format(adapter_fusion_name)) diff --git a/src/adapters/methods/bottleneck.py b/src/adapters/methods/bottleneck.py index ff12a91cd7..889941d2b9 100644 --- a/src/adapters/methods/bottleneck.py +++ b/src/adapters/methods/bottleneck.py @@ -96,9 +96,9 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: def add_fusion_layer(self, adapter_names: Union[List, str]): """See BertModel.add_fusion_layer""" - adapter_names = adapter_names if isinstance(adapter_names, list) else adapter_names.split(",") + fusion_name = ",".join(adapter_names) if isinstance(adapter_names, list) else adapter_names + fusion_config, adapter_names = self.adapters_config.get_fusion(fusion_name) if self.adapters_config.common_config_value(adapter_names, self.location_key): - fusion_config = self.adapters_config.get_fusion(adapter_names) dropout_prob = fusion_config.dropout_prob or getattr(self.model_config, "attention_probs_dropout_prob", 0) fusion = BertFusion( fusion_config, @@ -106,7 +106,7 @@ def add_fusion_layer(self, adapter_names: Union[List, str]): dropout_prob, ) fusion.train(self.training) # make sure training mode is consistent - self.adapter_fusion_layer[",".join(adapter_names)] = fusion + self.adapter_fusion_layer[fusion_name] = fusion def delete_fusion_layer(self, adapter_names: Union[List, str]): adapter_names = adapter_names if isinstance(adapter_names, str) else ",".join(adapter_names) @@ -223,7 +223,7 @@ def compose_fuse(self, adapter_setup: Fuse, state: BottleneckState, lvl: int = 0 context = ForwardContext.get_context() # config of _last_ fused adapter is significant - fusion_config = self.adapters_config.get_fusion(adapter_setup.name) + fusion_config, _ = self.adapters_config.get_fusion(adapter_setup.name) last = adapter_setup.last() last_adapter = self.adapters[last] hidden_states, query, residual = last_adapter.pre_forward( diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 3154af5ac8..62de6178ac 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -638,6 +638,7 @@ def add_adapter_fusion( self, adapter_names: Union[Fuse, list, str], config=None, + name: str = None, overwrite_ok: bool = False, set_active: bool = False, ): @@ -655,6 +656,8 @@ def add_adapter_fusion( - a string identifying a pre-defined adapter fusion configuration - a dictionary representing the adapter fusion configuration - the path to a file containing the adapter fusion configuration + name (str, optional): + Name of the AdapterFusion layer. If not specified, the name is generated automatically from the fused adapter names. overwrite_ok (bool, optional): Overwrite an AdapterFusion layer with the same name if it exists. By default (False), an exception is thrown. @@ -662,22 +665,24 @@ def add_adapter_fusion( Activate the added AdapterFusion. By default (False), the AdapterFusion is added but not activated. """ if isinstance(adapter_names, Fuse): + if name is None: + name = adapter_names.name adapter_names = adapter_names.children elif isinstance(adapter_names, str): adapter_names = adapter_names.split(",") + if name is None: + name = ",".join(adapter_names) if isinstance(config, dict): config = AdapterFusionConfig.from_dict(config) # ensure config is ok and up-to-date # In case adapter already exists and we allow overwriting, explicitly delete the existing one first - if overwrite_ok and self.adapters_config.get_fusion(adapter_names) is not None: - self.delete_adapter_fusion(adapter_names) - self.adapters_config.add_fusion(adapter_names, config=config) - self.apply_to_adapter_layers(lambda i, layer: layer.add_fusion_layer(adapter_names)) - self.apply_to_basemodel_childs(lambda i, child: child.add_fusion_layer(adapter_names)) + if overwrite_ok and self.adapters_config.get_fusion(name)[0] is not None: + self.delete_adapter_fusion(name) + self.adapters_config.add_fusion(adapter_names, config=config, fusion_name=name) + self.apply_to_adapter_layers(lambda i, layer: layer.add_fusion_layer(name)) + self.apply_to_basemodel_childs(lambda i, child: child.add_fusion_layer(name)) if set_active: - if not isinstance(adapter_names, list): - adapter_names = adapter_names.split(",") - self.set_active_adapters(Fuse(*adapter_names)) + self.set_active_adapters(Fuse(*adapter_names, name=name)) def delete_adapter(self, adapter_name: str): """ @@ -710,7 +715,7 @@ def delete_adapter_fusion(self, adapter_names: Union[Fuse, list, str]): adapter_names (Union[Fuse, list, str]): AdapterFusion layer to delete. """ if isinstance(adapter_names, Fuse): - adapter_fusion_name = ",".join(adapter_names.children) + adapter_fusion_name = adapter_names.name elif isinstance(adapter_names, list): adapter_fusion_name = ",".join(adapter_names) elif isinstance(adapter_names, str): @@ -776,7 +781,7 @@ def save_adapter_fusion( ValueError: If the given AdapterFusion name is invalid. """ if isinstance(adapter_names, Fuse): - adapter_fusion_name = ",".join(adapter_names.children) + adapter_fusion_name = adapter_names.name elif isinstance(adapter_names, list): adapter_fusion_name = ",".join(adapter_names) elif isinstance(adapter_names, str): @@ -1094,7 +1099,7 @@ def save_all_adapter_fusions( """ os.makedirs(save_directory, exist_ok=True) for name in self.adapters_config.fusions: - adapter_fusion_config = self.adapters_config.get_fusion(name) + adapter_fusion_config, _ = self.adapters_config.get_fusion(name) h = get_adapter_config_hash(adapter_fusion_config) save_path = join(save_directory, name) if meta_dict: diff --git a/tests/test_adapter_fusion_common.py b/tests/test_adapter_fusion_common.py index ccc860f667..695808eb24 100644 --- a/tests/test_adapter_fusion_common.py +++ b/tests/test_adapter_fusion_common.py @@ -214,3 +214,86 @@ def test_output_adapter_fusion_attentions(self): self.assertEqual(len(per_layer_scores), 1) for k, v in per_layer_scores.items(): self.assertEqual(self.default_input_samples_shape[0], v.shape[0], k) + + def test_add_adapter_fusion_custom_name(self): + config_name = "seq_bn" + model = self.get_model() + model.eval() + + name1 = f"{config_name}-1" + name2 = f"{config_name}-2" + model.add_adapter(name1, config=config_name) + model.add_adapter(name2, config=config_name) + + # adapter is correctly added to config + self.assertTrue(name1 in model.adapters_config) + self.assertTrue(name2 in model.adapters_config) + + # add fusion with default name + model.add_adapter_fusion([name1, name2]) + model.to(torch_device) + + # check forward pass + input_data = self.get_input_samples(config=model.config) + model.set_active_adapters(Fuse(name1, name2)) + fusion_default_ref_output = model(**input_data) + + # add fusion with custom name + model.add_adapter_fusion([name1, name2], name="custom_name_fusion") + model.to(torch_device) + + self.assertIn(f"{name1},{name2}", model.adapters_config.fusions) + self.assertIn("custom_name_fusion", model.adapters_config.fusions) + self.assertIn("custom_name_fusion", model.adapters_config.fusion_name_map) + + # check forward pass + model.set_active_adapters(Fuse(name1, name2, name="custom_name_fusion")) + fusion_custom_output = model(**input_data) + model.set_active_adapters(Fuse(name1, name2)) + fusion_default_output = model(**input_data) + model.set_active_adapters(None) + base_output = model(**input_data) + + self.assertFalse(torch.equal(fusion_default_ref_output[0], base_output[0])) + self.assertTrue(torch.equal(fusion_default_ref_output[0], fusion_default_output[0])) + self.assertFalse(torch.equal(fusion_custom_output[0], fusion_default_output[0])) + self.assertFalse(torch.equal(fusion_custom_output[0], base_output[0])) + + # delete only the custom fusion + model.delete_adapter_fusion(Fuse(name1, name2, name="custom_name_fusion")) + # model.delete_adapter_fusion("custom_name_fusion") + + self.assertIn(f"{name1},{name2}", model.adapters_config.fusions) + self.assertNotIn("custom_name_fusion", model.adapters_config.fusions) + + def test_load_adapter_fusion_custom_name(self): + model1 = self.get_model() + model1.eval() + + name1 = "name1" + name2 = "name2" + model1.add_adapter(name1) + model1.add_adapter(name2) + + model2 = copy.deepcopy(model1) + model2.eval() + + model1.add_adapter_fusion([name1, name2], name="custom_name_fusion") + model1.set_active_adapters(Fuse(name1, name2, name="custom_name_fusion")) + + with tempfile.TemporaryDirectory() as temp_dir: + model1.save_adapter_fusion(temp_dir, "custom_name_fusion") + # also tests that set_active works + model2.load_adapter_fusion(temp_dir, set_active=True) + + # check if adapter was correctly loaded + self.assertEqual(model1.adapters_config.fusions.keys(), model2.adapters_config.fusions.keys()) + + # check equal output + in_data = self.get_input_samples(config=model1.config) + model1.to(torch_device) + model2.to(torch_device) + output1 = model1(**in_data) + output2 = model2(**in_data) + self.assertEqual(len(output1), len(output2)) + self.assertTrue(torch.equal(output1[0], output2[0]))