Skip to content

Commit

Permalink
Merge pull request #146 from sarahwooders/main
Browse files Browse the repository at this point in the history
Support loading data into archival with Llama Index connectors
  • Loading branch information
sarahwooders authored Oct 27, 2023
2 parents a99de78 + 18f1496 commit 23e5221
Show file tree
Hide file tree
Showing 8 changed files with 1,764 additions and 16 deletions.
111 changes: 111 additions & 0 deletions memgpt/connectors/connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""
This file contains functions for loading data into MemGPT's archival storage.
Data can be loaded with the following command, once a load function is defined:
```
memgpt load <data-connector-type> --name <dataset-name> [ADDITIONAL ARGS]
```
"""

from llama_index import download_loader
from typing import List
import os
import typer
from memgpt.constants import MEMGPT_DIR
from memgpt.utils import estimate_openai_cost, get_index, save_index

app = typer.Typer()


@app.command("directory")
def load_directory(
name: str = typer.Option(help="Name of dataset to load."),
input_dir: str = typer.Option(None, help="Path to directory containing dataset."),
input_files: List[str] = typer.Option(None, help="List of paths to files containing dataset."),
recursive: bool = typer.Option(False, help="Recursively search for files in directory."),
):
from llama_index import SimpleDirectoryReader

if recursive:
assert input_dir is not None, "Must provide input directory if recursive is True."
reader = SimpleDirectoryReader(
input_dir=input_dir,
recursive=True,
)
else:
reader = SimpleDirectoryReader(input_files=input_files)

# load docs
print("Loading data...")
docs = reader.load_data()

# embed docs
print("Indexing documents...")
index = get_index(name, docs)
# save connector information into .memgpt metadata file
save_index(index, name)


@app.command("webpage")
def load_webpage(
name: str = typer.Option(help="Name of dataset to load."),
urls: List[str] = typer.Option(None, help="List of urls to load."),
):
from llama_index import SimpleWebPageReader

docs = SimpleWebPageReader(html_to_text=True).load_data(urls)

# embed docs
print("Indexing documents...")
index = get_index(docs)
# save connector information into .memgpt metadata file
save_index(index, name)


@app.command("database")
def load_database(
name: str = typer.Option(help="Name of dataset to load."),
query: str = typer.Option(help="Database query."),
dump_path: str = typer.Option(None, help="Path to dump file."),
scheme: str = typer.Option(None, help="Database scheme."),
host: str = typer.Option(None, help="Database host."),
port: int = typer.Option(None, help="Database port."),
user: str = typer.Option(None, help="Database user."),
password: str = typer.Option(None, help="Database password."),
dbname: str = typer.Option(None, help="Database name."),
):
from llama_index.readers.database import DatabaseReader

print(dump_path, scheme)

if dump_path is not None:
# read from database dump file
from sqlalchemy import create_engine, MetaData

engine = create_engine(f"sqlite:///{dump_path}")

db = DatabaseReader(engine=engine)
else:
assert dump_path is None, "Cannot provide both dump_path and database connection parameters."
assert scheme is not None, "Must provide database scheme."
assert host is not None, "Must provide database host."
assert port is not None, "Must provide database port."
assert user is not None, "Must provide database user."
assert password is not None, "Must provide database password."
assert dbname is not None, "Must provide database name."

db = DatabaseReader(
scheme=scheme, # Database Scheme
host=host, # Database Host
port=port, # Database Port
user=user, # Database User
password=password, # Database Password
dbname=dbname, # Database Name
)

# load data
docs = db.load_data(query=query)

index = get_index(name, docs)
save_index(index, name)
5 changes: 3 additions & 2 deletions memgpt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@

from memgpt.config import Config
from memgpt.constants import MEMGPT_DIR
from memgpt.connectors import connector
from memgpt.openai_tools import (
configure_azure_support,
check_azure_embeddings,
get_set_azure_env_vars,
)

import asyncio

app = typer.Typer()
app.add_typer(connector.app, name="load")


def clear_line():
Expand Down Expand Up @@ -109,7 +110,7 @@ def load(memgpt_agent, filename):
print(f"/load warning: loading persistence manager from {filename} failed with: {e}")


@app.command()
@app.callback(invoke_without_command=True) # make default command
def run(
persona: str = typer.Option(None, help="Specify persona"),
human: str = typer.Option(None, help="Specify human"),
Expand Down
80 changes: 78 additions & 2 deletions memgpt/memory.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
from abc import ABC, abstractmethod
import os
import datetime
import re
import faiss
import numpy as np
from typing import Optional, List, Tuple

from .constants import MESSAGE_SUMMARY_WARNING_TOKENS
from .constants import MESSAGE_SUMMARY_WARNING_TOKENS, MEMGPT_DIR
from .utils import cosine_similarity, get_local_time, printd, count_tokens
from .prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
from .openai_tools import acompletions_with_backoff as acreate, async_get_embedding_with_backoff

from llama_index import (
VectorStoreIndex,
get_response_synthesizer,
load_index_from_storage,
StorageContext,
)
from llama_index.retrievers import VectorIndexRetriever
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.indices.postprocessor import SimilarityPostprocessor


class CoreMemory(object):
"""Held in-context inside the system message
Expand Down Expand Up @@ -128,10 +140,26 @@ async def summarize_messages(
class ArchivalMemory(ABC):
@abstractmethod
def insert(self, memory_string):
"""Insert new archival memory
:param memory_string: Memory string to insert
:type memory_string: str
"""
pass

@abstractmethod
def search(self, query_string, count=None, start=None):
def search(self, query_string, count=None, start=None) -> Tuple[List[str], int]:
"""Search archival memory
:param query_string: Query string
:type query_string: str
:param count: Number of results to return (None for all)
:type count: Optional[int]
:param start: Offset to start returning results from (None if 0)
:type start: Optional[int]
:return: Tuple of (list of results, total number of results)
"""
pass

@abstractmethod
Expand Down Expand Up @@ -515,3 +543,51 @@ async def text_search(self, query_string, count=None, start=None):
return matches[start:], len(matches)
else:
return matches, len(matches)


class LocalArchivalMemory(ArchivalMemory):
"""Archival memory built on top of Llama Index"""

def __init__(self, archival_memory_database: Optional[str] = None, top_k: Optional[int] = 100):
"""Init function for archival memory
:param archiva_memory_database: name of dataset to pre-fill archival with
:type archival_memory_database: str
"""

if archival_memory_database is not None:
# TODO: load form ~/.memgpt/archival
directory = f"{MEMGPT_DIR}/archival/{archival_memory_database}"
assert os.path.exists(directory), f"Archival memory database {archival_memory_database} does not exist"
storage_context = StorageContext.from_defaults(persist_dir=directory)
self.index = load_index_from_storage(storage_context)
else:
self.index = VectorIndex()
self.top_k = top_k
self.retriever = VectorIndexRetriever(
index=self.index, # does this get refreshed?
similarity_top_k=self.top_k,
)
# TODO: have some mechanism for cleanup otherwise will lead to OOM
self.cache = {}

async def insert(self, memory_string):
self.index.insert(memory_string)

async def search(self, query_string, count=None, start=None):
start = start if start else 0
count = count if count else self.top_k
count = min(count + start, self.top_k)

if query_string not in self.cache:
self.cache[query_string] = self.retriever.retrieve(query_string)

results = self.cache[query_string][start : start + count]
results = [{"timestamp": get_local_time(), "content": node.node.text} for node in results]
# from pprint import pprint
# pprint(results)
return results, len(results)

def __repr__(self) -> str:
print(self.index.ref_doc_info)
return ""
69 changes: 69 additions & 0 deletions memgpt/persistence_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
DummyArchivalMemory,
DummyArchivalMemoryWithEmbeddings,
DummyArchivalMemoryWithFaiss,
LocalArchivalMemory,
)
from .utils import get_local_time, printd

Expand Down Expand Up @@ -100,6 +101,74 @@ def update_memory(self, new_memory):
self.memory = new_memory


class LocalStateManager(PersistenceManager):
"""In-memory state manager has nothing to manage, all agents are held in-memory"""

recall_memory_cls = DummyRecallMemory
archival_memory_cls = LocalArchivalMemory

def __init__(self, archival_memory_db=None):
# Memory held in-state useful for debugging stateful versions
self.memory = None
self.messages = []
self.all_messages = []
self.archival_memory = LocalArchivalMemory(archival_memory_database=archival_memory_db)

@staticmethod
def load(filename):
with open(filename, "rb") as f:
return pickle.load(f)

def save(self, filename):
with open(filename, "wb") as fh:
pickle.dump(self, fh, protocol=pickle.HIGHEST_PROTOCOL)

def init(self, agent):
printd(f"Initializing InMemoryStateManager with agent object")
self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
self.memory = agent.memory
printd(f"InMemoryStateManager.all_messages.len = {len(self.all_messages)}")
printd(f"InMemoryStateManager.messages.len = {len(self.messages)}")

# Persistence manager also handles DB-related state
self.recall_memory = self.recall_memory_cls(message_database=self.all_messages)

# TODO: init archival memory here?

def trim_messages(self, num):
# printd(f"InMemoryStateManager.trim_messages")
self.messages = [self.messages[0]] + self.messages[num:]

def prepend_to_messages(self, added_messages):
# first tag with timestamps
added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]

printd(f"InMemoryStateManager.prepend_to_message")
self.messages = [self.messages[0]] + added_messages + self.messages[1:]
self.all_messages.extend(added_messages)

def append_to_messages(self, added_messages):
# first tag with timestamps
added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]

printd(f"InMemoryStateManager.append_to_messages")
self.messages = self.messages + added_messages
self.all_messages.extend(added_messages)

def swap_system_message(self, new_system_message):
# first tag with timestamps
new_system_message = {"timestamp": get_local_time(), "message": new_system_message}

printd(f"InMemoryStateManager.swap_system_message")
self.messages[0] = new_system_message
self.all_messages.append(new_system_message)

def update_memory(self, new_memory):
printd(f"InMemoryStateManager.update_memory")
self.memory = new_memory


class InMemoryStateManagerWithPreloadedArchivalMemory(InMemoryStateManager):
archival_memory_cls = DummyArchivalMemory
recall_memory_cls = DummyRecallMemory
Expand Down
Loading

0 comments on commit 23e5221

Please sign in to comment.