diff --git a/libs/checkpoint/langgraph/store/base/batch.py b/libs/checkpoint/langgraph/store/base/batch.py index 079888222..b1030942d 100644 --- a/libs/checkpoint/langgraph/store/base/batch.py +++ b/libs/checkpoint/langgraph/store/base/batch.py @@ -6,6 +6,9 @@ BaseStore, GetOp, Item, + ListNamespacesOp, + MatchCondition, + NameSpacePath, Op, PutOp, SearchOp, @@ -68,6 +71,74 @@ async def adelete( self._aqueue[fut] = PutOp(namespace, key, None) return await fut + async def alist_namespaces( + self, + *, + prefix: Optional[NameSpacePath] = None, + suffix: Optional[NameSpacePath] = None, + max_depth: Optional[int] = None, + limit: int = 100, + offset: int = 0, + ) -> list[tuple[str, ...]]: + fut = self._loop.create_future() + match_conditions = [] + if prefix: + match_conditions.append(MatchCondition(match_type="prefix", path=prefix)) + if suffix: + match_conditions.append(MatchCondition(match_type="suffix", path=suffix)) + + op = ListNamespacesOp( + match_conditions=tuple(match_conditions), + max_depth=max_depth, + limit=limit, + offset=offset, + ) + self._aqueue[fut] = op + return await fut + + +def _dedupe_ops(values: list[Op]) -> tuple[Optional[list[int]], list[Op]]: + """Dedupe operations while preserving order for results. + + Args: + values: List of operations to dedupe + + Returns: + Tuple of (listen indices, deduped operations) + where listen indices map deduped operation results back to original positions + """ + if len(values) <= 1: + return None, list(values) + + dedupped: list[Op] = [] + listen: list[int] = [] + puts: dict[tuple[tuple[str, ...], str], int] = {} + + for op in values: + if isinstance(op, (GetOp, SearchOp, ListNamespacesOp)): + try: + listen.append(dedupped.index(op)) + except ValueError: + listen.append(len(dedupped)) + dedupped.append(op) + elif isinstance(op, PutOp): + putkey = (op.namespace, op.key) + if putkey in puts: + # Overwrite previous put + ix = puts[putkey] + dedupped[ix] = op + listen.append(ix) + else: + puts[putkey] = len(dedupped) + listen.append(len(dedupped)) + dedupped.append(op) + + else: # Any new ops will be treated regularly + listen.append(len(dedupped)) + dedupped.append(op) + + return listen, dedupped + async def _run( aqueue: dict[asyncio.Future, Op], store: weakref.ReferenceType[BaseStore] @@ -81,7 +152,12 @@ async def _run( taken = aqueue.copy() # action each operation try: - results = await s.abatch(taken.values()) + values = list(taken.values()) + listen, dedupped = _dedupe_ops(values) + results = await s.abatch(dedupped) + if listen is not None: + results = [results[ix] for ix in listen] + # set the results of each operation for fut, result in zip(taken, results): fut.set_result(result) diff --git a/libs/checkpoint/tests/test_store.py b/libs/checkpoint/tests/test_store.py index 9d06281d0..0ecd4bd84 100644 --- a/libs/checkpoint/tests/test_store.py +++ b/libs/checkpoint/tests/test_store.py @@ -10,6 +10,18 @@ from langgraph.store.memory import InMemoryStore +class MockAsyncBatchedStore(AsyncBatchedBaseStore): + def __init__(self) -> None: + super().__init__() + self._store = InMemoryStore() + + def batch(self, ops: Iterable[Op]) -> list[Result]: + return self._store.batch(ops) + + async def abatch(self, ops: Iterable[Op]) -> list[Result]: + return self._store.batch(ops) + + async def test_async_batch_store(mocker: MockerFixture) -> None: abatch = mocker.stub() @@ -313,17 +325,6 @@ async def test_cannot_put_empty_namespace() -> None: store.delete(("langgraph", "foo"), "bar") assert store.get(("langgraph", "foo"), "bar") is None - class MockAsyncBatchedStore(AsyncBatchedBaseStore): - def __init__(self) -> None: - super().__init__() - self._store = InMemoryStore() - - def batch(self, ops: Iterable[Op]) -> list[Result]: - return self._store.batch(ops) - - async def abatch(self, ops: Iterable[Op]) -> list[Result]: - return self._store.batch(ops) - async_store = MockAsyncBatchedStore() doc = {"foo": "bar"} @@ -354,3 +355,68 @@ async def abatch(self, ops: Iterable[Op]) -> list[Result]: assert (await async_store.asearch(("valid", "namespace")))[0].value == doc await async_store.adelete(("valid", "namespace"), "key") assert (await async_store.aget(("valid", "namespace"), "key")) is None + + +async def test_async_batch_store_deduplication(mocker: MockerFixture) -> None: + abatch = mocker.spy(InMemoryStore, "batch") + store = MockAsyncBatchedStore() + + same_doc = {"value": "same"} + diff_doc = {"value": "different"} + await asyncio.gather( + store.aput(namespace=("test",), key="same", value=same_doc), + store.aput(namespace=("test",), key="different", value=diff_doc), + ) + abatch.reset_mock() + + results = await asyncio.gather( + store.aget(namespace=("test",), key="same"), + store.aget(namespace=("test",), key="same"), + store.aget(namespace=("test",), key="different"), + ) + + assert len(results) == 3 + assert results[0] == results[1] + assert results[0] != results[2] + assert results[0].value == same_doc # type: ignore + assert results[2].value == diff_doc # type: ignore + assert len(abatch.call_args_list) == 1 + ops = list(abatch.call_args_list[0].args[1]) + assert len(ops) == 2 + assert GetOp(("test",), "same") in ops + assert GetOp(("test",), "different") in ops + + abatch.reset_mock() + + doc1 = {"value": 1} + doc2 = {"value": 2} + results = await asyncio.gather( + store.aput(namespace=("test",), key="key", value=doc1), + store.aput(namespace=("test",), key="key", value=doc2), + ) + assert len(abatch.call_args_list) == 1 + ops = list(abatch.call_args_list[0].args[1]) + assert len(ops) == 1 + assert ops[0] == PutOp(("test",), "key", doc2) + assert len(results) == 2 + assert all(result is None for result in results) + + result = await store.aget(namespace=("test",), key="key") + assert result is not None + assert result.value == doc2 + + abatch.reset_mock() + + results = await asyncio.gather( + store.asearch(("test",), filter={"value": 2}), + store.asearch(("test",), filter={"value": 2}), + ) + assert len(abatch.call_args_list) == 1 + ops = list(abatch.call_args_list[0].args[1]) + assert len(ops) == 1 + assert len(results) == 2 + assert results[0] == results[1] + assert len(results[0]) == 1 + assert results[0][0].value == doc2 + + abatch.reset_mock()