diff --git a/globals/logger.go b/globals/logger.go index 6f1eeb51..70ea0149 100644 --- a/globals/logger.go +++ b/globals/logger.go @@ -2,10 +2,11 @@ package globals import ( "fmt" + "strings" + "github.com/natefinch/lumberjack" "github.com/sirupsen/logrus" "github.com/spf13/viper" - "strings" ) const DefaultLoggerFile = "chatnio.log" @@ -25,7 +26,7 @@ func (l *AppLogger) Format(entry *logrus.Entry) ([]byte, error) { ) if !viper.GetBool("log.ignore_console") { - fmt.Println(data) + fmt.Print(data) } return []byte(data), nil diff --git a/utils/buffer.go b/utils/buffer.go index 00e3e82e..99b409fe 100644 --- a/utils/buffer.go +++ b/utils/buffer.go @@ -197,11 +197,6 @@ func (b *Buffer) IsFunctionCalling() bool { return b.FunctionCall != nil || b.ToolCalls != nil } -func (b *Buffer) WriteBytes(data []byte) []byte { - b.Write(string(data)) - return data -} - func (b *Buffer) IsEmpty() bool { return b.Cursor == 0 && !b.IsFunctionCalling() } @@ -237,12 +232,12 @@ func (b *Buffer) SetInputTokens(tokens int) { b.InputTokens = tokens } -func (b *Buffer) CountInputToken() int { - return b.InputTokens +func (b *Buffer) CountOutputToken() int { + return b.ReadTimes() * GetWeightByModel(b.Model) } func (b *Buffer) CountOutputToken() int { - return b.ReadTimes() * GetWeightByModel(b.Model) + return b.CountInputToken() + b.CountOutputToken() } func (b *Buffer) CountToken() int { diff --git a/utils/tokenizer.go b/utils/tokenizer.go index de38c010..7784f7dd 100644 --- a/utils/tokenizer.go +++ b/utils/tokenizer.go @@ -3,8 +3,9 @@ package utils import ( "chat/globals" "fmt" - "github.com/pkoukk/tiktoken-go" "strings" + + "github.com/pkoukk/tiktoken-go" ) // Using https://github.com/pkoukk/tiktoken-go @@ -45,9 +46,10 @@ func GetWeightByModel(model string) int { } } } -func NumTokensFromMessages(messages []globals.Message, model string) (tokens int) { +func NumTokensFromMessages(messages []globals.Message, model string, responseType bool) (tokens int) { tokensPerMessage := GetWeightByModel(model) tkm, err := tiktoken.EncodingForModel(model) + if err != nil { // the method above was deprecated, use the recall method instead // can not encode messages, use length of messages as a proxy for number of tokens @@ -59,16 +61,20 @@ func NumTokensFromMessages(messages []globals.Message, model string) (tokens int if globals.DebugMode { globals.Debug(fmt.Sprintf("[tiktoken] error encoding messages: %s (model: %s), using default model instead", err, model)) } - return NumTokensFromMessages(messages, globals.GPT3Turbo0613) + return NumTokensFromMessages(messages, globals.GPT3Turbo0613, responseType) } for _, message := range messages { - tokens += - len(tkm.Encode(message.Content, nil, nil)) + - len(tkm.Encode(message.Role, nil, nil)) + - tokensPerMessage + tokens += len(tkm.Encode(message.Content, nil, nil)) + + if !responseType { + tokens += len(tkm.Encode(message.Role, nil, nil)) + tokensPerMessage + } + } + + if !responseType { + tokens += 3 // every reply is primed with <|start|>assistant<|message|> } - tokens += 3 // every reply is primed with <|start|>assistant<|message|> if globals.DebugMode { globals.Debug(fmt.Sprintf("[tiktoken] num tokens from messages: %d (tokens per message: %d, model: %s)", tokens, tokensPerMessage, model)) @@ -76,8 +82,8 @@ func NumTokensFromMessages(messages []globals.Message, model string) (tokens int return tokens } -func CountTokenPrice(messages []globals.Message, model string) int { - return NumTokensFromMessages(messages, model) * GetWeightByModel(model) +func NumTokensFromResponse(response string, model string) int { + return NumTokensFromMessages([]globals.Message{{Content: response}}, model, true) } func CountInputQuota(charge Charge, token int) float32 { @@ -88,10 +94,10 @@ func CountInputQuota(charge Charge, token int) float32 { return 0 } -func CountOutputToken(charge Charge, model string, token int) float32 { +func CountOutputToken(charge Charge, token int) float32 { switch charge.GetType() { case globals.TokenBilling: - return float32(token*GetWeightByModel(model)) / 1000 * charge.GetOutput() + return float32(token) / 1000 * charge.GetOutput() case globals.TimesBilling: return charge.GetOutput() default: