diff --git a/examples/embeddings/simpleVector/main.go b/examples/embeddings/simpleVector/main.go index 96b4b996..5270def4 100644 --- a/examples/embeddings/simpleVector/main.go +++ b/examples/embeddings/simpleVector/main.go @@ -30,7 +30,7 @@ func main() { } } - query := "Describe within a paragraph what is the purpose of the NATO Alliance." + query := "What is the purpose of the NATO Alliance?" similarities, err := docsVectorIndex.Query( context.Background(), query, @@ -83,7 +83,7 @@ func ingestData(docsVectorIndex *simplevectorindex.Index, openaiEmbedder index.E return err } - textSplitter := textsplitter.NewRecursiveCharacterTextSplitter(2000, 100) + textSplitter := textsplitter.NewRecursiveCharacterTextSplitter(1000, 20) documentChunks := textSplitter.SplitDocuments(documents) diff --git a/index/index.go b/index/index.go index 3778ef4d..068e615c 100644 --- a/index/index.go +++ b/index/index.go @@ -3,7 +3,6 @@ package index import ( "context" "errors" - "sort" "github.com/henomis/lingoose/document" "github.com/henomis/lingoose/embedder" @@ -58,20 +57,6 @@ type Embedder interface { Embed(ctx context.Context, texts []string) ([]embedder.Embedding, error) } -func FilterSearchResults(searchResults SearchResults, topK int) SearchResults { - //sort by similarity score - sort.Slice(searchResults, func(i, j int) bool { - return searchResults[i].Score > searchResults[j].Score - }) - - maxTopK := topK - if maxTopK > len(searchResults) { - maxTopK = len(searchResults) - } - - return searchResults[:maxTopK] -} - func DeepCopyMetadata(metadata types.Meta) types.Meta { metadataCopy := make(types.Meta) for k, v := range metadata { diff --git a/index/pinecone/pinecone.go b/index/pinecone/pinecone.go index 27137fd1..41405dd2 100644 --- a/index/pinecone/pinecone.go +++ b/index/pinecone/pinecone.go @@ -28,6 +28,7 @@ type Index struct { namespace string embedder index.Embedder includeContent bool + includeValues bool batchUpsertSize int createIndex *CreateIndexOptions @@ -44,6 +45,7 @@ type Options struct { IndexName string Namespace string IncludeContent bool + IncludeValues bool BatchUpsertSize *int CreateIndex *CreateIndexOptions } @@ -65,6 +67,7 @@ func New(options Options, embedder index.Embedder) *Index { embedder: embedder, namespace: options.Namespace, includeContent: options.IncludeContent, + includeValues: options.IncludeValues, batchUpsertSize: batchUpsertSize, createIndex: options.CreateIndex, } @@ -165,9 +168,7 @@ func (p *Index) Search(ctx context.Context, values []float64, opts ...option.Opt return nil, fmt.Errorf("%w: %w", index.ErrInternal, err) } - searchResults := buildSearchResultsFromPineconeMatches(matches, p.includeContent) - - return index.FilterSearchResults(searchResults, pineconeOptions.TopK), nil + return buildSearchResultsFromPineconeMatches(matches, p.includeContent), nil } func (p *Index) Query(ctx context.Context, query string, opts ...option.Option) (index.SearchResults, error) { @@ -188,9 +189,7 @@ func (p *Index) Query(ctx context.Context, query string, opts ...option.Option) return nil, fmt.Errorf("%w: %w", index.ErrInternal, err) } - searchResults := buildSearchResultsFromPineconeMatches(matches, p.includeContent) - - return index.FilterSearchResults(searchResults, pineconeOptions.TopK), nil + return buildSearchResultsFromPineconeMatches(matches, p.includeContent), nil } func (p *Index) query(ctx context.Context, query string, opts *option.Options) ([]pineconeresponse.QueryMatch, error) { @@ -222,6 +221,7 @@ func (p *Index) similaritySearch( TopK: int32(opts.TopK), Vector: values, IncludeMetadata: &includeMetadata, + IncludeValues: &p.includeValues, Namespace: &p.namespace, Filter: opts.Filter.(map[string]string), }, diff --git a/index/qdrant/qdrant.go b/index/qdrant/qdrant.go index a39ccecc..81ccf6fe 100644 --- a/index/qdrant/qdrant.go +++ b/index/qdrant/qdrant.go @@ -25,6 +25,7 @@ type Index struct { collectionName string embedder index.Embedder includeContent bool + includeValues bool batchUpsertSize int createCollection *CreateCollectionOptions @@ -47,6 +48,7 @@ type CreateCollectionOptions struct { type Options struct { CollectionName string IncludeContent bool + IncludeValues bool BatchUpsertSize *int CreateCollection *CreateCollectionOptions } @@ -67,6 +69,7 @@ func New(options Options, embedder index.Embedder) *Index { collectionName: options.CollectionName, embedder: embedder, includeContent: options.IncludeContent, + includeValues: options.IncludeValues, batchUpsertSize: batchUpsertSize, createCollection: options.CreateCollection, } @@ -150,9 +153,7 @@ func (q *Index) Search(ctx context.Context, values []float64, opts ...option.Opt return nil, fmt.Errorf("%w: %w", index.ErrInternal, err) } - searchResults := buildSearchResultsFromQdrantMatches(matches, q.includeContent) - - return index.FilterSearchResults(searchResults, qdrantOptions.TopK), nil + return buildSearchResultsFromQdrantMatches(matches, q.includeContent), nil } func (q *Index) Query(ctx context.Context, query string, opts ...option.Option) (index.SearchResults, error) { @@ -169,9 +170,7 @@ func (q *Index) Query(ctx context.Context, query string, opts ...option.Option) return nil, fmt.Errorf("%w: %w", index.ErrInternal, err) } - searchResults := buildSearchResultsFromQdrantMatches(matches, q.includeContent) - - return index.FilterSearchResults(searchResults, qdrantOptions.TopK), nil + return buildSearchResultsFromQdrantMatches(matches, q.includeContent), nil } func (q *Index) similaritySearch( @@ -192,6 +191,7 @@ func (q *Index) similaritySearch( Limit: opts.TopK, Vector: values, WithPayload: &includeMetadata, + WithVector: &q.includeValues, Filter: opts.Filter.(qdrantrequest.Filter), }, res, diff --git a/index/simpleVectorIndex/simpleVectorIndex.go b/index/simpleVectorIndex/simpleVectorIndex.go index 2ebad31a..203104b9 100644 --- a/index/simpleVectorIndex/simpleVectorIndex.go +++ b/index/simpleVectorIndex/simpleVectorIndex.go @@ -3,9 +3,11 @@ package simplevectorindex import ( "context" "encoding/json" + "errors" "fmt" "math" "os" + "sort" "strings" "github.com/google/uuid" @@ -211,7 +213,10 @@ func (s *Index) similaritySearch( opts *option.Options, ) (index.SearchResults, error) { _ = ctx - scores := s.cosineSimilarityBatch(embedding) + scores, err := s.cosineSimilarityBatch(embedding) + if err != nil { + return nil, fmt.Errorf("%w: %w", index.ErrInternal, err) + } searchResults := make([]index.SearchResult, len(scores)) @@ -230,33 +235,64 @@ func (s *Index) similaritySearch( searchResults = opts.Filter.(FilterFn)(searchResults) } - return index.FilterSearchResults(searchResults, opts.TopK), nil + return filterSearchResults(searchResults, opts.TopK), nil } -func (s *Index) cosineSimilarity(a embedder.Embedding, b embedder.Embedding) float64 { - dotProduct := float64(0.0) - normA := float64(0.0) - normB := float64(0.0) - - for i := 0; i < len(a); i++ { - dotProduct += a[i] * b[i] - normA += a[i] * a[i] - normB += b[i] * b[i] +func (s *Index) cosineSimilarity(a []float64, b []float64) (cosine float64, err error) { + var count int + lengthA := len(a) + lengthB := len(b) + if lengthA > lengthB { + count = lengthA + } else { + count = lengthB } - - if normA == 0 || normB == 0 { - return float64(0.0) + sumA := 0.0 + s1 := 0.0 + s2 := 0.0 + for k := 0; k < count; k++ { + if k >= lengthA { + s2 += math.Pow(b[k], 2) + continue + } + if k >= lengthB { + s1 += math.Pow(a[k], 2) + continue + } + sumA += a[k] * b[k] + s1 += math.Pow(a[k], 2) + s2 += math.Pow(b[k], 2) } - - return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB)) + if s1 == 0 || s2 == 0 { + return 0.0, errors.New("vectors should not be null (all zeros)") + } + return sumA / (math.Sqrt(s1) * math.Sqrt(s2)), nil } -func (s *Index) cosineSimilarityBatch(a embedder.Embedding) []float64 { +func (s *Index) cosineSimilarityBatch(a embedder.Embedding) ([]float64, error) { + var err error scores := make([]float64, len(s.data)) for i := range s.data { - scores[i] = s.cosineSimilarity(a, s.data[i].Values) + scores[i], err = s.cosineSimilarity(a, s.data[i].Values) + if err != nil { + return nil, err + } + } + + return scores, nil +} + +func filterSearchResults(searchResults index.SearchResults, topK int) index.SearchResults { + //sort by similarity score + sort.Slice(searchResults, func(i, j int) bool { + return (1 - searchResults[i].Score) < (1 - searchResults[j].Score) + }) + + maxTopK := topK + if maxTopK > len(searchResults) { + maxTopK = len(searchResults) } - return scores + return searchResults[:maxTopK] }