diff --git a/go/plugins/localvec/localvec.go b/go/plugins/localvec/localvec.go index ecc256f79..d03b7d501 100644 --- a/go/plugins/localvec/localvec.go +++ b/go/plugins/localvec/localvec.go @@ -134,6 +134,8 @@ func (r *retriever) Index(ctx context.Context, req *ai.IndexerRequest) error { } // Update the file every time we add documents. + // We use a temporary file to avoid losing the original + // file, in case of a crash. tmpname := r.filename + ".tmp" f, err := os.Create(tmpname) if err != nil { @@ -146,6 +148,9 @@ func (r *retriever) Index(ctx context.Context, req *ai.IndexerRequest) error { if err := f.Close(); err != nil { return err } + if err := os.Rename(tmpname, r.filename); err != nil { + return err + } return nil } diff --git a/go/plugins/localvec/localvec_test.go b/go/plugins/localvec/localvec_test.go index 80f9bf9c9..f50ff9347 100644 --- a/go/plugins/localvec/localvec_test.go +++ b/go/plugins/localvec/localvec_test.go @@ -89,6 +89,93 @@ func TestLocalVec(t *testing.T) { } } +func TestPersistentIndexing(t *testing.T) { + ctx := context.Background() + + const dim = 32 + v1 := make([]float32, dim) + v2 := make([]float32, dim) + v3 := make([]float32, dim) + for i := range v1 { + v1[i] = float32(i) + v2[i] = float32(i) + v3[i] = float32(i) + } + + d1 := ai.DocumentFromText("hello1", nil) + d2 := ai.DocumentFromText("hello2", nil) + d3 := ai.DocumentFromText("goodbye", nil) + + embedder := fakeembedder.New() + embedder.Register(d1, v1) + embedder.Register(d2, v2) + embedder.Register(d3, v3) + + tDir := t.TempDir() + + r, err := newRetriever(ctx, tDir, "testLocalVec", embedder, nil) + if err != nil { + t.Fatal(err) + } + + indexerReq := &ai.IndexerRequest{ + Documents: []*ai.Document{d1, d2}, + } + err = r.Index(ctx, indexerReq) + if err != nil { + t.Fatalf("Index operation failed: %v", err) + } + + retrieverOptions := &RetrieverOptions{ + K: 100, // fetch all docs + } + + retrieverReq := &ai.RetrieverRequest{ + Document: d1, + Options: retrieverOptions, + } + retrieverResp, err := r.Retrieve(ctx, retrieverReq) + if err != nil { + t.Fatalf("Retrieve operation failed: %v", err) + } + + docs := retrieverResp.Documents + if len(docs) != 2 { + t.Errorf("got %d results, expected 2", len(docs)) + } + + rAnother, err := newRetriever(ctx, tDir, "testLocalVec", embedder, nil) + if err != nil { + t.Fatal(err) + } + + indexerReq = &ai.IndexerRequest{ + Documents: []*ai.Document{d3}, + } + err = rAnother.Index(ctx, indexerReq) + if err != nil { + t.Fatalf("Index operation failed: %v", err) + } + + retrieverOptions = &RetrieverOptions{ + K: 100, // fetch all docs + } + + retrieverReq = &ai.RetrieverRequest{ + Document: d1, + Options: retrieverOptions, + } + retrieverResp, err = rAnother.Retrieve(ctx, retrieverReq) + if err != nil { + t.Fatalf("Retrieve operation failed: %v", err) + } + + docs = retrieverResp.Documents + if len(docs) != 3 { + t.Errorf("got %d results, expected 3", len(docs)) + } +} + func TestSimilarity(t *testing.T) { x := []float32{5, 23, 2, 5, 9} y := []float32{3, 21, 2, 5, 14}