diff --git a/index/simpleVectorIndex/simpleVectorIndex.go b/index/simpleVectorIndex/simpleVectorIndex.go index 53208caf..81bbbec1 100644 --- a/index/simpleVectorIndex/simpleVectorIndex.go +++ b/index/simpleVectorIndex/simpleVectorIndex.go @@ -12,6 +12,7 @@ import ( "github.com/henomis/lingoose/embedder" "github.com/henomis/lingoose/index" "github.com/henomis/lingoose/index/option" + "github.com/henomis/lingoose/types" ) const ( @@ -20,8 +21,9 @@ const ( ) type data struct { - Document document.Document `json:"document"` - Embedding embedder.Embedding `json:"embedding"` + ID string `json:"id"` + Metadata types.Meta `json:"metadata"` + Values []float64 `json:"values"` } type Index struct { @@ -45,10 +47,12 @@ func New(name string, outputPath string, embedder index.Embedder) *Index { } func (s *Index) LoadFromDocuments(ctx context.Context, documents []document.Document) error { + err := s.load() + if err != nil { + return fmt.Errorf("%s: %w", index.ErrInternal, err) + } - s.data = []data{} - - documentIndex := 0 + id := 0 for i := 0; i < len(documents); i += defaultBatchSize { end := i + defaultBatchSize @@ -67,18 +71,13 @@ func (s *Index) LoadFromDocuments(ctx context.Context, documents []document.Docu } for j, document := range documents[i:end] { - s.data = append(s.data, data{ - Document: document, - Embedding: embeddings[j], - }) - - documents[documentIndex].Metadata[index.DefaultKeyID] = fmt.Sprintf("%d", documentIndex) - documentIndex++ + s.data = append(s.data, buildDataFromEmbeddingAndDocument(id, embeddings[j], document)) + id++ } } - err := s.save() + err = s.save() if err != nil { return fmt.Errorf("%s: %w", index.ErrInternal, err) } @@ -86,6 +85,20 @@ func (s *Index) LoadFromDocuments(ctx context.Context, documents []document.Docu return nil } +func buildDataFromEmbeddingAndDocument( + id int, + embedding embedder.Embedding, + document document.Document, +) data { + metadata := index.DeepCopyMetadata(document.Metadata) + metadata[index.DefaultKeyContent] = document.Content + return data{ + ID: fmt.Sprintf("%d", id), + Values: embedding, + Metadata: metadata, + } +} + func (s Index) save() error { jsonContent, err := json.Marshal(s.data) @@ -97,11 +110,14 @@ func (s Index) save() error { } func (s *Index) load() error { - if len(s.data) > 0 { return nil } + if _, err := os.Stat(s.database()); os.IsNotExist(err) { + return s.save() + } + content, err := os.ReadFile(s.database()) if err != nil { return err @@ -149,13 +165,13 @@ func (s *Index) SimilaritySearch(ctx context.Context, query string, opts ...opti searchResponses := make([]index.SearchResponse, len(scores)) for i, score := range scores { - - id := s.data[i].Document.Metadata[index.DefaultKeyID].(string) - searchResponses[i] = index.SearchResponse{ - ID: id, - Document: s.data[i].Document, - Score: score, + ID: s.data[i].ID, + Document: document.Document{ + Content: s.data[i].Metadata[index.DefaultKeyContent].(string), + Metadata: s.data[i].Metadata, + }, + Score: score, } } @@ -185,11 +201,10 @@ func (s *Index) cosineSimilarity(a embedder.Embedding, b embedder.Embedding) flo } func (s *Index) cosineSimilarityBatch(a embedder.Embedding) []float64 { - scores := make([]float64, len(s.data)) for i := range s.data { - scores[i] = s.cosineSimilarity(a, s.data[i].Embedding) + scores[i] = s.cosineSimilarity(a, s.data[i].Values) } return scores