Skip to content

Commit

Permalink
chore: refactor simpleVectorStorage (#126)
Browse files Browse the repository at this point in the history
* chore: refactor simpleVectorStorage

* fix
  • Loading branch information
henomis authored Sep 13, 2023
1 parent 5e29dde commit 1dfd66b
Showing 1 changed file with 37 additions and 22 deletions.
59 changes: 37 additions & 22 deletions index/simpleVectorIndex/simpleVectorIndex.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/henomis/lingoose/embedder"
"github.com/henomis/lingoose/index"
"github.com/henomis/lingoose/index/option"
"github.com/henomis/lingoose/types"
)

const (
Expand All @@ -20,8 +21,9 @@ const (
)

type data struct {
Document document.Document `json:"document"`
Embedding embedder.Embedding `json:"embedding"`
ID string `json:"id"`
Metadata types.Meta `json:"metadata"`
Values []float64 `json:"values"`
}

type Index struct {
Expand All @@ -45,10 +47,12 @@ func New(name string, outputPath string, embedder index.Embedder) *Index {
}

func (s *Index) LoadFromDocuments(ctx context.Context, documents []document.Document) error {
err := s.load()
if err != nil {
return fmt.Errorf("%s: %w", index.ErrInternal, err)
}

s.data = []data{}

documentIndex := 0
id := 0
for i := 0; i < len(documents); i += defaultBatchSize {

end := i + defaultBatchSize
Expand All @@ -67,25 +71,34 @@ func (s *Index) LoadFromDocuments(ctx context.Context, documents []document.Docu
}

for j, document := range documents[i:end] {
s.data = append(s.data, data{
Document: document,
Embedding: embeddings[j],
})

documents[documentIndex].Metadata[index.DefaultKeyID] = fmt.Sprintf("%d", documentIndex)
documentIndex++
s.data = append(s.data, buildDataFromEmbeddingAndDocument(id, embeddings[j], document))
id++
}

}

err := s.save()
err = s.save()
if err != nil {
return fmt.Errorf("%s: %w", index.ErrInternal, err)
}

return nil
}

func buildDataFromEmbeddingAndDocument(
id int,
embedding embedder.Embedding,
document document.Document,
) data {
metadata := index.DeepCopyMetadata(document.Metadata)
metadata[index.DefaultKeyContent] = document.Content
return data{
ID: fmt.Sprintf("%d", id),
Values: embedding,
Metadata: metadata,
}
}

func (s Index) save() error {

jsonContent, err := json.Marshal(s.data)
Expand All @@ -97,11 +110,14 @@ func (s Index) save() error {
}

func (s *Index) load() error {

if len(s.data) > 0 {
return nil
}

if _, err := os.Stat(s.database()); os.IsNotExist(err) {
return s.save()
}

content, err := os.ReadFile(s.database())
if err != nil {
return err
Expand Down Expand Up @@ -149,13 +165,13 @@ func (s *Index) SimilaritySearch(ctx context.Context, query string, opts ...opti
searchResponses := make([]index.SearchResponse, len(scores))

for i, score := range scores {

id := s.data[i].Document.Metadata[index.DefaultKeyID].(string)

searchResponses[i] = index.SearchResponse{
ID: id,
Document: s.data[i].Document,
Score: score,
ID: s.data[i].ID,
Document: document.Document{
Content: s.data[i].Metadata[index.DefaultKeyContent].(string),
Metadata: s.data[i].Metadata,
},
Score: score,
}
}

Expand Down Expand Up @@ -185,11 +201,10 @@ func (s *Index) cosineSimilarity(a embedder.Embedding, b embedder.Embedding) flo
}

func (s *Index) cosineSimilarityBatch(a embedder.Embedding) []float64 {

scores := make([]float64, len(s.data))

for i := range s.data {
scores[i] = s.cosineSimilarity(a, s.data[i].Embedding)
scores[i] = s.cosineSimilarity(a, s.data[i].Values)
}

return scores
Expand Down

0 comments on commit 1dfd66b

Please sign in to comment.