-
Notifications
You must be signed in to change notification settings - Fork 66
/
exact.py
48 lines (42 loc) · 1.83 KB
/
exact.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from colbert.infra import ColBERTConfig
from colbert.modeling.checkpoint import Checkpoint
from pgvector.psycopg import register_vector
import psycopg
conn = psycopg.connect(dbname='pgvector_example', autocommit=True)
conn.execute('CREATE EXTENSION IF NOT EXISTS vector')
register_vector(conn)
conn.execute('DROP TABLE IF EXISTS documents')
conn.execute('CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embeddings vector(128)[])')
conn.execute("""
CREATE OR REPLACE FUNCTION max_sim(document vector[], query vector[]) RETURNS double precision AS $$
WITH queries AS (
SELECT row_number() OVER () AS query_number, * FROM (SELECT unnest(query) AS query)
),
documents AS (
SELECT unnest(document) AS document
),
similarities AS (
SELECT query_number, 1 - (document <=> query) AS similarity FROM queries CROSS JOIN documents
),
max_similarities AS (
SELECT MAX(similarity) AS max_similarity FROM similarities GROUP BY query_number
)
SELECT SUM(max_similarity) FROM max_similarities
$$ LANGUAGE SQL
""")
config = ColBERTConfig(doc_maxlen=220, query_maxlen=32)
checkpoint = Checkpoint('colbert-ir/colbertv2.0', colbert_config=config, verbose=0)
input = [
'The dog is barking',
'The cat is purring',
'The bear is growling'
]
doc_embeddings = checkpoint.docFromText(input, keep_dims=False)
for content, embeddings in zip(input, doc_embeddings):
embeddings = [e.numpy() for e in embeddings]
conn.execute('INSERT INTO documents (content, embeddings) VALUES (%s, %s)', (content, embeddings))
query = 'puppy'
query_embeddings = [e.numpy() for e in checkpoint.queryFromText([query])[0]]
result = conn.execute('SELECT content, max_sim(embeddings, %s) AS max_sim FROM documents ORDER BY max_sim DESC LIMIT 5', (query_embeddings,)).fetchall()
for row in result:
print(row)