diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index 1f6658ddcf..cc5c60a753 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -104,7 +104,10 @@ class LoraConfig(PeftConfig): use_dora (`bool`): Enable 'Weight-Decomposed Low-Rank Adaptation' (DoRA). This technique decomposes the updates of the weights into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the magnitude is - handled by a separate learnable parameter. This can improve the performance of LoRA, especially at low + handled by a separate learnable parameter. This can improve the performance of LoRA especially at low + ranks. Right now, DoRA only supports linear and Conv2D layers. DoRA introduces a bigger overhead than pure + LoRA, so it is recommended to merge weights for inference. For more information, see + https://arxiv.org/abs/2402.09353. layer_replication(`List[Tuple[int, int]]`): Build a new stack of layers by stacking the original model layers according to the ranges specified. This allows expanding (or shrinking) the model without duplicating the base model weights. The new layers will @@ -239,7 +242,7 @@ class LoraConfig(PeftConfig): "Enable 'Weight-Decomposed Low-Rank Adaptation' (DoRA). This technique decomposes the updates of the " "weights into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the " "magnitude is handled by a separate learnable parameter. This can improve the performance of LoRA, " - "especially at low ranks. Right now, DoRA only supports linear layers. DoRA introduces a bigger" + "especially at low ranks. Right now, DoRA only supports linear and Conv2D layers. DoRA introduces a bigger" "overhead than pure LoRA, so it is recommended to merge weights for inference. For more information, " "see https://arxiv.org/abs/2402.09353." ) diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index a415016a4a..501ef53ef6 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -183,7 +183,11 @@ def dora_init(self, adapter_name: str) -> None: weight = self.get_base_layer().weight quant_state = getattr(self.get_base_layer(), "state", None) weight = dequantize_bnb_weight(weight, state=quant_state) # no-op if not bnb - lora_weight = lora_B.weight @ lora_A.weight + if weight.data.ndim == 4: # For handling LoRAs applied to Conv2Ds. + lora_weight = torch.mm(lora_B.weight.flatten(start_dim=1), lora_A.weight.flatten(start_dim=1)) + lora_weight = lora_weight.reshape(weight.shape) + else: + lora_weight = lora_B.weight @ lora_A.weight weight_norm = self._get_weight_norm(weight, lora_weight, scaling) self.lora_magnitude_vector = nn.ParameterDict() self.lora_magnitude_vector[adapter_name] = nn.Parameter(weight_norm, requires_grad=True) @@ -515,6 +519,7 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig if weight is not None: # the layer is already completely initialized, this is an update self.to(base_layer.weight.device, dtype=weight.dtype) + self.set_adapter(self.active_adapters) def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: @@ -657,9 +662,6 @@ def __init__( super().__init__() LoraLayer.__init__(self, base_layer) - if use_dora: - raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") - self._active_adapter = adapter_name self.update_layer( adapter_name, @@ -704,6 +706,13 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig if weight is not None: # the layer is already completely initialized, this is an update self.to(base_layer.weight.device, dtype=weight.dtype) + + if use_dora: + self.dora_init(adapter_name) + self.use_dora[adapter_name] = True + else: + self.use_dora[adapter_name] = False + self.set_adapter(self.active_adapters) def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: @@ -731,7 +740,20 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N # Note that safe_merge will be slower than the normal merge # because of the copy operation. orig_weights = base_layer.weight.data.clone() - orig_weights = orig_weights + self.get_delta_weight(active_adapter) + delta_weight = self.get_delta_weight(active_adapter) + + if not self.use_dora[active_adapter]: + orig_weights = orig_weights + delta_weight + else: + # handle dora + # since delta_weight already includes scaling, set it to 1 here + weight_norm = self._get_weight_norm(orig_weights, delta_weight, scaling=1).detach() + # We need to cache weight_norm because it has to be based on the original weights. We + # cannot calculate it on the fly based on the merged weights when unmerging because its a + # different value + self._cache_store(f"{active_adapter}-weight_norm", weight_norm) + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + orig_weights = dora_factor.view(-1, 1, 1, 1) * (orig_weights + delta_weight) if not torch.isfinite(orig_weights).all(): raise ValueError( @@ -739,7 +761,21 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N ) base_layer.weight.data = orig_weights else: - base_layer.weight.data = base_layer.weight.data + self.get_delta_weight(active_adapter) + delta_weight = self.get_delta_weight(active_adapter) + if not self.use_dora[active_adapter]: + base_layer.weight.data = base_layer.weight.data + delta_weight + else: + # handle dora + # since delta_weight already includes scaling, set it to 1 here + weight_norm = self._get_weight_norm(base_layer.weight, delta_weight, scaling=1).detach() + # We need to cache weight_norm because it has to be based on the original weights. We + # cannot calculate it on the fly based on the merged weights when unmerging because its a + # different value + self._cache_store(f"{active_adapter}-weight_norm", weight_norm) + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + new_weight = dora_factor.view(-1, 1, 1, 1) * (base_layer.weight.data + delta_weight) + base_layer.weight.data = new_weight + self.merged_adapters.append(active_adapter) def unmerge(self) -> None: @@ -752,7 +788,15 @@ def unmerge(self) -> None: while len(self.merged_adapters) > 0: active_adapter = self.merged_adapters.pop() if active_adapter in self.lora_A.keys(): - self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) + weight = self.get_base_layer().weight + delta_weight = self.get_delta_weight(active_adapter) + if not self.use_dora[active_adapter]: + weight.data -= delta_weight + else: + weight_norm = self._cache_pop(f"{active_adapter}-weight_norm") + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + weight_orig = weight.data / dora_factor.view(-1, 1, 1, 1) - delta_weight + weight.data = weight_orig def get_delta_weight(self, adapter) -> torch.Tensor: """ @@ -802,6 +846,46 @@ def get_delta_weight(self, adapter) -> torch.Tensor: return output_tensor + def _get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor: + # calculate L2 norm of weight matrix, channel-wise + weight = weight + scaling * lora_weight + # the following is needed to have compatibility with the 4D weight tensors of Conv2D + weight_norm = weight.norm(p=2, dim=(1, 2, 3), keepdim=True).transpose(1, 0) + return weight_norm + + def _apply_dora(self, x, lora_A, lora_B, scaling, active_adapter): + """ + For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer + output. + """ + base_layer = self.get_base_layer() + weight = base_layer.weight + lora_weight = torch.mm(lora_B.weight.flatten(start_dim=1), lora_A.weight.flatten(start_dim=1)) + lora_weight = lora_weight.reshape(weight.shape) + magnitude = self.lora_magnitude_vector[active_adapter] + weight_norm = self._get_weight_norm(weight, lora_weight, scaling) + # see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353) + # "[...] we suggest treating ||V +∆V ||_c in + # Eq. (5) as a constant, thereby detaching it from the gradient + # graph. This means that while ||V + ∆V ||_c dynamically + # reflects the updates of ∆V , it won’t receive any gradient + # during backpropagation" + weight_norm = weight_norm.detach() + mag_norm_scale = magnitude / weight_norm + result_dora = (mag_norm_scale - 1) * ( + F.conv2d( + x, + weight, + bias=None, + stride=base_layer.stride, + padding=base_layer.padding, + dilation=base_layer.dilation, + groups=base_layer.groups, + ) + ) + mag_norm_scale * lora_B(lora_A(x)) * scaling + + return result_dora + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: if self.disable_adapters: if self.merged: @@ -821,7 +905,12 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: dropout = self.lora_dropout[active_adapter] scaling = self.scaling[active_adapter] x = x.to(lora_A.weight.dtype) - result = result + lora_B(lora_A(dropout(x))) * scaling + + if not self.use_dora[active_adapter]: + result = result + lora_B(lora_A(dropout(x))) * scaling + else: + x = dropout(x) + result = result + self._apply_dora(x, lora_A, lora_B, scaling, active_adapter) result = result.to(torch_result_dtype) return result diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 9db559a3d8..947530ef15 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -68,6 +68,8 @@ ("Embedding + transformers Conv1D 3 LoRA", "EmbConv1D", LoraConfig, {"target_modules": ["emb", "conv1d"]}), ("Conv2d 1 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"]}), ("Conv2d 2 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"]}), + ("Conv2d 1 LoRA with DoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"], "use_dora": True}), + ("Conv2d 2 LoRA with DoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"], "use_dora": True}), ####### # IA³ # ####### @@ -647,6 +649,7 @@ def test_parameters_after_loading_model(self, test_name, model_id, config_cls, c def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs): X = self.prepare_inputs_for_testing() model = self.transformers_class.from_pretrained(model_id).to(self.torch_device).eval() + outputs_base = model(**X) config = config_cls(