Skip to content

Commit

Permalink
Feat: Support for Conv2D DoRA (#1516)
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul authored Mar 12, 2024
1 parent 3eb6bba commit 3b63996
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 10 deletions.
7 changes: 5 additions & 2 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,10 @@ class LoraConfig(PeftConfig):
use_dora (`bool`):
Enable 'Weight-Decomposed Low-Rank Adaptation' (DoRA). This technique decomposes the updates of the weights
into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the magnitude is
handled by a separate learnable parameter. This can improve the performance of LoRA, especially at low
handled by a separate learnable parameter. This can improve the performance of LoRA especially at low
ranks. Right now, DoRA only supports linear and Conv2D layers. DoRA introduces a bigger overhead than pure
LoRA, so it is recommended to merge weights for inference. For more information, see
https://arxiv.org/abs/2402.09353.
layer_replication(`List[Tuple[int, int]]`):
Build a new stack of layers by stacking the original model layers according to the ranges specified. This
allows expanding (or shrinking) the model without duplicating the base model weights. The new layers will
Expand Down Expand Up @@ -239,7 +242,7 @@ class LoraConfig(PeftConfig):
"Enable 'Weight-Decomposed Low-Rank Adaptation' (DoRA). This technique decomposes the updates of the "
"weights into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the "
"magnitude is handled by a separate learnable parameter. This can improve the performance of LoRA, "
"especially at low ranks. Right now, DoRA only supports linear layers. DoRA introduces a bigger"
"especially at low ranks. Right now, DoRA only supports linear and Conv2D layers. DoRA introduces a bigger"
"overhead than pure LoRA, so it is recommended to merge weights for inference. For more information, "
"see https://arxiv.org/abs/2402.09353."
)
Expand Down
105 changes: 97 additions & 8 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,11 @@ def dora_init(self, adapter_name: str) -> None:
weight = self.get_base_layer().weight
quant_state = getattr(self.get_base_layer(), "state", None)
weight = dequantize_bnb_weight(weight, state=quant_state) # no-op if not bnb
lora_weight = lora_B.weight @ lora_A.weight
if weight.data.ndim == 4: # For handling LoRAs applied to Conv2Ds.
lora_weight = torch.mm(lora_B.weight.flatten(start_dim=1), lora_A.weight.flatten(start_dim=1))
lora_weight = lora_weight.reshape(weight.shape)
else:
lora_weight = lora_B.weight @ lora_A.weight
weight_norm = self._get_weight_norm(weight, lora_weight, scaling)
self.lora_magnitude_vector = nn.ParameterDict()
self.lora_magnitude_vector[adapter_name] = nn.Parameter(weight_norm, requires_grad=True)
Expand Down Expand Up @@ -515,6 +519,7 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig
if weight is not None:
# the layer is already completely initialized, this is an update
self.to(base_layer.weight.device, dtype=weight.dtype)

self.set_adapter(self.active_adapters)

def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
Expand Down Expand Up @@ -657,9 +662,6 @@ def __init__(
super().__init__()
LoraLayer.__init__(self, base_layer)

if use_dora:
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")

self._active_adapter = adapter_name
self.update_layer(
adapter_name,
Expand Down Expand Up @@ -704,6 +706,13 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig
if weight is not None:
# the layer is already completely initialized, this is an update
self.to(base_layer.weight.device, dtype=weight.dtype)

if use_dora:
self.dora_init(adapter_name)
self.use_dora[adapter_name] = True
else:
self.use_dora[adapter_name] = False

self.set_adapter(self.active_adapters)

def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
Expand Down Expand Up @@ -731,15 +740,42 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N
# Note that safe_merge will be slower than the normal merge
# because of the copy operation.
orig_weights = base_layer.weight.data.clone()
orig_weights = orig_weights + self.get_delta_weight(active_adapter)
delta_weight = self.get_delta_weight(active_adapter)

if not self.use_dora[active_adapter]:
orig_weights = orig_weights + delta_weight
else:
# handle dora
# since delta_weight already includes scaling, set it to 1 here
weight_norm = self._get_weight_norm(orig_weights, delta_weight, scaling=1).detach()
# We need to cache weight_norm because it has to be based on the original weights. We
# cannot calculate it on the fly based on the merged weights when unmerging because its a
# different value
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm
orig_weights = dora_factor.view(-1, 1, 1, 1) * (orig_weights + delta_weight)

if not torch.isfinite(orig_weights).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
base_layer.weight.data = orig_weights
else:
base_layer.weight.data = base_layer.weight.data + self.get_delta_weight(active_adapter)
delta_weight = self.get_delta_weight(active_adapter)
if not self.use_dora[active_adapter]:
base_layer.weight.data = base_layer.weight.data + delta_weight
else:
# handle dora
# since delta_weight already includes scaling, set it to 1 here
weight_norm = self._get_weight_norm(base_layer.weight, delta_weight, scaling=1).detach()
# We need to cache weight_norm because it has to be based on the original weights. We
# cannot calculate it on the fly based on the merged weights when unmerging because its a
# different value
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm
new_weight = dora_factor.view(-1, 1, 1, 1) * (base_layer.weight.data + delta_weight)
base_layer.weight.data = new_weight

self.merged_adapters.append(active_adapter)

def unmerge(self) -> None:
Expand All @@ -752,7 +788,15 @@ def unmerge(self) -> None:
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter in self.lora_A.keys():
self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)
weight = self.get_base_layer().weight
delta_weight = self.get_delta_weight(active_adapter)
if not self.use_dora[active_adapter]:
weight.data -= delta_weight
else:
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm")
dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm
weight_orig = weight.data / dora_factor.view(-1, 1, 1, 1) - delta_weight
weight.data = weight_orig

def get_delta_weight(self, adapter) -> torch.Tensor:
"""
Expand Down Expand Up @@ -802,6 +846,46 @@ def get_delta_weight(self, adapter) -> torch.Tensor:

return output_tensor

def _get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
# calculate L2 norm of weight matrix, channel-wise
weight = weight + scaling * lora_weight
# the following is needed to have compatibility with the 4D weight tensors of Conv2D
weight_norm = weight.norm(p=2, dim=(1, 2, 3), keepdim=True).transpose(1, 0)
return weight_norm

def _apply_dora(self, x, lora_A, lora_B, scaling, active_adapter):
"""
For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
output.
"""
base_layer = self.get_base_layer()
weight = base_layer.weight
lora_weight = torch.mm(lora_B.weight.flatten(start_dim=1), lora_A.weight.flatten(start_dim=1))
lora_weight = lora_weight.reshape(weight.shape)
magnitude = self.lora_magnitude_vector[active_adapter]
weight_norm = self._get_weight_norm(weight, lora_weight, scaling)
# see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353)
# "[...] we suggest treating ||V +∆V ||_c in
# Eq. (5) as a constant, thereby detaching it from the gradient
# graph. This means that while ||V + ∆V ||_c dynamically
# reflects the updates of ∆V , it won’t receive any gradient
# during backpropagation"
weight_norm = weight_norm.detach()
mag_norm_scale = magnitude / weight_norm
result_dora = (mag_norm_scale - 1) * (
F.conv2d(
x,
weight,
bias=None,
stride=base_layer.stride,
padding=base_layer.padding,
dilation=base_layer.dilation,
groups=base_layer.groups,
)
) + mag_norm_scale * lora_B(lora_A(x)) * scaling

return result_dora

def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if self.disable_adapters:
if self.merged:
Expand All @@ -821,7 +905,12 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = x.to(lora_A.weight.dtype)
result = result + lora_B(lora_A(dropout(x))) * scaling

if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling
else:
x = dropout(x)
result = result + self._apply_dora(x, lora_A, lora_B, scaling, active_adapter)

result = result.to(torch_result_dtype)
return result
Expand Down
3 changes: 3 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@
("Embedding + transformers Conv1D 3 LoRA", "EmbConv1D", LoraConfig, {"target_modules": ["emb", "conv1d"]}),
("Conv2d 1 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"]}),
("Conv2d 2 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"]}),
("Conv2d 1 LoRA with DoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"], "use_dora": True}),
("Conv2d 2 LoRA with DoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"], "use_dora": True}),
#######
# IA³ #
#######
Expand Down Expand Up @@ -647,6 +649,7 @@ def test_parameters_after_loading_model(self, test_name, model_id, config_cls, c
def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs):
X = self.prepare_inputs_for_testing()
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device).eval()

outputs_base = model(**X)

config = config_cls(
Expand Down

0 comments on commit 3b63996

Please sign in to comment.