Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CrateDB vector: Add CrateDBVectorSearchMultiCollection #15

Merged
merged 3 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/docs/integrations/providers/cratedb.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ export OPENAI_API_KEY=foobar # FIXME
export CRATEDB_CONNECTION_STRING=crate://crate@localhost
```

### Example

Load and index documents, and invoke query.
```python
from langchain.document_loaders import UnstructuredURLLoader
from langchain.embeddings.openai import OpenAIEmbeddings
Expand Down
88 changes: 54 additions & 34 deletions docs/docs/integrations/vectorstores/cratedb.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,11 @@
{
"cell_type": "markdown",
"source": [
"Next, you will read input data, and tokenize it."
"## Load and Index Documents\n",
"\n",
"Next, you will read input data, and tokenize it. The module will create a table\n",
"with the name of the collection. Make sure the collection name is unique, and\n",
"that you have the permission to create a table."
],
"metadata": {
"collapsed": false
Expand All @@ -196,7 +200,18 @@
"loader = UnstructuredURLLoader(\"https://github.com/langchain-ai/langchain/raw/v0.0.325/docs/docs/modules/state_of_the_union.txt\")\n",
"documents = loader.load()\n",
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
"docs = text_splitter.split_documents(documents)"
"docs = text_splitter.split_documents(documents)\n",
"\n",
"COLLECTION_NAME = \"state_of_the_union_test\"\n",
"\n",
"embeddings = OpenAIEmbeddings()\n",
"\n",
"db = CrateDBVectorSearch.from_documents(\n",
" embedding=embeddings,\n",
" documents=docs,\n",
" collection_name=COLLECTION_NAME,\n",
" connection_string=CONNECTION_STRING,\n",
")"
],
"metadata": {
"collapsed": false,
Expand All @@ -208,39 +223,15 @@
{
"cell_type": "markdown",
"source": [
"## Similarity Search with Euclidean Distance (Default)\n",
"## Search Documents\n",
"\n",
"The module will create a table with the name of the collection. Make sure\n",
"the collection name is unique and that you have the permission to create\n",
"a table."
"### Similarity Search with Euclidean Distance\n",
"Searching by euclidean distance is the default."
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2023-09-09T08:04:16.696625Z",
"start_time": "2023-09-09T08:02:31.817790Z"
}
},
"outputs": [],
"source": [
"COLLECTION_NAME = \"state_of_the_union_test\"\n",
"\n",
"embeddings = OpenAIEmbeddings()\n",
"\n",
"db = CrateDBVectorSearch.from_documents(\n",
" embedding=embeddings,\n",
" documents=docs,\n",
" collection_name=COLLECTION_NAME,\n",
" connection_string=CONNECTION_STRING,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -277,7 +268,7 @@
{
"cell_type": "markdown",
"source": [
"## Maximal Marginal Relevance Search (MMR)\n",
"### Maximal Marginal Relevance Search (MMR)\n",
"Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents."
],
"metadata": {
Expand Down Expand Up @@ -318,11 +309,40 @@
}
}
},
{
"cell_type": "markdown",
"source": [
"### Searching in Multiple Collections\n",
"`CrateDBVectorSearchMultiCollection` is a special adapter which provides similarity search across\n",
"multiple collections. It can not be used for indexing documents."
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"from langchain.vectorstores.cratedb import CrateDBVectorSearchMultiCollection\n",
"\n",
"multisearch = CrateDBVectorSearchMultiCollection(\n",
" collection_names=[\"test_collection_1\", \"test_collection_2\"],\n",
" embedding_function=embeddings,\n",
" connection_string=CONNECTION_STRING,\n",
")\n",
"docs_with_score = multisearch.similarity_search_with_score(query)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Working with the vector store\n",
"## Working with the Vector Store\n",
"\n",
"In the example above, you created a vector store from scratch. When\n",
"aiming to work with an existing vector store, you can initialize it directly."
Expand All @@ -345,7 +365,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Add documents\n",
"### Add Documents\n",
"\n",
"You can also add documents to an existing vector store."
]
Expand Down Expand Up @@ -390,7 +410,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Overwriting a vector store\n",
"### Overwriting a Vector Store\n",
"\n",
"If you have an existing collection, you can overwrite it by using `from_documents`,\n",
"aad setting `pre_delete_collection = True`."
Expand Down Expand Up @@ -433,7 +453,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Using a vector store as a retriever"
"### Using a Vector Store as a Retriever"
]
},
{
Expand Down
2 changes: 2 additions & 0 deletions libs/langchain/langchain/vectorstores/cratedb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .base import CrateDBVectorSearch
from .extended import CrateDBVectorSearchMultiCollection

__all__ = [
"CrateDBVectorSearch",
"CrateDBVectorSearchMultiCollection",
]
38 changes: 35 additions & 3 deletions libs/langchain/langchain/vectorstores/cratedb/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,24 @@ def add_embeddings(
if not embeddings:
return []
self._init_models(embeddings[0])

# When the user requested to delete the collection before running subsequent
# operations on it, run the deletion gracefully if the table does not exist
# yet.
if self.pre_delete_collection:
self.delete_collection()
try:
self.delete_collection()
except sqlalchemy.exc.ProgrammingError as ex:
if "RelationUnknown" not in str(ex):
raise

# Tables need to be created at runtime, because the `EmbeddingStore.embedding`
# field, a `FloatVector`, needs to be initialized with a dimensionality
# parameter, which is only obtained at runtime.
self.create_tables_if_not_exists()
self.create_collection()

# After setting up the table/collection at runtime, add embeddings.
return super().add_embeddings(
texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
)
Expand Down Expand Up @@ -250,8 +264,26 @@ def _query_collection(
collection = self.get_collection(session)
if not collection:
raise ValueError("Collection not found")
return self._query_collection_multi(
collections=[collection], embedding=embedding, k=k, filter=filter
)

filter_by = self.EmbeddingStore.collection_id == collection.uuid
def _query_collection_multi(
self,
collections: List[Any],
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, str]] = None,
) -> List[Any]:
"""Query the collection."""
self._init_models(embedding)

collection_names = [coll.name for coll in collections]
collection_uuids = [coll.uuid for coll in collections]
self.logger.info(f"Querying collections: {collection_names}")

with self.Session() as session:
filter_by = self.EmbeddingStore.collection_id.in_(collection_uuids)

if filter is not None:
filter_clauses = []
Expand All @@ -271,7 +303,7 @@ def _query_collection(
) # type: ignore[assignment]
filter_clauses.append(filter_by_metadata)

filter_by = sqlalchemy.and_(filter_by, *filter_clauses)
filter_by = sqlalchemy.and_(filter_by, *filter_clauses) # type: ignore[assignment]

_type = self.EmbeddingStore

Expand Down
92 changes: 92 additions & 0 deletions libs/langchain/langchain/vectorstores/cratedb/extended.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import logging
from typing import (
Any,
Callable,
Dict,
List,
Optional,
)

import sqlalchemy
from sqlalchemy.orm import sessionmaker

from langchain.schema.embeddings import Embeddings
from langchain.vectorstores.cratedb.base import (
DEFAULT_DISTANCE_STRATEGY,
CrateDBVectorSearch,
DistanceStrategy,
)
from langchain.vectorstores.pgvector import _LANGCHAIN_DEFAULT_COLLECTION_NAME


class CrateDBVectorSearchMultiCollection(CrateDBVectorSearch):
"""
Provide functionality for searching multiple collections.
It can not be used for indexing documents.

To use it, you should have the ``crate[sqlalchemy]`` Python package installed.

Synopsis::

from langchain.vectorstores.cratedb import CrateDBVectorSearchMultiCollection

multisearch = CrateDBVectorSearchMultiCollection(
collection_names=["collection_foo", "collection_bar"],
embedding_function=embeddings,
connection_string=CONNECTION_STRING,
)
docs_with_score = multisearch.similarity_search_with_score(query)
"""

def __init__(
self,
connection_string: str,
embedding_function: Embeddings,
collection_names: List[str] = [_LANGCHAIN_DEFAULT_COLLECTION_NAME],
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, # type: ignore[arg-type]
logger: Optional[logging.Logger] = None,
relevance_score_fn: Optional[Callable[[float], float]] = None,
*,
connection: Optional[sqlalchemy.engine.Connection] = None,
engine_args: Optional[dict[str, Any]] = None,
) -> None:
self.connection_string = connection_string
self.embedding_function = embedding_function
self.collection_names = collection_names
self._distance_strategy = distance_strategy # type: ignore[assignment]
self.logger = logger or logging.getLogger(__name__)
self.override_relevance_score_fn = relevance_score_fn
self.engine_args = engine_args or {}
# Create a connection if not provided, otherwise use the provided connection
self._engine = self.create_engine()
self.Session = sessionmaker(self._engine)
self._conn = connection if connection else self.connect()
self.__post_init__()

@classmethod
def _from(cls, *args: List, **kwargs: Dict): # type: ignore[no-untyped-def,override]
raise NotImplementedError("This adapter can not be used for indexing documents")

def get_collections(self, session: sqlalchemy.orm.Session) -> Any:
if self.CollectionStore is None:
raise RuntimeError(
"Collection can't be accessed without specifying "
matriv marked this conversation as resolved.
Show resolved Hide resolved
"dimension size of embedding vectors"
)
return self.CollectionStore.get_by_names(session, self.collection_names)

def _query_collection(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, str]] = None,
) -> List[Any]:
"""Query multiple collections."""
self._init_models(embedding)
with self.Session() as session:
collections = self.get_collections(session)
if not collections:
raise ValueError("No collections found")
return self._query_collection_multi(
collections=collections, embedding=embedding, k=k, filter=filter
)
17 changes: 8 additions & 9 deletions libs/langchain/langchain/vectorstores/cratedb/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import uuid
from typing import Any, Optional, Tuple
from typing import Any, List, Optional, Tuple

import sqlalchemy
from crate.client.sqlalchemy.types import ObjectType
Expand Down Expand Up @@ -53,14 +53,13 @@ class CollectionStore(BaseModel):
def get_by_name(
cls, session: Session, name: str
) -> Optional["CollectionStore"]:
try:
return (
session.query(cls).filter(cls.name == name).first() # type: ignore[attr-defined] # noqa: E501
)
except sqlalchemy.exc.ProgrammingError as ex:
if "RelationUnknown" not in str(ex):
raise
return None
return session.query(cls).filter(cls.name == name).first() # type: ignore[attr-defined]

@classmethod
def get_by_names(
cls, session: Session, names: List[str]
) -> Optional["List[CollectionStore]"]:
return session.query(cls).filter(cls.name.in_(names)).all() # type: ignore[attr-defined]

@classmethod
def get_or_create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_query(self, text: str) -> List[float]:
"""Return consistent embeddings for the text, if seen before, or a constant
one if the text is unknown."""
return self.embed_documents([text])[0]
if text not in self.known_texts:
return [float(1.0)] * (self.dimensionality - 1) + [float(0.0)]
return [float(1.0)] * (self.dimensionality - 1) + [
Expand Down
Loading