This repository has been archived by the owner on Jan 12, 2024. It is now read-only.
forked from simran-arora/cs229s-nanoGPT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_part2_a.py
126 lines (114 loc) · 5.02 KB
/
run_part2_a.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
from contextlib import nullcontext
import numpy as np
import time
import torch
import json
import argparse
from model import GPT
from utils import L2PruningHandler
import tiktoken
from tqdm import tqdm
# -----------------------------------------------------------------------------
prune_method = 'l2norm' # 'l2norm' or 'individual'
# -----------------------------------------------------------------------------
batch_size = 8
block_size = 1024
iterations = 100
real_data = True
seed = 1337
gradient_accumulation_steps = 40
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
compile = False # use PyTorch 2.0 to compile the model to be faster
profile = False # use pytorch profiler, or just simple benchmarking?
exec(open('configurator.py').read()) # overrides from command line or config file
# -----------------------------------------------------------------------------
curr_time = time.strftime("%m%d-%H%M%S")
log_file = open(f'part2a_{prune_method}_{curr_time}.log', 'a')
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
# data loading init
if real_data:
dataset = 'wikitext'
data_dir = os.path.join('data', dataset)
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
start = 0
def get_batch(split="train"):
global start
if split == "train":
data = train_data
elif split == "val":
data = val_data
else:
raise ValueError(f"Invalid split: {split}")
if split == "train":
ix = torch.randint(len(data) - block_size, (batch_size,))
elif split == "val":
ix = torch.arange(start, start + batch_size * block_size, block_size)
start += batch_size * block_size
else:
raise ValueError(f"Invalid split: {split}")
x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
return x, y
else:
# alternatively, if fixed data is desired to not care about data loading
x = torch.randint(50304, (batch_size, block_size), device=device)
y = torch.randint(50304, (batch_size, block_size), device=device)
get_batch = lambda split: (x, y)
model = GPT.from_pretrained("gpt2-medium")
if prune_method == 'l2norm':
handler = L2PruningHandler(model)
handler.handle()
val_steps = (len(val_data)-1) // (batch_size * block_size) - 1
model.to(device)
optimizer = model.configure_optimizers(weight_decay=1e-2, learning_rate=1e-4, betas=(0.9, 0.95), device_type=device_type)
for sparsity in np.arange(1.0, 0.0, -0.1):
train_loss = []
model.train()
model.prune_weight(sparsity=sparsity, method=prune_method)
prev_time = time.time()
for num_steps in tqdm(range(iterations)):
optimizer.zero_grad(set_to_none=True)
for _ in range(gradient_accumulation_steps):
X, Y = get_batch('train')
with ctx:
logits, loss = model(X, Y)
loss.backward()
train_loss.append(loss.item())
model.prune_grad(method=prune_method)
optimizer.step()
curr_time = time.time()
print(f"sparsity: {sparsity:.2f}, training time per iteration: {(curr_time - prev_time) / iterations:.4f}", file=log_file, flush=True)
lossf = np.mean(train_loss)
print(f"sparsity: {sparsity:.2f}, training loss: {lossf:.4f}", file=log_file, flush=True)
# validation
model.eval()
val_loss = []
time_lst = []
with torch.no_grad():
start = 0
# precompute weight for l2norm
model.mask_weight()
torch.cuda.synchronize()
for k in tqdm(range(val_steps)):
X, Y = get_batch('val')
prev_time = time.time()
with ctx:
logits, loss = model(X, Y)
curr_time = time.time()
time_lst.append(curr_time - prev_time)
val_loss.append(loss.item())
val_lossf = np.mean(val_loss)
num_token_per_sec = batch_size * block_size / np.mean(time_lst)
print(f"sparsity: {sparsity:.2f}, validation number of tokens per second: {num_token_per_sec:.4f}", file=log_file, flush=True)
print(f"sparsity: {sparsity:.2f}, validation loss: {val_lossf:.4f}", file=log_file, flush=True)
log_file.close()