Skip to content

Commit

Permalink
extract memory llm function call to separate methods
Browse files Browse the repository at this point in the history
  • Loading branch information
GingerMoon committed Nov 6, 2024
1 parent 77b0912 commit e7b3f2a
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 54 deletions.
109 changes: 63 additions & 46 deletions mem0/memory/graph_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,8 @@ def __init__(self, config):
self.user_id = None
self.threshold = 0.7

def add(self, data, filters):
"""
Adds data to the graph.
Args:
data (str): The data to add to the graph.
filters (dict): A dictionary containing filters to be applied during the addition.
"""

# retrieve the search results
search_output = self._search(data, filters)

# extracts nodes and relations from data
def _llm_extract_entities(self, data):
if self.config.graph_store.custom_prompt:
messages = [
{
Expand Down Expand Up @@ -94,8 +84,10 @@ def add(self, data, filters):
extracted_entities = []

logger.debug(f"Extracted entities: {extracted_entities}")
return extracted_entities

update_memory_prompt = get_update_memory_messages(search_output, extracted_entities)
def _llm_update_existing_memory(self, existing_entities, extracted_entities):
update_memory_prompt = get_update_memory_messages(existing_entities, extracted_entities)

_tools = [UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
Expand All @@ -111,20 +103,71 @@ def add(self, data, filters):
)

to_be_added = []
to_be_updated = []

for item in memory_updates["tool_calls"]:
if item["name"] == "add_graph_memory":
to_be_added.append(item["arguments"])
elif item["name"] == "update_graph_memory":
self._update_relationship(
item["arguments"]["source"],
item["arguments"]["destination"],
item["arguments"]["relationship"],
filters,
)
to_be_updated.append(item["arguments"])
elif item["name"] == "noop":
continue

return to_be_added, to_be_updated

# extracts nodes from query, used for searching
def _llm_extract_nodes(self, query, filters):
_tools = [SEARCH_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [SEARCH_STRUCT_TOOL]
search_results = self.llm.generate_response(
messages=[
{
"role": "system",
"content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities.",
},
{"role": "user", "content": query},
],
tools=_tools,
)

node_list = []

for item in search_results["tool_calls"]:
if item["name"] == "search":
try:
node_list.extend(item["arguments"]["nodes"])
except Exception as e:
logger.error(f"Error in search tool: {e}")

node_list = list(set(node_list))
node_list = [node.lower().replace(" ", "_") for node in node_list]

logger.debug(f"Node list for search query : {node_list}")
return node_list

def add(self, data, filters):
"""
Adds data to the graph.
Args:
data (str): The data to add to the graph.
filters (dict): A dictionary containing filters to be applied during the addition.
"""

# retrieve the search results
existing_entities = self._search(data, filters)
extracted_entities = self._llm_extract_entities(data)
to_be_added, to_be_updated = self._llm_update_existing_memory(existing_entities, extracted_entities)

for item in to_be_updated:
self._update_relationship(
item["arguments"]["source"],
item["arguments"]["destination"],
item["arguments"]["relationship"],
filters,
)

returned_entities = []

for item in to_be_added:
Expand Down Expand Up @@ -168,34 +211,8 @@ def add(self, data, filters):
return returned_entities

def _search(self, query, filters, limit=100):
_tools = [SEARCH_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [SEARCH_STRUCT_TOOL]
search_results = self.llm.generate_response(
messages=[
{
"role": "system",
"content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities.",
},
{"role": "user", "content": query},
],
tools=_tools,
)

node_list = []

for item in search_results["tool_calls"]:
if item["name"] == "search":
try:
node_list.extend(item["arguments"]["nodes"])
except Exception as e:
logger.error(f"Error in search tool: {e}")

node_list = list(set(node_list))
node_list = [node.lower().replace(" ", "_") for node in node_list]

logger.debug(f"Node list for search query : {node_list}")

node_list = self._llm_extract_nodes(query, filters)
result_relations = []

for node in node_list:
Expand Down
25 changes: 17 additions & 8 deletions mem0/memory/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def add(
)
return vector_store_result

def _add_to_vector_store(self, messages, metadata, filters):
def _llm_extract_facts(self, messages):
parsed_messages = parse_messages(messages)

if self.custom_prompt:
Expand All @@ -157,6 +157,21 @@ def _add_to_vector_store(self, messages, metadata, filters):
logging.error(f"Error in new_retrieved_facts: {e}")
new_retrieved_facts = []

return new_retrieved_facts

def _llm_new_memories_with_actions(self, retrieved_old_memory, new_retrieved_facts):
function_calling_prompt = get_update_memory_messages(retrieved_old_memory, new_retrieved_facts)

new_memories_with_actions = self.llm.generate_response(
messages=[{"role": "user", "content": function_calling_prompt}],
response_format={"type": "json_object"},
)
new_memories_with_actions = json.loads(new_memories_with_actions)
return new_memories_with_actions

def _add_to_vector_store(self, messages, metadata, filters):
new_retrieved_facts = self._llm_extract_facts(messages)

retrieved_old_memory = []
new_message_embeddings = {}
for new_mem in new_retrieved_facts:
Expand All @@ -178,13 +193,7 @@ def _add_to_vector_store(self, messages, metadata, filters):
temp_uuid_mapping[str(idx)] = item["id"]
retrieved_old_memory[idx]["id"] = str(idx)

function_calling_prompt = get_update_memory_messages(retrieved_old_memory, new_retrieved_facts)

new_memories_with_actions = self.llm.generate_response(
messages=[{"role": "user", "content": function_calling_prompt}],
response_format={"type": "json_object"},
)
new_memories_with_actions = json.loads(new_memories_with_actions)
new_memories_with_actions = self._llm_new_memories_with_actions(retrieved_old_memory, new_retrieved_facts)

returned_memories = []
try:
Expand Down

0 comments on commit e7b3f2a

Please sign in to comment.