-
Notifications
You must be signed in to change notification settings - Fork 1
/
evo_prune_search.py
343 lines (315 loc) · 13.2 KB
/
evo_prune_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
import argparse
import random
import copy
import os
from tqdm import trange
from typing import List, Optional
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
try:
import wandb
has_wandb = True
except ModuleNotFoundError:
has_wandb = False
from src.data_utils import get_data
from src.common_utils import fix_seed
from src.metrics import compute_perplexity, compute_kl_div
def load_layers(model: AutoModelForCausalLM, layer_names: List[str], new_state: List[int], sparse_weights_path: str):
assert hasattr(model, "state")
for layer_name, new_level, old_level in zip(layer_names, new_state, model.state):
if new_level != old_level:
layer = model.get_submodule(layer_name)
layer.weight.data = torch.load(
os.path.join(sparse_weights_path, layer_name, f"{new_level}.pth"), map_location=layer.weight.device
).to(layer.weight.dtype)
# Update model state
model.state = new_state
def compute_fitness(model, data, fitness_fn, target_logits: Optional[torch.Tensor] = None) -> float:
if fitness_fn == "ppl":
return compute_perplexity(model, data)
else:
return compute_kl_div(model, data, target_logits)
def selection(
model,
layer_names,
sparse_weights_path: str,
candidates,
num_survive: int,
calibration_data,
num_tokens: int,
fitness_fn: str = "ppl",
target_logits: Optional[List[torch.Tensor]] = None,
):
calibration_minibatch = []
minibatch_ids = []
target_logits_minibatch = []
tokens_used = 0
while tokens_used < num_tokens: # generate minibatch with exactly num_tokens tokens
minibatch_id = random.randint(0, len(calibration_data) - 1)
if minibatch_id in minibatch_ids: # avoid duplicates
continue
minibatch_ids.append(minibatch_id)
if tokens_used + calibration_data[minibatch_id].shape[1] > num_tokens:
calibration_minibatch.append(calibration_data[minibatch_id][:, : num_tokens - tokens_used])
if fitness_fn == "kl":
target_logits_minibatch.append(target_logits[minibatch_id][:, : num_tokens - tokens_used])
tokens_used = num_tokens
else:
calibration_minibatch.append(calibration_data[minibatch_id])
if fitness_fn == "kl":
target_logits_minibatch.append(target_logits[minibatch_id])
tokens_used += calibration_data[minibatch_id].shape[1]
if len(target_logits_minibatch) == 0:
target_logits_minibatch = None
fitnesses = []
for candidate in candidates:
load_layers(model, layer_names, candidate, sparse_weights_path)
fitness = compute_fitness(model, calibration_minibatch, fitness_fn, target_logits_minibatch)
fitnesses.append(fitness)
# Keep only best
best_ids = np.argsort(fitnesses)[:num_survive]
return [candidates[i] for i in best_ids], [fitnesses[i] for i in best_ids]
def parse_args():
parser = argparse.ArgumentParser()
# Model params
parser.add_argument(
"--model_name_or_path",
required=True,
type=str,
help="The name or path to the model being pruned",
)
parser.add_argument(
"--tokenizer_name",
type=str,
default=None,
help="The name or path to the tokenizer. By default use model tokenizer.",
)
# Data params
parser.add_argument(
"--calibration_data",
type=str,
required=True,
help="The name or dataset or path used for calibration.",
)
parser.add_argument("--calibration_tokens", required=True, type=int, help="Number of tokens for calibration.")
parser.add_argument(
"--calibration_sequence_length", default=None, type=int, help="Length of calibration sequences."
)
parser.add_argument(
"--eval_datasets",
nargs="+",
type=str,
default=["fineweb_edu", "wikitext2", "c4"],
help="Datasets used for evaluation",
)
parser.add_argument("--eval_every", default=1, type=int, help="Eval every # generations.")
parser.add_argument("--eval_tokens", default=524288, type=int, help="Number of tokens for evaluation.")
parser.add_argument("--eval_sequence_length", default=None, type=int, help="Length of evaluation sequences.")
parser.add_argument("--fitness_fn", choices=["ppl", "kl"], default="kl", help="Fitness function.")
parser.add_argument("--max_level", default=99999, type=int, help="Max admissible level.")
parser.add_argument(
"--max_total_deviation",
default=99999,
type=int,
help="Max admissible total deviation (sum of absolute differences to uniform pruning).",
)
# Logging params
parser.add_argument("--log_wandb", default=False, action="store_true", help="Whether to log to W&B")
# Evolutionary Search params
parser.add_argument("--generations", type=int, required=True, help="Number of generations in evolutionary search")
parser.add_argument("--offspring", type=int, required=True, help="Number of offspring generated per parent")
parser.add_argument("--sparse_weights_path", type=str, required=True, help="Path to sparse weights")
parser.add_argument(
"--survivors_per_selection",
type=int,
nargs="+",
required=True,
help="Number of survivors after each stage of selection",
)
parser.add_argument(
"--tokens_per_selection",
type=int,
nargs="+",
required=True,
help="Number of calibration tokens at each stage of selection",
)
# Misc params
parser.add_argument(
"--dtype",
type=str,
default="auto",
choices=["auto", "float16", "float32", "bfloat16"],
help="dtype to load the model.",
)
parser.add_argument("--seed", default=0, type=int, help="Random seed.")
parser.add_argument(
"--attn_implementation",
type=str,
default=None,
choices=["eager", "sdpa", "flash_attention_2"],
help="Attention implementation for both teacher and student models: eager, sdpa, or flash_attention_2",
)
parser.add_argument(
"--memory_efficient", action="store_true", help="Whether to use memory efficient implementation."
)
parser.add_argument("--use_fast_tokenizer", action="store_true", help="Whether to use fast tokenizer.")
# Save params
parser.add_argument(
"--configuration_name", type=str, default="final_configuration.txt", help="Name of final configuration"
)
args = parser.parse_args()
return args
def main():
args = parse_args()
# Fix seed
fix_seed(args.seed)
# Init W&B logger
if args.log_wandb:
assert has_wandb, "`wandb` not installed, try pip install `wandb`"
wandb.init(config=args)
# init device
device = f"cuda"
if args.dtype != "auto":
args.dtype = getattr(torch, args.dtype)
# Load model
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
device_map=None if args.memory_efficient else "auto",
low_cpu_mem_usage=True,
torch_dtype=args.dtype,
attn_implementation=args.attn_implementation,
)
model.config.use_cache = False # do not use cache
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name or args.model_name_or_path, use_fast=args.use_fast_tokenizer
)
# Load calibration data
args.calibration_sequence_length = args.calibration_sequence_length or min(
model.config.max_position_embeddings, 8192
)
calibration_data = get_data(
args.calibration_data,
args.calibration_tokens,
args.calibration_sequence_length,
tokenizer,
train=True,
)
# Load eval datasets
args.eval_sequence_length = args.eval_sequence_length or min(model.config.max_position_embeddings, 8192)
eval_datasets = []
for eval_dataset_name in args.eval_datasets:
eval_datasets.append(
get_data(
eval_dataset_name,
args.eval_tokens, # ignored for WikiText2 and C4
args.eval_sequence_length,
tokenizer,
train=False,
)
)
target_logits = []
if args.fitness_fn == "kl":
# Compute target logits (calibration)
for i in trange(0, len(calibration_data), desc="Computing target logits (calib)", leave=False):
with torch.no_grad():
target_logits.append(model(calibration_data[i].to(device)).logits.cpu())
# Prepare layers and initial state
layer_names = []
for layer_name in sorted(os.listdir(args.sparse_weights_path)):
if os.path.isdir(os.path.join(args.sparse_weights_path, layer_name)):
layer_names.append(layer_name)
parent = [0 for _ in layer_names]
model.state = [None] * len(layer_names)
train_fitness = float("inf")
log_dict = {}
for generation in range(args.generations):
print(f"Generation {generation + 1}/{args.generations}")
print(f"Current search point: {parent}")
print(f"Train fitness: {train_fitness:.2e}")
load_layers(model, layer_names, parent, args.sparse_weights_path)
# Evaluate current search point
if generation % args.eval_every == 0:
for eval_dataset_name, eval_dataset in zip(args.eval_datasets, eval_datasets):
if eval_dataset_name != "fineweb_edu":
continue
ppl_eval = compute_perplexity(model, eval_dataset)
print(f"{eval_dataset_name}: {ppl_eval:.2f}")
log_dict[f"ppl_eval/{eval_dataset_name}"] = ppl_eval
ppl_train = compute_perplexity(model, calibration_data)
print(f"ppl_train: {ppl_train:.2f}")
log_dict["ppl_train"] = ppl_train
if args.log_wandb:
wandb.log(log_dict)
offspring_list = []
while len(offspring_list) < args.offspring:
offspring = copy.deepcopy(parent)
# mutate offspring
num_flips = min(random.randint(1, 3), random.randint(1, 3)) # bias towards lower values
for _ in range(num_flips):
# positions where sparsity can be decreased
while True:
decr_id = random.randint(0, len(offspring) - 1)
layer_name = layer_names[decr_id]
level = offspring[decr_id]
if abs(level - 1) > args.max_level:
continue
if os.path.exists(os.path.join(args.sparse_weights_path, layer_name, f"{level - 1}.pth")):
break
# positions where sparsity can be increased
while True:
incr_id = random.randint(0, len(offspring) - 1)
layer_name = layer_names[incr_id]
level = offspring[incr_id]
if abs(level + 1) > args.max_level:
continue
if os.path.exists(os.path.join(args.sparse_weights_path, layer_name, f"{level + 1}.pth")):
break
offspring[decr_id] -= 1
offspring[incr_id] += 1
# avoid duplicates
if offspring in offspring_list:
continue
# skip if total deviation exceeds specified threshold
if sum(map(abs, offspring)) > args.max_total_deviation:
continue
offspring_list.append(offspring)
for num_survive, num_tokens in zip(args.survivors_per_selection, args.tokens_per_selection):
if num_survive == args.survivors_per_selection[-1]:
if parent not in offspring_list: # Elitist EA
offspring_list.append(parent)
offspring_list, train_fitnesses = selection(
model=model,
layer_names=layer_names,
sparse_weights_path=args.sparse_weights_path,
candidates=offspring_list,
num_survive=num_survive,
calibration_data=calibration_data,
num_tokens=num_tokens,
fitness_fn=args.fitness_fn,
target_logits=target_logits,
)
# In the end we have lists with a single element (only 1 survivor in last selection step)
train_fitness = train_fitnesses[0]
parent = offspring_list[0]
print(f"Train fitnesses: {train_fitness:.2e}")
log_dict["train_fitness"] = train_fitness
# Save final configuration
with open(os.path.join(args.sparse_weights_path, args.configuration_name), "w") as f:
f.write("\n".join([f"{layer_name}: {level}" for layer_name, level in zip(layer_names, parent)]))
# Log final configuration
print("Final configuration:")
print(parent)
# Final evaluation
for eval_dataset_name, eval_dataset in zip(args.eval_datasets, eval_datasets):
ppl_eval = compute_perplexity(model, eval_dataset)
print(f"{eval_dataset_name}: {ppl_eval:.2f}")
log_dict[f"ppl_eval/{eval_dataset_name}"] = ppl_eval
ppl_train = compute_perplexity(model, calibration_data)
print(f"ppl_train: {ppl_train:.2f}")
log_dict["ppl_train"] = ppl_train
if args.log_wandb:
wandb.log(log_dict)
if __name__ == "__main__":
main()