Skip to content

Commit

Permalink
Merge pull request #3 from philippe2803/feature/add-sqlite-vss
Browse files Browse the repository at this point in the history
Adjust for mypy and tests
  • Loading branch information
philippe2803 authored Aug 31, 2023
2 parents 3b425cd + 29f3c16 commit 34a61d3
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 21 deletions.
33 changes: 16 additions & 17 deletions libs/langchain/langchain/vectorstores/sqlitevss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import sqlite3
import warnings
from typing import (
TYPE_CHECKING,
Any,
Iterable,
List,
Expand All @@ -18,9 +17,6 @@
from langchain.embeddings.base import Embeddings
from langchain.vectorstores.base import VectorStore

if TYPE_CHECKING:
import sqlite_vss # noqa # pylint: disable=unused-import

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -67,10 +63,10 @@ def create_table_if_not_exists(self) -> None:
f"""
CREATE TABLE IF NOT EXISTS {self._table}
(
text_id INT PRIMARY KEY AUTOINCREMENT,
text text,
metadata blob,
text_embedding blob
rowid INTEGER PRIMARY KEY AUTOINCREMENT,
text TEXT,
metadata BLOB,
text_embedding BLOB
)
;
"""
Expand Down Expand Up @@ -108,8 +104,11 @@ def add_texts(
kwargs: vectorstore specific parameters
"""
max_id = self._connection.execute(
f"SELECT max(text_id) as text_id FROM {self._table}"
).fetchone()["text_id"]
f"SELECT max(rowid) as rowid FROM {self._table}"
).fetchone()["rowid"]
if max_id is None: # no text added yet
max_id = 0

embeds = self._embedding.embed_documents(list(texts))
if not metadatas:
metadatas = [{} for _ in texts]
Expand All @@ -123,12 +122,11 @@ def add_texts(
data_input,
)
self._connection.commit()

# pulling every ids we just inserted
results = self._connection.execute(
f"SELECT text_id FROM {self._table} WHERE text_id > {max_id}"
f"SELECT rowid FROM {self._table} WHERE rowid > {max_id}"
)
return [row["text_id"] for row in results]
return [row["rowid"] for row in results]

def similarity_search_with_score_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
Expand All @@ -139,7 +137,7 @@ def similarity_search_with_score_by_vector(
metadata,
distance
FROM {self._table} e
INNER JOIN vss_{self._table} v on v.text_id = e.text_id
INNER JOIN vss_{self._table} v on v.rowid = e.rowid
WHERE vss_search(
v.text_embedding,
vss_search_params('{json.dumps(embedding)}', {k})
Expand All @@ -151,9 +149,8 @@ def similarity_search_with_score_by_vector(

documents = []
for row in results:
doc = Document(
page_content=row["text"], metadata=json.loads(row["metadata"])
)
metadata = json.loads(row["metadata"]) or {}
doc = Document(page_content=row["text"], metadata=metadata)
score = self._euclidean_relevance_score_fn(row["distance"])
documents.append((doc, score))

Expand Down Expand Up @@ -207,6 +204,8 @@ def from_texts(

@staticmethod
def create_connection(db_file: str) -> sqlite3.Connection:
import sqlite_vss

connection = sqlite3.connect(db_file)
connection.row_factory = sqlite3.Row
connection.enable_load_extension(True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_sqlitevss() -> None:
"""Test end to end construction and search."""
docsearch = _sqlite_vss_from_texts()
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", metadata=None)]
assert output == [Document(page_content="foo", metadata={})]


@pytest.mark.requires("sqlite-vss")
Expand All @@ -44,7 +44,7 @@ def test_sqlitevss_with_score() -> None:
Document(page_content="bar", metadata={"page": 1}),
Document(page_content="baz", metadata={"page": 2}),
]
assert scores[0] < scores[1] < scores[2]
assert scores[0] > scores[1] > scores[2]


@pytest.mark.requires("sqlite-vss")
Expand All @@ -53,8 +53,6 @@ def test_sqlitevss_add_extra() -> None:
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _sqlite_vss_from_texts(metadatas=metadatas)

docsearch.add_texts(texts, metadatas)

output = docsearch.similarity_search("foo", k=10)
assert len(output) == 6

0 comments on commit 34a61d3

Please sign in to comment.