Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix filtering and ordering indexes search results #135

Merged
merged 7 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/embeddings/simpleVector/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func main() {
}
}

query := "Describe within a paragraph what is the purpose of the NATO Alliance."
query := "What is the purpose of the NATO Alliance?"
similarities, err := docsVectorIndex.Query(
context.Background(),
query,
Expand Down Expand Up @@ -83,7 +83,7 @@ func ingestData(docsVectorIndex *simplevectorindex.Index, openaiEmbedder index.E
return err
}

textSplitter := textsplitter.NewRecursiveCharacterTextSplitter(2000, 100)
textSplitter := textsplitter.NewRecursiveCharacterTextSplitter(1000, 20)

documentChunks := textSplitter.SplitDocuments(documents)

Expand Down
15 changes: 0 additions & 15 deletions index/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package index
import (
"context"
"errors"
"sort"

"github.com/henomis/lingoose/document"
"github.com/henomis/lingoose/embedder"
Expand Down Expand Up @@ -58,20 +57,6 @@ type Embedder interface {
Embed(ctx context.Context, texts []string) ([]embedder.Embedding, error)
}

func FilterSearchResults(searchResults SearchResults, topK int) SearchResults {
//sort by similarity score
sort.Slice(searchResults, func(i, j int) bool {
return searchResults[i].Score > searchResults[j].Score
})

maxTopK := topK
if maxTopK > len(searchResults) {
maxTopK = len(searchResults)
}

return searchResults[:maxTopK]
}

func DeepCopyMetadata(metadata types.Meta) types.Meta {
metadataCopy := make(types.Meta)
for k, v := range metadata {
Expand Down
12 changes: 6 additions & 6 deletions index/pinecone/pinecone.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type Index struct {
namespace string
embedder index.Embedder
includeContent bool
includeValues bool
batchUpsertSize int

createIndex *CreateIndexOptions
Expand All @@ -44,6 +45,7 @@ type Options struct {
IndexName string
Namespace string
IncludeContent bool
IncludeValues bool
BatchUpsertSize *int
CreateIndex *CreateIndexOptions
}
Expand All @@ -65,6 +67,7 @@ func New(options Options, embedder index.Embedder) *Index {
embedder: embedder,
namespace: options.Namespace,
includeContent: options.IncludeContent,
includeValues: options.IncludeValues,
batchUpsertSize: batchUpsertSize,
createIndex: options.CreateIndex,
}
Expand Down Expand Up @@ -165,9 +168,7 @@ func (p *Index) Search(ctx context.Context, values []float64, opts ...option.Opt
return nil, fmt.Errorf("%w: %w", index.ErrInternal, err)
}

searchResults := buildSearchResultsFromPineconeMatches(matches, p.includeContent)

return index.FilterSearchResults(searchResults, pineconeOptions.TopK), nil
return buildSearchResultsFromPineconeMatches(matches, p.includeContent), nil
}

func (p *Index) Query(ctx context.Context, query string, opts ...option.Option) (index.SearchResults, error) {
Expand All @@ -188,9 +189,7 @@ func (p *Index) Query(ctx context.Context, query string, opts ...option.Option)
return nil, fmt.Errorf("%w: %w", index.ErrInternal, err)
}

searchResults := buildSearchResultsFromPineconeMatches(matches, p.includeContent)

return index.FilterSearchResults(searchResults, pineconeOptions.TopK), nil
return buildSearchResultsFromPineconeMatches(matches, p.includeContent), nil
}

func (p *Index) query(ctx context.Context, query string, opts *option.Options) ([]pineconeresponse.QueryMatch, error) {
Expand Down Expand Up @@ -222,6 +221,7 @@ func (p *Index) similaritySearch(
TopK: int32(opts.TopK),
Vector: values,
IncludeMetadata: &includeMetadata,
IncludeValues: &p.includeValues,
Namespace: &p.namespace,
Filter: opts.Filter.(map[string]string),
},
Expand Down
12 changes: 6 additions & 6 deletions index/qdrant/qdrant.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type Index struct {
collectionName string
embedder index.Embedder
includeContent bool
includeValues bool
batchUpsertSize int

createCollection *CreateCollectionOptions
Expand All @@ -47,6 +48,7 @@ type CreateCollectionOptions struct {
type Options struct {
CollectionName string
IncludeContent bool
IncludeValues bool
BatchUpsertSize *int
CreateCollection *CreateCollectionOptions
}
Expand All @@ -67,6 +69,7 @@ func New(options Options, embedder index.Embedder) *Index {
collectionName: options.CollectionName,
embedder: embedder,
includeContent: options.IncludeContent,
includeValues: options.IncludeValues,
batchUpsertSize: batchUpsertSize,
createCollection: options.CreateCollection,
}
Expand Down Expand Up @@ -150,9 +153,7 @@ func (q *Index) Search(ctx context.Context, values []float64, opts ...option.Opt
return nil, fmt.Errorf("%w: %w", index.ErrInternal, err)
}

searchResults := buildSearchResultsFromQdrantMatches(matches, q.includeContent)

return index.FilterSearchResults(searchResults, qdrantOptions.TopK), nil
return buildSearchResultsFromQdrantMatches(matches, q.includeContent), nil
}

func (q *Index) Query(ctx context.Context, query string, opts ...option.Option) (index.SearchResults, error) {
Expand All @@ -169,9 +170,7 @@ func (q *Index) Query(ctx context.Context, query string, opts ...option.Option)
return nil, fmt.Errorf("%w: %w", index.ErrInternal, err)
}

searchResults := buildSearchResultsFromQdrantMatches(matches, q.includeContent)

return index.FilterSearchResults(searchResults, qdrantOptions.TopK), nil
return buildSearchResultsFromQdrantMatches(matches, q.includeContent), nil
}

func (q *Index) similaritySearch(
Expand All @@ -192,6 +191,7 @@ func (q *Index) similaritySearch(
Limit: opts.TopK,
Vector: values,
WithPayload: &includeMetadata,
WithVector: &q.includeValues,
Filter: opts.Filter.(qdrantrequest.Filter),
},
res,
Expand Down
74 changes: 55 additions & 19 deletions index/simpleVectorIndex/simpleVectorIndex.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package simplevectorindex
import (
"context"
"encoding/json"
"errors"
"fmt"
"math"
"os"
"sort"
"strings"

"github.com/google/uuid"
Expand Down Expand Up @@ -211,7 +213,10 @@ func (s *Index) similaritySearch(
opts *option.Options,
) (index.SearchResults, error) {
_ = ctx
scores := s.cosineSimilarityBatch(embedding)
scores, err := s.cosineSimilarityBatch(embedding)
if err != nil {
return nil, fmt.Errorf("%w: %w", index.ErrInternal, err)
}

searchResults := make([]index.SearchResult, len(scores))

Expand All @@ -230,33 +235,64 @@ func (s *Index) similaritySearch(
searchResults = opts.Filter.(FilterFn)(searchResults)
}

return index.FilterSearchResults(searchResults, opts.TopK), nil
return filterSearchResults(searchResults, opts.TopK), nil
}

func (s *Index) cosineSimilarity(a embedder.Embedding, b embedder.Embedding) float64 {
dotProduct := float64(0.0)
normA := float64(0.0)
normB := float64(0.0)

for i := 0; i < len(a); i++ {
dotProduct += a[i] * b[i]
normA += a[i] * a[i]
normB += b[i] * b[i]
func (s *Index) cosineSimilarity(a []float64, b []float64) (cosine float64, err error) {
var count int
lengthA := len(a)
lengthB := len(b)
if lengthA > lengthB {
count = lengthA
} else {
count = lengthB
}

if normA == 0 || normB == 0 {
return float64(0.0)
sumA := 0.0
s1 := 0.0
s2 := 0.0
for k := 0; k < count; k++ {
if k >= lengthA {
s2 += math.Pow(b[k], 2)
continue
}
if k >= lengthB {
s1 += math.Pow(a[k], 2)
continue
}
sumA += a[k] * b[k]
s1 += math.Pow(a[k], 2)
s2 += math.Pow(b[k], 2)
}

return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB))
if s1 == 0 || s2 == 0 {
return 0.0, errors.New("vectors should not be null (all zeros)")
}
return sumA / (math.Sqrt(s1) * math.Sqrt(s2)), nil
}

func (s *Index) cosineSimilarityBatch(a embedder.Embedding) []float64 {
func (s *Index) cosineSimilarityBatch(a embedder.Embedding) ([]float64, error) {
var err error
scores := make([]float64, len(s.data))

for i := range s.data {
scores[i] = s.cosineSimilarity(a, s.data[i].Values)
scores[i], err = s.cosineSimilarity(a, s.data[i].Values)
if err != nil {
return nil, err
}
}

return scores, nil
}

func filterSearchResults(searchResults index.SearchResults, topK int) index.SearchResults {
//sort by similarity score
sort.Slice(searchResults, func(i, j int) bool {
return (1 - searchResults[i].Score) < (1 - searchResults[j].Score)
})

maxTopK := topK
if maxTopK > len(searchResults) {
maxTopK = len(searchResults)
}

return scores
return searchResults[:maxTopK]
}