Skip to content

Commit

Permalink
Top k and temperature to prevent the same token repeating
Browse files Browse the repository at this point in the history
  • Loading branch information
jmaczan committed Jun 22, 2024
1 parent 06583df commit 2c75855
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions src/run.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import torch
from torch.nn import functional as F
from data_loader import get_tokenizer
from train import load_checkpoint
from gpt import GPT

default_max_output=100
default_temperature=1.0
default_top_k=10

def load_model(model_path, device):
checkpoint = torch.load(model_path, map_location=device)
Expand Down Expand Up @@ -36,7 +39,7 @@ def prepare_context(text, tokenizer, context_window):

return torch.tensor(tokens).unsqueeze(0)

def inference(prompt, model, tokenizer, context_window, max_output=default_max_output):
def inference(prompt, model, tokenizer, context_window, max_output=default_max_output, temperature=default_temperature, top_k=default_top_k):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.eval()
Expand All @@ -47,11 +50,15 @@ def inference(prompt, model, tokenizer, context_window, max_output=default_max_o
with torch.no_grad():
for _ in range(max_output):
outputs = model(context)
next_token = outputs[0, -1, :]
next_token = torch.argmax(next_token, dim=-1).item()
output.append(next_token)

print(output)
next_token = outputs[0, -1, :] / temperature

top_k_logits, top_k_indices = torch.topk(next_token, top_k)
top_k_probs = F.softmax(top_k_logits, dim=-1)
next_token = torch.multinomial(top_k_probs, 1).item()

output.append(top_k_indices[next_token].item())

print(tokenizer.decode(output, skip_special_tokens=True))

context = torch.cat([context, torch.tensor([[next_token]], device=device)], dim=1)
context = context[:, -context_window:]
Expand Down

0 comments on commit 2c75855

Please sign in to comment.