Skip to content

Commit

Permalink
Add support for larger archival memory stores (#359)
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders authored Nov 9, 2023
1 parent b1ad4f0 commit 350e4af
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 19 deletions.
3 changes: 0 additions & 3 deletions memgpt/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,11 +771,8 @@ def edit_memory(self, name, content):
return None

def edit_memory_append(self, name, content):
print("edit append")
new_len = self.memory.edit_append(name, content)
print("rebuild memory")
self.rebuild_memory()
print("done")
return None

def edit_memory_replace(self, name, old_content, new_content):
Expand Down
17 changes: 11 additions & 6 deletions memgpt/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,21 +165,26 @@ def attach(
):
# loads the data contained in data source into the agent's memory
from memgpt.connectors.storage import StorageConnector
from tqdm import tqdm

agent_config = AgentConfig.load(agent)
config = MemGPTConfig.load()

# get storage connectors
source_storage = StorageConnector.get_storage_connector(name=data_source)
dest_storage = StorageConnector.get_storage_connector(agent_config=agent_config)

passages = source_storage.get_all()
for p in passages:
len(p.embedding) == config.embedding_dim, f"Mismatched embedding sizes {len(p.embedding)} != {config.embedding_dim}"
dest_storage.insert_many(passages)
size = source_storage.size()
typer.secho(f"Ingesting {size} passages into {agent_config.name}", fg=typer.colors.GREEN)
page_size = 100
generator = source_storage.get_all_paginated(page_size=page_size) # yields List[Passage]
for i in tqdm(range(0, size, page_size)):
passages = next(generator)
dest_storage.insert_many(passages, show_progress=False)

# save destination storage
dest_storage.save()

total_agent_passages = len(dest_storage.get_all())
total_agent_passages = dest_storage.size()

typer.secho(
f"Attached data source {data_source} to agent {agent}, consisting of {len(passages)}. Agent now has {total_agent_passages} embeddings in archival memory.",
Expand Down
28 changes: 25 additions & 3 deletions memgpt/connectors/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import re
from tqdm import tqdm
from typing import Optional, List
from typing import Optional, List, Iterator
import numpy as np
from tqdm import tqdm

Expand Down Expand Up @@ -76,9 +76,26 @@ def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfi
self.Session = sessionmaker(bind=self.engine)
self.Session().execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension

def get_all(self) -> List[Passage]:
def get_all_paginated(self, page_size: int) -> Iterator[List[Passage]]:
session = self.Session()
db_passages = session.query(self.db_model).all()
offset = 0
while True:
# Retrieve a chunk of records with the given page_size
db_passages_chunk = session.query(self.db_model).offset(offset).limit(page_size).all()

# If the chunk is empty, we've retrieved all records
if not db_passages_chunk:
break

# Yield a list of Passage objects converted from the chunk
yield [Passage(text=p.text, embedding=p.embedding, doc_id=p.doc_id, passage_id=p.id) for p in db_passages_chunk]

# Increment the offset to get the next chunk in the next iteration
offset += page_size

def get_all(self, limit=10) -> List[Passage]:
session = self.Session()
db_passages = session.query(self.db_model).limit(limit).all()
return [Passage(text=p.text, embedding=p.embedding, doc_id=p.doc_id, passage_id=p.id) for p in db_passages]

def get(self, id: str) -> Optional[Passage]:
Expand All @@ -88,6 +105,11 @@ def get(self, id: str) -> Optional[Passage]:
return None
return Passage(text=db_passage.text, embedding=db_passage.embedding, doc_id=db_passage.doc_id, passage_id=db_passage.passage_id)

def size(self) -> int:
# return size of table
session = self.Session()
return session.query(self.db_model).count()

def insert(self, passage: Passage):
session = self.Session()
db_passage = self.db_model(doc_id=passage.doc_id, text=passage.text, embedding=passage.embedding)
Expand Down
15 changes: 13 additions & 2 deletions memgpt/connectors/local.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, List
from typing import Optional, List, Iterator
from memgpt.config import AgentConfig, MemGPTConfig
from tqdm import tqdm
import re
Expand Down Expand Up @@ -72,11 +72,19 @@ def add_nodes(self, nodes: List[TextNode]):
self.nodes += nodes
self.index = VectorStoreIndex(self.nodes)

def get_all(self) -> List[Passage]:
def get_all_paginated(self, page_size: int = 100) -> Iterator[List[Passage]]:
"""Get all passages in the index"""
nodes = self.get_nodes()
for i in tqdm(range(0, len(nodes), page_size)):
yield [Passage(text=node.text, embedding=node.embedding) for node in nodes[i : i + page_size]]

def get_all(self, limit: int) -> List[Passage]:
passages = []
for node in self.get_nodes():
assert node.embedding is not None, f"Node embedding is None"
passages.append(Passage(text=node.text, embedding=node.embedding))
if len(passages) >= limit:
break
return passages

def get(self, id: str) -> Passage:
Expand Down Expand Up @@ -126,3 +134,6 @@ def list_loaded_data():
name = os.path.basename(data_source_file)
sources.append(name)
return sources

def size(self):
return len(self.get_nodes())
13 changes: 11 additions & 2 deletions memgpt/connectors/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
We originally tried to use Llama Index VectorIndex, but their limited API was extremely problematic.
"""
from typing import Optional, List
from typing import Optional, List, Iterator
import re
import pickle
import os
Expand Down Expand Up @@ -66,7 +66,11 @@ def list_loaded_data():
raise NotImplementedError(f"Storage type {storage_type} not implemented")

@abstractmethod
def get_all(self) -> List[Passage]:
def get_all_paginated(self, page_size: int) -> Iterator[List[Passage]]:
pass

@abstractmethod
def get_all(self, limit: int) -> List[Passage]:
pass

@abstractmethod
Expand All @@ -89,3 +93,8 @@ def query(self, query: str, query_vec: List[float], top_k: int = 10) -> List[Pas
def save(self):
"""Save state of storage connector"""
pass

@abstractmethod
def size(self):
"""Get number of passages (text/embedding pairs) in storage"""
pass
5 changes: 2 additions & 3 deletions memgpt/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,11 +818,10 @@ async def a_insert(self, memory_string):
def __repr__(self) -> str:
limit = 10
passages = []
for passage in list(self.storage.get_all())[:limit]: # TODO: only get first 10
for passage in list(self.storage.get_all(limit)): # TODO: only get first 10
passages.append(str(passage.text))
memory_str = "\n".join(passages)
return f"\n### ARCHIVAL MEMORY ###" + f"\n{memory_str}"

def __len__(self):
print("get archival storage size")
return len(self.storage.get_all())
return self.storage.size()

0 comments on commit 350e4af

Please sign in to comment.