Skip to content

Commit

Permalink
Feat Implement index retriever (#130)
Browse files Browse the repository at this point in the history
* feat: implement index retriever

* docs: update README

* docs: fix

* fix interfaces

* fix retriever
  • Loading branch information
henomis authored Sep 13, 2023
1 parent 15a40a1 commit ec43e05
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 13 deletions.
10 changes: 3 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ import (
"context"

openaiembedder "github.com/henomis/lingoose/embedder/openai"
indexoption "github.com/henomis/lingoose/index/option"
"github.com/henomis/lingoose/index/retriever"
simplevectorindex "github.com/henomis/lingoose/index/simpleVectorIndex"
"github.com/henomis/lingoose/llm/openai"
"github.com/henomis/lingoose/loader"
Expand All @@ -67,16 +67,12 @@ import (
)

func main() {
query := "What is the NATO purpose?"
docs, _ := loader.NewPDFToTextLoader("./kb").WithTextSplitter(textsplitter.NewRecursiveCharacterTextSplitter(2000, 200)).Load(context.Background())
openaiEmbedder := openaiembedder.New(openaiembedder.AdaEmbeddingV2)
simplevectorindex.New("db", ".", openaiEmbedder).LoadFromDocuments(context.Background(), docs)
results, _ := simplevectorindex.New("db", ".", openaiEmbedder).Query(context.Background(), query, indexoption.WithTopK(3))
qapipeline.New(openai.NewChat().WithVerbose(true)).Run(context.Background(), query, results.ToDocuments())
qapipeline.New(openai.NewChat().WithVerbose(true)).WithRetriever(retriever.New(simplevectorindex.New("db", ".", openaiembedder.New(openaiembedder.AdaEmbeddingV2)), &docs)).Query(context.Background(), "What is the NATO purpose?")
}
```

This is the _famous_ 6-lines **lingoose** knowledge base chatbot. 🤖
This is the _famous_ 2-lines **lingoose** knowledge base chatbot. 🤖

# Installation

Expand Down
8 changes: 2 additions & 6 deletions examples/embeddings/simplekb/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"context"

openaiembedder "github.com/henomis/lingoose/embedder/openai"
indexoption "github.com/henomis/lingoose/index/option"
"github.com/henomis/lingoose/index/retriever"
simplevectorindex "github.com/henomis/lingoose/index/simpleVectorIndex"
"github.com/henomis/lingoose/llm/openai"
"github.com/henomis/lingoose/loader"
Expand All @@ -13,10 +13,6 @@ import (
)

func main() {
query := "What is the NATO purpose?"
docs, _ := loader.NewPDFToTextLoader("./kb").WithTextSplitter(textsplitter.NewRecursiveCharacterTextSplitter(2000, 200)).Load(context.Background())
openaiEmbedder := openaiembedder.New(openaiembedder.AdaEmbeddingV2)
simplevectorindex.New("db", ".", openaiEmbedder).LoadFromDocuments(context.Background(), docs)
results, _ := simplevectorindex.New("db", ".", openaiEmbedder).Query(context.Background(), query, indexoption.WithTopK(3))
qapipeline.New(openai.NewChat().WithVerbose(true)).Run(context.Background(), query, results.ToDocuments())
qapipeline.New(openai.NewChat().WithVerbose(true)).WithRetriever(retriever.New(simplevectorindex.New("db", ".", openaiembedder.New(openaiembedder.AdaEmbeddingV2)), &docs)).Query(context.Background(), "What is the NATO purpose?")
}
61 changes: 61 additions & 0 deletions index/retriever/retriever.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package retriever

import (
"context"
"fmt"

"github.com/henomis/lingoose/document"
"github.com/henomis/lingoose/index"
indexoption "github.com/henomis/lingoose/index/option"
)

type Index interface {
LoadFromDocuments(context.Context, []document.Document) error
Query(context.Context, string, ...indexoption.Option) (index.SearchResults, error)
}

const (
defautTopK = 3
)

type Retriever struct {
index Index
topK int
documents *[]document.Document
areDocumentsLoaded bool
}

func New(index Index, documents *[]document.Document) *Retriever {
return &Retriever{
index: index,
topK: defautTopK,
documents: documents,
}
}

func (r *Retriever) WithTopK(topK int) *Retriever {
r.topK = topK
return r
}

func (r *Retriever) Query(ctx context.Context, query string) ([]document.Document, error) {

if r.documents == nil {
return nil, fmt.Errorf("documents are not defined")
}

if !r.areDocumentsLoaded {
err := r.index.LoadFromDocuments(context.Background(), *r.documents)
if err != nil {
return nil, err
}
r.areDocumentsLoaded = true
}

results, err := r.index.Query(context.Background(), query, indexoption.WithTopK(r.topK))
if err != nil {
return nil, err
}

return results.ToDocuments(), nil
}
25 changes: 25 additions & 0 deletions pipeline/qa/qa.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package qapipeline

import (
"context"
"fmt"

"github.com/henomis/lingoose/chat"
"github.com/henomis/lingoose/document"
Expand All @@ -18,9 +19,14 @@ const (
qaTubeUserPromptTemplate = "Based on the following context answer to the question.\n\nContext:\n{{.context}}\n\nQuestion: {{.query}}"
)

type Retriever interface {
Query(context.Context, string) ([]document.Document, error)
}

type QAPipeline struct {
llmEngine pipeline.LlmEngine
pipeline *pipeline.Pipeline
retriever Retriever
}

func New(llmEngine pipeline.LlmEngine) *QAPipeline {
Expand Down Expand Up @@ -49,6 +55,7 @@ func New(llmEngine pipeline.LlmEngine) *QAPipeline {
return &QAPipeline{
llmEngine: llmEngine,
pipeline: pipeline.New(tube),
retriever: nil,
}
}

Expand All @@ -66,6 +73,24 @@ func (p *QAPipeline) WithPrompt(chat *chat.Chat) *QAPipeline {
}
}

func (p *QAPipeline) WithRetriever(retriever Retriever) *QAPipeline {
p.retriever = retriever
return p
}

func (q *QAPipeline) Query(ctx context.Context, query string) (types.M, error) {
if q.retriever == nil {
return nil, fmt.Errorf("retriever is not defined")
}

docs, err := q.retriever.Query(ctx, query)
if err != nil {
return nil, err
}

return q.Run(ctx, query, docs)
}

func (t *QAPipeline) Run(ctx context.Context, query string, documents []document.Document) (types.M, error) {

content := ""
Expand Down

0 comments on commit ec43e05

Please sign in to comment.