Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adalora deepspeed #1625

Merged
merged 1 commit into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/peft/tuners/adalora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,22 @@
import warnings
from typing import Any, List, Optional

import packaging
import torch
import transformers
from torch import nn

from peft.tuners.lora import LoraLayer
from peft.tuners.tuners_utils import check_adapters_to_merge
from peft.utils import transpose


if packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.33.0"):
from transformers.integrations import deepspeed_config
else:
from transformers.deepspeed import deepspeed_config


class AdaLoraLayer(LoraLayer):
# List all names of layers that may contain adapter weights
# Note: ranknum doesn't need to be included as it is not an nn.Module
Expand Down Expand Up @@ -253,7 +261,13 @@ def update_ipt(self, model):
self.exp_avg_ipt[n] = torch.zeros_like(p)
self.exp_avg_unc[n] = torch.zeros_like(p)
with torch.no_grad():
self.ipt[n] = (p * p.grad).abs().detach()
if deepspeed_config() is not None:
import deepspeed

grad = deepspeed.utils.safe_get_full_grad(p)
self.ipt[n] = (p * grad).abs().detach()
else:
self.ipt[n] = (p * p.grad).abs().detach()
# Sensitivity smoothing
self.exp_avg_ipt[n] = self.beta1 * self.exp_avg_ipt[n] + (1 - self.beta1) * self.ipt[n]
# Uncertainty quantification
Expand Down
7 changes: 6 additions & 1 deletion src/peft/tuners/adalora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
get_auto_gptq_quant_linear,
get_quantization_config,
)
from peft.utils.integrations import gather_params_ctx

from .gptq import SVDQuantLinear
from .layer import AdaLoraLayer, RankAllocator, SVDLinear
Expand Down Expand Up @@ -244,7 +245,11 @@ def forward(self, *args, **kwargs):
num_param = 0
for n, p in self.model.named_parameters():
if ("lora_A" in n or "lora_B" in n) and self.trainable_adapter_name in n:
para_cov = p @ p.T if "lora_A" in n else p.T @ p
if p.shape == torch.Size([0]):
with gather_params_ctx(p, fwd_module=self):
para_cov = p @ p.T if "lora_A" in n else p.T @ p
else:
para_cov = p @ p.T if "lora_A" in n else p.T @ p
I = torch.eye(*para_cov.size(), out=torch.empty_like(para_cov)) # noqa: E741
I.requires_grad = False
num_param += 1
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def dora_init(self, adapter_name: str) -> None:
lora_A = self.lora_A[adapter_name]
lora_B = self.lora_B[adapter_name]
scaling = self.scaling[adapter_name]
with gather_params_ctx(self.get_base_layer()):
with gather_params_ctx(self.get_base_layer().parameters()):
weight = self.get_base_layer().weight
quant_state = getattr(self.get_base_layer(), "state", None)
weight = dequantize_bnb_weight(weight, state=quant_state) # no-op if not bnb
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/prompt_tuning/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self, config, word_embeddings):
init_token_ids = init_token_ids * num_reps
init_token_ids = init_token_ids[:total_virtual_tokens]
init_token_ids = torch.LongTensor(init_token_ids).to(word_embeddings.weight.device)
with gather_params_ctx(word_embeddings):
with gather_params_ctx(word_embeddings.parameters()):
word_embedding_weights = word_embeddings(init_token_ids).detach().clone()
word_embedding_weights = word_embedding_weights.to(torch.float32)
self.embedding.weight = torch.nn.Parameter(word_embedding_weights)
Expand Down
5 changes: 2 additions & 3 deletions src/peft/utils/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


@contextmanager
def gather_params_ctx(module: torch.nn.Module, modifier_rank: int = 0):
def gather_params_ctx(param, modifier_rank: int = 0, fwd_module: torch.nn.Module = None):
"""Call DeepSpeed GatheredParameters context manager if DeepSpeed is enabled, otherwise do nothing."""
if packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.33.0"):
from transformers.integrations import is_deepspeed_zero3_enabled
Expand All @@ -33,8 +33,7 @@ def gather_params_ctx(module: torch.nn.Module, modifier_rank: int = 0):

import deepspeed

params_to_gather = module.parameters()
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=modifier_rank):
with deepspeed.zero.GatheredParameters(param, modifier_rank=modifier_rank, fwd_module=fwd_module):
yield
return

Expand Down
Loading