From 8df7f75390fd970865c84b75e657ff07c7770f91 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Thu, 14 Sep 2023 10:09:03 +0200 Subject: [PATCH] feat: implement llm cache --- examples/llm/cache/main.go | 41 +++++++ index/simpleVectorIndex/simpleVectorIndex.go | 22 ++-- llm/cache/cache.go | 108 +++++++++++++++++++ llm/openai/openai.go | 27 +++++ 4 files changed, 187 insertions(+), 11 deletions(-) create mode 100644 examples/llm/cache/main.go create mode 100644 llm/cache/cache.go diff --git a/examples/llm/cache/main.go b/examples/llm/cache/main.go new file mode 100644 index 00000000..b6f6d475 --- /dev/null +++ b/examples/llm/cache/main.go @@ -0,0 +1,41 @@ +package main + +import ( + "bufio" + "context" + "fmt" + "os" + "strings" + + openaiembedder "github.com/henomis/lingoose/embedder/openai" + simplevectorindex "github.com/henomis/lingoose/index/simpleVectorIndex" + "github.com/henomis/lingoose/llm/cache" + "github.com/henomis/lingoose/llm/openai" +) + +func main() { + + embedder := openaiembedder.New(openaiembedder.AdaEmbeddingV2) + index := simplevectorindex.New("db", ".", embedder) + llm := openai.NewCompletion().WithCompletionCache(cache.New(embedder, index).WithTopK(3)) + + for { + text := askUserInput("What is your question?") + + response, err := llm.Completion(context.Background(), text) + if err != nil { + fmt.Println(err) + continue + } + + fmt.Println(response) + } +} + +func askUserInput(question string) string { + fmt.Printf("%s > ", question) + reader := bufio.NewReader(os.Stdin) + name, _ := reader.ReadString('\n') + name = strings.TrimSuffix(name, "\n") + return name +} diff --git a/index/simpleVectorIndex/simpleVectorIndex.go b/index/simpleVectorIndex/simpleVectorIndex.go index eb145519..621e0739 100644 --- a/index/simpleVectorIndex/simpleVectorIndex.go +++ b/index/simpleVectorIndex/simpleVectorIndex.go @@ -6,9 +6,9 @@ import ( "fmt" "math" "os" - "strconv" "strings" + "github.com/google/uuid" "github.com/henomis/lingoose/document" "github.com/henomis/lingoose/embedder" "github.com/henomis/lingoose/index" @@ -53,7 +53,6 @@ func (s *Index) LoadFromDocuments(ctx context.Context, documents []document.Docu return fmt.Errorf("%s: %w", index.ErrInternal, err) } - id := 0 for i := 0; i < len(documents); i += defaultBatchSize { end := i + defaultBatchSize @@ -72,8 +71,11 @@ func (s *Index) LoadFromDocuments(ctx context.Context, documents []document.Docu } for j, document := range documents[i:end] { - s.data = append(s.data, buildDataFromEmbeddingAndDocument(id, embeddings[j], document)) - id++ + id, err := uuid.NewUUID() + if err != nil { + return err + } + s.data = append(s.data, buildDataFromEmbeddingAndDocument(id.String(), embeddings[j], document)) } } @@ -87,14 +89,14 @@ func (s *Index) LoadFromDocuments(ctx context.Context, documents []document.Docu } func buildDataFromEmbeddingAndDocument( - id int, + id string, embedding embedder.Embedding, document document.Document, ) data { metadata := index.DeepCopyMetadata(document.Metadata) metadata[index.DefaultKeyContent] = document.Content return data{ - ID: fmt.Sprintf("%d", id), + ID: id, Values: embedding, Metadata: metadata, } @@ -148,13 +150,11 @@ func (s *Index) Add(ctx context.Context, item *index.Data) error { } if item.ID == "" { - lastID := s.data[len(s.data)-1].ID - lastIDAsInt, err := strconv.Atoi(lastID) + id, err := uuid.NewUUID() if err != nil { - return fmt.Errorf("%s: %w", index.ErrInternal, err) + return err } - - item.ID = fmt.Sprintf("%d", lastIDAsInt+1) + item.ID = id.String() } s.data = append( diff --git a/llm/cache/cache.go b/llm/cache/cache.go new file mode 100644 index 00000000..4782d690 --- /dev/null +++ b/llm/cache/cache.go @@ -0,0 +1,108 @@ +package cache + +import ( + "context" + "fmt" + + "github.com/henomis/lingoose/embedder" + "github.com/henomis/lingoose/index" + indexoption "github.com/henomis/lingoose/index/option" + "github.com/henomis/lingoose/types" +) + +type Embedder interface { + Embed(ctx context.Context, texts []string) ([]embedder.Embedding, error) +} + +type Index interface { + Search(context.Context, []float64, ...indexoption.Option) (index.SearchResults, error) + Add(context.Context, *index.Data) error +} + +var ErrCacheMiss = fmt.Errorf("cache miss") + +const ( + defaultTopK = 1 + defaultScoreThreshold = 0.9 + cacheAnswerMetadataKey = "cache-answer" +) + +type Cache struct { + embedder Embedder + index Index + topK int + scoreThreshold float64 +} + +type CacheResult struct { + Answer []string + Embedding []float64 +} + +func New(embedder Embedder, index Index) *Cache { + return &Cache{ + embedder: embedder, + index: index, + topK: defaultTopK, + scoreThreshold: defaultScoreThreshold, + } +} + +func (c *Cache) WithTopK(topK int) *Cache { + c.topK = topK + return c +} + +func (c *Cache) WithScoreThreshold(scoreThreshold float64) *Cache { + c.scoreThreshold = scoreThreshold + return c +} + +func (c *Cache) Get(ctx context.Context, query string) (*CacheResult, error) { + + embedding, err := c.embedder.Embed(ctx, []string{query}) + if err != nil { + return nil, err + } + + results, err := c.index.Search(ctx, embedding[0], indexoption.WithTopK(c.topK)) + if err != nil { + return nil, err + } + + answers, cacheHit := c.extractResults(results) + if cacheHit { + return &CacheResult{ + Answer: answers, + Embedding: embedding[0], + }, nil + } + + return nil, ErrCacheMiss +} + +func (c *Cache) Set(ctx context.Context, embedding []float64, answer string) error { + return c.index.Add(ctx, &index.Data{ + Values: embedding, + Metadata: types.Meta{ + cacheAnswerMetadataKey: answer, + }, + }) +} + +func (c *Cache) extractResults(results index.SearchResults) ([]string, bool) { + var output []string + + for _, result := range results { + if result.Score > c.scoreThreshold { + answer, ok := result.Metadata[cacheAnswerMetadataKey] + if !ok { + continue + } + + output = append(output, answer.(string)) + } + } + + return output, len(output) > 0 +} diff --git a/llm/openai/openai.go b/llm/openai/openai.go index 7219cc7a..0297a873 100644 --- a/llm/openai/openai.go +++ b/llm/openai/openai.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/henomis/lingoose/chat" + "github.com/henomis/lingoose/llm/cache" "github.com/henomis/lingoose/types" "github.com/mitchellh/mapstructure" "github.com/sashabaranov/go-openai" @@ -71,6 +72,7 @@ type OpenAI struct { functionsMaxIterations uint calledFunctionName *string finishReason string + cache *cache.Cache } func New(model Model, temperature float32, maxTokens int, verbose bool) *OpenAI { @@ -130,6 +132,12 @@ func (o *OpenAI) WithVerbose(verbose bool) *OpenAI { return o } +// WithCache sets the cache to use for the OpenAI instance. +func (o *OpenAI) WithCompletionCache(cache *cache.Cache) *OpenAI { + o.cache = cache + return o +} + // CalledFunctionName returns the name of the function that was called. func (o *OpenAI) CalledFunctionName() *string { return o.calledFunctionName @@ -160,11 +168,30 @@ func NewChat() *OpenAI { // Completion returns a single completion for the given prompt. func (o *OpenAI) Completion(ctx context.Context, prompt string) (string, error) { + var cacheResult *cache.CacheResult + var err error + + if o.cache != nil { + cacheResult, err = o.cache.Get(ctx, prompt) + if err == nil { + return strings.Join(cacheResult.Answer, "\n"), nil + } else if err != cache.ErrCacheMiss { + return "", fmt.Errorf("%s: %w", ErrOpenAICompletion, err) + } + } + outputs, err := o.BatchCompletion(ctx, []string{prompt}) if err != nil { return "", err } + if o.cache != nil { + err = o.cache.Set(ctx, cacheResult.Embedding, outputs[0]) + if err != nil { + return "", fmt.Errorf("%s: %w", ErrOpenAICompletion, err) + } + } + return outputs[0], nil }