From e67152c05e1b508fec910b8038adc10e5a1af6cc Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 28 Feb 2024 16:21:19 +0530 Subject: [PATCH 01/11] add: apply_dora method for conv2d. --- src/peft/tuners/lora/layer.py | 44 +++++++++++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 83617b69a2..95ed1600c2 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -510,6 +510,7 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig if weight is not None: # the layer is already completely initialized, this is an update self.to(base_layer.weight.device, dtype=weight.dtype) + self.set_adapter(self.active_adapters) def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: @@ -652,8 +653,8 @@ def __init__( super().__init__() LoraLayer.__init__(self, base_layer) - if use_dora: - raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") + # if use_dora: + # raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") self._active_adapter = adapter_name self.update_layer( @@ -699,6 +700,13 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig if weight is not None: # the layer is already completely initialized, this is an update self.to(base_layer.weight.device, dtype=weight.dtype) + + if use_dora: + self.dora_init(adapter_name) + self.use_dora[adapter_name] = True + else: + self.use_dora[adapter_name] = False + self.set_adapter(self.active_adapters) def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: @@ -797,6 +805,38 @@ def get_delta_weight(self, adapter) -> torch.Tensor: return output_tensor + def _get_weight_norm(weight, lora_weight, scaling) -> torch.Tensor: + # calculate L2 norm of weight matrix, channel-wise + weight = weight + scaling * lora_weight + weight_norm = weight.norm(p=2, dim=(1, 2, 3)) + return weight_norm + + def _apply_dora(self, x, lora_A, lora_B, scaling, active_adapter): + """ + For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer + output. + """ + base_layer = self.get_base_layer() + weight = base_layer.weight + lora_weight = torch.mm(lora_B.weight.flatten(start_dim=1), lora_A.weight.flatten(start_dim=1)) + lora_weight = lora_weight.reshape(weight.shape) + magnitude = self.lora_magnitude_vector[active_adapter] + weight_norm = self._get_weight_norm(weight, lora_weight, scaling) + # see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353) + # "[...] we suggest treating ||V +∆V ||_c in + # Eq. (5) as a constant, thereby detaching it from the gradient + # graph. This means that while ||V + ∆V ||_c dynamically + # reflects the updates of ∆V , it won’t receive any gradient + # during backpropagation" + weight_norm = weight_norm.detach() + weight_norm = weight_norm.reshape((1, base_layer.out_channels, 1, 1)) + mag_norm_scale = magnitude / weight_norm + result_dora = (mag_norm_scale - 1) * ( + F.conv2d(x, weight, None, base_layer.stride, base_layer.padding, base_layer.dilation, base_layer.groups) + ) + mag_norm_scale * lora_B(lora_A(x)) * scaling + + return result_dora + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: if self.disable_adapters: if self.merged: From 7d797a9de058a4cde6f63f3b9415fc3df786cffe Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 7 Mar 2024 16:15:27 +0530 Subject: [PATCH 02/11] implement forward pass. --- src/peft/tuners/lora/layer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 79e3f01045..8b3129c4c0 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -653,9 +653,6 @@ def __init__( super().__init__() LoraLayer.__init__(self, base_layer) - # if use_dora: - # raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") - self._active_adapter = adapter_name self.update_layer( adapter_name, @@ -856,7 +853,12 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: dropout = self.lora_dropout[active_adapter] scaling = self.scaling[active_adapter] x = x.to(lora_A.weight.dtype) - result = result + lora_B(lora_A(dropout(x))) * scaling + + if not self.use_dora[active_adapter]: + result = result + lora_B(lora_A(dropout(x))) * scaling + else: + x = dropout(x) + result = result + self._apply_dora(x, lora_A, lora_B, scaling, active_adapter) result = result.to(torch_result_dtype) return result From 17d5cfe9bfdb0ed96e01909968477aa3288c1733 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 7 Mar 2024 16:21:37 +0530 Subject: [PATCH 03/11] implement merge --- src/peft/tuners/lora/layer.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 8b3129c4c0..f43105c4e2 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -731,7 +731,20 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N # Note that safe_merge will be slower than the normal merge # because of the copy operation. orig_weights = base_layer.weight.data.clone() - orig_weights = orig_weights + self.get_delta_weight(active_adapter) + delta_weight = self.get_delta_weight(active_adapter) + + if not self.use_dora[active_adapter]: + orig_weights = orig_weights + delta_weight + else: + # handle dora + # since delta_weight already includes scaling, set it to 1 here + weight_norm = self._get_weight_norm(orig_weights, delta_weight, scaling=1).detach() + # We need to cache weight_norm because it has to be based on the original weights. We + # cannot calculate it on the fly based on the merged weights when unmerging because its a + # different value + self._cache_store(f"{active_adapter}-weight_norm", weight_norm) + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + orig_weights = dora_factor.view(-1, 1) * (orig_weights + delta_weight) if not torch.isfinite(orig_weights).all(): raise ValueError( @@ -739,7 +752,21 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N ) base_layer.weight.data = orig_weights else: - base_layer.weight.data = base_layer.weight.data + self.get_delta_weight(active_adapter) + delta_weight = self.get_delta_weight(active_adapter) + if not self.use_dora[active_adapter]: + base_layer.weight.data = base_layer.weight.data + delta_weight + else: + # handle dora + # since delta_weight already includes scaling, set it to 1 here + weight_norm = self._get_weight_norm(base_layer.weight, delta_weight, scaling=1).detach() + # We need to cache weight_norm because it has to be based on the original weights. We + # cannot calculate it on the fly based on the merged weights when unmerging because its a + # different value + self._cache_store(f"{active_adapter}-weight_norm", weight_norm) + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + new_weight = dora_factor.view(-1, 1) * (base_layer.weight.data + delta_weight) + base_layer.weight.data = new_weight + self.merged_adapters.append(active_adapter) def unmerge(self) -> None: From 2162d1707f67b732056d9e617e05ff7e715a1401 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 8 Mar 2024 17:10:50 +0530 Subject: [PATCH 04/11] add: tests --- tests/test_custom_models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 9db559a3d8..b60edf70ba 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -68,6 +68,8 @@ ("Embedding + transformers Conv1D 3 LoRA", "EmbConv1D", LoraConfig, {"target_modules": ["emb", "conv1d"]}), ("Conv2d 1 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"]}), ("Conv2d 2 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"]}), + ("Conv2d 1 LoRA with LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"], "use_dora": True}), + ("Conv2d 2 LoRA with DoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"], "use_dora": True}), ####### # IA³ # ####### From 537931b4f830380e4ba703c1f9bd852639c0407c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 8 Mar 2024 17:14:35 +0530 Subject: [PATCH 05/11] fix typo --- tests/test_custom_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index b60edf70ba..632a8d8c35 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -68,7 +68,7 @@ ("Embedding + transformers Conv1D 3 LoRA", "EmbConv1D", LoraConfig, {"target_modules": ["emb", "conv1d"]}), ("Conv2d 1 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"]}), ("Conv2d 2 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"]}), - ("Conv2d 1 LoRA with LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"], "use_dora": True}), + ("Conv2d 1 LoRA with DoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"], "use_dora": True}), ("Conv2d 2 LoRA with DoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"], "use_dora": True}), ####### # IA³ # From 978e99237a26ce8e10fc494c969da0729d4922b6 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 12 Mar 2024 16:06:28 +0530 Subject: [PATCH 06/11] fix dora conv2d implementation (merging still pending) --- src/peft/tuners/lora/layer.py | 9 +++++---- tests/test_custom_models.py | 3 +++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index f43105c4e2..966bf07408 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -181,7 +181,8 @@ def dora_init(self, adapter_name: str) -> None: scaling = self.scaling[adapter_name] with gather_params_ctx(self.get_base_layer()): weight = self.get_base_layer().weight - lora_weight = lora_B.weight @ lora_A.weight + lora_weight = torch.mm(lora_B.weight.flatten(start_dim=1), lora_A.weight.flatten(start_dim=1)) + lora_weight = lora_weight.reshape(weight.shape) weight_norm = self._get_weight_norm(weight, lora_weight, scaling) self.lora_magnitude_vector = nn.ParameterDict() self.lora_magnitude_vector[adapter_name] = nn.Parameter(weight_norm, requires_grad=True) @@ -764,7 +765,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N # different value self._cache_store(f"{active_adapter}-weight_norm", weight_norm) dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm - new_weight = dora_factor.view(-1, 1) * (base_layer.weight.data + delta_weight) + new_weight = dora_factor.view(-1, 1, 1, 1) * (base_layer.weight.data + delta_weight) base_layer.weight.data = new_weight self.merged_adapters.append(active_adapter) @@ -829,10 +830,11 @@ def get_delta_weight(self, adapter) -> torch.Tensor: return output_tensor - def _get_weight_norm(weight, lora_weight, scaling) -> torch.Tensor: + def _get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor: # calculate L2 norm of weight matrix, channel-wise weight = weight + scaling * lora_weight weight_norm = weight.norm(p=2, dim=(1, 2, 3)) + weight_norm = weight_norm.reshape((1, self.get_base_layer().out_channels, 1, 1)) return weight_norm def _apply_dora(self, x, lora_A, lora_B, scaling, active_adapter): @@ -853,7 +855,6 @@ def _apply_dora(self, x, lora_A, lora_B, scaling, active_adapter): # reflects the updates of ∆V , it won’t receive any gradient # during backpropagation" weight_norm = weight_norm.detach() - weight_norm = weight_norm.reshape((1, base_layer.out_channels, 1, 1)) mag_norm_scale = magnitude / weight_norm result_dora = (mag_norm_scale - 1) * ( F.conv2d(x, weight, None, base_layer.stride, base_layer.padding, base_layer.dilation, base_layer.groups) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 632a8d8c35..5b947dc42e 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -648,7 +648,10 @@ def test_parameters_after_loading_model(self, test_name, model_id, config_cls, c @parameterized.expand(TEST_CASES) def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs): X = self.prepare_inputs_for_testing() + print(f"Class name: {self.transformers_class.__name__}") model = self.transformers_class.from_pretrained(model_id).to(self.torch_device).eval() + for k in X: + print(k, X[k].shape) outputs_base = model(**X) config = config_cls( From e95f1a7728c301d7a9b020694e05c99abeeffbc7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 12 Mar 2024 16:25:26 +0530 Subject: [PATCH 07/11] fix merging tests --- src/peft/tuners/lora/layer.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 966bf07408..252965a762 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -745,7 +745,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N # different value self._cache_store(f"{active_adapter}-weight_norm", weight_norm) dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm - orig_weights = dora_factor.view(-1, 1) * (orig_weights + delta_weight) + orig_weights = dora_factor.view(-1, 1, 1, 1) * (orig_weights + delta_weight) if not torch.isfinite(orig_weights).all(): raise ValueError( @@ -780,7 +780,15 @@ def unmerge(self) -> None: while len(self.merged_adapters) > 0: active_adapter = self.merged_adapters.pop() if active_adapter in self.lora_A.keys(): - self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) + weight = self.get_base_layer().weight + delta_weight = self.get_delta_weight(active_adapter) + if not self.use_dora[active_adapter]: + weight.data -= delta_weight + else: + weight_norm = self._cache_pop(f"{active_adapter}-weight_norm") + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + weight_orig = weight.data / dora_factor.view(-1, 1, 1, 1) - delta_weight + weight.data = weight_orig def get_delta_weight(self, adapter) -> torch.Tensor: """ From b599bfedb793cb47fb73eceaeb1f38586cf5bd57 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 12 Mar 2024 17:14:44 +0530 Subject: [PATCH 08/11] documentation --- src/peft/tuners/lora/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index 695fed6f82..2bacf8f0f7 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -105,7 +105,7 @@ class LoraConfig(PeftConfig): Enable 'Weight-Decomposed Low-Rank Adaptation' (DoRA). This technique decomposes the updates of the weights into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the magnitude is handled by a separate learnable parameter. This can improve the performance of LoRA, especially at low - ranks. Right now, DoRA only supports non-quantized linear layers. DoRA introduces a bigger overhead than + ranks. Right now, DoRA only supports non-quantized linear and Conv2D layers. DoRA introduces a bigger overhead than pure LoRA, so it is recommended to merge weights for inference. For more information, see https://arxiv.org/abs/2402.09353. """ @@ -238,7 +238,7 @@ class LoraConfig(PeftConfig): "Enable 'Weight-Decomposed Low-Rank Adaptation' (DoRA). This technique decomposes the updates of the " "weights into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the " "magnitude is handled by a separate learnable parameter. This can improve the performance of LoRA, " - "especially at low ranks. Right now, DoRA only supports non-quantized linear layers. DoRA introduces " + "especially at low ranks. Right now, DoRA only supports non-quantized linear and Conv2D layers. DoRA introduces " "a bigger overhead than pure LoRA, so it is recommended to merge weights for inference. For more " "information, see https://arxiv.org/abs/2402.09353." ) From d32e0fbc2fb6074fcc0aa78eae38f182b7569d49 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 12 Mar 2024 17:16:40 +0530 Subject: [PATCH 09/11] use keyword args for the F.conv2d call. --- src/peft/tuners/lora/config.py | 4 ++-- src/peft/tuners/lora/layer.py | 10 +++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index 2bacf8f0f7..afe9993ea7 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -105,8 +105,8 @@ class LoraConfig(PeftConfig): Enable 'Weight-Decomposed Low-Rank Adaptation' (DoRA). This technique decomposes the updates of the weights into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the magnitude is handled by a separate learnable parameter. This can improve the performance of LoRA, especially at low - ranks. Right now, DoRA only supports non-quantized linear and Conv2D layers. DoRA introduces a bigger overhead than - pure LoRA, so it is recommended to merge weights for inference. For more information, see + ranks. Right now, DoRA only supports non-quantized linear and Conv2D layers. DoRA introduces a bigger + overhead than pure LoRA, so it is recommended to merge weights for inference. For more information, see https://arxiv.org/abs/2402.09353. """ diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 252965a762..c13e58b729 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -865,7 +865,15 @@ def _apply_dora(self, x, lora_A, lora_B, scaling, active_adapter): weight_norm = weight_norm.detach() mag_norm_scale = magnitude / weight_norm result_dora = (mag_norm_scale - 1) * ( - F.conv2d(x, weight, None, base_layer.stride, base_layer.padding, base_layer.dilation, base_layer.groups) + F.conv2d( + x, + weight, + bias=None, + stride=base_layer.stride, + padding=base_layer.padding, + dilation=base_layer.dilation, + groups=base_layer.groups, + ) ) + mag_norm_scale * lora_B(lora_A(x)) * scaling return result_dora From 743a74c370a78f449cc35d7ade9c22b1fcf4e1fa Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 12 Mar 2024 17:17:20 +0530 Subject: [PATCH 10/11] remove print. --- tests/test_custom_models.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 5b947dc42e..947530ef15 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -648,10 +648,8 @@ def test_parameters_after_loading_model(self, test_name, model_id, config_cls, c @parameterized.expand(TEST_CASES) def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs): X = self.prepare_inputs_for_testing() - print(f"Class name: {self.transformers_class.__name__}") model = self.transformers_class.from_pretrained(model_id).to(self.torch_device).eval() - for k in X: - print(k, X[k].shape) + outputs_base = model(**X) config = config_cls( From 641c5b52d410687d0f6ed012d64f7c9d1b4c4bb8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 12 Mar 2024 17:49:50 +0530 Subject: [PATCH 11/11] fix dora_init and make the weightnorm code leaner. --- src/peft/tuners/lora/layer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 4adac2cbdb..501ef53ef6 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -183,8 +183,9 @@ def dora_init(self, adapter_name: str) -> None: weight = self.get_base_layer().weight quant_state = getattr(self.get_base_layer(), "state", None) weight = dequantize_bnb_weight(weight, state=quant_state) # no-op if not bnb - if lora_A.weight.data.ndim == 4: # For handling LoRAs applied to Conv2Ds. + if weight.data.ndim == 4: # For handling LoRAs applied to Conv2Ds. lora_weight = torch.mm(lora_B.weight.flatten(start_dim=1), lora_A.weight.flatten(start_dim=1)) + lora_weight = lora_weight.reshape(weight.shape) else: lora_weight = lora_B.weight @ lora_A.weight weight_norm = self._get_weight_norm(weight, lora_weight, scaling) @@ -848,8 +849,8 @@ def get_delta_weight(self, adapter) -> torch.Tensor: def _get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor: # calculate L2 norm of weight matrix, channel-wise weight = weight + scaling * lora_weight - weight_norm = weight.norm(p=2, dim=(1, 2, 3)) - weight_norm = weight_norm.reshape((1, self.get_base_layer().out_channels, 1, 1)) + # the following is needed to have compatibility with the 4D weight tensors of Conv2D + weight_norm = weight.norm(p=2, dim=(1, 2, 3), keepdim=True).transpose(1, 0) return weight_norm def _apply_dora(self, x, lora_A, lora_B, scaling, active_adapter):