Skip to content

Commit

Permalink
Fixes tmc#974 - returning score driven by cosine similarity by (1 - d…
Browse files Browse the repository at this point in the history
…istance) instead of distance (tmc#1048)

Fixing the returned value of cosine similarity by (1 - distance)

Co-authored-by: avi.tal <[email protected]>
  • Loading branch information
avi3tal and afavi3tal authored Oct 24, 2024
1 parent 1794009 commit 238d1c7
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
2 changes: 1 addition & 1 deletion vectorstores/pgvector/pgvector.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ func (s Store) SimilaritySearch(
SELECT
data.document,
data.cmetadata,
data.distance
(1 - data.distance) AS score
FROM (
SELECT
filtered_embedding_dims.*,
Expand Down
45 changes: 45 additions & 0 deletions vectorstores/pgvector/pgvector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,51 @@ func TestPgvectorStoreRestWithScoreThreshold(t *testing.T) {
require.Len(t, docs, 10)
}

func TestPgvectorStoreSimilarityScore(t *testing.T) {
t.Parallel()
pgvectorURL := preCheckEnvSetting(t)
ctx := context.Background()

llm, err := openai.New(
openai.WithEmbeddingModel("text-embedding-ada-002"),
)
require.NoError(t, err)
e, err := embeddings.NewEmbedder(llm)
require.NoError(t, err)

conn, err := pgx.Connect(ctx, pgvectorURL)
require.NoError(t, err)

store, err := pgvector.New(
ctx,
pgvector.WithConn(conn),
pgvector.WithEmbedder(e),
pgvector.WithPreDeleteCollection(true),
pgvector.WithCollectionName(makeNewCollectionName()),
)
require.NoError(t, err)

defer cleanupTestArtifacts(ctx, t, store, pgvectorURL)

_, err = store.AddDocuments(context.Background(), []schema.Document{
{PageContent: "Tokyo is the capital city of Japan."},
{PageContent: "Paris is the city of love."},
{PageContent: "I like to visit London."},
})
require.NoError(t, err)

// test with a score threshold of 0.8, expected 6 documents
docs, err := store.SimilaritySearch(
ctx,
"What is the capital city of Japan?",
3,
vectorstores.WithScoreThreshold(0.8),
)
require.NoError(t, err)
require.Len(t, docs, 1)
require.True(t, docs[0].Score > 0.9)
}

func TestSimilaritySearchWithInvalidScoreThreshold(t *testing.T) {
t.Parallel()
pgvectorURL := preCheckEnvSetting(t)
Expand Down

0 comments on commit 238d1c7

Please sign in to comment.