Skip to content

Commit

Permalink
Dedup store batch operations (#2534)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Nov 26, 2024
1 parent a4eb4c6 commit 1febec7
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 12 deletions.
78 changes: 77 additions & 1 deletion libs/checkpoint/langgraph/store/base/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
BaseStore,
GetOp,
Item,
ListNamespacesOp,
MatchCondition,
NameSpacePath,
Op,
PutOp,
SearchOp,
Expand Down Expand Up @@ -68,6 +71,74 @@ async def adelete(
self._aqueue[fut] = PutOp(namespace, key, None)
return await fut

async def alist_namespaces(
self,
*,
prefix: Optional[NameSpacePath] = None,
suffix: Optional[NameSpacePath] = None,
max_depth: Optional[int] = None,
limit: int = 100,
offset: int = 0,
) -> list[tuple[str, ...]]:
fut = self._loop.create_future()
match_conditions = []
if prefix:
match_conditions.append(MatchCondition(match_type="prefix", path=prefix))
if suffix:
match_conditions.append(MatchCondition(match_type="suffix", path=suffix))

op = ListNamespacesOp(
match_conditions=tuple(match_conditions),
max_depth=max_depth,
limit=limit,
offset=offset,
)
self._aqueue[fut] = op
return await fut


def _dedupe_ops(values: list[Op]) -> tuple[Optional[list[int]], list[Op]]:
"""Dedupe operations while preserving order for results.
Args:
values: List of operations to dedupe
Returns:
Tuple of (listen indices, deduped operations)
where listen indices map deduped operation results back to original positions
"""
if len(values) <= 1:
return None, list(values)

dedupped: list[Op] = []
listen: list[int] = []
puts: dict[tuple[tuple[str, ...], str], int] = {}

for op in values:
if isinstance(op, (GetOp, SearchOp, ListNamespacesOp)):
try:
listen.append(dedupped.index(op))
except ValueError:
listen.append(len(dedupped))
dedupped.append(op)
elif isinstance(op, PutOp):
putkey = (op.namespace, op.key)
if putkey in puts:
# Overwrite previous put
ix = puts[putkey]
dedupped[ix] = op
listen.append(ix)
else:
puts[putkey] = len(dedupped)
listen.append(len(dedupped))
dedupped.append(op)

else: # Any new ops will be treated regularly
listen.append(len(dedupped))
dedupped.append(op)

return listen, dedupped


async def _run(
aqueue: dict[asyncio.Future, Op], store: weakref.ReferenceType[BaseStore]
Expand All @@ -81,7 +152,12 @@ async def _run(
taken = aqueue.copy()
# action each operation
try:
results = await s.abatch(taken.values())
values = list(taken.values())
listen, dedupped = _dedupe_ops(values)
results = await s.abatch(dedupped)
if listen is not None:
results = [results[ix] for ix in listen]

# set the results of each operation
for fut, result in zip(taken, results):
fut.set_result(result)
Expand Down
88 changes: 77 additions & 11 deletions libs/checkpoint/tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,18 @@
from langgraph.store.memory import InMemoryStore


class MockAsyncBatchedStore(AsyncBatchedBaseStore):
def __init__(self) -> None:
super().__init__()
self._store = InMemoryStore()

def batch(self, ops: Iterable[Op]) -> list[Result]:
return self._store.batch(ops)

async def abatch(self, ops: Iterable[Op]) -> list[Result]:
return self._store.batch(ops)


async def test_async_batch_store(mocker: MockerFixture) -> None:
abatch = mocker.stub()

Expand Down Expand Up @@ -313,17 +325,6 @@ async def test_cannot_put_empty_namespace() -> None:
store.delete(("langgraph", "foo"), "bar")
assert store.get(("langgraph", "foo"), "bar") is None

class MockAsyncBatchedStore(AsyncBatchedBaseStore):
def __init__(self) -> None:
super().__init__()
self._store = InMemoryStore()

def batch(self, ops: Iterable[Op]) -> list[Result]:
return self._store.batch(ops)

async def abatch(self, ops: Iterable[Op]) -> list[Result]:
return self._store.batch(ops)

async_store = MockAsyncBatchedStore()
doc = {"foo": "bar"}

Expand Down Expand Up @@ -354,3 +355,68 @@ async def abatch(self, ops: Iterable[Op]) -> list[Result]:
assert (await async_store.asearch(("valid", "namespace")))[0].value == doc
await async_store.adelete(("valid", "namespace"), "key")
assert (await async_store.aget(("valid", "namespace"), "key")) is None


async def test_async_batch_store_deduplication(mocker: MockerFixture) -> None:
abatch = mocker.spy(InMemoryStore, "batch")
store = MockAsyncBatchedStore()

same_doc = {"value": "same"}
diff_doc = {"value": "different"}
await asyncio.gather(
store.aput(namespace=("test",), key="same", value=same_doc),
store.aput(namespace=("test",), key="different", value=diff_doc),
)
abatch.reset_mock()

results = await asyncio.gather(
store.aget(namespace=("test",), key="same"),
store.aget(namespace=("test",), key="same"),
store.aget(namespace=("test",), key="different"),
)

assert len(results) == 3
assert results[0] == results[1]
assert results[0] != results[2]
assert results[0].value == same_doc # type: ignore
assert results[2].value == diff_doc # type: ignore
assert len(abatch.call_args_list) == 1
ops = list(abatch.call_args_list[0].args[1])
assert len(ops) == 2
assert GetOp(("test",), "same") in ops
assert GetOp(("test",), "different") in ops

abatch.reset_mock()

doc1 = {"value": 1}
doc2 = {"value": 2}
results = await asyncio.gather(
store.aput(namespace=("test",), key="key", value=doc1),
store.aput(namespace=("test",), key="key", value=doc2),
)
assert len(abatch.call_args_list) == 1
ops = list(abatch.call_args_list[0].args[1])
assert len(ops) == 1
assert ops[0] == PutOp(("test",), "key", doc2)
assert len(results) == 2
assert all(result is None for result in results)

result = await store.aget(namespace=("test",), key="key")
assert result is not None
assert result.value == doc2

abatch.reset_mock()

results = await asyncio.gather(
store.asearch(("test",), filter={"value": 2}),
store.asearch(("test",), filter={"value": 2}),
)
assert len(abatch.call_args_list) == 1
ops = list(abatch.call_args_list[0].args[1])
assert len(ops) == 1
assert len(results) == 2
assert results[0] == results[1]
assert len(results[0]) == 1
assert results[0][0].value == doc2

abatch.reset_mock()

0 comments on commit 1febec7

Please sign in to comment.