Skip to content

Commit

Permalink
Make example generation work when the model is torch.compile-d
Browse files Browse the repository at this point in the history
  • Loading branch information
IggShaman committed Aug 26, 2024
1 parent 6104ab1 commit dc054e5
Showing 1 changed file with 23 additions and 5 deletions.
28 changes: 23 additions & 5 deletions train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def get_most_likely_row(tokens, mask, logits):
model = GPT(GPTConfig(vocab_size=50304))
# model = GPT.from_pretrained("gpt2") # or init from OpenAI GPT-2
model.to(device)
use_compile = False # torch.compile interferes with HellaSwag eval and Generation. TODO fix
use_compile = True # torch.compile interferes with HellaSwag. TODO fix
if use_compile:
model = torch.compile(model)
if ddp:
Expand Down Expand Up @@ -444,23 +444,38 @@ def get_lr(it):
f.write(f"{step} hella {acc_norm:.4f}\n")

# once in a while generate from the model (except step 0, which is noise)
if ((step > 0 and step % 250 == 0) or last_step) and (not use_compile):
if True or (step > 0 and step % 250 == 0) or last_step:
model.eval()
num_return_sequences = 4
max_length = 32
tokens = enc.encode("Hello, I'm a language model,")
tokens = torch.tensor(tokens, dtype=torch.long)
tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)
at_idx = tokens.size(1)
pad_x_to = (T if use_compile else max_length) - tokens.size(1)
tokens = torch.cat((tokens, torch.zeros([tokens.size(0), pad_x_to], dtype=torch.long)), dim=1)

pad_x_to = (T if use_compile else max_length) - tokens.size(1)
tokens = torch.cat((tokens, torch.zeros([tokens.size(0), pad_x_to], dtype=torch.long)), dim=1)

# pad the y axis
if B != num_return_sequences and use_compile:
# When B is smaller than the number of examples we generate, we can simply
# loop over the example generation code.
assert num_return_sequences <= B, \
f"TODO: {num_return_sequences=} is > {B=}; add support for that"
tokens = F.pad(tokens, (0, 0, 0, B - num_return_sequences), 'constant', enc.eot_token)

xgen = tokens.to(device)
sample_rng = torch.Generator(device=device)
sample_rng.manual_seed(42 + ddp_rank)
while xgen.size(1) < max_length:
while at_idx < max_length:
# forward the model to get the logits
with torch.no_grad():
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
logits, loss = model(xgen) # (B, T, vocab_size)
# take the logits at the last position
logits = logits[:, -1, :] # (B, vocab_size)
logits = logits[:, at_idx-1, :] # (B, vocab_size)
# get the probabilities
probs = F.softmax(logits, dim=-1)
# do top-k sampling of 50 (huggingface pipeline default)
Expand All @@ -472,7 +487,10 @@ def get_lr(it):
# gather the corresponding indices
xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
# append to the sequence
xgen = torch.cat((xgen, xcol), dim=1)
xgen[:, at_idx:at_idx + 1] = xcol

at_idx += 1

# print the generated text
for i in range(num_return_sequences):
tokens = xgen[i, :max_length].tolist()
Expand Down

0 comments on commit dc054e5

Please sign in to comment.