Skip to content

Commit

Permalink
feat: implement llm cache
Browse files Browse the repository at this point in the history
  • Loading branch information
henomis committed Sep 14, 2023
1 parent ec43e05 commit 8df7f75
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 11 deletions.
41 changes: 41 additions & 0 deletions examples/llm/cache/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package main

import (
"bufio"
"context"
"fmt"
"os"
"strings"

openaiembedder "github.com/henomis/lingoose/embedder/openai"
simplevectorindex "github.com/henomis/lingoose/index/simpleVectorIndex"
"github.com/henomis/lingoose/llm/cache"
"github.com/henomis/lingoose/llm/openai"
)

func main() {

embedder := openaiembedder.New(openaiembedder.AdaEmbeddingV2)
index := simplevectorindex.New("db", ".", embedder)
llm := openai.NewCompletion().WithCompletionCache(cache.New(embedder, index).WithTopK(3))

for {
text := askUserInput("What is your question?")

response, err := llm.Completion(context.Background(), text)
if err != nil {
fmt.Println(err)
continue
}

fmt.Println(response)
}
}

func askUserInput(question string) string {
fmt.Printf("%s > ", question)
reader := bufio.NewReader(os.Stdin)
name, _ := reader.ReadString('\n')
name = strings.TrimSuffix(name, "\n")
return name
}
22 changes: 11 additions & 11 deletions index/simpleVectorIndex/simpleVectorIndex.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import (
"fmt"
"math"
"os"
"strconv"
"strings"

"github.com/google/uuid"
"github.com/henomis/lingoose/document"
"github.com/henomis/lingoose/embedder"
"github.com/henomis/lingoose/index"
Expand Down Expand Up @@ -53,7 +53,6 @@ func (s *Index) LoadFromDocuments(ctx context.Context, documents []document.Docu
return fmt.Errorf("%s: %w", index.ErrInternal, err)
}

id := 0
for i := 0; i < len(documents); i += defaultBatchSize {

end := i + defaultBatchSize
Expand All @@ -72,8 +71,11 @@ func (s *Index) LoadFromDocuments(ctx context.Context, documents []document.Docu
}

for j, document := range documents[i:end] {
s.data = append(s.data, buildDataFromEmbeddingAndDocument(id, embeddings[j], document))
id++
id, err := uuid.NewUUID()
if err != nil {
return err
}
s.data = append(s.data, buildDataFromEmbeddingAndDocument(id.String(), embeddings[j], document))
}

}
Expand All @@ -87,14 +89,14 @@ func (s *Index) LoadFromDocuments(ctx context.Context, documents []document.Docu
}

func buildDataFromEmbeddingAndDocument(
id int,
id string,
embedding embedder.Embedding,
document document.Document,
) data {
metadata := index.DeepCopyMetadata(document.Metadata)
metadata[index.DefaultKeyContent] = document.Content
return data{
ID: fmt.Sprintf("%d", id),
ID: id,
Values: embedding,
Metadata: metadata,
}
Expand Down Expand Up @@ -148,13 +150,11 @@ func (s *Index) Add(ctx context.Context, item *index.Data) error {
}

if item.ID == "" {
lastID := s.data[len(s.data)-1].ID
lastIDAsInt, err := strconv.Atoi(lastID)
id, err := uuid.NewUUID()
if err != nil {
return fmt.Errorf("%s: %w", index.ErrInternal, err)
return err
}

item.ID = fmt.Sprintf("%d", lastIDAsInt+1)
item.ID = id.String()
}

s.data = append(
Expand Down
108 changes: 108 additions & 0 deletions llm/cache/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package cache

import (
"context"
"fmt"

"github.com/henomis/lingoose/embedder"
"github.com/henomis/lingoose/index"
indexoption "github.com/henomis/lingoose/index/option"
"github.com/henomis/lingoose/types"
)

type Embedder interface {
Embed(ctx context.Context, texts []string) ([]embedder.Embedding, error)
}

type Index interface {
Search(context.Context, []float64, ...indexoption.Option) (index.SearchResults, error)
Add(context.Context, *index.Data) error
}

var ErrCacheMiss = fmt.Errorf("cache miss")

const (
defaultTopK = 1
defaultScoreThreshold = 0.9
cacheAnswerMetadataKey = "cache-answer"
)

type Cache struct {
embedder Embedder
index Index
topK int
scoreThreshold float64
}

type CacheResult struct {
Answer []string
Embedding []float64
}

func New(embedder Embedder, index Index) *Cache {
return &Cache{
embedder: embedder,
index: index,
topK: defaultTopK,
scoreThreshold: defaultScoreThreshold,
}
}

func (c *Cache) WithTopK(topK int) *Cache {
c.topK = topK
return c
}

func (c *Cache) WithScoreThreshold(scoreThreshold float64) *Cache {
c.scoreThreshold = scoreThreshold
return c
}

func (c *Cache) Get(ctx context.Context, query string) (*CacheResult, error) {

embedding, err := c.embedder.Embed(ctx, []string{query})
if err != nil {
return nil, err
}

results, err := c.index.Search(ctx, embedding[0], indexoption.WithTopK(c.topK))
if err != nil {
return nil, err
}

answers, cacheHit := c.extractResults(results)
if cacheHit {
return &CacheResult{
Answer: answers,
Embedding: embedding[0],
}, nil
}

return nil, ErrCacheMiss
}

func (c *Cache) Set(ctx context.Context, embedding []float64, answer string) error {
return c.index.Add(ctx, &index.Data{
Values: embedding,
Metadata: types.Meta{
cacheAnswerMetadataKey: answer,
},
})
}

func (c *Cache) extractResults(results index.SearchResults) ([]string, bool) {
var output []string

for _, result := range results {
if result.Score > c.scoreThreshold {
answer, ok := result.Metadata[cacheAnswerMetadataKey]
if !ok {
continue
}

output = append(output, answer.(string))
}
}

return output, len(output) > 0
}
27 changes: 27 additions & 0 deletions llm/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strings"

"github.com/henomis/lingoose/chat"
"github.com/henomis/lingoose/llm/cache"
"github.com/henomis/lingoose/types"
"github.com/mitchellh/mapstructure"
"github.com/sashabaranov/go-openai"
Expand Down Expand Up @@ -71,6 +72,7 @@ type OpenAI struct {
functionsMaxIterations uint
calledFunctionName *string
finishReason string
cache *cache.Cache
}

func New(model Model, temperature float32, maxTokens int, verbose bool) *OpenAI {
Expand Down Expand Up @@ -130,6 +132,12 @@ func (o *OpenAI) WithVerbose(verbose bool) *OpenAI {
return o
}

// WithCache sets the cache to use for the OpenAI instance.
func (o *OpenAI) WithCompletionCache(cache *cache.Cache) *OpenAI {
o.cache = cache
return o
}

// CalledFunctionName returns the name of the function that was called.
func (o *OpenAI) CalledFunctionName() *string {
return o.calledFunctionName
Expand Down Expand Up @@ -160,11 +168,30 @@ func NewChat() *OpenAI {

// Completion returns a single completion for the given prompt.
func (o *OpenAI) Completion(ctx context.Context, prompt string) (string, error) {
var cacheResult *cache.CacheResult
var err error

if o.cache != nil {
cacheResult, err = o.cache.Get(ctx, prompt)
if err == nil {
return strings.Join(cacheResult.Answer, "\n"), nil
} else if err != cache.ErrCacheMiss {
return "", fmt.Errorf("%s: %w", ErrOpenAICompletion, err)
}
}

outputs, err := o.BatchCompletion(ctx, []string{prompt})
if err != nil {
return "", err
}

if o.cache != nil {
err = o.cache.Set(ctx, cacheResult.Embedding, outputs[0])
if err != nil {
return "", fmt.Errorf("%s: %w", ErrOpenAICompletion, err)
}
}

return outputs[0], nil
}

Expand Down

0 comments on commit 8df7f75

Please sign in to comment.