diff --git a/README.md b/README.md index 66f6d18e..c8ad1264 100644 --- a/README.md +++ b/README.md @@ -71,10 +71,9 @@ func main() { docs, _ := loader.NewPDFToTextLoader("./kb").WithTextSplitter(textsplitter.NewRecursiveCharacterTextSplitter(2000, 200)).Load(context.Background()) openaiEmbedder := openaiembedder.New(openaiembedder.AdaEmbeddingV2) simplevectorindex.New("db", ".", openaiEmbedder).LoadFromDocuments(context.Background(), docs) - similarities, _ := simplevectorindex.New("db", ".", openaiEmbedder).SimilaritySearch(context.Background(), query, indexoption.WithTopK(3)) - qapipeline.New(openai.NewChat().WithVerbose(true)).Run(context.Background(), query, similarities.ToDocuments()) + results, _ := simplevectorindex.New("db", ".", openaiEmbedder).Query(context.Background(), query, indexoption.WithTopK(3)) + qapipeline.New(openai.NewChat().WithVerbose(true)).Run(context.Background(), query, results.ToDocuments()) } - ``` This is the _famous_ 6-lines **lingoose** knowledge base chatbot. 🤖 diff --git a/examples/embeddings/knowledge_base/main.go b/examples/embeddings/knowledge_base/main.go index 425d814d..d3623ac6 100644 --- a/examples/embeddings/knowledge_base/main.go +++ b/examples/embeddings/knowledge_base/main.go @@ -49,7 +49,7 @@ func main() { break } - similarities, err := docsVectorIndex.SimilaritySearch(context.Background(), query, indexoption.WithTopK(3)) + similarities, err := docsVectorIndex.Query(context.Background(), query, indexoption.WithTopK(3)) if err != nil { panic(err) } diff --git a/examples/embeddings/pinecone/main.go b/examples/embeddings/pinecone/main.go index 8250cef3..2cd7fb15 100644 --- a/examples/embeddings/pinecone/main.go +++ b/examples/embeddings/pinecone/main.go @@ -47,7 +47,7 @@ func main() { } query := "What is the purpose of the NATO Alliance?" - similarities, err := pineconeIndex.SimilaritySearch( + similarities, err := pineconeIndex.Query( context.Background(), query, indexoption.WithTopK(3), diff --git a/examples/embeddings/qdrant/main.go b/examples/embeddings/qdrant/main.go index b4e64d26..48b07da5 100644 --- a/examples/embeddings/qdrant/main.go +++ b/examples/embeddings/qdrant/main.go @@ -45,7 +45,7 @@ func main() { } query := "What is the purpose of the NATO Alliance?" - similarities, err := qdrantIndex.SimilaritySearch( + similarities, err := qdrantIndex.Query( context.Background(), query, indexoption.WithTopK(3), diff --git a/examples/embeddings/simpleVector/main.go b/examples/embeddings/simpleVector/main.go index c103aae3..96b4b996 100644 --- a/examples/embeddings/simpleVector/main.go +++ b/examples/embeddings/simpleVector/main.go @@ -31,7 +31,7 @@ func main() { } query := "Describe within a paragraph what is the purpose of the NATO Alliance." - similarities, err := docsVectorIndex.SimilaritySearch( + similarities, err := docsVectorIndex.Query( context.Background(), query, indexoption.WithTopK(3), diff --git a/examples/embeddings/simplekb/main.go b/examples/embeddings/simplekb/main.go index 82668a82..73b9a8f4 100644 --- a/examples/embeddings/simplekb/main.go +++ b/examples/embeddings/simplekb/main.go @@ -17,6 +17,6 @@ func main() { docs, _ := loader.NewPDFToTextLoader("./kb").WithTextSplitter(textsplitter.NewRecursiveCharacterTextSplitter(2000, 200)).Load(context.Background()) openaiEmbedder := openaiembedder.New(openaiembedder.AdaEmbeddingV2) simplevectorindex.New("db", ".", openaiEmbedder).LoadFromDocuments(context.Background(), docs) - similarities, _ := simplevectorindex.New("db", ".", openaiEmbedder).SimilaritySearch(context.Background(), query, indexoption.WithTopK(3)) - qapipeline.New(openai.NewChat().WithVerbose(true)).Run(context.Background(), query, similarities.ToDocuments()) + results, _ := simplevectorindex.New("db", ".", openaiEmbedder).Query(context.Background(), query, indexoption.WithTopK(3)) + qapipeline.New(openai.NewChat().WithVerbose(true)).Run(context.Background(), query, results.ToDocuments()) } diff --git a/index/index.go b/index/index.go index a5ca4b8d..0b2e817b 100644 --- a/index/index.go +++ b/index/index.go @@ -18,11 +18,15 @@ const ( DefaultKeyContent = "content" ) -type SearchResult struct { +type Data struct { ID string Values []float64 Metadata types.Meta - Score float64 +} + +type SearchResult struct { + Data + Score float64 } func (s *SearchResult) Content() string { diff --git a/index/pinecone/pinecone.go b/index/pinecone/pinecone.go index edeea294..d29f6387 100644 --- a/index/pinecone/pinecone.go +++ b/index/pinecone/pinecone.go @@ -126,7 +126,7 @@ func (p *Index) IsEmpty(ctx context.Context) (bool, error) { } -func (p *Index) SimilaritySearch(ctx context.Context, query string, opts ...option.Option) (index.SearchResults, error) { +func (p *Index) Search(ctx context.Context, values []float64, opts ...option.Option) (index.SearchResults, error) { pineconeOptions := &option.Options{ TopK: defaultTopK, @@ -140,7 +140,7 @@ func (p *Index) SimilaritySearch(ctx context.Context, query string, opts ...opti pineconeOptions.Filter = map[string]string{} } - matches, err := p.similaritySearch(ctx, query, pineconeOptions) + matches, err := p.similaritySearch(ctx, values, pineconeOptions) if err != nil { return nil, fmt.Errorf("%s: %w", index.ErrInternal, err) } @@ -150,18 +150,46 @@ func (p *Index) SimilaritySearch(ctx context.Context, query string, opts ...opti return index.FilterSearchResults(searchResults, pineconeOptions.TopK), nil } -func (p *Index) similaritySearch(ctx context.Context, query string, opts *option.Options) ([]pineconeresponse.QueryMatch, error) { +func (p *Index) Query(ctx context.Context, query string, opts ...option.Option) (index.SearchResults, error) { - err := p.getProjectID(ctx) + pineconeOptions := &option.Options{ + TopK: defaultTopK, + } + + for _, opt := range opts { + opt(pineconeOptions) + } + + if pineconeOptions.Filter == nil { + pineconeOptions.Filter = map[string]string{} + } + + matches, err := p.query(ctx, query, pineconeOptions) if err != nil { return nil, fmt.Errorf("%s: %w", index.ErrInternal, err) } + searchResults := buildSearchResultsFromPineconeMatches(matches, p.includeContent) + + return index.FilterSearchResults(searchResults, pineconeOptions.TopK), nil +} + +func (p *Index) query(ctx context.Context, query string, opts *option.Options) ([]pineconeresponse.QueryMatch, error) { + embeddings, err := p.embedder.Embed(ctx, []string{query}) if err != nil { return nil, err } + return p.similaritySearch(ctx, embeddings[0], opts) +} + +func (p *Index) similaritySearch(ctx context.Context, values []float64, opts *option.Options) ([]pineconeresponse.QueryMatch, error) { + err := p.getProjectID(ctx) + if err != nil { + return nil, fmt.Errorf("%s: %w", index.ErrInternal, err) + } + includeMetadata := true res := &pineconeresponse.VectorQuery{} err = p.pineconeClient.VectorQuery( @@ -170,7 +198,7 @@ func (p *Index) similaritySearch(ctx context.Context, query string, opts *option IndexName: p.indexName, ProjectID: *p.projectID, TopK: int32(opts.TopK), - Vector: embeddings[0], + Vector: values, IncludeMetadata: &includeMetadata, Namespace: &p.namespace, Filter: opts.Filter.(map[string]string), @@ -374,10 +402,12 @@ func buildSearchResultsFromPineconeMatches(matches []pineconeresponse.QueryMatch } searchResults[i] = index.SearchResult{ - ID: id, - Metadata: metadata, - Values: match.Values, - Score: score, + Data: index.Data{ + ID: id, + Metadata: metadata, + Values: match.Values, + }, + Score: score, } } diff --git a/index/qdrant/qdrant.go b/index/qdrant/qdrant.go index ccb1f2de..b5d26728 100644 --- a/index/qdrant/qdrant.go +++ b/index/qdrant/qdrant.go @@ -115,8 +115,7 @@ func (p *Index) IsEmpty(ctx context.Context) (bool, error) { } -func (q *Index) SimilaritySearch(ctx context.Context, query string, opts ...option.Option) (index.SearchResults, error) { - +func (q *Index) Search(ctx context.Context, values []float64, opts ...option.Option) (index.SearchResults, error) { qdrantOptions := &option.Options{ TopK: defaultTopK, } @@ -125,7 +124,7 @@ func (q *Index) SimilaritySearch(ctx context.Context, query string, opts ...opti opt(qdrantOptions) } - matches, err := q.similaritySearch(ctx, query, qdrantOptions) + matches, err := q.similaritySearch(ctx, values, qdrantOptions) if err != nil { return nil, fmt.Errorf("%s: %w", index.ErrInternal, err) } @@ -135,25 +134,40 @@ func (q *Index) SimilaritySearch(ctx context.Context, query string, opts ...opti return index.FilterSearchResults(searchResults, qdrantOptions.TopK), nil } -func (p *Index) similaritySearch(ctx context.Context, query string, opts *option.Options) ([]qdrantresponse.PointSearchResult, error) { +func (q *Index) Query(ctx context.Context, query string, opts ...option.Option) (index.SearchResults, error) { + + qdrantOptions := &option.Options{ + TopK: defaultTopK, + } + + for _, opt := range opts { + opt(qdrantOptions) + } - embeddings, err := p.embedder.Embed(ctx, []string{query}) + matches, err := q.query(ctx, query, qdrantOptions) if err != nil { - return nil, err + return nil, fmt.Errorf("%s: %w", index.ErrInternal, err) } + searchResults := buildSearchResultsFromQdrantMatches(matches, q.includeContent) + + return index.FilterSearchResults(searchResults, qdrantOptions.TopK), nil +} + +func (q *Index) similaritySearch(ctx context.Context, values []float64, opts *option.Options) ([]qdrantresponse.PointSearchResult, error) { + if opts.Filter == nil { opts.Filter = qdrantrequest.Filter{} } includeMetadata := true res := &qdrantresponse.PointSearch{} - err = p.qdrantClient.PointSearch( + err := q.qdrantClient.PointSearch( ctx, &qdrantrequest.PointSearch{ - CollectionName: p.collectionName, + CollectionName: q.collectionName, Limit: opts.TopK, - Vector: embeddings[0], + Vector: values, WithPayload: &includeMetadata, Filter: opts.Filter.(qdrantrequest.Filter), }, @@ -166,6 +180,15 @@ func (p *Index) similaritySearch(ctx context.Context, query string, opts *option return res.Result, nil } +func (q *Index) query(ctx context.Context, query string, opts *option.Options) ([]qdrantresponse.PointSearchResult, error) { + embeddings, err := q.embedder.Embed(ctx, []string{query}) + if err != nil { + return nil, err + } + + return q.similaritySearch(ctx, embeddings[0], opts) +} + func (q *Index) createCollectionIfRequired(ctx context.Context) error { if q.createCollection == nil { @@ -299,10 +322,12 @@ func buildSearchResultsFromQdrantMatches(matches []qdrantresponse.PointSearchRes } searchResults[i] = index.SearchResult{ - ID: match.ID, - Metadata: metadata, - Values: match.Vector, - Score: match.Score, + Data: index.Data{ + ID: match.ID, + Metadata: metadata, + Values: match.Vector, + }, + Score: match.Score, } } diff --git a/index/simpleVectorIndex/simpleVectorIndex.go b/index/simpleVectorIndex/simpleVectorIndex.go index 1e73ecc2..e068023c 100644 --- a/index/simpleVectorIndex/simpleVectorIndex.go +++ b/index/simpleVectorIndex/simpleVectorIndex.go @@ -140,7 +140,24 @@ func (s *Index) IsEmpty() (bool, error) { return len(s.data) == 0, nil } -func (s *Index) SimilaritySearch(ctx context.Context, query string, opts ...option.Option) (index.SearchResults, error) { +func (s *Index) Search(ctx context.Context, values []float64, opts ...option.Option) (index.SearchResults, error) { + sviOptions := &option.Options{ + TopK: defaultTopK, + } + + for _, opt := range opts { + opt(sviOptions) + } + + err := s.load() + if err != nil { + return nil, fmt.Errorf("%s: %w", index.ErrInternal, err) + } + + return s.similaritySearch(ctx, values, sviOptions) +} + +func (s *Index) Query(ctx context.Context, query string, opts ...option.Option) (index.SearchResults, error) { sviOptions := &option.Options{ TopK: defaultTopK, @@ -160,24 +177,35 @@ func (s *Index) SimilaritySearch(ctx context.Context, query string, opts ...opti return nil, fmt.Errorf("%s: %w", index.ErrInternal, err) } - scores := s.cosineSimilarityBatch(embeddings[0]) + return s.similaritySearch(ctx, embeddings[0], sviOptions) +} + +func (s *Index) similaritySearch( + ctx context.Context, + embedding embedder.Embedding, + opts *option.Options, +) (index.SearchResults, error) { + + scores := s.cosineSimilarityBatch(embedding) searchResults := make([]index.SearchResult, len(scores)) for i, score := range scores { searchResults[i] = index.SearchResult{ - ID: s.data[i].ID, - Values: s.data[i].Values, - Metadata: s.data[i].Metadata, - Score: score, + Data: index.Data{ + ID: s.data[i].ID, + Values: s.data[i].Values, + Metadata: s.data[i].Metadata, + }, + Score: score, } } - if sviOptions.Filter != nil { - searchResults = sviOptions.Filter.(SimpleVectorIndexFilterFn)(searchResults) + if opts.Filter != nil { + searchResults = opts.Filter.(SimpleVectorIndexFilterFn)(searchResults) } - return index.FilterSearchResults(searchResults, sviOptions.TopK), nil + return index.FilterSearchResults(searchResults, opts.TopK), nil } func (s *Index) cosineSimilarity(a embedder.Embedding, b embedder.Embedding) float64 {