Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extract memory llm function call to separate methods #2011

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 63 additions & 46 deletions mem0/memory/graph_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,8 @@ def __init__(self, config):
self.user_id = None
self.threshold = 0.7

def add(self, data, filters):
"""
Adds data to the graph.

Args:
data (str): The data to add to the graph.
filters (dict): A dictionary containing filters to be applied during the addition.
"""

# retrieve the search results
search_output = self._search(data, filters)

# extracts nodes and relations from data
def _llm_extract_entities(self, data):
if self.config.graph_store.custom_prompt:
messages = [
{
Expand Down Expand Up @@ -94,8 +84,10 @@ def add(self, data, filters):
extracted_entities = []

logger.debug(f"Extracted entities: {extracted_entities}")
return extracted_entities

update_memory_prompt = get_update_memory_messages(search_output, extracted_entities)
def _llm_update_existing_memory(self, existing_entities, extracted_entities):
update_memory_prompt = get_update_memory_messages(existing_entities, extracted_entities)

_tools = [UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
Expand All @@ -111,20 +103,71 @@ def add(self, data, filters):
)

to_be_added = []
to_be_updated = []

for item in memory_updates["tool_calls"]:
if item["name"] == "add_graph_memory":
to_be_added.append(item["arguments"])
elif item["name"] == "update_graph_memory":
self._update_relationship(
item["arguments"]["source"],
item["arguments"]["destination"],
item["arguments"]["relationship"],
filters,
)
to_be_updated.append(item["arguments"])
elif item["name"] == "noop":
continue

return to_be_added, to_be_updated

# extracts nodes from query, used for searching
def _llm_extract_nodes(self, query, filters):
_tools = [SEARCH_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [SEARCH_STRUCT_TOOL]
search_results = self.llm.generate_response(
messages=[
{
"role": "system",
"content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities. ***DO NOT*** answer the question itself if the given text is a question.",
},
{"role": "user", "content": query},
],
tools=_tools,
)

node_list = []

for item in search_results["tool_calls"]:
if item["name"] == "search":
try:
node_list.extend(item["arguments"]["nodes"])
except Exception as e:
logger.error(f"Error in search tool: {e}")

node_list = list(set(node_list))
node_list = [node.lower().replace(" ", "_") for node in node_list]

logger.debug(f"Node list for search query : {node_list}")
return node_list

def add(self, data, filters):
"""
Adds data to the graph.

Args:
data (str): The data to add to the graph.
filters (dict): A dictionary containing filters to be applied during the addition.
"""

# retrieve the search results
existing_entities = self._search(data, filters)
extracted_entities = self._llm_extract_entities(data)
to_be_added, to_be_updated = self._llm_update_existing_memory(existing_entities, extracted_entities)

for item in to_be_updated:
self._update_relationship(
item["arguments"]["source"],
item["arguments"]["destination"],
item["arguments"]["relationship"],
filters,
)

returned_entities = []

for item in to_be_added:
Expand Down Expand Up @@ -168,34 +211,8 @@ def add(self, data, filters):
return returned_entities

def _search(self, query, filters, limit=100):
_tools = [SEARCH_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [SEARCH_STRUCT_TOOL]
search_results = self.llm.generate_response(
messages=[
{
"role": "system",
"content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities. ***DO NOT*** answer the question itself if the given text is a question.",
},
{"role": "user", "content": query},
],
tools=_tools,
)

node_list = []

for item in search_results["tool_calls"]:
if item["name"] == "search":
try:
node_list.extend(item["arguments"]["nodes"])
except Exception as e:
logger.error(f"Error in search tool: {e}")

node_list = list(set(node_list))
node_list = [node.lower().replace(" ", "_") for node in node_list]

logger.debug(f"Node list for search query : {node_list}")

node_list = self._llm_extract_nodes(query, filters)
result_relations = []

for node in node_list:
Expand Down
25 changes: 17 additions & 8 deletions mem0/memory/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def add(
)
return vector_store_result

def _add_to_vector_store(self, messages, metadata, filters):
def _llm_extract_facts(self, messages):
parsed_messages = parse_messages(messages)

if self.custom_prompt:
Expand All @@ -157,6 +157,21 @@ def _add_to_vector_store(self, messages, metadata, filters):
logging.error(f"Error in new_retrieved_facts: {e}")
new_retrieved_facts = []

return new_retrieved_facts

def _llm_new_memories_with_actions(self, retrieved_old_memory, new_retrieved_facts):
function_calling_prompt = get_update_memory_messages(retrieved_old_memory, new_retrieved_facts)

new_memories_with_actions = self.llm.generate_response(
messages=[{"role": "user", "content": function_calling_prompt}],
response_format={"type": "json_object"},
)
new_memories_with_actions = json.loads(new_memories_with_actions)
return new_memories_with_actions

def _add_to_vector_store(self, messages, metadata, filters):
new_retrieved_facts = self._llm_extract_facts(messages)

retrieved_old_memory = []
new_message_embeddings = {}
for new_mem in new_retrieved_facts:
Expand All @@ -178,13 +193,7 @@ def _add_to_vector_store(self, messages, metadata, filters):
temp_uuid_mapping[str(idx)] = item["id"]
retrieved_old_memory[idx]["id"] = str(idx)

function_calling_prompt = get_update_memory_messages(retrieved_old_memory, new_retrieved_facts)

new_memories_with_actions = self.llm.generate_response(
messages=[{"role": "user", "content": function_calling_prompt}],
response_format={"type": "json_object"},
)
new_memories_with_actions = json.loads(new_memories_with_actions)
new_memories_with_actions = self._llm_new_memories_with_actions(retrieved_old_memory, new_retrieved_facts)

returned_memories = []
try:
Expand Down
Loading