Skip to content

Commit

Permalink
Count token using github.com/pkoukk/tiktoken-go
Browse files Browse the repository at this point in the history
  • Loading branch information
leafduo committed May 4, 2023
1 parent 57d0e30 commit f40c69b
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 84 deletions.
65 changes: 32 additions & 33 deletions count_token.go
Original file line number Diff line number Diff line change
@@ -1,43 +1,42 @@
package main

import (
"log"
"sync"
"fmt"

tokenizer "github.com/samber/go-gpt-3-encoder"
"github.com/pkoukk/tiktoken-go"
"github.com/sashabaranov/go-openai"
)

var encoder *tokenizer.Encoder
var once sync.Once

func CountToken(msg string) (int, error) {
once.Do(func() {
var err error
encoder, err = tokenizer.NewEncoder()
if err != nil {
log.Fatal(err)
}
})

/**
The exact algorithm for counting tokens has been documented at
https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb. However, we
are not using that algorithm in this project because no Go library currently implements the ChatGPT algorithm. The
Go library github.com/samber/go-gpt-3-encoder does implement the encoding used by GPT-3(p50k_base), but not
ChatGPT(cl100k_base). Additionally, counting tokens is not a critical part of this project; we only require a rough
estimation of the token count.
Based on my naïve experiments, the token count is not significantly different when English text is tokenized.
However, there is a significant difference when tokenizing Chinese or Japanese text. cl100k_base generates far fewer
tokens than p50k_base when tokenizing Chinese or Japanese. Most Chinese characters are counted as 1 token in
cl100k_base, whereas in p50k_base, they are mostly counted as 2 tokens.
*/

encoded, err := encoder.Encode(msg)
func NumTokensFromMessages(messages []openai.ChatCompletionMessage, model string) (num_tokens int) {
tkm, err := tiktoken.EncodingForModel(model)
if err != nil {
return 0, err
err = fmt.Errorf("EncodingForModel: %v", err)
fmt.Println(err)
return
}

// 4 is the number of tokens added by the encoder
return len(encoded) + 4, nil
var tokens_per_message int
var tokens_per_name int
if model == "gpt-3.5-turbo-0301" || model == "gpt-3.5-turbo" {
tokens_per_message = 4
tokens_per_name = -1
} else if model == "gpt-4-0314" || model == "gpt-4" {
tokens_per_message = 3
tokens_per_name = 1
} else {
fmt.Println("Warning: model not found. Using cl100k_base encoding.")
tokens_per_message = 3
tokens_per_name = 1
}

for _, message := range messages {
num_tokens += tokens_per_message
num_tokens += len(tkm.Encode(message.Content, nil, nil))
num_tokens += len(tkm.Encode(message.Role, nil, nil))
if message.Name != "" {
num_tokens += tokens_per_name
}
}
num_tokens += 3
return num_tokens
}
5 changes: 4 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ require (
github.com/sourcegraph/conc v0.3.0
)

require github.com/google/uuid v1.3.0 // indirect

require (
github.com/dlclark/regexp2 v1.7.0 // indirect
github.com/dlclark/regexp2 v1.8.1 // indirect
github.com/pkoukk/tiktoken-go v0.1.1
github.com/samber/lo v1.37.0 // indirect
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
)
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@ github.com/caarlos0/env/v7 v7.0.0 h1:cyczlTd/zREwSr9ch/mwaDl7Hse7kJuUY8hvHfXu5WI
github.com/caarlos0/env/v7 v7.0.0/go.mod h1:LPPWniDUq4JaO6Q41vtlyikhMknqymCLBw0eX4dcH1E=
github.com/dlclark/regexp2 v1.7.0 h1:7lJfhqlPssTb1WQx4yvTHN0uElPEv52sbaECrAQxjAo=
github.com/dlclark/regexp2 v1.7.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0=
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 h1:wG8n/XJQ07TmjbITcGiUaOtXxdrINDz1b0J1w0SzqDc=
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1/go.mod h1:A2S0CWkNylc2phvKXWBBdD3K0iGnDBGbzRpISP2zBl8=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/pkoukk/tiktoken-go v0.1.1 h1:jtkYlIECjyM9OW1w4rjPmTohK4arORP9V25y6TM6nXo=
github.com/pkoukk/tiktoken-go v0.1.1/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
github.com/samber/go-gpt-3-encoder v0.3.1 h1:YWb9GsGYUgSX/wPtsEHjyNGRQXsQ9vDCg9SU2x9uMeU=
github.com/samber/go-gpt-3-encoder v0.3.1/go.mod h1:27nvdvk9ZtALyNtgs9JsPCMYja0Eleow/XzgjqwRtLU=
github.com/samber/lo v1.37.0 h1:XjVcB8g6tgUp8rsPsJ2CvhClfImrpL04YpQHXeHPhRw=
Expand Down
60 changes: 13 additions & 47 deletions gpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,7 @@ type GPT struct {
type UserState struct {
TelegramID int64
LastActiveTime time.Time
HistoryMessage []Message
}

type Message struct {
Role string
Content string
TokenCount int
}

func convertMessageToChatCompletionMessage(msg []Message) []openai.ChatCompletionMessage {
var result []openai.ChatCompletionMessage
for _, m := range msg {
result = append(result, openai.ChatCompletionMessage{
Role: m.Role,
Content: m.Content,
})
}
return result
HistoryMessage []openai.ChatCompletionMessage
}

func NewGPT() *GPT {
Expand Down Expand Up @@ -71,36 +54,33 @@ func (gpt *GPT) SendMessage(userID int64, msg string, answerChan chan<- string)
gpt.userState[userID] = &UserState{
TelegramID: userID,
LastActiveTime: time.Now(),
HistoryMessage: []Message{},
HistoryMessage: []openai.ChatCompletionMessage{},
}
}

user := gpt.userState[userID]

userTokenCount, err := CountToken(msg)
if err != nil {
log.Print(err)
}
user.HistoryMessage = append(user.HistoryMessage, Message{
Role: "user",
Content: msg,
TokenCount: userTokenCount,
user.HistoryMessage = append(user.HistoryMessage, openai.ChatCompletionMessage{
Role: "user",
Content: msg,
})
user.LastActiveTime = time.Now()

for NumTokensFromMessages(user.HistoryMessage, "gpt-3.5-turbo") > 3500 {
user.HistoryMessage = user.HistoryMessage[1:]
}

c := openai.NewClient(cfg.OpenAIAPIKey)
ctx := context.Background()

log.Print(user.HistoryMessage)

req := openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
Temperature: cfg.ModelTemperature,
TopP: 1,
N: 1,
// PresencePenalty: 0.2,
// FrequencyPenalty: 0.2,
Messages: convertMessageToChatCompletionMessage(user.HistoryMessage),
Messages: user.HistoryMessage,
Stream: true,
}

Expand All @@ -112,7 +92,6 @@ func (gpt *GPT) SendMessage(userID int64, msg string, answerChan chan<- string)
}

var currentAnswer string
var assistantTokenCount int

defer stream.Close()
for {
Expand All @@ -128,28 +107,15 @@ func (gpt *GPT) SendMessage(userID int64, msg string, answerChan chan<- string)
break
}

// It seems that OpenAI sends one token per event, so we can count the tokens by the number of events we
// receive.
assistantTokenCount++
currentAnswer += response.Choices[0].Delta.Content
answerChan <- currentAnswer
}

user.HistoryMessage = append(user.HistoryMessage, Message{
Role: "assistant",
Content: currentAnswer,
TokenCount: assistantTokenCount,
user.HistoryMessage = append(user.HistoryMessage, openai.ChatCompletionMessage{
Role: "assistant",
Content: currentAnswer,
})

var totalTokenCount int
for i := len(user.HistoryMessage) - 1; i >= 0; i-- {
totalTokenCount += user.HistoryMessage[i].TokenCount
if totalTokenCount > 3500 {
user.HistoryMessage = user.HistoryMessage[i+1:]
return true, nil
}
}

return false, nil
}

Expand Down
5 changes: 2 additions & 3 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,23 +148,22 @@ func main() {
close(throttledAnswerChan)
})
wg.Go(func() {
msg, err := bot.Send(tgbotapi.NewChatAction(userID, tgbotapi.ChatTyping))
_, err := bot.Send(tgbotapi.NewChatAction(userID, tgbotapi.ChatTyping))
if err != nil {
log.Print(err)
}

var messageID int

for currentAnswer := range throttledAnswerChan {
fmt.Print(messageID)
if messageID == 0 {
msg, err := bot.Send(tgbotapi.NewMessage(userID, currentAnswer))
if err != nil {
log.Print(err)
}
messageID = msg.MessageID
} else {
editedMsg := tgbotapi.NewEditMessageText(userID, msg.MessageID, currentAnswer)
editedMsg := tgbotapi.NewEditMessageText(userID, messageID, currentAnswer)
_, err := bot.Send(editedMsg)
if err != nil {
log.Print(err)
Expand Down

0 comments on commit f40c69b

Please sign in to comment.