Skip to content

Commit

Permalink
Add in-mem vector search
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw committed Nov 27, 2024
1 parent f04ce5d commit 01b4d15
Show file tree
Hide file tree
Showing 12 changed files with 1,506 additions and 112 deletions.
19 changes: 18 additions & 1 deletion libs/checkpoint-duckdb/langgraph/store/duckdb/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Op,
PutOp,
Result,
SearchItem,
SearchOp,
)

Expand Down Expand Up @@ -283,7 +284,7 @@ def _batch_search_ops(

for cur, idx in cursors:
rows = cur.fetchall()
items = [_row_to_item(_convert_ns(row[0]), row) for row in rows]
items = [_row_to_search_item(_convert_ns(row[0]), row) for row in rows]
results[idx] = items

def _batch_list_namespaces_ops(
Expand Down Expand Up @@ -376,6 +377,22 @@ def _row_to_item(
)


def _row_to_search_item(
namespace: tuple[str, ...],
row: tuple,
) -> SearchItem:
"""Convert a row from the database into an SearchItem."""
# TODO: Add support for search
_, key, val, created_at, updated_at = row
return SearchItem(
value=val if isinstance(val, dict) else json.loads(val),
key=key,
namespace=namespace,
created_at=created_at,
updated_at=updated_at,
)


def _group_ops(ops: Iterable[Op]) -> tuple[dict[type, list[tuple[int, Op]]], int]:
grouped_ops: dict[type, list[tuple[int, Op]]] = defaultdict(list)
tot = 0
Expand Down
16 changes: 10 additions & 6 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,15 +378,19 @@ def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]:
# a connection not in pipeline mode can only be used by one
# thread/coroutine at a time, so we acquire a lock
if self.supports_pipeline:
with self.lock, conn.pipeline(), conn.cursor(
binary=True, row_factory=dict_row
) as cur:
with (
self.lock,
conn.pipeline(),
conn.cursor(binary=True, row_factory=dict_row) as cur,
):
yield cur
else:
# Use connection's transaction context manager when pipeline mode not supported
with self.lock, conn.transaction(), conn.cursor(
binary=True, row_factory=dict_row
) as cur:
with (
self.lock,
conn.transaction(),
conn.cursor(binary=True, row_factory=dict_row) as cur,
):
yield cur
else:
with self.lock, conn.cursor(binary=True, row_factory=dict_row) as cur:
Expand Down
23 changes: 14 additions & 9 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,20 +338,25 @@ async def _cursor(
# a connection not in pipeline mode can only be used by one
# thread/coroutine at a time, so we acquire a lock
if self.supports_pipeline:
async with self.lock, conn.pipeline(), conn.cursor(
binary=True, row_factory=dict_row
) as cur:
async with (
self.lock,
conn.pipeline(),
conn.cursor(binary=True, row_factory=dict_row) as cur,
):
yield cur
else:
# Use connection's transaction context manager when pipeline mode not supported
async with self.lock, conn.transaction(), conn.cursor(
binary=True, row_factory=dict_row
) as cur:
async with (
self.lock,
conn.transaction(),
conn.cursor(binary=True, row_factory=dict_row) as cur,
):
yield cur
else:
async with self.lock, conn.cursor(
binary=True, row_factory=dict_row
) as cur:
async with (
self.lock,
conn.cursor(binary=True, row_factory=dict_row) as cur,
):
yield cur

def list(
Expand Down
11 changes: 7 additions & 4 deletions libs/checkpoint-postgres/langgraph/store/postgres/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
_decode_ns_bytes,
_group_ops,
_row_to_item,
_row_to_search_item,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -146,7 +147,7 @@ async def _batch_search_ops(
await cur.execute(query, params)
rows = cast(list[Row], await cur.fetchall())
items = [
_row_to_item(
_row_to_search_item(
_decode_ns_bytes(row["prefix"]), row, loader=self._deserializer
)
for row in rows
Expand Down Expand Up @@ -195,9 +196,11 @@ async def _cursor(
async with self.lock, conn.pipeline(), conn.cursor(binary=True) as cur:
yield cur
else:
async with self.lock, conn.transaction(), conn.cursor(
binary=True
) as cur:
async with (
self.lock,
conn.transaction(),
conn.cursor(binary=True) as cur,
):
yield cur
else:
async with conn.cursor(binary=True) as cur:
Expand Down
46 changes: 39 additions & 7 deletions libs/checkpoint-postgres/langgraph/store/postgres/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
ListNamespacesOp,
Op,
PutOp,
ResponseMetadata,
Result,
SearchItem,
SearchOp,
)

Expand Down Expand Up @@ -344,14 +346,18 @@ def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]:
# a connection not in pipeline mode can only be used by one
# thread/coroutine at a time, so we acquire a lock
if self.supports_pipeline:
with self.lock, conn.pipeline(), conn.cursor(
binary=True, row_factory=dict_row
) as cur:
with (
self.lock,
conn.pipeline(),
conn.cursor(binary=True, row_factory=dict_row) as cur,
):
yield cur
else:
with self.lock, conn.transaction(), conn.cursor(
binary=True, row_factory=dict_row
) as cur:
with (
self.lock,
conn.transaction(),
conn.cursor(binary=True, row_factory=dict_row) as cur,
):
yield cur
else:
with conn.cursor(binary=True, row_factory=dict_row) as cur:
Expand Down Expand Up @@ -430,7 +436,7 @@ def _batch_search_ops(
cur.execute(query, params)
rows = cast(list[Row], cur.fetchall())
results[idx] = [
_row_to_item(
_row_to_search_item(
_decode_ns_bytes(row["prefix"]), row, loader=self._deserializer
)
for row in rows
Expand Down Expand Up @@ -517,6 +523,32 @@ def _row_to_item(
)


def _row_to_search_item(
namespace: tuple[str, ...],
row: Row,
*,
loader: Optional[Callable[[Union[bytes, orjson.Fragment]], dict[str, Any]]] = None,
) -> SearchItem:
"""Convert a row from the database into an Item."""
loader = loader or _json_loads
val = row["value"]
response_metadata: Optional[ResponseMetadata] = (
{
"score": float(row["score"]),
}
if row.get("score") is not None
else None
)
return SearchItem(
value=val if isinstance(val, dict) else loader(val),
key=row["key"],
namespace=namespace,
created_at=row["created_at"],
updated_at=row["updated_at"],
response_metadata=response_metadata,
)


def _group_ops(ops: Iterable[Op]) -> tuple[dict[type, list[tuple[int, Op]]], int]:
grouped_ops: dict[type, list[tuple[int, Op]]] = defaultdict(list)
tot = 0
Expand Down
Loading

0 comments on commit 01b4d15

Please sign in to comment.