Skip to content

Commit

Permalink
chore: implement index engines
Browse files Browse the repository at this point in the history
  • Loading branch information
henomis committed Oct 13, 2023
1 parent aa668f9 commit 84d7dc5
Show file tree
Hide file tree
Showing 4 changed files with 378 additions and 381 deletions.
35 changes: 18 additions & 17 deletions examples/embeddings/qdrant/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ import (
"fmt"

openaiembedder "github.com/henomis/lingoose/embedder/openai"
"github.com/henomis/lingoose/index"
qdrantindexengine "github.com/henomis/lingoose/index/engines/qdrant"
indexoption "github.com/henomis/lingoose/index/option"
qdrantindex "github.com/henomis/lingoose/index/qdrant"
"github.com/henomis/lingoose/llm/openai"
"github.com/henomis/lingoose/loader"
"github.com/henomis/lingoose/prompt"
Expand All @@ -18,34 +19,34 @@ import (

func main() {

openaiEmbedder := openaiembedder.New(openaiembedder.AdaEmbeddingV2)

qdrantIndex := qdrantindex.New(
qdrantindex.Options{
CollectionName: "test",
IncludeContent: true,
CreateCollection: &qdrantindex.CreateCollectionOptions{
Dimension: 1536,
Distance: qdrantindex.DistanceCosine,
index := index.New(
qdrantindexengine.New(
qdrantindexengine.Options{
CollectionName: "test",
IncludeContent: true,
CreateCollection: &qdrantindexengine.CreateCollectionOptions{
Dimension: 1536,
Distance: qdrantindexengine.DistanceCosine,
},
},
},
openaiEmbedder,
).WithAPIKeyAndEdpoint("", "http://localhost:6333")
).WithAPIKeyAndEdpoint("", "http://localhost:6333"),
openaiembedder.New(openaiembedder.AdaEmbeddingV2),
).WithIncludeContents(true)

indexIsEmpty, err := qdrantIndex.IsEmpty(context.Background())
indexIsEmpty, err := index.IsEmpty(context.Background())
if err != nil {
panic(err)
}

if indexIsEmpty {
err = ingestData(qdrantIndex)
err = ingestData(index)
if err != nil {
panic(err)
}
}

query := "What is the purpose of the NATO Alliance?"
similarities, err := qdrantIndex.Query(
similarities, err := index.Query(
context.Background(),
query,
indexoption.WithTopK(3),
Expand Down Expand Up @@ -86,7 +87,7 @@ func main() {

}

func ingestData(qdrantIndex *qdrantindex.Index) error {
func ingestData(qdrantIndex *index.Index) error {

documents, err := loader.NewDirectoryLoader(".", ".txt").Load(context.Background())
if err != nil {
Expand Down
223 changes: 223 additions & 0 deletions index/engines/qdrant/qdrant.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
package qdrant

import (
"context"
"fmt"
"os"

"github.com/google/uuid"
"github.com/henomis/lingoose/index"
"github.com/henomis/lingoose/index/option"
qdrantgo "github.com/henomis/qdrant-go"
qdrantrequest "github.com/henomis/qdrant-go/request"
qdrantresponse "github.com/henomis/qdrant-go/response"
)

const (
defaultTopK = 10

Check failure on line 17 in index/engines/qdrant/qdrant.go

View workflow job for this annotation

GitHub Actions / lint

const `defaultTopK` is unused (unused)
)

type IndexEngine struct {
qdrantClient *qdrantgo.Client
collectionName string
includeContent bool
includeValues bool

createCollection *CreateCollectionOptions
}

type Distance string

const (
DistanceCosine Distance = Distance(qdrantrequest.DistanceCosine)
DistanceEuclidean Distance = Distance(qdrantrequest.DistanceEuclidean)
DistanceDot Distance = Distance(qdrantrequest.DistanceDot)
)

type CreateCollectionOptions struct {
Dimension uint64
Distance Distance
OnDisk bool
}

type Options struct {
CollectionName string
IncludeContent bool
IncludeValues bool
BatchUpsertSize *int
CreateCollection *CreateCollectionOptions
}

func New(options Options) *IndexEngine {
apiKey := os.Getenv("QDRANT_API_KEY")
endpoint := os.Getenv("QDRANT_ENDPOINT")

qdrantClient := qdrantgo.New(endpoint, apiKey)

return &IndexEngine{
qdrantClient: qdrantClient,
collectionName: options.CollectionName,
includeContent: options.IncludeContent,
includeValues: options.IncludeValues,
createCollection: options.CreateCollection,
}
}

func (i *IndexEngine) WithAPIKeyAndEdpoint(apiKey, endpoint string) *IndexEngine {
i.qdrantClient = qdrantgo.New(endpoint, apiKey)
return i
}

func (q *IndexEngine) IsEmpty(ctx context.Context) (bool, error) {

Check warning on line 71 in index/engines/qdrant/qdrant.go

View workflow job for this annotation

GitHub Actions / lint

receiver-naming: receiver name q should be consistent with previous receiver name i for IndexEngine (revive)
err := q.createCollectionIfRequired(ctx)
if err != nil {
return true, fmt.Errorf("%w: %w", index.ErrInternal, err)
}

res := &qdrantresponse.CollectionCollectInfo{}
err = q.qdrantClient.CollectionCollectInfo(
ctx,
&qdrantrequest.CollectionCollectInfo{
CollectionName: q.collectionName,
},
res,
)
if err != nil {
return true, fmt.Errorf("%w: %w", index.ErrInternal, err)
}

return res.Result.VectorsCount == 0, nil
}

func (i *IndexEngine) Insert(ctx context.Context, data []index.Data) error {

Check failure on line 92 in index/engines/qdrant/qdrant.go

View workflow job for this annotation

GitHub Actions / lint

ST1016: methods on the same type should have the same receiver name (seen 3x "i", 3x "q") (stylecheck)
err := i.createCollectionIfRequired(ctx)
if err != nil {
return fmt.Errorf("%w: %w", index.ErrInternal, err)
}

var points []qdrantrequest.Point
for _, d := range data {
if d.ID == "" {
id, errUUID := uuid.NewUUID()
if errUUID != nil {
return errUUID
}
d.ID = id.String()
}

point := qdrantrequest.Point{
ID: d.ID,
Vector: d.Values,
Payload: d.Metadata,
}
points = append(points, point)
}

wait := true
req := &qdrantrequest.PointUpsert{
Wait: &wait,
CollectionName: i.collectionName,
Points: points,
}
res := &qdrantresponse.PointUpsert{}

return i.qdrantClient.PointUpsert(ctx, req, res)
}

func (i *IndexEngine) Search(ctx context.Context, values []float64, options *option.Options) (index.SearchResults, error) {

Check failure on line 127 in index/engines/qdrant/qdrant.go

View workflow job for this annotation

GitHub Actions / lint

line is 123 characters (lll)
matches, err := i.similaritySearch(ctx, values, options)
if err != nil {
return nil, fmt.Errorf("%w: %w", index.ErrInternal, err)
}

return buildSearchResultsFromQdrantMatches(matches, i.includeContent), nil
}

func (q *IndexEngine) similaritySearch(

Check warning on line 136 in index/engines/qdrant/qdrant.go

View workflow job for this annotation

GitHub Actions / lint

receiver-naming: receiver name q should be consistent with previous receiver name i for IndexEngine (revive)
ctx context.Context,
values []float64,
opts *option.Options,
) ([]qdrantresponse.PointSearchResult, error) {
if opts.Filter == nil {
opts.Filter = qdrantrequest.Filter{}
}

includeMetadata := true
res := &qdrantresponse.PointSearch{}
err := q.qdrantClient.PointSearch(
ctx,
&qdrantrequest.PointSearch{
CollectionName: q.collectionName,
Limit: opts.TopK,
Vector: values,
WithPayload: &includeMetadata,
WithVector: &q.includeValues,
Filter: opts.Filter.(qdrantrequest.Filter),
},
res,
)
if err != nil {
return nil, err
}

return res.Result, nil
}

func (q *IndexEngine) createCollectionIfRequired(ctx context.Context) error {

Check warning on line 166 in index/engines/qdrant/qdrant.go

View workflow job for this annotation

GitHub Actions / lint

receiver-naming: receiver name q should be consistent with previous receiver name i for IndexEngine (revive)
if q.createCollection == nil {
return nil
}

resp := &qdrantresponse.CollectionList{}
err := q.qdrantClient.CollectionList(ctx, &qdrantrequest.CollectionList{}, resp)
if err != nil {
return err
}

for _, collection := range resp.Result.Collections {
if collection.Name == q.collectionName {
return nil
}
}

req := &qdrantrequest.CollectionCreate{
CollectionName: q.collectionName,
Vectors: qdrantrequest.VectorsParams{
Size: q.createCollection.Dimension,
Distance: qdrantrequest.Distance(q.createCollection.Distance),
OnDisk: &q.createCollection.OnDisk,
},
}

err = q.qdrantClient.CollectionCreate(ctx, req, &qdrantresponse.CollectionCreate{})
if err != nil {
return err
}

return nil
}

func buildSearchResultsFromQdrantMatches(
matches []qdrantresponse.PointSearchResult,
includeContent bool,
) index.SearchResults {
searchResults := make([]index.SearchResult, len(matches))

for i, match := range matches {
metadata := index.DeepCopyMetadata(match.Payload)
if !includeContent {
delete(metadata, index.DefaultKeyContent)
}

searchResults[i] = index.SearchResult{
Data: index.Data{
ID: match.ID,
Metadata: metadata,
Values: match.Vector,
},
Score: match.Score,
}
}

return searchResults
}
Loading

0 comments on commit 84d7dc5

Please sign in to comment.