-
Notifications
You must be signed in to change notification settings - Fork 38
/
sparsify_ckpt_unsharded.py
70 lines (66 loc) · 3.13 KB
/
sparsify_ckpt_unsharded.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""
1. Unshard ckpt using `python /home/niklas/OLMoE/scripts/unshard.py /data/niklas/llm/checkpoints/23485/step954000 /data/niklas/llm/checkpoints/1b-954000-unsharded --safe-tensors --model-only`
2. Run this script via `python /home/niklas/OLMoE/scripts/sparsify_ckpt_unsharded.py /data/niklas/llm/checkpoints/1b-954000-unsharded/model.safetensors`
"""
import copy
import sys
import torch
from olmo.safetensors_util import safetensors_file_to_state_dict, state_dict_to_safetensors_file
path = sys.argv[1]
sd = safetensors_file_to_state_dict(path)
tensors = {}
swiglu = True
noise = False
share = False
interleave = False
n_experts = 8
D = 2048
def noise_injection(weight, noise_ratio=0.5, init_std=0.02):
mask = torch.FloatTensor(weight.size()).uniform_() < noise_ratio
mask = mask.to(weight.device)
rand_weight = torch.nn.init.normal_(copy.deepcopy(weight), mean=0.0, std=init_std)
weight[mask] = rand_weight[mask]
return weight
for key in list(sd.keys()):
if "ff_proj.weight" in key:
block_num = int(key.split(".")[2])
if interleave and block_num % 2 == 0:
tensors[key] = sd.pop(key)
continue
new_key = key.replace("ff_proj.weight", "ffn.experts.mlp.w1")
if swiglu:
new_key_v1 = new_key.replace("w1", "v1")
# OLMo takes the F.silu on the second part of the tensor which corresponds to v1
v1, w1 = sd.pop(key).chunk(2, dim=0) # e.g. [16384, 2048]
tensors[new_key] = torch.cat([w1] * n_experts, dim=0)
tensors[new_key_v1] = torch.cat([v1] * n_experts, dim=0)
if noise:
tensors[new_key] = noise_injection(tensors[new_key])
tensors[new_key_v1] = noise_injection(tensors[new_key_v1])
if share:
share_key = new_key.replace("experts.mlp.w1", "shared_expert.up_proj.weight")
share_key_v1 = new_key_v1.replace("experts.mlp.v1", "shared_expert.gate_proj.weight")
tensors[share_key] = w1
tensors[share_key_v1] = v1
else:
tensors[new_key] = torch.cat([sd.pop(key)] * n_experts, dim=0)
elif ("ff_out.weight" in key) and (key != 'transformer.ff_out.weight'):
block_num = int(key.split(".")[2])
if interleave and block_num % 2 == 0:
tensors[key] = sd.pop(key)
continue
new_key = key.replace("ff_out.weight", "ffn.experts.mlp.w2")
w = sd.pop(key)
tensors[new_key] = torch.cat([w.t()] * n_experts, dim=0)
if noise:
tensors[new_key] = noise_injection(tensors[new_key])
if share:
share_key = new_key.replace("experts.mlp.w2", "shared_expert.down_proj.weight")
tensors[share_key] = w
# Add router
router_key = key.replace("ff_out.weight", "ffn.router.layer.weight")
# tensors[router_key] = torch.ones((n_experts, D)).squeeze() # Worse perf
tensors[router_key] = torch.nn.init.normal_(torch.ones((n_experts, D)).squeeze(), std=0.02)
else:
tensors[key] = sd.pop(key)
state_dict_to_safetensors_file(tensors, path.replace("model.safetensors", "model_sparse.safetensors"))