Skip to content

Commit

Permalink
Chore/116 refactor indexes Query and Search (#128)
Browse files Browse the repository at this point in the history
* chore: add Data structure

* chore: refactor Query and Search

* chore: update examples

* fix
  • Loading branch information
henomis authored Sep 13, 2023
1 parent be60c13 commit f97cc8b
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 42 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,9 @@ func main() {
docs, _ := loader.NewPDFToTextLoader("./kb").WithTextSplitter(textsplitter.NewRecursiveCharacterTextSplitter(2000, 200)).Load(context.Background())
openaiEmbedder := openaiembedder.New(openaiembedder.AdaEmbeddingV2)
simplevectorindex.New("db", ".", openaiEmbedder).LoadFromDocuments(context.Background(), docs)
similarities, _ := simplevectorindex.New("db", ".", openaiEmbedder).SimilaritySearch(context.Background(), query, indexoption.WithTopK(3))
qapipeline.New(openai.NewChat().WithVerbose(true)).Run(context.Background(), query, similarities.ToDocuments())
results, _ := simplevectorindex.New("db", ".", openaiEmbedder).Query(context.Background(), query, indexoption.WithTopK(3))
qapipeline.New(openai.NewChat().WithVerbose(true)).Run(context.Background(), query, results.ToDocuments())
}

```

This is the _famous_ 6-lines **lingoose** knowledge base chatbot. 🤖
Expand Down
2 changes: 1 addition & 1 deletion examples/embeddings/knowledge_base/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func main() {
break
}

similarities, err := docsVectorIndex.SimilaritySearch(context.Background(), query, indexoption.WithTopK(3))
similarities, err := docsVectorIndex.Query(context.Background(), query, indexoption.WithTopK(3))
if err != nil {
panic(err)
}
Expand Down
2 changes: 1 addition & 1 deletion examples/embeddings/pinecone/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func main() {
}

query := "What is the purpose of the NATO Alliance?"
similarities, err := pineconeIndex.SimilaritySearch(
similarities, err := pineconeIndex.Query(
context.Background(),
query,
indexoption.WithTopK(3),
Expand Down
2 changes: 1 addition & 1 deletion examples/embeddings/qdrant/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func main() {
}

query := "What is the purpose of the NATO Alliance?"
similarities, err := qdrantIndex.SimilaritySearch(
similarities, err := qdrantIndex.Query(
context.Background(),
query,
indexoption.WithTopK(3),
Expand Down
2 changes: 1 addition & 1 deletion examples/embeddings/simpleVector/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func main() {
}

query := "Describe within a paragraph what is the purpose of the NATO Alliance."
similarities, err := docsVectorIndex.SimilaritySearch(
similarities, err := docsVectorIndex.Query(
context.Background(),
query,
indexoption.WithTopK(3),
Expand Down
4 changes: 2 additions & 2 deletions examples/embeddings/simplekb/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ func main() {
docs, _ := loader.NewPDFToTextLoader("./kb").WithTextSplitter(textsplitter.NewRecursiveCharacterTextSplitter(2000, 200)).Load(context.Background())
openaiEmbedder := openaiembedder.New(openaiembedder.AdaEmbeddingV2)
simplevectorindex.New("db", ".", openaiEmbedder).LoadFromDocuments(context.Background(), docs)
similarities, _ := simplevectorindex.New("db", ".", openaiEmbedder).SimilaritySearch(context.Background(), query, indexoption.WithTopK(3))
qapipeline.New(openai.NewChat().WithVerbose(true)).Run(context.Background(), query, similarities.ToDocuments())
results, _ := simplevectorindex.New("db", ".", openaiEmbedder).Query(context.Background(), query, indexoption.WithTopK(3))
qapipeline.New(openai.NewChat().WithVerbose(true)).Run(context.Background(), query, results.ToDocuments())
}
8 changes: 6 additions & 2 deletions index/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@ const (
DefaultKeyContent = "content"
)

type SearchResult struct {
type Data struct {
ID string
Values []float64
Metadata types.Meta
Score float64
}

type SearchResult struct {
Data
Score float64
}

func (s *SearchResult) Content() string {
Expand Down
48 changes: 39 additions & 9 deletions index/pinecone/pinecone.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func (p *Index) IsEmpty(ctx context.Context) (bool, error) {

}

func (p *Index) SimilaritySearch(ctx context.Context, query string, opts ...option.Option) (index.SearchResults, error) {
func (p *Index) Search(ctx context.Context, values []float64, opts ...option.Option) (index.SearchResults, error) {

pineconeOptions := &option.Options{
TopK: defaultTopK,
Expand All @@ -140,7 +140,7 @@ func (p *Index) SimilaritySearch(ctx context.Context, query string, opts ...opti
pineconeOptions.Filter = map[string]string{}
}

matches, err := p.similaritySearch(ctx, query, pineconeOptions)
matches, err := p.similaritySearch(ctx, values, pineconeOptions)
if err != nil {
return nil, fmt.Errorf("%s: %w", index.ErrInternal, err)
}
Expand All @@ -150,18 +150,46 @@ func (p *Index) SimilaritySearch(ctx context.Context, query string, opts ...opti
return index.FilterSearchResults(searchResults, pineconeOptions.TopK), nil
}

func (p *Index) similaritySearch(ctx context.Context, query string, opts *option.Options) ([]pineconeresponse.QueryMatch, error) {
func (p *Index) Query(ctx context.Context, query string, opts ...option.Option) (index.SearchResults, error) {

err := p.getProjectID(ctx)
pineconeOptions := &option.Options{
TopK: defaultTopK,
}

for _, opt := range opts {
opt(pineconeOptions)
}

if pineconeOptions.Filter == nil {
pineconeOptions.Filter = map[string]string{}
}

matches, err := p.query(ctx, query, pineconeOptions)
if err != nil {
return nil, fmt.Errorf("%s: %w", index.ErrInternal, err)
}

searchResults := buildSearchResultsFromPineconeMatches(matches, p.includeContent)

return index.FilterSearchResults(searchResults, pineconeOptions.TopK), nil
}

func (p *Index) query(ctx context.Context, query string, opts *option.Options) ([]pineconeresponse.QueryMatch, error) {

embeddings, err := p.embedder.Embed(ctx, []string{query})
if err != nil {
return nil, err
}

return p.similaritySearch(ctx, embeddings[0], opts)
}

func (p *Index) similaritySearch(ctx context.Context, values []float64, opts *option.Options) ([]pineconeresponse.QueryMatch, error) {
err := p.getProjectID(ctx)
if err != nil {
return nil, fmt.Errorf("%s: %w", index.ErrInternal, err)
}

includeMetadata := true
res := &pineconeresponse.VectorQuery{}
err = p.pineconeClient.VectorQuery(
Expand All @@ -170,7 +198,7 @@ func (p *Index) similaritySearch(ctx context.Context, query string, opts *option
IndexName: p.indexName,
ProjectID: *p.projectID,
TopK: int32(opts.TopK),
Vector: embeddings[0],
Vector: values,
IncludeMetadata: &includeMetadata,
Namespace: &p.namespace,
Filter: opts.Filter.(map[string]string),
Expand Down Expand Up @@ -374,10 +402,12 @@ func buildSearchResultsFromPineconeMatches(matches []pineconeresponse.QueryMatch
}

searchResults[i] = index.SearchResult{
ID: id,
Metadata: metadata,
Values: match.Values,
Score: score,
Data: index.Data{
ID: id,
Metadata: metadata,
Values: match.Values,
},
Score: score,
}
}

Expand Down
51 changes: 38 additions & 13 deletions index/qdrant/qdrant.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,7 @@ func (p *Index) IsEmpty(ctx context.Context) (bool, error) {

}

func (q *Index) SimilaritySearch(ctx context.Context, query string, opts ...option.Option) (index.SearchResults, error) {

func (q *Index) Search(ctx context.Context, values []float64, opts ...option.Option) (index.SearchResults, error) {
qdrantOptions := &option.Options{
TopK: defaultTopK,
}
Expand All @@ -125,7 +124,7 @@ func (q *Index) SimilaritySearch(ctx context.Context, query string, opts ...opti
opt(qdrantOptions)
}

matches, err := q.similaritySearch(ctx, query, qdrantOptions)
matches, err := q.similaritySearch(ctx, values, qdrantOptions)
if err != nil {
return nil, fmt.Errorf("%s: %w", index.ErrInternal, err)
}
Expand All @@ -135,25 +134,40 @@ func (q *Index) SimilaritySearch(ctx context.Context, query string, opts ...opti
return index.FilterSearchResults(searchResults, qdrantOptions.TopK), nil
}

func (p *Index) similaritySearch(ctx context.Context, query string, opts *option.Options) ([]qdrantresponse.PointSearchResult, error) {
func (q *Index) Query(ctx context.Context, query string, opts ...option.Option) (index.SearchResults, error) {

qdrantOptions := &option.Options{
TopK: defaultTopK,
}

for _, opt := range opts {
opt(qdrantOptions)
}

embeddings, err := p.embedder.Embed(ctx, []string{query})
matches, err := q.query(ctx, query, qdrantOptions)
if err != nil {
return nil, err
return nil, fmt.Errorf("%s: %w", index.ErrInternal, err)
}

searchResults := buildSearchResultsFromQdrantMatches(matches, q.includeContent)

return index.FilterSearchResults(searchResults, qdrantOptions.TopK), nil
}

func (q *Index) similaritySearch(ctx context.Context, values []float64, opts *option.Options) ([]qdrantresponse.PointSearchResult, error) {

if opts.Filter == nil {
opts.Filter = qdrantrequest.Filter{}
}

includeMetadata := true
res := &qdrantresponse.PointSearch{}
err = p.qdrantClient.PointSearch(
err := q.qdrantClient.PointSearch(
ctx,
&qdrantrequest.PointSearch{
CollectionName: p.collectionName,
CollectionName: q.collectionName,
Limit: opts.TopK,
Vector: embeddings[0],
Vector: values,
WithPayload: &includeMetadata,
Filter: opts.Filter.(qdrantrequest.Filter),
},
Expand All @@ -166,6 +180,15 @@ func (p *Index) similaritySearch(ctx context.Context, query string, opts *option
return res.Result, nil
}

func (q *Index) query(ctx context.Context, query string, opts *option.Options) ([]qdrantresponse.PointSearchResult, error) {
embeddings, err := q.embedder.Embed(ctx, []string{query})
if err != nil {
return nil, err
}

return q.similaritySearch(ctx, embeddings[0], opts)
}

func (q *Index) createCollectionIfRequired(ctx context.Context) error {

if q.createCollection == nil {
Expand Down Expand Up @@ -299,10 +322,12 @@ func buildSearchResultsFromQdrantMatches(matches []qdrantresponse.PointSearchRes
}

searchResults[i] = index.SearchResult{
ID: match.ID,
Metadata: metadata,
Values: match.Vector,
Score: match.Score,
Data: index.Data{
ID: match.ID,
Metadata: metadata,
Values: match.Vector,
},
Score: match.Score,
}
}

Expand Down
46 changes: 37 additions & 9 deletions index/simpleVectorIndex/simpleVectorIndex.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,24 @@ func (s *Index) IsEmpty() (bool, error) {
return len(s.data) == 0, nil
}

func (s *Index) SimilaritySearch(ctx context.Context, query string, opts ...option.Option) (index.SearchResults, error) {
func (s *Index) Search(ctx context.Context, values []float64, opts ...option.Option) (index.SearchResults, error) {
sviOptions := &option.Options{
TopK: defaultTopK,
}

for _, opt := range opts {
opt(sviOptions)
}

err := s.load()
if err != nil {
return nil, fmt.Errorf("%s: %w", index.ErrInternal, err)
}

return s.similaritySearch(ctx, values, sviOptions)
}

func (s *Index) Query(ctx context.Context, query string, opts ...option.Option) (index.SearchResults, error) {

sviOptions := &option.Options{
TopK: defaultTopK,
Expand All @@ -160,24 +177,35 @@ func (s *Index) SimilaritySearch(ctx context.Context, query string, opts ...opti
return nil, fmt.Errorf("%s: %w", index.ErrInternal, err)
}

scores := s.cosineSimilarityBatch(embeddings[0])
return s.similaritySearch(ctx, embeddings[0], sviOptions)
}

func (s *Index) similaritySearch(
ctx context.Context,
embedding embedder.Embedding,
opts *option.Options,
) (index.SearchResults, error) {

scores := s.cosineSimilarityBatch(embedding)

searchResults := make([]index.SearchResult, len(scores))

for i, score := range scores {
searchResults[i] = index.SearchResult{
ID: s.data[i].ID,
Values: s.data[i].Values,
Metadata: s.data[i].Metadata,
Score: score,
Data: index.Data{
ID: s.data[i].ID,
Values: s.data[i].Values,
Metadata: s.data[i].Metadata,
},
Score: score,
}
}

if sviOptions.Filter != nil {
searchResults = sviOptions.Filter.(SimpleVectorIndexFilterFn)(searchResults)
if opts.Filter != nil {
searchResults = opts.Filter.(SimpleVectorIndexFilterFn)(searchResults)
}

return index.FilterSearchResults(searchResults, sviOptions.TopK), nil
return index.FilterSearchResults(searchResults, opts.TopK), nil
}

func (s *Index) cosineSimilarity(a embedder.Embedding, b embedder.Embedding) float64 {
Expand Down

0 comments on commit f97cc8b

Please sign in to comment.