Skip to content

Commit

Permalink
Lancedb storage integration (#455)
Browse files Browse the repository at this point in the history
  • Loading branch information
PrashantDixit0 authored Nov 17, 2023
1 parent 85d0e0d commit f957209
Show file tree
Hide file tree
Showing 9 changed files with 421 additions and 6 deletions.
17 changes: 17 additions & 0 deletions docs/storage.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,22 @@ pip install 'pymemgpt[postgres]'
You will need to have a URI to a Postgres database which support [pgvector](https://github.com/pgvector/pgvector). You can either use a [hosted provider](https://github.com/pgvector/pgvector/issues/54) or [install pgvector](https://github.com/pgvector/pgvector#installation).


## LanceDB
In order to use the LanceDB backend.

You have to enable the LanceDB backend by running

```
memgpt configure
```
and selecting `lancedb` for archival storage, and database URI (e.g. `./.lancedb`"), Empty archival uri is also handled and default uri is set at `./.lancedb`.

To enable the LanceDB backend, make sure to install the required dependencies with:
```
pip install 'pymemgpt[lancedb]'
```
for more checkout [lancedb docs](https://lancedb.github.io/lancedb/)


## Chroma
(Coming soon)
3 changes: 2 additions & 1 deletion memgpt/autogen/examples/memgpt_coder_autogen.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
"outputs": [],
"source": [
"import openai\n",
"openai.api_key=\"YOUR_API_KEY\""
"\n",
"openai.api_key = \"YOUR_API_KEY\""
]
},
{
Expand Down
11 changes: 10 additions & 1 deletion memgpt/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def configure_cli(config: MemGPTConfig):

def configure_archival_storage(config: MemGPTConfig):
# Configure archival storage backend
archival_storage_options = ["local", "postgres"]
archival_storage_options = ["local", "lancedb", "postgres"]
archival_storage_type = questionary.select(
"Select storage backend for archival data:", archival_storage_options, default=config.archival_storage_type
).ask()
Expand All @@ -220,8 +220,17 @@ def configure_archival_storage(config: MemGPTConfig):
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):",
default=config.archival_storage_uri if config.archival_storage_uri else "",
).ask()

if archival_storage_type == "lancedb":
archival_storage_uri = questionary.text(
"Enter lanncedb connection string (e.g. ./.lancedb",
default=config.archival_storage_uri if config.archival_storage_uri else "./.lancedb",
).ask()

return archival_storage_type, archival_storage_uri

# TODO: allow configuring embedding model


@app.command()
def configure():
Expand Down
137 changes: 137 additions & 0 deletions memgpt/connectors/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Optional, List, Iterator
import numpy as np
from tqdm import tqdm
import pandas as pd

from memgpt.config import MemGPTConfig
from memgpt.connectors.storage import StorageConnector, Passage
Expand Down Expand Up @@ -181,3 +182,139 @@ def generate_table_name_agent(self, agent_config: AgentConfig):

def generate_table_name(self, name: str):
return f"memgpt_{self.sanitize_table_name(name)}"


class LanceDBConnector(StorageConnector):
"""Storage via LanceDB"""

# TODO: this should probably eventually be moved into a parent DB class

def __init__(self, name: Optional[str] = None):
config = MemGPTConfig.load()

# determine table name
if name:
self.table_name = self.generate_table_name(name)
else:
self.table_name = "lancedb_tbl"

printd(f"Using table name {self.table_name}")

# create table
self.uri = config.archival_storage_uri
if config.archival_storage_uri is None:
raise ValueError(f"Must specifiy archival_storage_uri in config {config.config_path}")
import lancedb

self.db = lancedb.connect(self.uri)
self.table = None

def get_all_paginated(self, page_size: int) -> Iterator[List[Passage]]:
session = self.Session()
offset = 0
while True:
# Retrieve a chunk of records with the given page_size
db_passages_chunk = self.table.search().limit(page_size).to_list()

# If the chunk is empty, we've retrieved all records
if not db_passages_chunk:
break

# Yield a list of Passage objects converted from the chunk
yield [
Passage(text=p["text"], embedding=p["vector"], doc_id=p["doc_id"], passage_id=p["passage_id"]) for p in db_passages_chunk
]

# Increment the offset to get the next chunk in the next iteration
offset += page_size

def get_all(self, limit=10) -> List[Passage]:
db_passages = self.table.search().limit(limit).to_list()
return [Passage(text=p["text"], embedding=p["vector"], doc_id=p["doc_id"], passage_id=p["passage_id"]) for p in db_passages]

def get(self, id: str) -> Optional[Passage]:
db_passage = self.table.where(f"passage_id={id}").to_list()
if len(db_passage) == 0:
return None
return Passage(
text=db_passage["text"], embedding=db_passage["embedding"], doc_id=db_passage["doc_id"], passage_id=db_passage["passage_id"]
)

def size(self) -> int:
# return size of table
if self.table:
return len(self.table.search().to_list())
else:
print(f"Table with name {self.table_name} not present")
return 0

def insert(self, passage: Passage):
data = [{"doc_id": passage.doc_id, "text": passage.text, "passage_id": passage.passage_id, "vector": passage.embedding}]

if self.table:
self.table.add(data)
else:
self.table = self.db.create_table(self.table_name, data=data, mode="overwrite")

def insert_many(self, passages: List[Passage], show_progress=True):
data = []
iterable = tqdm(passages) if show_progress else passages
for passage in iterable:
temp_dict = {"doc_id": passage.doc_id, "text": passage.text, "passage_id": passage.passage_id, "vector": passage.embedding}
data.append(temp_dict)

if self.table:
self.table.add(data)
else:
self.table = self.db.create_table(self.table_name, data=data, mode="overwrite")

def query(self, query: str, query_vec: List[float], top_k: int = 10) -> List[Passage]:
# Assuming query_vec is of same length as embeddings inside table
results = self.table.search(query_vec).limit(top_k)

# Convert the results into Passage objects
passages = [
Passage(text=result["text"], embedding=result["embedding"], doc_id=result["doc_id"], passage_id=result["passage_id"])
for result in results
]
return passages

def delete(self):
"""Drop the passage table from the database."""
# Drop the table specified by the PassageModel class
self.db.drop_table(self.table_name)

def save(self):
return

@staticmethod
def list_loaded_data():
config = MemGPTConfig.load()
import lancedb

db = lancedb.connect(config.archival_storage_uri)

tables = db.table_names()
tables = [table for table in tables if table.startswith("memgpt_")]
tables = [table.replace("memgpt_", "") for table in tables]
return tables

def sanitize_table_name(self, name: str) -> str:
# Remove leading and trailing whitespace
name = name.strip()

# Replace spaces and invalid characters with underscores
name = re.sub(r"\s+|\W+", "_", name)

# Truncate to the maximum identifier length
max_length = 63
if len(name) > max_length:
name = name[:max_length].rstrip("_")

# Convert to lowercase
name = name.lower()

return name

def generate_table_name(self, name: str):
return f"memgpt_{self.sanitize_table_name(name)}"
10 changes: 10 additions & 0 deletions memgpt/connectors/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ def get_storage_connector(name: Optional[str] = None, agent_config: Optional[Age

return PostgresStorageConnector(name=name, agent_config=agent_config)

elif storage_type == "lancedb":
from memgpt.connectors.db import LanceDBConnector

return LanceDBConnector(name=name)

else:
raise NotImplementedError(f"Storage type {storage_type} not implemented")

Expand All @@ -62,6 +67,11 @@ def list_loaded_data():
from memgpt.connectors.db import PostgresStorageConnector

return PostgresStorageConnector.list_loaded_data()

elif storage_type == "lancedb":
from memgpt.connectors.db import LanceDBConnector

return LanceDBConnector.list_loaded_data()
else:
raise NotImplementedError(f"Storage type {storage_type} not implemented")

Expand Down
Loading

0 comments on commit f957209

Please sign in to comment.