Skip to content

Commit

Permalink
Estimate token count using github.com/samber/go-gpt-3-encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
leafduo committed May 4, 2023
1 parent 72ab396 commit 57d0e30
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 43 deletions.
43 changes: 43 additions & 0 deletions count_token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package main

import (
"log"
"sync"

tokenizer "github.com/samber/go-gpt-3-encoder"
)

var encoder *tokenizer.Encoder
var once sync.Once

func CountToken(msg string) (int, error) {
once.Do(func() {
var err error
encoder, err = tokenizer.NewEncoder()
if err != nil {
log.Fatal(err)
}
})

/**
The exact algorithm for counting tokens has been documented at
https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb. However, we
are not using that algorithm in this project because no Go library currently implements the ChatGPT algorithm. The
Go library github.com/samber/go-gpt-3-encoder does implement the encoding used by GPT-3(p50k_base), but not
ChatGPT(cl100k_base). Additionally, counting tokens is not a critical part of this project; we only require a rough
estimation of the token count.
Based on my naïve experiments, the token count is not significantly different when English text is tokenized.
However, there is a significant difference when tokenizing Chinese or Japanese text. cl100k_base generates far fewer
tokens than p50k_base when tokenizing Chinese or Japanese. Most Chinese characters are counted as 1 token in
cl100k_base, whereas in p50k_base, they are mostly counted as 2 tokens.
*/

encoded, err := encoder.Encode(msg)
if err != nil {
return 0, err
}

// 4 is the number of tokens added by the encoder
return len(encoded) + 4, nil
}
7 changes: 7 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ go 1.20
require (
github.com/caarlos0/env/v7 v7.0.0
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1
github.com/samber/go-gpt-3-encoder v0.3.1
github.com/sashabaranov/go-openai v1.4.2
github.com/sourcegraph/conc v0.3.0
)

require (
github.com/dlclark/regexp2 v1.7.0 // indirect
github.com/samber/lo v1.37.0 // indirect
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
)
8 changes: 8 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
github.com/caarlos0/env/v7 v7.0.0 h1:cyczlTd/zREwSr9ch/mwaDl7Hse7kJuUY8hvHfXu5WI=
github.com/caarlos0/env/v7 v7.0.0/go.mod h1:LPPWniDUq4JaO6Q41vtlyikhMknqymCLBw0eX4dcH1E=
github.com/dlclark/regexp2 v1.7.0 h1:7lJfhqlPssTb1WQx4yvTHN0uElPEv52sbaECrAQxjAo=
github.com/dlclark/regexp2 v1.7.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 h1:wG8n/XJQ07TmjbITcGiUaOtXxdrINDz1b0J1w0SzqDc=
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1/go.mod h1:A2S0CWkNylc2phvKXWBBdD3K0iGnDBGbzRpISP2zBl8=
github.com/samber/go-gpt-3-encoder v0.3.1 h1:YWb9GsGYUgSX/wPtsEHjyNGRQXsQ9vDCg9SU2x9uMeU=
github.com/samber/go-gpt-3-encoder v0.3.1/go.mod h1:27nvdvk9ZtALyNtgs9JsPCMYja0Eleow/XzgjqwRtLU=
github.com/samber/lo v1.37.0 h1:XjVcB8g6tgUp8rsPsJ2CvhClfImrpL04YpQHXeHPhRw=
github.com/samber/lo v1.37.0/go.mod h1:9vaz2O4o8oOnK23pd2TrXufcbdbJIa3b6cstBWKpopA=
github.com/sashabaranov/go-openai v1.4.2 h1:IhacPY7O+ljlBoZRQe9VpsLNm0b4PHa6fOBGA9O4vfc=
github.com/sashabaranov/go-openai v1.4.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
71 changes: 47 additions & 24 deletions gpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,24 @@ type GPT struct {
type UserState struct {
TelegramID int64
LastActiveTime time.Time
HistoryMessage []openai.ChatCompletionMessage
HistoryMessage []Message
}

type Message struct {
Role string
Content string
TokenCount int
}

func convertMessageToChatCompletionMessage(msg []Message) []openai.ChatCompletionMessage {
var result []openai.ChatCompletionMessage
for _, m := range msg {
result = append(result, openai.ChatCompletionMessage{
Role: m.Role,
Content: m.Content,
})
}
return result
}

func NewGPT() *GPT {
Expand Down Expand Up @@ -47,22 +64,27 @@ func NewGPT() *GPT {
return gpt
}

func (gpt *GPT) SendMessage(userID int64, msg string, answerChan chan<- string) error {
func (gpt *GPT) SendMessage(userID int64, msg string, answerChan chan<- string) (bool, error) {
gpt.clearUserContextIfExpires(userID)

if _, ok := gpt.userState[userID]; !ok {
gpt.userState[userID] = &UserState{
TelegramID: userID,
LastActiveTime: time.Now(),
HistoryMessage: []openai.ChatCompletionMessage{},
HistoryMessage: []Message{},
}
}

user := gpt.userState[userID]

user.HistoryMessage = append(user.HistoryMessage, openai.ChatCompletionMessage{
Role: "user",
Content: msg,
userTokenCount, err := CountToken(msg)
if err != nil {
log.Print(err)
}
user.HistoryMessage = append(user.HistoryMessage, Message{
Role: "user",
Content: msg,
TokenCount: userTokenCount,
})
user.LastActiveTime = time.Now()

Expand All @@ -78,18 +100,19 @@ func (gpt *GPT) SendMessage(userID int64, msg string, answerChan chan<- string)
N: 1,
// PresencePenalty: 0.2,
// FrequencyPenalty: 0.2,
Messages: user.HistoryMessage,
Messages: convertMessageToChatCompletionMessage(user.HistoryMessage),
Stream: true,
}

stream, err := c.CreateChatCompletionStream(ctx, req)
if err != nil {
log.Print(err)
user.HistoryMessage = user.HistoryMessage[:len(user.HistoryMessage)-1]
return err
return false, err
}

var currentAnswer string
var assistantTokenCount int

defer stream.Close()
for {
Expand All @@ -105,29 +128,29 @@ func (gpt *GPT) SendMessage(userID int64, msg string, answerChan chan<- string)
break
}

fmt.Printf("%+v\n", response)
// It seems that OpenAI sends one token per event, so we can count the tokens by the number of events we
// receive.
assistantTokenCount++
currentAnswer += response.Choices[0].Delta.Content
answerChan <- currentAnswer
}

user.HistoryMessage = append(user.HistoryMessage, openai.ChatCompletionMessage{
Role: "assistant",
Content: currentAnswer,
user.HistoryMessage = append(user.HistoryMessage, Message{
Role: "assistant",
Content: currentAnswer,
TokenCount: assistantTokenCount,
})

return nil

// answer := resp.Choices[0].Message

// users[userID].HistoryMessage = append(users[userID].HistoryMessage, answer)

// var contextTrimmed bool
// if resp.Usage.TotalTokens > 3500 {
// users[userID].HistoryMessage = users[userID].HistoryMessage[1:]
// contextTrimmed = true
// }
var totalTokenCount int
for i := len(user.HistoryMessage) - 1; i >= 0; i-- {
totalTokenCount += user.HistoryMessage[i].TokenCount
if totalTokenCount > 3500 {
user.HistoryMessage = user.HistoryMessage[i+1:]
return true, nil
}
}

// return answer.Content, contextTrimmed, nil
return false, nil
}

func (gpt *GPT) clearUserContextIfExpires(userID int64) bool {
Expand Down
50 changes: 31 additions & 19 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,24 +110,36 @@ func main() {

wg := conc.NewWaitGroup()
wg.Go(func() {
err := gpt.SendMessage(userID, msg, answerChan)
contextTrimmed, err := gpt.SendMessage(userID, msg, answerChan)
if err != nil {
log.Print(err)

_, err = bot.Send(tgbotapi.NewMessage(userID, err.Error()))
if err != nil {
log.Print(err)
}

return
}

if contextTrimmed {
msg := tgbotapi.NewMessage(userID, "Context trimmed.")
_, err = bot.Send(msg)
if err != nil {
log.Print(err)
}
}
})
wg.Go(func() {
lastUpdateTime := time.Now()
var currentAnswer string
for answer := range answerChan {
currentAnswer = answer
// Limit message to 1 message per second
// Update message every 2.5 seconds to avoid hitting Telegram API limits. In the documentation,
// Although the documentation states that the limit is one message per second, in practice, it is
// still rate-limited.
// https://core.telegram.org/bots/faq#my-bot-is-hitting-limits-how-do-i-avoid-this
if lastUpdateTime.Add(time.Second + time.Millisecond).Before(time.Now()) {
if lastUpdateTime.Add(time.Duration(2500) * time.Millisecond).Before(time.Now()) {
throttledAnswerChan <- currentAnswer
lastUpdateTime = time.Now()
}
Expand All @@ -136,32 +148,32 @@ func main() {
close(throttledAnswerChan)
})
wg.Go(func() {
msg, err := bot.Send(tgbotapi.NewMessage(userID, "Generating..."))
msg, err := bot.Send(tgbotapi.NewChatAction(userID, tgbotapi.ChatTyping))
if err != nil {
log.Print(err)
return
}

var messageID int

for currentAnswer := range throttledAnswerChan {
editedMsg := tgbotapi.NewEditMessageText(userID, msg.MessageID, currentAnswer)
_, err := bot.Send(editedMsg)
if err != nil {
log.Print(err)
fmt.Print(messageID)
if messageID == 0 {
msg, err := bot.Send(tgbotapi.NewMessage(userID, currentAnswer))
if err != nil {
log.Print(err)
}
messageID = msg.MessageID
} else {
editedMsg := tgbotapi.NewEditMessageText(userID, msg.MessageID, currentAnswer)
_, err := bot.Send(editedMsg)
if err != nil {
log.Print(err)
}
}
}
})

wg.Wait()

// TODO: count tokens
// if contextTrimmed {
// msg := tgbotapi.NewMessage(update.Message.Chat.ID, "Context trimmed.")
// msg.DisableNotification = true
// err = send(bot, msg)
// if err != nil {
// log.Print(err)
// }
// }
}
}
}

0 comments on commit 57d0e30

Please sign in to comment.