Skip to content

Commit

Permalink
ENH Support Conv3d layer in LoRA and IA3 (#2082)
Browse files Browse the repository at this point in the history
  • Loading branch information
jsilter authored Sep 25, 2024
1 parent 58ca0ad commit 0f9bdad
Show file tree
Hide file tree
Showing 10 changed files with 169 additions and 37 deletions.
4 changes: 2 additions & 2 deletions src/peft/tuners/ia3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
from peft.import_utils import is_bnb_4bit_available, is_bnb_available

from .config import IA3Config
from .layer import Conv2d, IA3Layer, Linear
from .layer import Conv2d, Conv3d, IA3Layer, Linear
from .model import IA3Model


__all__ = ["Conv2d", "IA3Config", "IA3Layer", "IA3Model", "Linear"]
__all__ = ["Conv2d", "Conv3d", "IA3Config", "IA3Layer", "IA3Model", "Linear"]


def __getattr__(name):
Expand Down
33 changes: 25 additions & 8 deletions src/peft/tuners/ia3/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, base_layer: nn.Module, is_feedforward: bool, **kwargs) -> Non
base_layer = self.get_base_layer()
if isinstance(base_layer, nn.Linear):
in_features, out_features = base_layer.in_features, base_layer.out_features
elif isinstance(base_layer, nn.Conv2d):
elif isinstance(base_layer, (nn.Conv2d, nn.Conv3d)):
in_features, out_features = base_layer.in_channels, base_layer.out_channels
elif isinstance(base_layer, nn.Embedding):
in_features, out_features = base_layer.num_embeddings, base_layer.embedding_dim
Expand Down Expand Up @@ -184,7 +184,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
return result


class Conv2d(nn.Module, IA3Layer):
class _ConvNd(nn.Module, IA3Layer):
def __init__(
self,
base_layer: nn.Module,
Expand All @@ -198,15 +198,15 @@ def __init__(
IA3Layer.__init__(self, base_layer, is_feedforward=is_feedforward)
self.fan_in_fan_out = fan_in_fan_out
self._active_adapter = adapter_name
self._kernel_dim = base_layer.weight.dim()

self.update_layer(adapter_name, init_ia3_weights)

def update_layer(self, adapter_name, init_ia3_weights):
# Actual trainable parameters
if self.is_feedforward:
weight = torch.randn((1, self.in_features, 1, 1))
else:
weight = torch.randn((1, self.out_features, 1, 1))
num_features = self.in_features if self.is_feedforward else self.out_features
weights_size = (1, num_features) + (1,) * (self._kernel_dim - 2)
weight = torch.randn(weights_size)
self.ia3_l[adapter_name] = nn.Parameter(weight)
if init_ia3_weights:
self.reset_ia3_parameters(adapter_name)
Expand Down Expand Up @@ -236,7 +236,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N
base_layer = self.get_base_layer()
ia3_scaling = self.ia3_l[active_adapter].data
if not self.is_feedforward:
ia3_scaling = ia3_scaling.permute(1, 0, 2, 3)
ia3_scaling = ia3_scaling.transpose(0, 1)

if safe_merge:
output_weight = torch.mul(base_layer.weight.data, ia3_scaling).clone()
Expand Down Expand Up @@ -272,7 +272,7 @@ def unmerge(self) -> None:
# divide by (IA)^3 vector. Add tolerace to avoid division by zero
ia3_scaling = self.ia3_l[active_adapter].data
if not self.is_feedforward:
ia3_scaling = ia3_scaling.permute(1, 0, 2, 3)
ia3_scaling = ia3_scaling.transpose(0, 1)
base_layer.weight.data = torch.div(base_layer.weight.data, ia3_scaling + 1e-8)

if not self.is_feedforward and (base_layer.bias is not None):
Expand Down Expand Up @@ -308,3 +308,20 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:

result = result.to(previous_dtype)
return result


class Conv2d(_ConvNd):
# IA3 implemented in a 2D convolutional layer

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self._kernel_dim == 4:
raise ValueError(f"Conv2d layer kernel must have 4 dimensions, not {self._kernel_dim}")


class Conv3d(_ConvNd):
# IA3 implemented in a 3D convolutional layer
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self._kernel_dim == 5:
raise ValueError(f"Conv2d layer kernel must have 5 dimensions, not {self._kernel_dim}")
4 changes: 3 additions & 1 deletion src/peft/tuners/ia3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
_get_submodules,
)

from .layer import Conv2d, IA3Layer, Linear
from .layer import Conv2d, Conv3d, IA3Layer, Linear


class IA3Model(BaseTuner):
Expand Down Expand Up @@ -121,6 +121,8 @@ def _create_new_module(ia3_config, adapter_name, target, **kwargs):
new_module = Linear4bit(target, adapter_name, is_feedforward=is_feedforward, **fourbit_kwargs)
elif isinstance(target, torch.nn.Conv2d):
new_module = Conv2d(target, adapter_name, is_feedforward=is_feedforward, **kwargs)
elif isinstance(target, torch.nn.Conv3d):
new_module = Conv3d(target, adapter_name, is_feedforward=is_feedforward, **kwargs)
elif isinstance(target_base_layer, torch.nn.Linear):
if kwargs["fan_in_fan_out"]:
warnings.warn(
Expand Down
3 changes: 2 additions & 1 deletion src/peft/tuners/lora/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from .config import LoftQConfig, LoraConfig, LoraRuntimeConfig
from .gptq import QuantLinear
from .layer import Conv2d, Embedding, Linear, LoraLayer
from .layer import Conv2d, Conv3d, Embedding, Linear, LoraLayer
from .model import LoraModel


Expand All @@ -25,6 +25,7 @@
"LoraRuntimeConfig",
"LoftQConfig",
"Conv2d",
"Conv3d",
"Embedding",
"LoraLayer",
"Linear",
Expand Down
23 changes: 18 additions & 5 deletions src/peft/tuners/lora/dora.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def update_layer(self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=Fals
base_layer = deepcopy(base_layer)

weight = dequantize_module_weight(base_layer)
if weight.data.ndim == 4: # For handling LoRAs applied to Conv2Ds.
if weight.data.ndim >= 4: # For handling LoRAs applied to Conv layers.
lora_weight = torch.mm(lora_B.flatten(start_dim=1), lora_A.flatten(start_dim=1))
lora_weight = lora_weight.reshape(weight.shape)
else:
Expand Down Expand Up @@ -133,12 +133,13 @@ def __repr__(self) -> str:
return "lora.dora." + rep


class DoraConv2dLayer(DoraLinearLayer):
class _DoraConvNdLayer(DoraLinearLayer):
def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
# calculate L2 norm of weight matrix, column-wise
weight = weight + scaling * lora_weight
# the following is needed to have compatibility with the 4D weight tensors of Conv2D
weight_norm = weight.norm(p=2, dim=(1, 2, 3), keepdim=True).transpose(1, 0)
# the following is needed to have compatibility with the 4/5D weight tensors of Conv2D/3D
dim = tuple(range(1, weight.dim()))
weight_norm = weight.norm(p=2, dim=dim, keepdim=True).transpose(1, 0)
return weight_norm

def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
Expand All @@ -160,7 +161,7 @@ def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
weight_norm = weight_norm.detach()
mag_norm_scale = magnitude / weight_norm
result_dora = (mag_norm_scale - 1) * (
F.conv2d(
self.conv_fn(
x,
weight,
bias=None,
Expand All @@ -176,3 +177,15 @@ def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
def __repr__(self) -> str:
rep = super().__repr__()
return "lora.dora." + rep


class DoraConv2dLayer(_DoraConvNdLayer):
def __init__(self, fan_in_fan_out):
super().__init__(fan_in_fan_out)
self.conv_fn = F.conv2d


class DoraConv3dLayer(_DoraConvNdLayer):
def __init__(self, fan_in_fan_out):
super().__init__(fan_in_fan_out)
self.conv_fn = F.conv3d
68 changes: 55 additions & 13 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from peft.utils.other import transpose

from .config import LoraConfig
from .dora import DoraConv2dLayer, DoraEmbeddingLayer, DoraLinearLayer
from .dora import DoraConv2dLayer, DoraConv3dLayer, DoraEmbeddingLayer, DoraLinearLayer, _DoraConvNdLayer


class LoraLayer(BaseTunerLayer):
Expand Down Expand Up @@ -63,6 +63,8 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, *
in_features, out_features = base_layer.in_features, base_layer.out_features
elif isinstance(base_layer, nn.Conv2d):
in_features, out_features = base_layer.in_channels, base_layer.out_channels
elif isinstance(base_layer, nn.Conv3d):
in_features, out_features = base_layer.in_channels, base_layer.out_channels
elif isinstance(base_layer, nn.Embedding):
in_features, out_features = base_layer.num_embeddings, base_layer.embedding_dim
elif isinstance(base_layer, Conv1D):
Expand Down Expand Up @@ -851,8 +853,8 @@ def __repr__(self) -> str:
return "lora." + rep


class Conv2d(nn.Module, LoraLayer):
# Lora implemented in a conv2d layer
class _ConvNd(nn.Module, LoraLayer):
# Lora implemented in a conv(2,3)d layer
def __init__(
self,
base_layer: nn.Module,
Expand All @@ -869,6 +871,8 @@ def __init__(
LoraLayer.__init__(self, base_layer)

self._active_adapter = adapter_name
self._kernel_dim = base_layer.weight.dim()

self.update_layer(
adapter_name,
r,
Expand Down Expand Up @@ -896,8 +900,10 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig
kernel_size = base_layer.kernel_size
stride = base_layer.stride
padding = base_layer.padding
self.lora_A[adapter_name] = nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False)
self.lora_B[adapter_name] = nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False)
conv_layer = type(base_layer)
out_kernel = out_stride = (1,) * (self._kernel_dim - 2)
self.lora_A[adapter_name] = conv_layer(self.in_features, r, kernel_size, stride, padding, bias=False)
self.lora_B[adapter_name] = conv_layer(r, self.out_features, out_kernel, out_stride, bias=False)
if use_rslora:
self.scaling[adapter_name] = lora_alpha / math.sqrt(r)
else:
Expand All @@ -919,18 +925,26 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig

self.set_adapter(self.active_adapters)

def _get_dora_factor_view(self):
return (-1,) + (1,) * (self._kernel_dim - 1)

def dora_init(self, adapter_name: str) -> None:
if self.lora_magnitude_vector is None:
# first dora layer being added, add lora_magnitude_vector to the list of learnable parameters
self.adapter_layer_names = self.adapter_layer_names[:] + ("lora_magnitude_vector",)

dora_layer = DoraConv2dLayer(fan_in_fan_out=False)
dora_layer_class = self._get_dora_layer_class()
dora_layer = dora_layer_class(fan_in_fan_out=False)
lora_A = self.lora_A[adapter_name].weight
lora_B = self.lora_B[adapter_name].weight
scaling = self.scaling[adapter_name]
dora_layer.update_layer(base_layer=self.get_base_layer(), lora_A=lora_A, lora_B=lora_B, scaling=scaling)
self.lora_magnitude_vector[adapter_name] = dora_layer

def _get_dora_layer_class(self) -> type[_DoraConvNdLayer]:
# Subclasses should override this method to return the appropriate DoraLayer class
raise NotImplementedError

def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights inside the base weights
Expand Down Expand Up @@ -973,7 +987,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
# different value
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
orig_weights = dora_factor.view(-1, 1, 1, 1) * (orig_weights + delta_weight)
orig_weights = dora_factor.view(*self._get_dora_factor_view()) * (orig_weights + delta_weight)

if not torch.isfinite(orig_weights).all():
raise ValueError(
Expand All @@ -997,7 +1011,9 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
# different value
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
new_weight = dora_factor.view(-1, 1, 1, 1) * (base_layer.weight.data + delta_weight)
new_weight = dora_factor.view(*self._get_dora_factor_view()) * (
base_layer.weight.data + delta_weight
)
base_layer.weight.data = new_weight

self.merged_adapters.append(active_adapter)
Expand All @@ -1019,7 +1035,7 @@ def unmerge(self) -> None:
else:
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm")
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
weight_orig = weight.data / dora_factor.view(-1, 1, 1, 1) - delta_weight
weight_orig = weight.data / dora_factor.view(*self._get_dora_factor_view()) - delta_weight
weight.data = weight_orig

def get_delta_weight(self, adapter) -> torch.Tensor:
Expand Down Expand Up @@ -1052,12 +1068,11 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
3
) * self.scaling[adapter]
else:
# conv2d 3x3
output_tensor = (
F.conv2d(
weight_A.permute(1, 0, 2, 3),
self.conv_fn(
weight_A.transpose(0, 1),
weight_B,
).permute(1, 0, 2, 3)
).transpose(0, 1)
* self.scaling[adapter]
)

Expand Down Expand Up @@ -1115,6 +1130,30 @@ def __repr__(self) -> str:
return "lora." + rep


class Conv2d(_ConvNd):
# Lora implemented in a conv2d layer
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self._kernel_dim == 4:
raise ValueError(f"Conv2d layer kernel must have 4 dimensions, not {self._kernel_dim}")
self.conv_fn = F.conv2d

def _get_dora_layer_class(self):
return DoraConv2dLayer


class Conv3d(_ConvNd):
# Lora implemented in a conv3d layer
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self._kernel_dim == 5:
raise ValueError(f"Conv3d layer kernel must have 5 dimensions, not {self._kernel_dim}")
self.conv_fn = F.conv3d

def _get_dora_layer_class(self):
return DoraConv3dLayer


def dispatch_default(
target: torch.nn.Module,
adapter_name: str,
Expand All @@ -1136,6 +1175,9 @@ def dispatch_default(
elif isinstance(target_base_layer, torch.nn.Conv2d):
kwargs.update(lora_config.loftq_config)
new_module = Conv2d(target, adapter_name, **kwargs)
elif isinstance(target_base_layer, torch.nn.Conv3d):
kwargs.update(lora_config.loftq_config)
new_module = Conv3d(target, adapter_name, **kwargs)
elif isinstance(target_base_layer, torch.nn.Linear):
if kwargs["fan_in_fan_out"]:
warnings.warn(
Expand Down
3 changes: 2 additions & 1 deletion src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,8 @@ def dynamic_dispatch_func(target, adapter_name, lora_config, **kwargs):
# no module could be matched
raise ValueError(
f"Target module {target} is not supported. Currently, only the following modules are supported: "
"`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`."
"`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `torch.nn.Conv3d`, "
"`transformers.pytorch_utils.Conv1D`."
)

return new_module
Expand Down
2 changes: 1 addition & 1 deletion src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def _mixed_batch_forward(
# This is a special method that handles the case when users pass the argument `adapter_names`. This is an
# extra argument that allows mixing different adapters in the same batch at inference time.

SUPPORTED_MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv2d, torch.nn.Conv1d)
SUPPORTED_MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)

module_names = ", ".join([module.__name__ for module in SUPPORTED_MODULES])

Expand Down
Loading

0 comments on commit 0f9bdad

Please sign in to comment.