-
Notifications
You must be signed in to change notification settings - Fork 7
/
model.py
263 lines (224 loc) · 10.6 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
# ** minGPT; version dffb6a1 -- <https://github.com/karpathy/minGPT/tree/feature/lightning>
# The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy
"""
GPT model:
- the initial stem consists of a combination of token encoding and a positional encoding
- the meat of it is a uniform sequence of Transformer blocks
- each Transformer is a sequential combination of a 1-hidden-layer MLP block and a
self-attention block
- all blocks feed into a central residual pathway similar to resnets
- the final decoder is a linear projection into a vanilla Softmax classifier
Based on: upstream/feature/lightning
"""
import logging
import math
from typing import Tuple
import pytorch_lightning as pl
import torch
import torch.nn as nn
from torch.nn import functional as F
from ...utils import force_assert
logger = logging.getLogger(__name__)
class CausalSelfAttention(nn.Module):
"""
A vanilla multi-head masked self-attention layer with a projection at the end.
It is possible to use torch.nn.MultiheadAttention here but I am including an
explicit implementation here to show that there is nothing too scary here.
"""
def __init__(self, n_embd, block_size, n_head, attn_pdrop, resid_pdrop):
super().__init__()
force_assert(
n_embd % n_head == 0,
(
f"(n_embd={n_embd}) % (n_head={n_head}) != 0. "
"Please correct these model hyperparameters, as required by the self-attention "
"transformer module. For more information, see "
"https://pytorch.org/docs/stable/generated/"
"torch.nn.MultiheadAttention.html#torch.nn.MultiheadAttention, or the paper "
'"Attention is all you need" (https://arxiv.org/abs/1706.03762).'
),
)
self.n_head = n_head
# key, query, value projections for all heads
self.key = nn.Linear(n_embd, n_embd)
self.query = nn.Linear(n_embd, n_embd)
self.value = nn.Linear(n_embd, n_embd)
# regularization
self.attn_drop = nn.Dropout(attn_pdrop)
self.resid_drop = nn.Dropout(resid_pdrop)
# output projection
self.proj = nn.Linear(n_embd, n_embd)
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer(
"mask",
torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size),
)
def forward(self, x, layer_past=None):
B, T, C = x.size()
# calculate query, key, values for all heads in batch and move head forward to be the batch
# dim
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = (
self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
) # (B, nh, T, hs)
v = (
self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
) # (B, nh, T, hs)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
att = F.softmax(att, dim=-1)
att = self.attn_drop(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = (
y.transpose(1, 2).contiguous().view(B, T, C)
) # re-assemble all head outputs side by side
# output projection
y = self.resid_drop(self.proj(y))
return y
class Block(nn.Module):
"""an unassuming Transformer block"""
def __init__(self, n_embd, block_size, n_head, attn_pdrop, resid_pdrop):
super().__init__()
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
self.attn = CausalSelfAttention(n_embd, block_size, n_head, attn_pdrop, resid_pdrop)
self.mlp = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.GELU(),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(resid_pdrop),
)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class LightGPT(pl.LightningModule):
"""the full GPT language model, with a context size of block_size"""
def __init__(
self,
# model definition args
vocab_size: int, # size of the vocabulary (number of possible tokens)
block_size: int, # length of the model's context window in time
n_layer: int, # depth of the model; number of Transformer blocks in sequence
n_embd: int, # the "width" of the model, number of channels in each Transformer
n_head: int, # number of heads in each multi-head attention inside each Transformer block
# model optimization args
learning_rate: float = 3e-4, # the base learning rate of the model
weight_decay: float = 0.1, # amount of regularizing L2 weight decay on MatMul ops
betas: Tuple[float, float] = (0.9, 0.95), # momentum terms (betas) for the Adam optimizer
embd_pdrop: float = 0.1, # \in [0,1]: amount of dropout on input embeddings
resid_pdrop: float = 0.1, # \in [0,1]: amount of dropout in each residual connection
attn_pdrop: float = 0.1, # \in [0,1]: amount of dropout on the attention matrix
):
super().__init__()
# save these for optimizer init later
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.betas = betas
# input embedding stem: drop(content + position)
self.tok_emb = nn.Embedding(vocab_size, n_embd)
self.pos_emb = nn.Parameter(torch.zeros(1, block_size, n_embd))
self.drop = nn.Dropout(embd_pdrop)
# deep transformer: just a sequence of transformer blocks
self.blocks = nn.Sequential(
*[Block(n_embd, block_size, n_head, attn_pdrop, resid_pdrop) for _ in range(n_layer)]
)
# decoder: at the end one more layernorm and decode the answers
self.ln_f = nn.LayerNorm(n_embd)
self.head = nn.Linear(
n_embd, vocab_size, bias=False
) # no need for extra bias due to one in ln_f
self.block_size = block_size
self.apply(self._init_weights)
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
def get_block_size(self):
return self.block_size
def _init_weights(self, module):
r"""
Vanilla model initialization:
- all MatMul weights \in N(0, 0.02) and biases to zero
- all LayerNorm post-normalization scaling set to identity, so weight=1, bias=0
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def configure_optimizers(self):
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will
experience weight decay for regularization and those that won't (biases, and
layernorm/embedding weights). We are then returning the PyTorch optimizer object.
"""
# separate out all parameters to those that will and won't experience regularizing weight
# decay
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear,)
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
for pn, p in m.named_parameters():
fpn = f"{mn}.{pn}" if mn else pn # full param name
if pn.endswith("bias"):
# all biases will not be decayed
no_decay.add(fpn)
elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
# special case the position embedding parameter in the root GPT module as not decayed
no_decay.add("pos_emb")
# validate that we considered every parameter
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
force_assert(
len(inter_params) == 0,
f"parameters {inter_params} made it into both decay/no_decay sets!",
)
force_assert(
len(param_dict.keys() - union_params) == 0,
(
f"parameters {param_dict.keys() - union_params} were not separated into either "
"decay/no_decay set!"
),
)
# create the pytorch optimizer object
optim_groups = [
{
"params": [param_dict[pn] for pn in sorted(list(decay))],
"weight_decay": self.weight_decay,
},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=self.betas)
return optimizer
def forward(self, idx):
b, t = idx.size()
force_assert(t <= self.block_size, "Cannot forward, model block size is exhausted.")
# forward the GPT model
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
x = self.drop(token_embeddings + position_embeddings)
x = self.blocks(x)
x = self.ln_f(x)
logits = self.head(x)
return logits
def step_(self, split, batch, batch_idx=None):
idx, targets = batch
logits = self(idx)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
self.logger.experiment.add_scalars("loss", {split: loss}, self.global_step)
return {"loss": loss}
def training_step(self, *args, **kwargs):
return self.step_("train", *args, **kwargs)
def validation_step(self, *args, **kwargs):
return self.step_("val", *args, **kwargs)
def test_step(self, *args, **kwargs):
return self.step_("test", *args, **kwargs)