Skip to content

Commit

Permalink
fix interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
henomis committed Sep 13, 2023
1 parent dbec5bb commit 59ce0d1
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 25 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,8 @@ import (

func main() {
docs, _ := loader.NewPDFToTextLoader("./kb").WithTextSplitter(textsplitter.NewRecursiveCharacterTextSplitter(2000, 200)).Load(context.Background())
qapipeline.New(openai.NewChat().WithVerbose(true)).WithDocuments(&docs).WithRetriever(retriever.New(simplevectorindex.New("db", ".", openaiembedder.New(openaiembedder.AdaEmbeddingV2)))).Query(context.Background(), "What is the NATO purpose?")
qapipeline.New(openai.NewChat().WithVerbose(true)).WithRetriever(retriever.New(simplevectorindex.New("db", ".", openaiembedder.New(openaiembedder.AdaEmbeddingV2)), &docs)).Query(context.Background(), "What is the NATO purpose?")
}

```

This is the _famous_ 2-lines **lingoose** knowledge base chatbot. 🤖
Expand Down
4 changes: 2 additions & 2 deletions examples/embeddings/simplekb/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ import (
)

func main() {
docs, _ := loader.NewPDFToTextLoader("./kb").WithPDFToTextPath("/opt/homebrew/bin/pdftotext").WithTextSplitter(textsplitter.NewRecursiveCharacterTextSplitter(2000, 200)).Load(context.Background())
qapipeline.New(openai.NewChat().WithVerbose(true)).WithDocuments(&docs).WithRetriever(retriever.New(simplevectorindex.New("db", ".", openaiembedder.New(openaiembedder.AdaEmbeddingV2)))).Query(context.Background(), "What is the NATO purpose?")
docs, _ := loader.NewPDFToTextLoader("./kb").WithTextSplitter(textsplitter.NewRecursiveCharacterTextSplitter(2000, 200)).Load(context.Background())
qapipeline.New(openai.NewChat().WithVerbose(true)).WithRetriever(retriever.New(simplevectorindex.New("db", ".", openaiembedder.New(openaiembedder.AdaEmbeddingV2)), &docs)).Query(context.Background(), "What is the NATO purpose?")
}
29 changes: 20 additions & 9 deletions index/retriever/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package retriever

import (
"context"
"fmt"

"github.com/henomis/lingoose/document"
"github.com/henomis/lingoose/index"
Expand All @@ -18,14 +19,17 @@ const (
)

type Retriever struct {
index Index
topK int
index Index
topK int
documents *[]document.Document
documentAreSet bool
}

func New(index Index) *Retriever {
func New(index Index, documents *[]document.Document) *Retriever {
return &Retriever{
index: index,
topK: defautTopK,
index: index,
topK: defautTopK,
documents: documents,
}
}

Expand All @@ -34,10 +38,17 @@ func (r *Retriever) WithTopK(topK int) *Retriever {
return r
}

func (r *Retriever) Query(ctx context.Context, query string, documents []document.Document) ([]document.Document, error) {
err := r.index.LoadFromDocuments(context.Background(), documents)
if err != nil {
return nil, err
func (r *Retriever) Query(ctx context.Context, query string) ([]document.Document, error) {

if r.documents == nil {
return nil, fmt.Errorf("documents are not defined")
}

if !r.documentAreSet {
err := r.index.LoadFromDocuments(context.Background(), *r.documents)
if err != nil {
return nil, err
}
}

results, err := r.index.Query(context.Background(), query, indexoption.WithTopK(r.topK))
Expand Down
14 changes: 2 additions & 12 deletions pipeline/qa/qa.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@ const (
)

type Retriever interface {
Query(context.Context, string, []document.Document) ([]document.Document, error)
Query(context.Context, string) ([]document.Document, error)
}

type QAPipeline struct {
llmEngine pipeline.LlmEngine
pipeline *pipeline.Pipeline
retriever Retriever
documents *[]document.Document
}

func New(llmEngine pipeline.LlmEngine) *QAPipeline {
Expand Down Expand Up @@ -79,21 +78,12 @@ func (p *QAPipeline) WithRetriever(retriever Retriever) *QAPipeline {
return p
}

func (p *QAPipeline) WithDocuments(documents *[]document.Document) *QAPipeline {
p.documents = documents
return p
}

func (q *QAPipeline) Query(ctx context.Context, query string) (types.M, error) {
if q.retriever == nil {
return nil, fmt.Errorf("retriever is not defined")
}

if q.documents == nil {
return nil, fmt.Errorf("documents are not defined")
}

docs, err := q.retriever.Query(ctx, query, *q.documents)
docs, err := q.retriever.Query(ctx, query)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 59ce0d1

Please sign in to comment.