From 9b7f6a1c354b623ab6583d81781334046b81bee4 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Sat, 9 Mar 2024 17:29:52 +0100 Subject: [PATCH] feat: add voyage embedd and rerank --- embedder/voyage/api.go | 79 ++++++++++ embedder/voyage/voyage.go | 66 ++++++++ examples/embeddings/voyage/main.go | 102 +++++++++++++ examples/transformer/voyage-rerank/main.go | 36 +++++ transformer/voyage-rerank.go | 170 +++++++++++++++++++++ 5 files changed, 453 insertions(+) create mode 100644 embedder/voyage/api.go create mode 100644 embedder/voyage/voyage.go create mode 100644 examples/embeddings/voyage/main.go create mode 100644 examples/transformer/voyage-rerank/main.go create mode 100644 transformer/voyage-rerank.go diff --git a/embedder/voyage/api.go b/embedder/voyage/api.go new file mode 100644 index 00000000..f5141d68 --- /dev/null +++ b/embedder/voyage/api.go @@ -0,0 +1,79 @@ +package voyageembedder + +import ( + "bytes" + "encoding/json" + "io" + + "github.com/henomis/lingoose/embedder" + "github.com/henomis/restclientgo" +) + +type request struct { + Model string `json:"model"` + Input []string `json:"input"` +} + +func (r *request) Path() (string, error) { + return "/embeddings", nil +} + +func (r *request) Encode() (io.Reader, error) { + jsonBytes, err := json.Marshal(r) + if err != nil { + return nil, err + } + + return bytes.NewReader(jsonBytes), nil +} + +func (r *request) ContentType() string { + return "application/json" +} + +type response struct { + HTTPStatusCode int `json:"-"` + acceptContentType string `json:"-"` + Object string `json:"object"` + Data []data `json:"data"` + Model string `json:"model"` + RawBody []byte `json:"-"` +} + +type data struct { + Object string `json:"object"` + Embedding embedder.Embedding `json:"embedding"` + Index int `json:"index"` +} + +func (r *response) SetAcceptContentType(contentType string) { + r.acceptContentType = contentType +} + +func (r *response) Decode(body io.Reader) error { + return json.NewDecoder(body).Decode(r) +} + +func (r *response) SetBody(body io.Reader) error { + b, err := io.ReadAll(body) + if err != nil { + return err + } + + r.RawBody = b + return nil +} + +func (r *response) AcceptContentType() string { + if r.acceptContentType != "" { + return r.acceptContentType + } + return "application/json" +} + +func (r *response) SetStatusCode(code int) error { + r.HTTPStatusCode = code + return nil +} + +func (r *response) SetHeaders(_ restclientgo.Headers) error { return nil } diff --git a/embedder/voyage/voyage.go b/embedder/voyage/voyage.go new file mode 100644 index 00000000..d1f3f7c6 --- /dev/null +++ b/embedder/voyage/voyage.go @@ -0,0 +1,66 @@ +package voyageembedder + +import ( + "context" + "net/http" + "os" + + "github.com/henomis/lingoose/embedder" + "github.com/henomis/restclientgo" +) + +const ( + defaultModel = "voyage-2" + defaultEndpoint = "https://api.voyageai.com/v1" +) + +type Embedder struct { + model string + restClient *restclientgo.RestClient +} + +func New() *Embedder { + apiKey := os.Getenv("VOYAGE_API_KEY") + + return &Embedder{ + restClient: restclientgo.New(defaultEndpoint).WithRequestModifier( + func(req *http.Request) *http.Request { + req.Header.Set("Authorization", "Bearer "+apiKey) + return req + }), + model: defaultModel, + } +} + +func (e *Embedder) WithModel(model string) *Embedder { + e.model = model + return e +} + +// Embed returns the embeddings for the given texts +func (e *Embedder) Embed(ctx context.Context, texts []string) ([]embedder.Embedding, error) { + return e.embed(ctx, texts) +} + +// Embed returns the embeddings for the given texts +func (e *Embedder) embed(ctx context.Context, text []string) ([]embedder.Embedding, error) { + resp := &response{} + err := e.restClient.Post( + ctx, + &request{ + Input: text, + Model: e.model, + }, + resp, + ) + if err != nil { + return nil, err + } + + embeddings := make([]embedder.Embedding, len(resp.Data)) + for i, data := range resp.Data { + embeddings[i] = data.Embedding + } + + return embeddings, nil +} diff --git a/examples/embeddings/voyage/main.go b/examples/embeddings/voyage/main.go new file mode 100644 index 00000000..68144045 --- /dev/null +++ b/examples/embeddings/voyage/main.go @@ -0,0 +1,102 @@ +package main + +import ( + "context" + "fmt" + + voyageembedder "github.com/henomis/lingoose/embedder/voyage" + "github.com/henomis/lingoose/index" + indexoption "github.com/henomis/lingoose/index/option" + "github.com/henomis/lingoose/index/vectordb/jsondb" + "github.com/henomis/lingoose/llm/antropic" + "github.com/henomis/lingoose/loader" + "github.com/henomis/lingoose/textsplitter" + "github.com/henomis/lingoose/thread" + "github.com/henomis/lingoose/types" +) + +// download https://raw.githubusercontent.com/hwchase17/chat-your-data/master/state_of_the_union.txt + +func main() { + + index := index.New( + jsondb.New().WithPersist("db.json"), + voyageembedder.New().WithModel("voyage-2"), + ).WithIncludeContents(true).WithAddDataCallback(func(data *index.Data) error { + data.Metadata["contentLen"] = len(data.Metadata["content"].(string)) + return nil + }) + + indexIsEmpty, _ := index.IsEmpty(context.Background()) + + if indexIsEmpty { + err := ingestData(index) + if err != nil { + panic(err) + } + } + + query := "What is the purpose of the NATO Alliance?" + similarities, err := index.Query( + context.Background(), + query, + indexoption.WithTopK(3), + ) + if err != nil { + panic(err) + } + + for _, similarity := range similarities { + fmt.Printf("Similarity: %f\n", similarity.Score) + fmt.Printf("Document: %s\n", similarity.Content()) + fmt.Println("Metadata: ", similarity.Metadata) + fmt.Println("----------") + } + + documentContext := "" + for _, similarity := range similarities { + documentContext += similarity.Content() + "\n\n" + } + + antropicllm := antropic.New().WithModel("claude-3-opus-20240229") + t := thread.New() + t.AddMessage(thread.NewUserMessage().AddContent( + thread.NewTextContent("Based on the following context answer to the" + + "question.\n\nContext:\n{{.context}}\n\nQuestion: {{.query}}").Format( + types.M{ + "query": query, + "context": documentContext, + }, + ), + )) + + err = antropicllm.Generate(context.Background(), t) + if err != nil { + panic(err) + } + + fmt.Println(t) +} + +func ingestData(index *index.Index) error { + + fmt.Printf("Ingesting data...") + + documents, err := loader.NewDirectoryLoader(".", ".txt").Load(context.Background()) + if err != nil { + return err + } + + textSplitter := textsplitter.NewRecursiveCharacterTextSplitter(1000, 20) + + documentChunks := textSplitter.SplitDocuments(documents) + + err = index.LoadFromDocuments(context.Background(), documentChunks) + if err != nil { + return err + } + + fmt.Printf("Done!\n") + + return nil +} diff --git a/examples/transformer/voyage-rerank/main.go b/examples/transformer/voyage-rerank/main.go new file mode 100644 index 00000000..d2ae7b91 --- /dev/null +++ b/examples/transformer/voyage-rerank/main.go @@ -0,0 +1,36 @@ +package main + +import ( + "context" + "fmt" + + "github.com/henomis/lingoose/document" + "github.com/henomis/lingoose/transformer" +) + +func main() { + + r := transformer.NewVoyageRerank() + + documents, err := r.Rerank( + context.Background(), + "What is the capital of the United States?", + []document.Document{ + { + Content: "Carson City is the capital city of the American state of Nevada.", + }, { + Content: "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.", + }, { + Content: "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.", + }, + }, + ) + if err != nil { + panic(err) + } + + for _, doc := range documents { + fmt.Println(doc.GetEnrichedContent()) + fmt.Println("-----") + } +} diff --git a/transformer/voyage-rerank.go b/transformer/voyage-rerank.go new file mode 100644 index 00000000..76cdaffc --- /dev/null +++ b/transformer/voyage-rerank.go @@ -0,0 +1,170 @@ +package transformer + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "os" + + "github.com/henomis/lingoose/document" + "github.com/henomis/lingoose/types" + "github.com/henomis/restclientgo" +) + +const ( + defaultVoyageRerankEndpoint = "https://api.voyageai.com/v1" + defaultVoyageRerankModel = "rerank-lite-1" + VoyageRerankScoreMetdataKey = "voyage-rerank-score" +) + +type VoyageRerank struct { + model string + restClient *restclientgo.RestClient +} + +func NewVoyageRerank() *VoyageRerank { + apiKey := os.Getenv("VOYAGE_API_KEY") + + return &VoyageRerank{ + restClient: restclientgo.New(defaultVoyageRerankEndpoint).WithRequestModifier( + func(req *http.Request) *http.Request { + req.Header.Set("Authorization", "Bearer "+apiKey) + return req + }), + model: defaultVoyageRerankModel, + } +} + +func (v *VoyageRerank) WithModel(model string) *VoyageRerank { + v.model = model + return v +} + +func (v *VoyageRerank) Rerank( + ctx context.Context, + query string, + documents []document.Document, +) ([]document.Document, error) { + resp := &voyageRerankResponse{} + err := v.restClient.Post( + ctx, + &voyageRerankRequest{ + Documents: v.documentsToStringSlice(documents), + Model: v.model, + Query: query, + }, + resp, + ) + if err != nil { + return nil, err + } + + return v.rerankDocuments(documents, resp.Data), nil +} + +func (v *VoyageRerank) rerankDocuments( + documents []document.Document, + results []voyageRerankResponseData, +) []document.Document { + rerankedDocuments := make([]document.Document, 0) + for _, result := range results { + index := result.Index + metadata := documents[index].Metadata + if metadata == nil { + metadata = make(types.Meta) + } + metadata[VoyageRerankScoreMetdataKey] = result.RelevanceScore + + rerankedDocuments = append( + rerankedDocuments, + document.Document{ + Content: documents[index].Content, + Metadata: metadata, + }, + ) + } + + return rerankedDocuments +} + +func (v *VoyageRerank) documentsToStringSlice(documents []document.Document) []string { + strings := make([]string, len(documents)) + for i, d := range documents { + strings[i] = d.Content + } + return strings +} + +// API + +type voyageRerankRequest struct { + Model string `json:"model"` + Documents []string `json:"documents"` + Query string `json:"query"` +} + +func (r *voyageRerankRequest) Path() (string, error) { + return "/rerank", nil +} + +func (r *voyageRerankRequest) Encode() (io.Reader, error) { + jsonBytes, err := json.Marshal(r) + if err != nil { + return nil, err + } + + return bytes.NewReader(jsonBytes), nil +} + +func (r *voyageRerankRequest) ContentType() string { + return "application/json" +} + +type voyageRerankResponse struct { + HTTPStatusCode int `json:"-"` + acceptContentType string `json:"-"` + Object string `json:"object"` + Data []voyageRerankResponseData `json:"data"` + Model string `json:"model"` + RawBody []byte `json:"-"` +} + +type voyageRerankResponseData struct { + Object string `json:"object"` + RelevanceScore float64 `json:"relevance_score"` + Index int `json:"index"` +} + +func (r *voyageRerankResponse) SetAcceptContentType(contentType string) { + r.acceptContentType = contentType +} + +func (r *voyageRerankResponse) Decode(body io.Reader) error { + return json.NewDecoder(body).Decode(r) +} + +func (r *voyageRerankResponse) SetBody(body io.Reader) error { + b, err := io.ReadAll(body) + if err != nil { + return err + } + + r.RawBody = b + return nil +} + +func (r *voyageRerankResponse) AcceptContentType() string { + if r.acceptContentType != "" { + return r.acceptContentType + } + return "application/json" +} + +func (r *voyageRerankResponse) SetStatusCode(code int) error { + r.HTTPStatusCode = code + return nil +} + +func (r *voyageRerankResponse) SetHeaders(_ restclientgo.Headers) error { return nil }