Skip to content

Commit

Permalink
feat(genai): add tokenizer package (#10699)
Browse files Browse the repository at this point in the history
  • Loading branch information
eliben authored Aug 20, 2024
1 parent f5833e6 commit 214af16
Show file tree
Hide file tree
Showing 3 changed files with 312 additions and 0 deletions.
38 changes: 38 additions & 0 deletions vertexai/genai/tokenizer/example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tokenizer_test

import (
"fmt"
"log"

"cloud.google.com/go/vertexai/genai"
"cloud.google.com/go/vertexai/genai/tokenizer"
)

func ExampleTokenizer_CountTokens() {
tok, err := tokenizer.New("gemini-1.5-flash")
if err != nil {
log.Fatal(err)
}

ntoks, err := tok.CountTokens(genai.Text("a prompt"), genai.Text("another prompt"))
if err != nil {
log.Fatal(err)
}

fmt.Println("total token count:", ntoks.TotalTokens)

// Output: total token count: 4
}
156 changes: 156 additions & 0 deletions vertexai/genai/tokenizer/tokenizer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package tokenizer provides local token counting for Gemini models. This
// tokenizer downloads its model from the web, but otherwise doesn't require
// an API call for every CountTokens invocation.
package tokenizer

import (
"bytes"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net/http"
"os"
"path/filepath"

"cloud.google.com/go/vertexai/genai"
"cloud.google.com/go/vertexai/internal/sentencepiece"
)

var supportedModels = map[string]bool{
"gemini-1.0-pro": true,
"gemini-1.5-pro": true,
"gemini-1.5-flash": true,
"gemini-1.0-pro-001": true,
"gemini-1.0-pro-002": true,
"gemini-1.5-pro-001": true,
"gemini-1.5-flash-001": true,
}

// Tokenizer is a local tokenizer for text.
type Tokenizer struct {
encoder *sentencepiece.Encoder
}

// CountTokensResponse is the response of [Tokenizer.CountTokens].
type CountTokensResponse struct {
TotalTokens int32
}

// New creates a new [Tokenizer] from a model name; the model name is the same
// as you would pass to a [genai.Client.GenerativeModel].
func New(modelName string) (*Tokenizer, error) {
if !supportedModels[modelName] {
return nil, fmt.Errorf("model %s is not supported", modelName)
}

data, err := loadModelData(gemmaModelURL, gemmaModelHash)
if err != nil {
return nil, fmt.Errorf("loading model: %w", err)
}

encoder, err := sentencepiece.NewEncoder(bytes.NewReader(data))
if err != nil {
return nil, fmt.Errorf("creating encoder: %w", err)
}

return &Tokenizer{encoder: encoder}, nil
}

// CountTokens counts the tokens in all the given parts and returns their
// sum. Only [genai.Text] parts are suppored; an error will be returned if
// non-text parts are provided.
func (tok *Tokenizer) CountTokens(parts ...genai.Part) (*CountTokensResponse, error) {
sum := 0

for _, part := range parts {
if t, ok := part.(genai.Text); ok {
toks := tok.encoder.Encode(string(t))
sum += len(toks)
} else {
return nil, fmt.Errorf("Tokenizer.CountTokens only supports Text parts")
}
}

return &CountTokensResponse{TotalTokens: int32(sum)}, nil
}

// gemmaModelURL is the URL from which we download the model file.
const gemmaModelURL = "https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model"

// gemmaModelHash is the expected hash of the model file (as calculated
// by [hashString]).
const gemmaModelHash = "61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2"

// downloadModelFile downloads a file from the given URL.
func downloadModelFile(url string) ([]byte, error) {
resp, err := http.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()

return io.ReadAll(resp.Body)
}

// hashString computes a hex string of the SHA256 hash of data.
func hashString(data []byte) string {
hash256 := sha256.Sum256(data)
return hex.EncodeToString(hash256[:])
}

// loadModelData loads model data from the given URL, using a local file-system
// cache. wantHash is the hash (as returned by [hashString] expected on the
// loaded data.
//
// Caching logic:
//
// Assuming $TEMP_DIR is the temporary directory used by the OS, this function
// uses the file $TEMP_DIR/vertexai_tokenizer_model/$urlhash as a cache, where
// $urlhash is hashString(url).
//
// If this cache file doesn't exist, or the data it contains doesn't match
// wantHash, downloads data from the URL and writes it into the cache. If the
// URL's data doesn't match the hash, an error is returned.
func loadModelData(url string, wantHash string) ([]byte, error) {
urlhash := hashString([]byte(url))
cacheDir := filepath.Join(os.TempDir(), "vertexai_tokenizer_model")
cachePath := filepath.Join(cacheDir, urlhash)

cacheData, err := os.ReadFile(cachePath)
if err != nil || hashString(cacheData) != wantHash {
cacheData, err = downloadModelFile(url)
if err != nil {
return nil, fmt.Errorf("loading cache and downloading model: %w", err)
}

if hashString(cacheData) != wantHash {
return nil, fmt.Errorf("downloaded model hash mismatch")
}

err = os.MkdirAll(cacheDir, 0770)
if err != nil {
return nil, fmt.Errorf("creating cache dir: %w", err)
}
err = os.WriteFile(cachePath, cacheData, 0660)
if err != nil {
return nil, fmt.Errorf("writing cache file: %w", err)
}
}

return cacheData, nil
}
118 changes: 118 additions & 0 deletions vertexai/genai/tokenizer/tokenizer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tokenizer

import (
"fmt"
"os"
"path/filepath"
"testing"

"cloud.google.com/go/vertexai/genai"
)

func TestDownload(t *testing.T) {
b, err := downloadModelFile(gemmaModelURL)
if err != nil {
t.Fatal(err)
}

if hashString(b) != gemmaModelHash {
t.Errorf("gemma model hash doesn't match")
}
}

func TestLoadModelData(t *testing.T) {
// Tests that loadModelData manages to load the model properly, and download
// a new one as needed.
checkDataAndErr := func(data []byte, err error) {
t.Helper()
if err != nil {
t.Error(err)
}
gotHash := hashString(data)
if gotHash != gemmaModelHash {
t.Errorf("got hash=%v, want=%v", gotHash, gemmaModelHash)
}
}

data, err := loadModelData(gemmaModelURL, gemmaModelHash)
checkDataAndErr(data, err)

// The cache should exist now and have the right data, try again.
data, err = loadModelData(gemmaModelURL, gemmaModelHash)
checkDataAndErr(data, err)

// Overwrite cache file with wrong data, and try again.
cacheDir := filepath.Join(os.TempDir(), "vertexai_tokenizer_model")
cachePath := filepath.Join(cacheDir, hashString([]byte(gemmaModelURL)))
_ = os.MkdirAll(cacheDir, 0770)
_ = os.WriteFile(cachePath, []byte{0, 1, 2, 3}, 0660)
data, err = loadModelData(gemmaModelURL, gemmaModelHash)
checkDataAndErr(data, err)
}

func TestCreateTokenizer(t *testing.T) {
// Create a tokenizer successfully
_, err := New("gemini-1.5-flash")
if err != nil {
t.Error(err)
}

// Create a tokenizer with an unsupported model
_, err = New("gemini-0.92")
if err == nil {
t.Errorf("got no error, want error")
}
}

func TestCountTokens(t *testing.T) {
var tests = []struct {
parts []genai.Part
wantCount int32
}{
{[]genai.Part{genai.Text("hello world")}, 2},
{[]genai.Part{genai.Text("<table><th></th></table>")}, 4},
{[]genai.Part{genai.Text("hello world"), genai.Text("<table><th></th></table>")}, 6},
}

tok, err := New("gemini-1.5-flash")
if err != nil {
t.Error(err)
}

for i, tt := range tests {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
got, err := tok.CountTokens(tt.parts...)
if err != nil {
t.Error(err)
}
if got.TotalTokens != tt.wantCount {
t.Errorf("got %v, want %v", got.TotalTokens, tt.wantCount)
}
})
}
}

func TestCountTokensNonText(t *testing.T) {
tok, err := New("gemini-1.5-flash")
if err != nil {
t.Error(err)
}

_, err = tok.CountTokens(genai.Text("foo"), genai.ImageData("format", []byte{0, 1}))
if err == nil {
t.Error("got no error, want error")
}
}

0 comments on commit 214af16

Please sign in to comment.