Skip to content

Commit

Permalink
0.4.3 - fix postgres issues
Browse files Browse the repository at this point in the history
  • Loading branch information
lalalune committed Aug 11, 2023
1 parent 6cc0f19 commit 77a0b1b
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 49 deletions.
2 changes: 1 addition & 1 deletion agentmemory/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def increment_epoch():
epoch = get_epoch()
epoch = epoch + 1
create_memory("epoch", str(epoch))
return epoch
return int(epoch)


def get_epoch():
Expand Down
2 changes: 2 additions & 0 deletions agentmemory/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,8 @@ def update_memory(category, id, text=None, metadata=None, embedding=None):
if isinstance(value, bool):
debug_log(f"WARNING: Boolean metadata field {key} converted to string")
metadata[key] = str(value)
else:
metadata = {}

metadata["updated_at"] = datetime.datetime.now().timestamp()

Expand Down
126 changes: 80 additions & 46 deletions agentmemory/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def get(
):
category = self.category
table_name = self.client._table_name(category)

if not ids:
if limit is None:
limit = 100 # or another default value
Expand All @@ -60,9 +61,7 @@ def get(
if offset is None:
offset = 0

table_name = self.client._table_name(category)
ids = [int(i) for i in ids]

query = f"SELECT * FROM {table_name} WHERE id=ANY(%s) LIMIT %s OFFSET %s"
params = (ids, limit, offset)

Expand All @@ -71,14 +70,22 @@ def get(

# Convert rows to list of dictionaries
columns = [desc[0] for desc in self.client.cur.description]
result = [dict(zip(columns, row)) for row in rows]
metadata_columns = [col for col in columns if col not in ["id", "document", "embedding"]]

result = []
for row in rows:
item = dict(zip(columns, row))
metadata = {col: item[col] for col in metadata_columns}
item["metadata"] = metadata
result.append(item)

return {
"ids": [row["id"] for row in result],
"documents": [row["document"] for row in result],
"metadatas": [row["metadata"] for row in result],
}


def peek(self, limit=10):
return self.get(limit=limit)

Expand Down Expand Up @@ -151,14 +158,21 @@ class PostgresCategory:
def __init__(self, name):
self.name = name


default_model_path = str(Path.home() / ".cache" / "onnx_models")


class PostgresClient:
def __init__(self, connection_string, model_name = "all-MiniLM-L6-v2", model_path = default_model_path):
def __init__(
self,
connection_string,
model_name="all-MiniLM-L6-v2",
model_path=default_model_path,
):
self.connection = psycopg2.connect(connection_string)
self.cur = self.connection.cursor()
from pgvector.psycopg2 import register_vector

register_vector(self.cur) # Register PGVector functions
full_model_path = check_model(model_name=model_name, model_path=model_path)
self.model_path = full_model_path
Expand All @@ -173,7 +187,6 @@ def ensure_table_exists(self, category):
CREATE TABLE IF NOT EXISTS {table_name} (
id SERIAL PRIMARY KEY,
document TEXT NOT NULL,
metadata JSONB,
embedding VECTOR(384)
)
"""
Expand Down Expand Up @@ -226,19 +239,23 @@ def insert_memory(self, category, document, metadata={}, embedding=None, id=None
self._ensure_metadata_columns_exist(category, metadata)
table_name = self._table_name(category)

metadata_string = json.dumps(metadata) # Convert the dict to a JSON string
if embedding is None:
embedding = self.create_embedding(document)

# if the id is None, get the length of the table by counting the number of rows in the category
if id is None:
id = self.get_or_create_collection(category).count()

# Extracting the keys and values from metadata to insert them into respective columns
columns = ["id", "document", "embedding"] + list(metadata.keys())
placeholders = ["%s"] * len(columns)
values = [id, document, embedding] + list(metadata.values())

query = f"""
INSERT INTO {table_name} (id, document, metadata, embedding) VALUES (%s, %s, %s, %s)
INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({', '.join(placeholders)})
RETURNING id;
"""
self.cur.execute(query, (id, document, metadata_string, embedding))
self.cur.execute(query, tuple(values))
self.connection.commit()
return self.cur.fetchone()[0]

Expand All @@ -252,76 +269,93 @@ def add(self, category, documents, metadatas, ids):
with self.connection.cursor() as cur:
for document, metadata, id_ in zip(documents, metadatas, ids):
self._ensure_metadata_columns_exist(category, metadata)

columns = ["id", "document", "embedding"] + list(metadata.keys())
placeholders = ["%s"] * len(columns)
embedding = self.create_embedding(document)
cur.execute(
f"""
INSERT INTO {table_name} (id, document, metadata, embedding)
VALUES (%s, %s %s, %s)
""",
(id_, document, metadata, embedding),
)
values = [id_, document, embedding] + list(metadata.values())

query = f"""
INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({', '.join(placeholders)});
"""
cur.execute(query, tuple(values))
self.connection.commit()

def query(self, category, query_texts, n_results=5):
embeddings = [self.create_embedding(q) for q in query_texts]
self.ensure_table_exists(category)
table_name = self._table_name(category)
results = {
"ids": [],
"documents": [],
"metadatas": [],
"embeddings": [],
"distances": [],
}
self.ensure_table_exists(category)
table_name = self._table_name(category)
with self.connection.cursor() as cur:
for emb in embeddings:
for emb in query_texts:
query_emb = self.create_embedding(emb)
cur.execute(
f"""
SELECT id, document, metadata, embedding, embedding <-> %s AS distance
SELECT id, document, embedding, embedding <-> %s AS distance, *
FROM {table_name}
ORDER BY embedding <-> %s
LIMIT %s
""",
(emb, emb, n_results),
(query_emb, query_emb, n_results),
)
rows = cur.fetchall()
columns = [desc[0] for desc in cur.description]
metadata_columns = [
col
for col in columns
if col not in ["id", "document", "embedding", "distance"]
]
for row in rows:
results["ids"].append(row[0])
results["documents"].append(row[1])
results["metadatas"].append(row[2])
results["embeddings"].append(row[3])
results["distances"].append(row[4])
results["embeddings"].append(row[2])
results["distances"].append(row[3])
metadata = {
col: row[columns.index(col)] for col in metadata_columns
}
results["metadatas"].append(metadata)
return results

def update(self, category, id_, document=None, metadata=None, embedding=None):
self.ensure_table_exists(category)
table_name = self._table_name(category)
with self.connection.cursor() as cur:
if document:
# if metadata is a dict, convert it to a JSON string
if isinstance(metadata, dict):
metadata = json.dumps(metadata)
if embedding is None:
embedding = self.create_embedding(document)
cur.execute(
f"""
UPDATE {table_name}
SET document=%s, embedding=%s, metadata=%s
WHERE id=%s
""",
(document, embedding, metadata, id_),
)
else:
cur.execute(
f"""
UPDATE {table_name}
SET metadata=%s
WHERE id=%s
""",
(metadata, id_),
)
if metadata:
self._ensure_metadata_columns_exist(category, metadata)
columns = ["document=%s", "embedding=%s"] + [
f"{key}=%s" for key in metadata.keys()
]
values = [document, embedding] + list(metadata.values())
else:
columns = ["document=%s", "embedding=%s"]
values = [document, embedding]

query = f"""
UPDATE {table_name}
SET {', '.join(columns)}
WHERE id=%s
"""
cur.execute(query, tuple(values) + (id_,))
elif metadata:
self._ensure_metadata_columns_exist(category, metadata)
columns = [f"{key}=%s" for key in metadata.keys()]
values = list(metadata.values())
query = f"""
UPDATE {table_name}
SET {', '.join(columns)}
WHERE id=%s
"""
cur.execute(query, tuple(values) + (id_,))
self.connection.commit()

def close(self):
self.cur.close()
self.connection.close()
self.connection.close()
3 changes: 2 additions & 1 deletion agentmemory/tests/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def test_create_event():
event = get_events()[0]
assert event["document"] == "test event"
assert event["metadata"]["test"] == "test"
assert event["metadata"]["epoch"] == 1
print(event["metadata"])
assert int(event["metadata"]["epoch"]) == 1
wipe_category("events")
wipe_category("epoch")

Expand Down
9 changes: 9 additions & 0 deletions agentmemory/tests/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,17 @@ def test_memory_update():
memories = get_memories("test")
memory_id = memories[0]["id"]

update_memory("test", memory_id, "doc 1 updated no", metadata={"test": "test"})
update_memory("test", memory_id, "doc 1 updated", metadata={"test": "test"})

assert get_memory("test", memory_id)["document"] == "doc 1 updated"

create_memory("test", "new memory test", metadata={"test": "test"})
memories = get_memories("test")
memory_id = memories[0]["id"]
update_memory("test", memory_id, "doc 2 updated", metadata={"test": "test"})
assert get_memory("test", memory_id)["document"] == "doc 2 updated"

wipe_category("test")


Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

setup(
name='agentmemory',
version='0.4.2',
version='0.4.3',
description='Easy-to-use memory for agents, document search, knowledge graphing and more.',
long_description=long_description, # added this line
long_description_content_type="text/markdown", # and this line
Expand Down

0 comments on commit 77a0b1b

Please sign in to comment.