-
Notifications
You must be signed in to change notification settings - Fork 0
/
LoraHelpers.py
92 lines (85 loc) · 3.76 KB
/
LoraHelpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import torch.nn as nn
import transformers
import pickle
class LoraModule(nn.Module):
def __init__(self, orig_module:nn.modules.linear.Linear, r:int=8, alpha:float=1.0):
assert type(orig_module)==nn.modules.linear.Linear, "Original module should be a Linear layer"
assert type(r)==int, "r(rank) should be an integer"
assert type(alpha)==float, "alpha should be a float"
super().__init__()
self.alpha = alpha
self.original_module = orig_module
self.original_module.requires_grad = False
orig_in_features = orig_module.in_features
orig_out_features = orig_module.out_features
self.lora_module = nn.Sequential(nn.Linear(orig_in_features, r, bias=False), nn.Linear(r, orig_out_features, bias=False))
self.lora_module[0].weight = nn.Parameter(self.lora_module[0].weight)
self.lora_module[1].weight = nn.Parameter(self.lora_module[1].weight)
return
def forward(self, x, *args, **kwargs):
outs = self.original_module(x) + self.alpha*self.lora_module(x)
return outs
def set_alpha(self, new_alpha:float):
assert type(new_alpha)==float, "New alpha value must be a float"
self.alpha = new_alpha
return
def convert_model_to_lora_model(model:transformers.modeling_utils.PreTrainedModel):
for name, module in model.named_parameters():
module.requires_grad = False
if module.ndim==1:
module.data = module.data.to(torch.float32)
continue
names = name.split('.')[:-1]
module_pointer = model
module_pointer_parent = None
for layer in names:
module_pointer_parent = module_pointer
module_pointer = getattr(module_pointer, layer)
if type(module_pointer)==LoraModule:
break
if type(module_pointer)==LoraModule:
continue
if type(module_pointer)==nn.modules.linear.Linear:
lora_module = LoraModule(module_pointer)
setattr(module_pointer_parent, names[-1], lora_module)
return
def change_lora_alpha(model:transformers.modeling_utils.PreTrainedModel, new_alpha):
for name, module in model.named_parameters():
names = name.split('.')[:-1]
module_pointer = model
for layer in names:
module_pointer = getattr(module_pointer, layer)
if type(module_pointer)==LoraModule:
break
if type(module_pointer)==LoraModule:
module_pointer.set_alpha(new_alpha)
def dump_lora_weights(model:transformers.modeling_utils.PreTrainedModel, filename:str):
lora_weights = {}
for name, module in model.named_parameters():
names = name.split('.')[:-1]
module_pointer = model
for layer in names:
module_pointer = getattr(module_pointer, layer)
if type(module_pointer)==LoraModule:
break
if type(module_pointer)==LoraModule:
lora_weights[name] = module_pointer.lora_module
with open(filename, 'wb') as f:
pickle.dump(lora_weights, f)
def load_lora_weights(model:transformers.modeling_utils.PreTrainedModel, filename:str):
with open(filename, 'rb') as f:
lora_weights = pickle.load(f)
for module in lora_weights:
names = module.split('.')[:-2]
module_pointer = model
module_pointer_parent = None
for layer in names:
module_pointer_parent = module_pointer
module_pointer = getattr(module_pointer, layer)
if type(module_pointer)==nn.modules.linear.Linear:
lora_module = LoraModule(module_pointer)
lora_module.lora_module = lora_weights[module]
setattr(module_pointer_parent, names[-1], lora_module)
else:
continue