diff --git a/libs/checkpoint-duckdb/langgraph/store/duckdb/base.py b/libs/checkpoint-duckdb/langgraph/store/duckdb/base.py index e0fb57067..89bf13681 100644 --- a/libs/checkpoint-duckdb/langgraph/store/duckdb/base.py +++ b/libs/checkpoint-duckdb/langgraph/store/duckdb/base.py @@ -23,6 +23,7 @@ Op, PutOp, Result, + SearchItem, SearchOp, ) @@ -283,7 +284,7 @@ def _batch_search_ops( for cur, idx in cursors: rows = cur.fetchall() - items = [_row_to_item(_convert_ns(row[0]), row) for row in rows] + items = [_row_to_search_item(_convert_ns(row[0]), row) for row in rows] results[idx] = items def _batch_list_namespaces_ops( @@ -376,6 +377,22 @@ def _row_to_item( ) +def _row_to_search_item( + namespace: tuple[str, ...], + row: tuple, +) -> SearchItem: + """Convert a row from the database into an SearchItem.""" + # TODO: Add support for search + _, key, val, created_at, updated_at = row + return SearchItem( + value=val if isinstance(val, dict) else json.loads(val), + key=key, + namespace=namespace, + created_at=created_at, + updated_at=updated_at, + ) + + def _group_ops(ops: Iterable[Op]) -> tuple[dict[type, list[tuple[int, Op]]], int]: grouped_ops: dict[type, list[tuple[int, Op]]] = defaultdict(list) tot = 0 diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py index 2107b05a3..1a3aff119 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py @@ -378,15 +378,19 @@ def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]: # a connection not in pipeline mode can only be used by one # thread/coroutine at a time, so we acquire a lock if self.supports_pipeline: - with self.lock, conn.pipeline(), conn.cursor( - binary=True, row_factory=dict_row - ) as cur: + with ( + self.lock, + conn.pipeline(), + conn.cursor(binary=True, row_factory=dict_row) as cur, + ): yield cur else: # Use connection's transaction context manager when pipeline mode not supported - with self.lock, conn.transaction(), conn.cursor( - binary=True, row_factory=dict_row - ) as cur: + with ( + self.lock, + conn.transaction(), + conn.cursor(binary=True, row_factory=dict_row) as cur, + ): yield cur else: with self.lock, conn.cursor(binary=True, row_factory=dict_row) as cur: diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py index 589520efc..5b07e0067 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py @@ -338,20 +338,25 @@ async def _cursor( # a connection not in pipeline mode can only be used by one # thread/coroutine at a time, so we acquire a lock if self.supports_pipeline: - async with self.lock, conn.pipeline(), conn.cursor( - binary=True, row_factory=dict_row - ) as cur: + async with ( + self.lock, + conn.pipeline(), + conn.cursor(binary=True, row_factory=dict_row) as cur, + ): yield cur else: # Use connection's transaction context manager when pipeline mode not supported - async with self.lock, conn.transaction(), conn.cursor( - binary=True, row_factory=dict_row - ) as cur: + async with ( + self.lock, + conn.transaction(), + conn.cursor(binary=True, row_factory=dict_row) as cur, + ): yield cur else: - async with self.lock, conn.cursor( - binary=True, row_factory=dict_row - ) as cur: + async with ( + self.lock, + conn.cursor(binary=True, row_factory=dict_row) as cur, + ): yield cur def list( diff --git a/libs/checkpoint-postgres/langgraph/store/postgres/aio.py b/libs/checkpoint-postgres/langgraph/store/postgres/aio.py index 578523052..b90a9a0d5 100644 --- a/libs/checkpoint-postgres/langgraph/store/postgres/aio.py +++ b/libs/checkpoint-postgres/langgraph/store/postgres/aio.py @@ -28,6 +28,7 @@ _decode_ns_bytes, _group_ops, _row_to_item, + _row_to_search_item, ) logger = logging.getLogger(__name__) @@ -146,7 +147,7 @@ async def _batch_search_ops( await cur.execute(query, params) rows = cast(list[Row], await cur.fetchall()) items = [ - _row_to_item( + _row_to_search_item( _decode_ns_bytes(row["prefix"]), row, loader=self._deserializer ) for row in rows @@ -195,9 +196,11 @@ async def _cursor( async with self.lock, conn.pipeline(), conn.cursor(binary=True) as cur: yield cur else: - async with self.lock, conn.transaction(), conn.cursor( - binary=True - ) as cur: + async with ( + self.lock, + conn.transaction(), + conn.cursor(binary=True) as cur, + ): yield cur else: async with conn.cursor(binary=True) as cur: diff --git a/libs/checkpoint-postgres/langgraph/store/postgres/base.py b/libs/checkpoint-postgres/langgraph/store/postgres/base.py index 3bb343b0c..0dfce8871 100644 --- a/libs/checkpoint-postgres/langgraph/store/postgres/base.py +++ b/libs/checkpoint-postgres/langgraph/store/postgres/base.py @@ -35,7 +35,9 @@ ListNamespacesOp, Op, PutOp, + ResponseMetadata, Result, + SearchItem, SearchOp, ) @@ -344,14 +346,18 @@ def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]: # a connection not in pipeline mode can only be used by one # thread/coroutine at a time, so we acquire a lock if self.supports_pipeline: - with self.lock, conn.pipeline(), conn.cursor( - binary=True, row_factory=dict_row - ) as cur: + with ( + self.lock, + conn.pipeline(), + conn.cursor(binary=True, row_factory=dict_row) as cur, + ): yield cur else: - with self.lock, conn.transaction(), conn.cursor( - binary=True, row_factory=dict_row - ) as cur: + with ( + self.lock, + conn.transaction(), + conn.cursor(binary=True, row_factory=dict_row) as cur, + ): yield cur else: with conn.cursor(binary=True, row_factory=dict_row) as cur: @@ -430,7 +436,7 @@ def _batch_search_ops( cur.execute(query, params) rows = cast(list[Row], cur.fetchall()) results[idx] = [ - _row_to_item( + _row_to_search_item( _decode_ns_bytes(row["prefix"]), row, loader=self._deserializer ) for row in rows @@ -517,6 +523,32 @@ def _row_to_item( ) +def _row_to_search_item( + namespace: tuple[str, ...], + row: Row, + *, + loader: Optional[Callable[[Union[bytes, orjson.Fragment]], dict[str, Any]]] = None, +) -> SearchItem: + """Convert a row from the database into an Item.""" + loader = loader or _json_loads + val = row["value"] + response_metadata: Optional[ResponseMetadata] = ( + { + "score": float(row["score"]), + } + if row.get("score") is not None + else None + ) + return SearchItem( + value=val if isinstance(val, dict) else loader(val), + key=row["key"], + namespace=namespace, + created_at=row["created_at"], + updated_at=row["updated_at"], + response_metadata=response_metadata, + ) + + def _group_ops(ops: Iterable[Op]) -> tuple[dict[type, list[tuple[int, Op]]], int]: grouped_ops: dict[type, list[tuple[int, Op]]] = defaultdict(list) tot = 0 diff --git a/libs/checkpoint/langgraph/store/base/__init__.py b/libs/checkpoint/langgraph/store/base/__init__.py index 098462339..ef6729d83 100644 --- a/libs/checkpoint/langgraph/store/base/__init__.py +++ b/libs/checkpoint/langgraph/store/base/__init__.py @@ -1,12 +1,27 @@ """Base classes and types for persistent key-value stores. -Stores enable persistence and memory that can be shared across threads, -scoped to user IDs, assistant IDs, or other arbitrary namespaces. +Stores provide long-term memory that persists across threads and conversations. +Supports hierarchical namespaces, key-value storage, and optional vector search. + +Core types: +- BaseStore: Store interface with sync/async operations +- Item: Stored key-value pairs with metadata +- Op: Get/Put/Search/List operations """ from abc import ABC, abstractmethod from datetime import datetime -from typing import Any, Iterable, Literal, NamedTuple, Optional, Union, cast +from typing import Any, Iterable, Literal, NamedTuple, Optional, TypedDict, Union, cast + +from langchain_core.embeddings import Embeddings + +from langgraph.store.base._embed import ( + AEmbeddingsFunc, + EmbeddingsFunc, + ensure_embeddings, + get_text_at_path, + tokenize_path, +) class Item: @@ -73,6 +88,52 @@ def dict(self) -> dict: } +class ResponseMetadata(TypedDict, total=False): + """Additional metadata about the response/result.""" + + score: float + """Relevance/similarity score if from a ranked operation.""" + + +class SearchItem(Item): + """Represents a result item with additional response metadata.""" + + __slots__ = "response_metadata" + + def __init__( + self, + namespace: tuple[str, ...], + key: str, + value: dict[str, Any], + created_at: datetime, + updated_at: datetime, + response_metadata: Optional[ResponseMetadata] = None, + ) -> None: + """Initialize a result item. + + Args: + namespace: Hierarchical path to the item. + key: Unique identifier within the namespace. + value: The stored value. + created_at: When the item was first created. + updated_at: When the item was last updated. + response_metadata: Optional metadata about the response/result. + """ + super().__init__( + value=value, + key=key, + namespace=namespace, + created_at=created_at, + updated_at=updated_at, + ) + self.response_metadata = response_metadata or {} + + def dict(self) -> dict: + result = super().dict() + result["response_metadata"] = self.response_metadata + return result + + class GetOp(NamedTuple): """Operation to retrieve an item by namespace and key.""" @@ -93,6 +154,8 @@ class SearchOp(NamedTuple): """Maximum number of items to return.""" offset: int = 0 """Number of items to skip before returning results.""" + query: Optional[str] = None + """The search query for natural language search.""" class PutOp(NamedTuple): @@ -120,6 +183,12 @@ class PutOp(NamedTuple): - Values can be of any serializable type - If None, it indicates that the item should be deleted """ + index: Optional[bool] = None # type: ignore[assignment] + """Whether to index the item (if supported by the store). + + Defaults to True if the store supports indexing. This will embed the document + so it can be queried using search. + """ NameSpacePath = tuple[Union[str, Literal["*"]], ...] @@ -151,34 +220,43 @@ class ListNamespacesOp(NamedTuple): Op = Union[GetOp, SearchOp, PutOp, ListNamespacesOp] -Result = Union[Item, list[Item], list[tuple[str, ...]], None] +Result = Union[Item, list[Item], list[SearchItem], list[tuple[str, ...]], None] class InvalidNamespaceError(ValueError): """Provided namespace is invalid.""" -def _validate_namespace(namespace: tuple[str, ...]) -> None: - if not namespace: - raise InvalidNamespaceError("Namespace cannot be empty.") - for label in namespace: - if not isinstance(label, str): - raise InvalidNamespaceError( - f"Invalid namespace label '{label}' found in {namespace}. Namespace labels" - f" must be strings, but got {type(label).__name__}." - ) - if "." in label: - raise InvalidNamespaceError( - f"Invalid namespace label '{label}' found in {namespace}. Namespace labels cannot contain periods ('.')." - ) - elif not label: - raise InvalidNamespaceError( - f"Namespace labels cannot be empty strings. Got {label} in {namespace}" - ) - if namespace[0] == "langgraph": - raise InvalidNamespaceError( - f'Root label for namespace cannot be "langgraph". Got: {namespace}' - ) +class EmbeddingConfig(TypedDict, total=False): + """Configuration for vector embeddings in PostgreSQL store.""" + + dims: int + """Number of dimensions in the embedding vectors. + + Common embedding models have the following dimensions: + - OpenAI text-embedding-3-large: 256, 1024, or 3072 + - OpenAI text-embedding-3-small: 512 or 1536 + - OpenAI text-embedding-ada-002: 1536 + - Cohere embed-english-v3.0: 1024 + - Cohere embed-english-light-v3.0: 384 + - Cohere embed-multilingual-v3.0: 1024 + - Cohere embed-multilingual-light-v3.0: 384 + """ + + embed: Union[Embeddings, EmbeddingsFunc, AEmbeddingsFunc] + """Optional function to generate embeddings from text.""" + aembed: Optional[AEmbeddingsFunc] + """Optional asynchronous function to generate embeddings from text. + + Provide for asynchronous embedding generation if you do not provide + an Embeddings object. + """ + + text_fields: Optional[list[str]] + """Fields to extract text from for embedding generation. + + Defaults to ["__root__"], which embeds the json object as a whole. + """ class BaseStore(ABC): @@ -231,14 +309,16 @@ def search( namespace_prefix: tuple[str, ...], /, *, + query: Optional[str] = None, filter: Optional[dict[str, Any]] = None, limit: int = 10, offset: int = 0, - ) -> list[Item]: + ) -> list[SearchItem]: """Search for items within a namespace prefix. Args: namespace_prefix: Hierarchical path prefix to search within. + query: Optional query for natural language search. filter: Key-value pairs to filter results. limit: Maximum number of items to return. offset: Number of items to skip before returning results. @@ -246,18 +326,26 @@ def search( Returns: List of items matching the search criteria. """ - return self.batch([SearchOp(namespace_prefix, filter, limit, offset)])[0] + return self.batch([SearchOp(namespace_prefix, filter, limit, offset, query)])[0] - def put(self, namespace: tuple[str, ...], key: str, value: dict[str, Any]) -> None: + def put( + self, + namespace: tuple[str, ...], + key: str, + value: dict[str, Any], + index: Optional[bool] = None, + ) -> None: """Store or update an item. Args: namespace: Hierarchical path for the item. key: Unique identifier within the namespace. value: Dictionary containing the item's data. + index: Whether to index the item (if supported by the store). + Defaults to True if the store supports indexing. """ _validate_namespace(namespace) - self.batch([PutOp(namespace, key, value)]) + self.batch([PutOp(namespace, key, value, index=index)]) def delete(self, namespace: tuple[str, ...], key: str) -> None: """Delete an item. @@ -336,14 +424,16 @@ async def asearch( namespace_prefix: tuple[str, ...], /, *, + query: Optional[str] = None, filter: Optional[dict[str, Any]] = None, limit: int = 10, offset: int = 0, - ) -> list[Item]: + ) -> list[SearchItem]: """Asynchronously search for items within a namespace prefix. Args: namespace_prefix: Hierarchical path prefix to search within. + query: Optional query for natural language search. filter: Key-value pairs to filter results. limit: Maximum number of items to return. offset: Number of items to skip before returning results. @@ -351,12 +441,18 @@ async def asearch( Returns: List of items matching the search criteria. """ - return (await self.abatch([SearchOp(namespace_prefix, filter, limit, offset)]))[ - 0 - ] + return ( + await self.abatch( + [SearchOp(namespace_prefix, filter, limit, offset, query)] + ) + )[0] async def aput( - self, namespace: tuple[str, ...], key: str, value: dict[str, Any] + self, + namespace: tuple[str, ...], + key: str, + value: dict[str, Any], + index: Optional[bool] = None, ) -> None: """Asynchronously store or update an item. @@ -364,9 +460,11 @@ async def aput( namespace: Hierarchical path for the item. key: Unique identifier within the namespace. value: Dictionary containing the item's data. + index: Whether to index the item (if supported by the store). + Defaults to True if the store supports indexing. """ _validate_namespace(namespace) - await self.abatch([PutOp(namespace, key, value)]) + await self.abatch([PutOp(namespace, key, value, index)]) async def adelete(self, namespace: tuple[str, ...], key: str) -> None: """Asynchronously delete an item. @@ -427,3 +525,44 @@ async def alist_namespaces( offset=offset, ) return (await self.abatch([op]))[0] + + +def _validate_namespace(namespace: tuple[str, ...]) -> None: + if not namespace: + raise InvalidNamespaceError("Namespace cannot be empty.") + for label in namespace: + if not isinstance(label, str): + raise InvalidNamespaceError( + f"Invalid namespace label '{label}' found in {namespace}. Namespace labels" + f" must be strings, but got {type(label).__name__}." + ) + if "." in label: + raise InvalidNamespaceError( + f"Invalid namespace label '{label}' found in {namespace}. Namespace labels cannot contain periods ('.')." + ) + elif not label: + raise InvalidNamespaceError( + f"Namespace labels cannot be empty strings. Got {label} in {namespace}" + ) + if namespace[0] == "langgraph": + raise InvalidNamespaceError( + f'Root label for namespace cannot be "langgraph". Got: {namespace}' + ) + + +__all__ = [ + "BaseStore", + "Item", + "Op", + "PutOp", + "GetOp", + "SearchOp", + "ListNamespacesOp", + "MatchCondition", + "NameSpacePath", + "NamespaceMatchType", + "Embeddings", + "ensure_embeddings", + "tokenize_path", + "get_text_at_path", +] diff --git a/libs/checkpoint/langgraph/store/base/_embed.py b/libs/checkpoint/langgraph/store/base/_embed.py new file mode 100644 index 000000000..a04b4f138 --- /dev/null +++ b/libs/checkpoint/langgraph/store/base/_embed.py @@ -0,0 +1,368 @@ +"""Utilities for working with embedding functions and LangChain's Embeddings interface. + +This module provides tools to wrap arbitrary embedding functions (both sync and async) +into LangChain's Embeddings interface. This enables using custom embedding functions +with LangChain-compatible tools while maintaining support for both synchronous and +asynchronous operations. +""" + +import asyncio +import json +from typing import Any, Awaitable, Callable, Optional, Sequence, Union + +from langchain_core.embeddings import Embeddings + +EmbeddingsFunc = Callable[[Sequence[str]], list[list[float]]] +"""Type for synchronous embedding functions. + +The function should take a sequence of strings and return a list of embeddings, +where each embedding is a list of floats. The dimensionality of the embeddings +should be consistent for all inputs. +""" + +AEmbeddingsFunc = Callable[[Sequence[str]], Awaitable[list[list[float]]]] +"""Type for asynchronous embedding functions. + +Similar to EmbeddingsFunc, but returns an awaitable that resolves to the embeddings. +""" + + +def ensure_embeddings( + embed: Union[Embeddings, EmbeddingsFunc, AEmbeddingsFunc, None], + *, + aembed: Optional[AEmbeddingsFunc] = None, +) -> Embeddings: + """Ensure that an embedding function conforms to LangChain's Embeddings interface. + + This function wraps arbitrary embedding functions to make them compatible with + LangChain's Embeddings interface. It handles both synchronous and asynchronous + functions. + + Args: + embed: Either an existing Embeddings instance, or a function that converts + text to embeddings. If the function is async, it will be used for both + sync and async operations. + aembed: Optional async function for embeddings. If provided, it will be used + for async operations while the sync function is used for sync operations. + Must be None if embed is async. + + Returns: + An Embeddings instance that wraps the provided function(s). + + Example: + >>> def my_embed_fn(texts): return [[0.1, 0.2] for _ in texts] + >>> async def my_async_fn(texts): return [[0.1, 0.2] for _ in texts] + >>> # Wrap a sync function + >>> embeddings = ensure_embeddings(my_embed_fn) + >>> # Wrap an async function + >>> embeddings = ensure_embeddings(my_async_fn) + >>> # Provide both sync and async implementations + >>> embeddings = ensure_embeddings(my_embed_fn, aembed=my_async_fn) + """ + if embed is None and aembed is None: + raise ValueError("embed or aembed must be provided") + if isinstance(embed, Embeddings): + return embed + return EmbeddingsLambda(embed, afunc=aembed) + + +class EmbeddingsLambda(Embeddings): + """Wrapper to convert embedding functions into LangChain's Embeddings interface. + + This class allows arbitrary embedding functions to be used with LangChain-compatible + tools. It supports both synchronous and asynchronous operations, and can be + initialized with either: + 1. A synchronous function for both sync/async operations + 2. An async function for both sync/async operations + 3. Both sync and async functions for their respective operations + + The embedding functions should convert text into fixed-dimensional vectors that + capture the semantic meaning of the text. + + Args: + func: Function that converts text to embeddings. Can be sync or async. + If async, it will be used for both sync and async operations. + afunc: Optional async function for embeddings. If provided, it will be used + for async operations while func is used for sync operations. + Must be None if func is async. + + Example: + >>> def my_embed_fn(texts): + ... # Return 2D embeddings for each text + ... return [[0.1, 0.2] for _ in texts] + >>> embeddings = EmbeddingsLambda(my_embed_fn) + >>> result = embeddings.embed_query("hello") # Returns [0.1, 0.2] + """ + + def __init__( + self, + func: Union[EmbeddingsFunc, AEmbeddingsFunc, None], + afunc: Optional[AEmbeddingsFunc] = None, + ) -> None: + if _is_async_callable(func): + if afunc is not None: + raise ValueError( + "afunc must be None if func is async. The async func will be used for both sync and async operations." + ) + self.afunc = func + else: + self.func = func + if afunc is not None: + self.afunc = afunc + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + """Embed a list of texts into vectors. + + Args: + texts: list of texts to convert to embeddings. + + Returns: + list of embeddings, one per input text. Each embedding is a list of floats. + + Raises: + ValueError: If the instance was initialized with only an async function. + """ + func = getattr(self, "func", None) + if func is None: + raise ValueError( + "EmbeddingsLambda was initialized with an async function but no sync function. " + "Use aembed_documents for async operation or provide a sync function." + ) + return func(texts) + + def embed_query(self, text: str) -> list[float]: + """Embed a single piece of text. + + Args: + text: Text to convert to an embedding. + + Returns: + Embedding vector as a list of floats. + + Note: + This is equivalent to calling embed_documents with a single text + and taking the first result. + """ + return self.embed_documents([text])[0] + + async def aembed_documents(self, texts: list[str]) -> list[list[float]]: + """Asynchronously embed a list of texts into vectors. + + Args: + texts: list of texts to convert to embeddings. + + Returns: + list of embeddings, one per input text. Each embedding is a list of floats. + + Note: + If no async function was provided, this falls back to the sync implementation. + """ + afunc = getattr(self, "afunc", None) + if afunc is None: + return await super().aembed_documents(texts) + return await afunc(texts) + + async def aembed_query(self, text: str) -> list[float]: + """Asynchronously embed a single piece of text. + + Args: + text: Text to convert to an embedding. + + Returns: + Embedding vector as a list of floats. + + Note: + This is equivalent to calling aembed_documents with a single text + and taking the first result. + """ + afunc = getattr(self, "afunc", None) + if afunc is None: + return await super().aembed_query(text) + return (await afunc([text]))[0] + + +def get_text_at_path(obj: Any, path: Union[str, list[str]]) -> list[str]: + """Extract text from an object using a path expression or pre-tokenized path. + + Args: + obj: The object to extract text from + path: Either a path string or pre-tokenized path list. Path string supports: + - Simple paths: "field1.field2" + - Array indexing: "[0]", "[*]", "[-1]" + - Wildcards: "*" + - Multi-field selection: "{field1,field2}" + - Nested paths in multi-field: "{field1,nested.field2}" + """ + if not path or path == "__root__": + return [json.dumps(obj, sort_keys=True)] + + tokens = tokenize_path(path) if isinstance(path, str) else path + + def _extract_from_obj(obj: Any, tokens: list[str], pos: int) -> list[str]: + if pos >= len(tokens): + if isinstance(obj, (str, int, float, bool)): + return [str(obj)] + elif obj is None: + return [] + elif isinstance(obj, (list, dict)): + return [json.dumps(obj, sort_keys=True)] + return [] + + token = tokens[pos] + results = [] + + if token.startswith("[") and token.endswith("]"): + if not isinstance(obj, list): + return [] + + index = token[1:-1] + if index == "*": + for item in obj: + results.extend(_extract_from_obj(item, tokens, pos + 1)) + else: + try: + idx = int(index) + if idx < 0: + idx = len(obj) + idx + if 0 <= idx < len(obj): + results.extend(_extract_from_obj(obj[idx], tokens, pos + 1)) + except (ValueError, IndexError): + return [] + + elif token.startswith("{") and token.endswith("}"): + if not isinstance(obj, dict): + return [] + + fields = [f.strip() for f in token[1:-1].split(",")] + for field in fields: + nested_tokens = tokenize_path(field) + if nested_tokens: + current_obj: Optional[dict] = obj + for nested_token in nested_tokens: + if ( + isinstance(current_obj, dict) + and nested_token in current_obj + ): + current_obj = current_obj[nested_token] + else: + current_obj = None + break + if current_obj is not None: + if isinstance(current_obj, (str, int, float, bool)): + results.append(str(current_obj)) + elif isinstance(current_obj, (list, dict)): + results.append(json.dumps(current_obj, sort_keys=True)) + + # Handle wildcard + elif token == "*": + if isinstance(obj, dict): + for value in obj.values(): + results.extend(_extract_from_obj(value, tokens, pos + 1)) + elif isinstance(obj, list): + for item in obj: + results.extend(_extract_from_obj(item, tokens, pos + 1)) + + # Handle regular field + else: + if isinstance(obj, dict) and token in obj: + results.extend(_extract_from_obj(obj[token], tokens, pos + 1)) + + return results + + return _extract_from_obj(obj, tokens, 0) + + +# Private utility functions + + +def tokenize_path(path: str) -> list[str]: + """Tokenize a path into components. + + Handles: + - Simple paths: "field1.field2" + - Array indexing: "[0]", "[*]", "[-1]" + - Wildcards: "*" + - Multi-field selection: "{field1,field2}" + """ + if not path: + return [] + + tokens = [] + current: list[str] = [] + i = 0 + while i < len(path): + char = path[i] + + if char == "[": # Handle array index + if current: + tokens.append("".join(current)) + current = [] + bracket_count = 1 + index_chars = ["["] + i += 1 + while i < len(path) and bracket_count > 0: + if path[i] == "[": + bracket_count += 1 + elif path[i] == "]": + bracket_count -= 1 + index_chars.append(path[i]) + i += 1 + tokens.append("".join(index_chars)) + continue + + elif char == "{": # Handle multi-field selection + if current: + tokens.append("".join(current)) + current = [] + brace_count = 1 + field_chars = ["{"] + i += 1 + while i < len(path) and brace_count > 0: + if path[i] == "{": + brace_count += 1 + elif path[i] == "}": + brace_count -= 1 + field_chars.append(path[i]) + i += 1 + tokens.append("".join(field_chars)) + continue + + elif char == ".": # Handle regular field + if current: + tokens.append("".join(current)) + current = [] + else: + current.append(char) + i += 1 + + if current: + tokens.append("".join(current)) + + return tokens + + +def _is_async_callable( + func: Any, +) -> bool: + """Check if a function is async. + + This includes both async def functions and classes with async __call__ methods. + + Args: + func: Function or callable object to check. + + Returns: + True if the function is async, False otherwise. + """ + return ( + asyncio.iscoroutinefunction(func) + or hasattr(func, "__call__") # noqa: B004 + and asyncio.iscoroutinefunction(func.__call__) + ) + + +__all__ = [ + "ensure_embeddings", + "EmbeddingsFunc", + "AEmbeddingsFunc", +] diff --git a/libs/checkpoint/langgraph/store/base/_embed_test_utils.py b/libs/checkpoint/langgraph/store/base/_embed_test_utils.py new file mode 100644 index 000000000..10e1e373c --- /dev/null +++ b/libs/checkpoint/langgraph/store/base/_embed_test_utils.py @@ -0,0 +1,65 @@ +"""Embedding utilities for testing.""" + +import math +import random +from collections import Counter +from typing import Any, Optional + +from langchain_core.embeddings import Embeddings + + +class CharacterEmbeddings(Embeddings): + """Simple character-frequency based embeddings using random projections.""" + + def __init__(self, dims: int = 50, seed: int = 42): + """Initialize with embedding dimensions and random seed.""" + self._rng = random.Random(seed) + self._char_to_idx: dict[str, int] = {} + self._projection: Optional[list[list[float]]] = None + self.dims = dims + + def _ensure_projection_matrix(self, texts: list[str]) -> None: + """Lazily initialize character mapping and projection matrix.""" + if self._projection is None: + chars = sorted(set("".join(texts))) + self._char_to_idx = {c: i for i, c in enumerate(chars)} + self._projection = [ + [self._rng.gauss(0, 1 / math.sqrt(self.dims)) for _ in range(self.dims)] + for _ in range(len(chars)) + ] + + def _embed_one(self, text: str) -> list[float]: + """Embed a single text.""" + counts = Counter(text) + char_vec = [0.0] * len(self._char_to_idx) + + for char, count in counts.items(): + if char in self._char_to_idx: + char_vec[self._char_to_idx[char]] = count + + total = sum(char_vec) + if total > 0: + char_vec = [v / total for v in char_vec] + embedding = [ + sum(a * b for a, b in zip(char_vec, proj)) + for proj in zip(*self._projection) + ] + + norm = math.sqrt(sum(x * x for x in embedding)) + if norm > 0: + embedding = [x / norm for x in embedding] + + return embedding + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + """Embed a list of documents.""" + self._ensure_projection_matrix(texts) + return [self._embed_one(text) for text in texts] + + def embed_query(self, text: str) -> list[float]: + """Embed a query string.""" + self._ensure_projection_matrix([text]) + return self._embed_one(text) + + def __eq__(self, other: Any) -> bool: + return isinstance(other, CharacterEmbeddings) and self.dims == other.dims diff --git a/libs/checkpoint/langgraph/store/base/batch.py b/libs/checkpoint/langgraph/store/base/batch.py index b1030942d..bb74eae8e 100644 --- a/libs/checkpoint/langgraph/store/base/batch.py +++ b/libs/checkpoint/langgraph/store/base/batch.py @@ -11,6 +11,7 @@ NameSpacePath, Op, PutOp, + SearchItem, SearchOp, _validate_namespace, ) @@ -43,12 +44,13 @@ async def asearch( namespace_prefix: tuple[str, ...], /, *, + query: Optional[str] = None, filter: Optional[dict[str, Any]] = None, limit: int = 10, offset: int = 0, - ) -> list[Item]: + ) -> list[SearchItem]: fut = self._loop.create_future() - self._aqueue[fut] = SearchOp(namespace_prefix, filter, limit, offset) + self._aqueue[fut] = SearchOp(namespace_prefix, filter, limit, offset, query) return await fut async def aput( @@ -56,10 +58,11 @@ async def aput( namespace: tuple[str, ...], key: str, value: dict[str, Any], + index: Optional[bool] = None, ) -> None: _validate_namespace(namespace) fut = self._loop.create_future() - self._aqueue[fut] = PutOp(namespace, key, value) + self._aqueue[fut] = PutOp(namespace, key, value, index) return await fut async def adelete( diff --git a/libs/checkpoint/langgraph/store/memory/__init__.py b/libs/checkpoint/langgraph/store/memory/__init__.py index 69a315096..605f2b48c 100644 --- a/libs/checkpoint/langgraph/store/memory/__init__.py +++ b/libs/checkpoint/langgraph/store/memory/__init__.py @@ -1,9 +1,47 @@ +"""In-memory key-value store. + +A lightweight store implementation using Python dictionaries. Supports basic +key-value operations and vector search when configured with embeddings. + +Examples: + Basic key-value storage: + store = InMemoryStore() + store.put(("users", "123"), "prefs", {"theme": "dark"}) + item = store.get(("users", "123"), "prefs") + + Vector search with embeddings: + from langchain_openai import OpenAIEmbeddings + store = InMemoryStore(embedding_config={ + "dims": 1536, + "embed": OpenAIEmbeddings(model="text-embedding-3-small"), + }) + + # Store documents + store.put(("docs",), "doc1", {"text": "Python tutorial"}) + store.put(("docs",), "doc2", {"text": "TypeScript guide"}) + + # Search by similarity + results = store.search(("docs",), query="python programming") + + +Note: + For production use cases requiring persistence, use a database-backed store instead. +""" + +import asyncio +import concurrent.futures as cf +import functools +import logging from collections import defaultdict from datetime import datetime, timezone -from typing import Iterable +from importlib import util +from typing import Any, Iterable, Optional + +from langchain_core.embeddings import Embeddings from langgraph.store.base import ( BaseStore, + EmbeddingConfig, GetOp, Item, ListNamespacesOp, @@ -11,69 +49,319 @@ Op, PutOp, Result, + SearchItem, SearchOp, + ensure_embeddings, + get_text_at_path, + tokenize_path, ) +logger = logging.getLogger(__name__) + class InMemoryStore(BaseStore): - """A KV store backed by an in-memory python dictionary. + """In-memory dictionary-backed store with optional vector search. - Useful for testing/experimentation and lightweight PoC's. - For actual persistence, use a Store backed by a proper database. + Examples: + Basic key-value storage: + store = InMemoryStore() + store.put(("users", "123"), "prefs", {"theme": "dark"}) + item = store.get(("users", "123"), "prefs") + + Vector search with embeddings: + from langchain_openai import OpenAIEmbeddings + store = InMemoryStore(embedding_config={ + "dims": 1536, + "embed": OpenAIEmbeddings(model="text-embedding-3-small"), + }) + + # Store documents + store.put(("docs",), "doc1", {"text": "Python tutorial"}) + store.put(("docs",), "doc2", {"text": "TypeScript guide"}) + + # Search by similarity + results = store.search(("docs",), query="python programming") + + Warning: + This store keeps all data in memory. Data is lost when the process exits. + For persistence, use a database-backed store like PostgresStore. + + Tip: + For vector search, install numpy for better performance: + ```bash + pip install numpy + ``` """ - __slots__ = ("_data",) + __slots__ = ( + "_data", + "embedding_config", + "inmem_store", + "embeddings", + "_vectors", + ) - def __init__(self) -> None: + def __init__(self, embedding_config: Optional[EmbeddingConfig] = None) -> None: self._data: dict[tuple[str, ...], dict[str, Item]] = defaultdict(dict) + # [ns][key][path] + self.inmem_store: dict[tuple[str, ...], dict[str, dict[str, list[float]]]] = ( + defaultdict(lambda: defaultdict(dict)) + ) + self.embedding_config = embedding_config + if self.embedding_config: + self.embedding_config = self.embedding_config.copy() + self.embeddings: Optional[Embeddings] = ensure_embeddings( + self.embedding_config.get("embed"), + aembed=self.embedding_config.get("aembed"), + ) + self.embedding_config["__tokenized_fields"] = [ + (p, tokenize_path(p)) if p != "__root__" else (p, p) + for p in (self.embedding_config.get("text_fields") or ["__root__"]) + ] + + else: + self.embedding_config = None + self.embeddings = None def batch(self, ops: Iterable[Op]) -> list[Result]: + # The batch/abatch methods are treated as internal. + # Users should access via put/search/get/list_namespaces/etc. + results, put_ops, search_ops = self._prepare_ops(ops) + if search_ops: + queryinmem_store = self._embed_search_queries(search_ops) + self._batch_search(search_ops, queryinmem_store, results) + + to_embed = self._extract_texts(put_ops) + if to_embed and self.embedding_config and self.embeddings: + embeddings = self.embeddings.embed_documents(list(to_embed)) + self._insertinmem_store(to_embed, embeddings) + self._apply_put_ops(put_ops) + return results + + async def abatch(self, ops: Iterable[Op]) -> list[Result]: + # The batch/abatch methods are treated as internal. + # Users should access via put/search/get/list_namespaces/etc. + results, put_ops, search_ops = self._prepare_ops(ops) + if search_ops: + queryinmem_store = await self._aembed_search_queries(search_ops) + self._batch_search(search_ops, queryinmem_store, results) + + to_embed = self._extract_texts(put_ops) + if to_embed and self.embedding_config and self.embeddings: + embeddings = await self.embeddings.aembed_documents(list(to_embed)) + self._insertinmem_store(to_embed, embeddings) + self._apply_put_ops(put_ops) + return results + + # Helpers + + def _filter_items(self, op: SearchOp) -> list[tuple[Item, list[list[float]]]]: + """Filter items by namespace and filter function, return items with their embeddings.""" + namespace_prefix = op.namespace_prefix + + def filter_func(item: Item) -> bool: + if not op.filter: + return True + + return all( + _compare_values(item.value.get(key), filter_value) + for key, filter_value in op.filter.items() + ) + + filtered = [] + for namespace in self._data: + if not ( + namespace[: len(namespace_prefix)] == namespace_prefix + if len(namespace) >= len(namespace_prefix) + else False + ): + continue + + for key, item in self._data[namespace].items(): + if filter_func(item): + if op.query and ( + embeddings := self.inmem_store[namespace].get(key) + ): + filtered.append((item, list(embeddings.values()))) + else: + filtered.append((item, [])) + return filtered + + def _embed_search_queries( + self, + search_ops: dict[int, tuple[SearchOp, list[tuple[Item, list[list[float]]]]]], + ) -> dict[str, list[float]]: + queryinmem_store = {} + if self.embedding_config and self.embeddings and search_ops: + queries = {op.query for (op, _) in search_ops.values() if op.query} + + if queries: + with cf.ThreadPoolExecutor() as executor: + futures = { + q: executor.submit(self.embeddings.embed_query, q) + for q in queries + } + for query, future in futures.items(): + queryinmem_store[query] = future.result() + + return queryinmem_store + + async def _aembed_search_queries( + self, + search_ops: dict[int, tuple[SearchOp, list[tuple[Item, list[list[float]]]]]], + ) -> dict[str, list[float]]: + queryinmem_store = {} + if self.embedding_config and self.embeddings and search_ops: + queries = {op.query for (op, _) in search_ops.values() if op.query} + + if queries: + coros = [self.embeddings.aembed_query(q) for q in queries] + results = await asyncio.gather(*coros) + queryinmem_store = dict(zip(queries, results)) + + return queryinmem_store + + def _batch_search( + self, + ops: dict[int, tuple[SearchOp, list[tuple[Item, list[list[float]]]]]], + queryinmem_store: dict[str, list[float]], + results: list[Result], + ) -> None: + """Perform batch similarity search for multiple queries.""" + for i, (op, candidates) in ops.items(): + if not candidates: + results[i] = [] + continue + if op.query: + query_embedding = queryinmem_store[op.query] + flat_items, flat_vectors = [], [] + for item, vectors in candidates: + for vector in vectors: + flat_items.append(item) + flat_vectors.append(vector) + scores = _cosine_similarity(query_embedding, flat_vectors) + sorted_results = sorted( + zip(scores, flat_items), key=lambda x: x[0], reverse=True + ) + # max pooling + seen: set[tuple[tuple[str, ...], str]] = set() + kept = [] + for score, item in sorted_results: + key = (item.namespace, item.key) + if key in seen: + continue + ix = len(seen) + seen.add(key) + if ix >= op.offset + op.limit: + break + if ix < op.offset: + continue + + kept.append((score, item)) + + results[i] = [ + SearchItem( + namespace=item.namespace, + key=item.key, + value=item.value, + created_at=item.created_at, + updated_at=item.updated_at, + response_metadata={"score": float(score)}, + ) + for score, item in kept + ] + else: + results[i] = [ + SearchItem( + namespace=item.namespace, + key=item.key, + value=item.value, + created_at=item.created_at, + updated_at=item.updated_at, + ) + for (item, _) in candidates[op.offset : op.offset + op.limit] + ] + + def _prepare_ops( + self, ops: Iterable[Op] + ) -> tuple[ + list[Result], + dict[tuple[tuple[str, ...], str], PutOp], + dict[int, tuple[SearchOp, list[tuple[Item, list[list[float]]]]]], + ]: results: list[Result] = [] - for op in ops: + put_ops: dict[tuple[tuple[str, ...], str], PutOp] = {} + search_ops: dict[ + int, tuple[SearchOp, list[tuple[Item, list[list[float]]]]] + ] = {} + for i, op in enumerate(ops): if isinstance(op, GetOp): item = self._data[op.namespace].get(op.key) results.append(item) elif isinstance(op, SearchOp): - candidates = [ - item - for namespace, items in self._data.items() - if ( - namespace[: len(op.namespace_prefix)] == op.namespace_prefix - if len(namespace) >= len(op.namespace_prefix) - else False - ) - for item in items.values() - ] - if op.filter: - candidates = [ - item - for item in candidates - if item.value.items() >= op.filter.items() - ] - results.append(candidates[op.offset : op.offset + op.limit]) - elif isinstance(op, PutOp): - if op.value is None: - self._data[op.namespace].pop(op.key, None) - elif op.key in self._data[op.namespace]: - self._data[op.namespace][op.key].value = op.value - self._data[op.namespace][op.key].updated_at = datetime.now( - timezone.utc - ) - else: - self._data[op.namespace][op.key] = Item( - value=op.value, - key=op.key, - namespace=op.namespace, - created_at=datetime.now(timezone.utc), - updated_at=datetime.now(timezone.utc), - ) + search_ops[i] = (op, self._filter_items(op)) results.append(None) elif isinstance(op, ListNamespacesOp): results.append(self._handle_list_namespaces(op)) - return results + elif isinstance(op, PutOp): + put_ops[(op.namespace, op.key)] = op + results.append(None) + else: + raise ValueError(f"Unknown operation type: {type(op)}") - async def abatch(self, ops: Iterable[Op]) -> list[Result]: - return self.batch(ops) + return results, put_ops, search_ops + + def _apply_put_ops(self, put_ops: dict[tuple[tuple[str, ...], str], PutOp]) -> None: + for (namespace, key), op in put_ops.items(): + if op.value is None: + self._data[namespace].pop(key, None) + self.inmem_store[namespace].pop(key, None) + else: + self._data[namespace][key] = Item( + value=op.value, + key=key, + namespace=namespace, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + + def _extract_texts( + self, put_ops: dict[tuple[tuple[str, ...], str], PutOp] + ) -> dict[str, list[tuple[tuple[str, ...], str, str]]]: + if put_ops and self.embedding_config and self.embeddings: + to_embed = defaultdict(list) + + for op in put_ops.values(): + if op.value is not None and op.index is not False: + for path, field in self.embedding_config["__tokenized_fields"]: + texts = get_text_at_path(op.value, field) + if texts: + if len(texts) > 1: + for i, text in enumerate(texts): + to_embed[text].append( + (op.namespace, op.key, f"{path}.{i}") + ) + + else: + to_embed[texts[0]].append((op.namespace, op.key, path)) + + return to_embed + + return {} + + def _insertinmem_store( + self, + to_embed: dict[str, list[tuple[tuple[str, ...], str, str]]], + embeddings: list[list[float]], + ) -> None: + indices = [index for indices in to_embed.values() for index in indices] + if len(indices) != len(embeddings): + raise ValueError( + f"Number of embeddings ({len(embeddings)}) does not" + f" match number of indices ({len(indices)})" + ) + for embedding, (ns, key, path) in zip(embeddings, indices): + self.inmem_store[ns][key][path] = embedding def _handle_list_namespaces(self, op: ListNamespacesOp) -> list[tuple[str, ...]]: all_namespaces = list( @@ -94,7 +382,52 @@ def _handle_list_namespaces(self, op: ListNamespacesOp) -> list[tuple[str, ...]] return namespaces[op.offset : op.offset + op.limit] +@functools.lru_cache(maxsize=1) +def _check_numpy() -> bool: + if bool(util.find_spec("numpy")): + return True + logger.warning( + "NumPy not found in the current Python environment. " + "The InMemoryStore will use a pure Python implementation for vector operations, " + "which may significantly impact performance, especially for large datasets or frequent searches. " + "For optimal speed and efficiency, consider installing NumPy: " + "pip install numpy" + ) + return False + + +def _cosine_similarity(X: list[float], Y: list[list[float]]) -> list[float]: + """ + Compute cosine similarity between a vector X and a matrix Y. + Lazy import numpy for efficiency. + """ + if _check_numpy(): + import numpy as np # type: ignore + + X = np.array(X) if not isinstance(X, np.ndarray) else X + Y = np.array(Y) if not isinstance(Y, np.ndarray) else Y + X_norm = np.linalg.norm(X) + Y_norm = np.linalg.norm(Y, axis=1) + + # Avoid division by zero + mask = Y_norm != 0 + similarities = np.zeros_like(Y_norm) + similarities[mask] = np.dot(Y[mask], X) / (Y_norm[mask] * X_norm) + return similarities.tolist() + + similarities = [] + for y in Y: + dot_product = sum(a * b for a, b in zip(X, y)) + norm1 = sum(a * a for a in X) ** 0.5 + norm2 = sum(a * a for a in y) ** 0.5 + similarity = dot_product / (norm1 * norm2) if norm1 > 0 and norm2 > 0 else 0.0 + similarities.append(similarity) + + return similarities + + def _does_match(match_condition: MatchCondition, key: tuple[str, ...]) -> bool: + """Whether a namespace key matches a match condition.""" match_type = match_condition.match_type path = match_condition.path @@ -117,3 +450,44 @@ def _does_match(match_condition: MatchCondition, key: tuple[str, ...]) -> bool: return True else: raise ValueError(f"Unsupported match type: {match_type}") + + +def _compare_values(item_value: Any, filter_value: Any) -> bool: + """Compare values in a JSONB-like way, handling nested objects.""" + if isinstance(filter_value, dict): + if any(k.startswith("$") for k in filter_value): + return all( + _apply_operator(item_value, op_key, op_value) + for op_key, op_value in filter_value.items() + ) + if not isinstance(item_value, dict): + return False + return all( + _compare_values(item_value.get(k), v) for k, v in filter_value.items() + ) + elif isinstance(filter_value, (list, tuple)): + return ( + isinstance(item_value, (list, tuple)) + and len(item_value) == len(filter_value) + and all(_compare_values(iv, fv) for iv, fv in zip(item_value, filter_value)) + ) + else: + return item_value == filter_value + + +def _apply_operator(value: Any, operator: str, op_value: Any) -> bool: + """Apply a comparison operator, matching PostgreSQL's JSONB behavior.""" + if operator == "$eq": + return value == op_value + elif operator == "$gt": + return float(value) > float(op_value) + elif operator == "$gte": + return float(value) >= float(op_value) + elif operator == "$lt": + return float(value) < float(op_value) + elif operator == "$lte": + return float(value) <= float(op_value) + elif operator == "$ne": + return value != op_value + else: + raise ValueError(f"Unsupported operator: {operator}") diff --git a/libs/checkpoint/pyproject.toml b/libs/checkpoint/pyproject.toml index 278594fcb..deb7de5c4 100644 --- a/libs/checkpoint/pyproject.toml +++ b/libs/checkpoint/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langgraph-checkpoint" -version = "2.0.6" +version = "2.0.5" description = "Library with base interfaces for LangGraph checkpoint savers." authors = [] license = "MIT" diff --git a/libs/checkpoint/tests/test_store.py b/libs/checkpoint/tests/test_store.py index 0ecd4bd84..550ebe1af 100644 --- a/libs/checkpoint/tests/test_store.py +++ b/libs/checkpoint/tests/test_store.py @@ -1,19 +1,20 @@ import asyncio from datetime import datetime -from typing import Iterable +from typing import Any, Iterable import pytest from pytest_mock import MockerFixture from langgraph.store.base import GetOp, InvalidNamespaceError, Item, Op, PutOp, Result +from langgraph.store.base._embed_test_utils import CharacterEmbeddings from langgraph.store.base.batch import AsyncBatchedBaseStore from langgraph.store.memory import InMemoryStore class MockAsyncBatchedStore(AsyncBatchedBaseStore): - def __init__(self) -> None: + def __init__(self, **kwargs: Any) -> None: super().__init__() - self._store = InMemoryStore() + self._store = InMemoryStore(**kwargs) def batch(self, ops: Iterable[Op]) -> list[Result]: return self._store.batch(ops) @@ -420,3 +421,386 @@ async def test_async_batch_store_deduplication(mocker: MockerFixture) -> None: assert results[0][0].value == doc2 abatch.reset_mock() + + +@pytest.fixture +def fake_embeddings() -> CharacterEmbeddings: + return CharacterEmbeddings(dims=500) + + +def test_vector_store_initialization(fake_embeddings: CharacterEmbeddings) -> None: + """Test store initialization with embedding config.""" + store = InMemoryStore( + embedding_config={"dims": fake_embeddings.dims, "embed": fake_embeddings} + ) + assert store.embedding_config is not None + assert store.embedding_config["dims"] == fake_embeddings.dims + assert store.embedding_config["embed"] == fake_embeddings + + +def test_vector_insert_with_auto_embedding( + fake_embeddings: CharacterEmbeddings, +) -> None: + """Test inserting items that get auto-embedded.""" + store = InMemoryStore( + embedding_config={"dims": fake_embeddings.dims, "embed": fake_embeddings} + ) + docs = [ + ("doc1", {"text": "short text"}), + ("doc2", {"text": "longer text document"}), + ("doc3", {"text": "longest text document here"}), + ("doc4", {"description": "text in description field"}), + ("doc5", {"content": "text in content field"}), + ("doc6", {"body": "text in body field"}), + ] + + for key, value in docs: + store.put(("test",), key, value) + + results = store.search(("test",), query="long text") + assert len(results) > 0 + + doc_order = [r.key for r in results] + assert "doc2" in doc_order + assert "doc3" in doc_order + + +async def test_async_vector_insert_with_auto_embedding( + fake_embeddings: CharacterEmbeddings, +) -> None: + """Test inserting items that get auto-embedded using async methods.""" + store = InMemoryStore( + embedding_config={"dims": fake_embeddings.dims, "embed": fake_embeddings} + ) + docs = [ + ("doc1", {"text": "short text"}), + ("doc2", {"text": "longer text document"}), + ("doc3", {"text": "longest text document here"}), + ("doc4", {"description": "text in description field"}), + ("doc5", {"content": "text in content field"}), + ("doc6", {"body": "text in body field"}), + ] + + for key, value in docs: + await store.aput(("test",), key, value) + + results = await store.asearch(("test",), query="long text") + assert len(results) > 0 + + doc_order = [r.key for r in results] + assert "doc2" in doc_order + assert "doc3" in doc_order + + +def test_vector_update_with_embedding(fake_embeddings: CharacterEmbeddings) -> None: + """Test that updating items properly updates their embeddings.""" + store = InMemoryStore( + embedding_config={"dims": fake_embeddings.dims, "embed": fake_embeddings} + ) + store.put(("test",), "doc1", {"text": "zany zebra Xerxes"}) + store.put(("test",), "doc2", {"text": "something about dogs"}) + store.put(("test",), "doc3", {"text": "text about birds"}) + + results_initial = store.search(("test",), query="Zany Xerxes") + assert len(results_initial) > 0 + assert results_initial[0].key == "doc1" + initial_score = results_initial[0].response_metadata["score"] + + store.put(("test",), "doc1", {"text": "new text about dogs"}) + + results_after = store.search(("test",), query="Zany Xerxes") + after_score = next( + (r.response_metadata["score"] for r in results_after if r.key == "doc1"), 0.0 + ) + assert after_score < initial_score + + results_new = store.search(("test",), query="new text about dogs") + for r in results_new: + if r.key == "doc1": + assert r.response_metadata["score"] > after_score + + # Don't index this one + store.put(("test",), "doc4", {"text": "new text about dogs"}, index=False) + results_new = store.search(("test",), query="new text about dogs", limit=3) + assert not any(r.key == "doc4" for r in results_new) + + +async def test_async_vector_update_with_embedding( + fake_embeddings: CharacterEmbeddings, +) -> None: + """Test that updating items properly updates their embeddings using async methods.""" + store = InMemoryStore( + embedding_config={"dims": fake_embeddings.dims, "embed": fake_embeddings} + ) + await store.aput(("test",), "doc1", {"text": "zany zebra Xerxes"}) + await store.aput(("test",), "doc2", {"text": "something about dogs"}) + await store.aput(("test",), "doc3", {"text": "text about birds"}) + + results_initial = await store.asearch(("test",), query="Zany Xerxes") + assert len(results_initial) > 0 + assert results_initial[0].key == "doc1" + initial_score = results_initial[0].response_metadata["score"] + + await store.aput(("test",), "doc1", {"text": "new text about dogs"}) + + results_after = await store.asearch(("test",), query="Zany Xerxes") + after_score = next( + (r.response_metadata["score"] for r in results_after if r.key == "doc1"), 0.0 + ) + assert after_score < initial_score + + results_new = await store.asearch(("test",), query="new text about dogs") + for r in results_new: + if r.key == "doc1": + assert r.response_metadata["score"] > after_score + + # Don't index this one + await store.aput(("test",), "doc4", {"text": "new text about dogs"}, index=False) + results_new = await store.asearch(("test",), query="new text about dogs", limit=3) + assert not any(r.key == "doc4" for r in results_new) + + +def test_vector_search_with_filters(fake_embeddings: CharacterEmbeddings) -> None: + """Test combining vector search with filters.""" + inmem_store = InMemoryStore( + embedding_config={"dims": fake_embeddings.dims, "embed": fake_embeddings} + ) + # Insert test documents + docs = [ + ("doc1", {"text": "red apple", "color": "red", "score": 4.5}), + ("doc2", {"text": "red car", "color": "red", "score": 3.0}), + ("doc3", {"text": "green apple", "color": "green", "score": 4.0}), + ("doc4", {"text": "blue car", "color": "blue", "score": 3.5}), + ] + + for key, value in docs: + inmem_store.put(("test",), key, value) + + results = inmem_store.search(("test",), query="apple", filter={"color": "red"}) + assert len(results) == 2 + assert results[0].key == "doc1" + + results = inmem_store.search(("test",), query="car", filter={"color": "red"}) + assert len(results) == 2 + assert results[0].key == "doc2" + + results = inmem_store.search( + ("test",), query="bbbbluuu", filter={"score": {"$gt": 3.2}} + ) + assert len(results) == 3 + assert results[0].key == "doc4" + + # Multiple filters + results = inmem_store.search( + ("test",), query="apple", filter={"score": {"$gte": 4.0}, "color": "green"} + ) + assert len(results) == 1 + assert results[0].key == "doc3" + + +async def test_async_vector_search_with_filters( + fake_embeddings: CharacterEmbeddings, +) -> None: + """Test combining vector search with filters using async methods.""" + store = InMemoryStore( + embedding_config={"dims": fake_embeddings.dims, "embed": fake_embeddings} + ) + # Insert test documents + docs = [ + ("doc1", {"text": "red apple", "color": "red", "score": 4.5}), + ("doc2", {"text": "red car", "color": "red", "score": 3.0}), + ("doc3", {"text": "green apple", "color": "green", "score": 4.0}), + ("doc4", {"text": "blue car", "color": "blue", "score": 3.5}), + ] + + for key, value in docs: + await store.aput(("test",), key, value) + + results = await store.asearch(("test",), query="apple", filter={"color": "red"}) + assert len(results) == 2 + assert results[0].key == "doc1" + + results = await store.asearch(("test",), query="car", filter={"color": "red"}) + assert len(results) == 2 + assert results[0].key == "doc2" + + results = await store.asearch( + ("test",), query="bbbbluuu", filter={"score": {"$gt": 3.2}} + ) + assert len(results) == 3 + assert results[0].key == "doc4" + + # Multiple filters + results = await store.asearch( + ("test",), query="apple", filter={"score": {"$gte": 4.0}, "color": "green"} + ) + assert len(results) == 1 + assert results[0].key == "doc3" + + +async def test_async_batched_vector_search_concurrent( + fake_embeddings: CharacterEmbeddings, +) -> None: + """Test concurrent vector search operations using async batched store.""" + store = MockAsyncBatchedStore( + embedding_config={"dims": fake_embeddings.dims, "embed": fake_embeddings} + ) + + colors = ["red", "blue", "green", "yellow", "purple"] + items = ["apple", "car", "house", "book", "phone"] + scores = [3.0, 3.5, 4.0, 4.5, 5.0] + + docs = [] + for i in range(50): + color = colors[i % len(colors)] + item = items[i % len(items)] + score = scores[i % len(scores)] + docs.append( + ( + f"doc{i}", + {"text": f"{color} {item}", "color": color, "score": score, "index": i}, + ) + ) + coros = [ + *[store.aput(("test",), key, value) for key, value in docs], + *[store.adelete(("test",), key) for key, value in docs], + *[store.aput(("test",), key, value) for key, value in docs], + ] + await asyncio.gather(*coros) + + # Prepare multiple search queries with different filters + search_queries: list[tuple[str, dict[str, Any]]] = [ + ("apple", {"color": "red"}), + ("car", {"color": "blue"}), + ("house", {"color": "green"}), + ("phone", {"score": {"$gt": 4.99}}), + ("book", {"score": {"$lte": 3.5}}), + ("apple", {"score": {"$gte": 3.0}, "color": "red"}), + ("car", {"score": {"$lt": 5.1}, "color": "blue"}), + ("house", {"index": {"$gt": 25}}), + ("phone", {"index": {"$lte": 10}}), + ] + + all_results = await asyncio.gather( + *[ + store.asearch(("test",), query=query, filter=filter_) + for query, filter_ in search_queries + ] + ) + + for results, (query, filter_) in zip(all_results, search_queries): + assert len(results) > 0, f"No results for query '{query}' with filter {filter_}" + + for result in results: + if "color" in filter_: + assert result.value["color"] == filter_["color"] + + if "score" in filter_: + score = result.value["score"] + for op, value in filter_["score"].items(): + if op == "$gt": + assert score > value + elif op == "$gte": + assert score >= value + elif op == "$lt": + assert score < value + elif op == "$lte": + assert score <= value + + if "index" in filter_: + index = result.value["index"] + for op, value in filter_["index"].items(): + if op == "$gt": + assert index > value + elif op == "$gte": + assert index >= value + elif op == "$lt": + assert index < value + elif op == "$lte": + assert index <= value + + +def test_vector_search_pagination(fake_embeddings: CharacterEmbeddings) -> None: + """Test pagination with vector search.""" + store = InMemoryStore( + embedding_config={"dims": fake_embeddings.dims, "embed": fake_embeddings} + ) + for i in range(5): + store.put(("test",), f"doc{i}", {"text": f"test document number {i}"}) + + results_page1 = store.search(("test",), query="test", limit=2) + results_page2 = store.search(("test",), query="test", limit=2, offset=2) + + assert len(results_page1) == 2 + assert len(results_page2) == 2 + assert results_page1[0].key != results_page2[0].key + + all_results = store.search(("test",), query="test", limit=10) + assert len(all_results) == 5 + + +async def test_async_vector_search_pagination( + fake_embeddings: CharacterEmbeddings, +) -> None: + """Test pagination with vector search using async methods.""" + store = InMemoryStore( + embedding_config={"dims": fake_embeddings.dims, "embed": fake_embeddings} + ) + for i in range(5): + await store.aput(("test",), f"doc{i}", {"text": f"test document number {i}"}) + + results_page1 = await store.asearch(("test",), query="test", limit=2) + results_page2 = await store.asearch(("test",), query="test", limit=2, offset=2) + + assert len(results_page1) == 2 + assert len(results_page2) == 2 + assert results_page1[0].key != results_page2[0].key + + all_results = await store.asearch(("test",), query="test", limit=10) + assert len(all_results) == 5 + + +def test_vector_search_edge_cases(fake_embeddings: CharacterEmbeddings) -> None: + """Test edge cases in vector search.""" + store = InMemoryStore( + embedding_config={"dims": fake_embeddings.dims, "embed": fake_embeddings} + ) + store.put(("test",), "doc1", {"text": "test document"}) + + results = store.search(("test",), query="") + assert len(results) == 1 + + results = store.search(("test",), query=None) + assert len(results) == 1 + + long_query = "test " * 100 + results = store.search(("test",), query=long_query) + assert len(results) == 1 + + special_query = "test!@#$%^&*()" + results = store.search(("test",), query=special_query) + assert len(results) == 1 + + +async def test_async_vector_search_edge_cases( + fake_embeddings: CharacterEmbeddings, +) -> None: + """Test edge cases in vector search using async methods.""" + store = InMemoryStore( + embedding_config={"dims": fake_embeddings.dims, "embed": fake_embeddings} + ) + await store.aput(("test",), "doc1", {"text": "test document"}) + + results = await store.asearch(("test",), query="") + assert len(results) == 1 + + results = await store.asearch(("test",), query=None) + assert len(results) == 1 + + long_query = "test " * 100 + results = await store.asearch(("test",), query=long_query) + assert len(results) == 1 + + special_query = "test!@#$%^&*()" + results = await store.asearch(("test",), query=special_query) + assert len(results) == 1