diff --git a/README.md b/README.md index bfe1a13919..91f418c6a6 100644 --- a/README.md +++ b/README.md @@ -160,7 +160,7 @@ history = m.history(memory_id=) ### Graph Memory To initialize Graph Memory you'll need to set up your configuration with graph store providers. -Currently, we support Neo4j as a graph store provider. You can setup [Neo4j](https://neo4j.com/) locally or use the hosted [Neo4j AuraDB](https://neo4j.com/product/auradb/). +Currently, we support FalkorDB and Neo4j as a graph store providers. You can set up [FalkorDB](https://www.falkordb.com/) or [Neo4j](https://neo4j.com/) locally or use the hosted [FalkorDB Cloud](https://app.falkordb.cloud/) or [Neo4j AuraDB](https://neo4j.com/product/auradb/). Moreover, you also need to set the version to `v1.1` (*prior versions are not supported*). Here's how you can do it: @@ -169,11 +169,12 @@ from mem0 import Memory config = { "graph_store": { - "provider": "neo4j", + "provider": "falkordb", "config": { - "url": "neo4j+s://xxx", - "username": "neo4j", - "password": "xxx" + "host": "---" + "username": "---", + "password": "---", + "port": "---" } }, "version": "v1.1" diff --git a/cookbooks/mem0_graph_memory.py b/cookbooks/mem0_graph_memory.py new file mode 100644 index 0000000000..b29547d47a --- /dev/null +++ b/cookbooks/mem0_graph_memory.py @@ -0,0 +1,40 @@ +# This example shows how to use graph config to use falkordb graph databese +import os +from mem0 import Memory +from dotenv import load_dotenv + +# Loading OpenAI API Key +load_dotenv() +OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY') +USER_ID = "test" + +# Creating the config dict from the environment variables +config = { + "llm": { # This is the language model configuration, use your carditionals + "provider": "openai", + "config": { + "model": "gpt-4o-mini", + "temperature": 0 + } + }, + "graph_store": { # See https://app.falkordb.cloud/ for the carditionals + "provider": "falkordb", + "config": { + "host": os.environ['HOST'], + "username": os.environ['USERNAME'], + "password": os.environ['PASSWORD'], + "port": os.environ['PORT'] + } + }, + "version": "v1.1" +} + +# Create the memory class using from config +memory = Memory.from_config(config_dict=config) + +# Use the Mem0 to add and search memories +memory.add("I like painting", user_id=USER_ID) +memory.add("I hate playing badminton", user_id=USER_ID) +print(memory.get_all(user_id=USER_ID)) +memory.add("My friend name is john and john has a dog named tommy", user_id=USER_ID) +print(memory.search("What I like to do", user_id=USER_ID)) diff --git a/docs/open-source/graph_memory/features.mdx b/docs/open-source/graph_memory/features.mdx index dd2cc3e352..8cc35d3c30 100644 --- a/docs/open-source/graph_memory/features.mdx +++ b/docs/open-source/graph_memory/features.mdx @@ -19,11 +19,13 @@ from mem0 import Memory config = { "graph_store": { - "provider": "neo4j", + "provider": "falkordb", "config": { - "url": "neo4j+s://xxx", - "username": "neo4j", - "password": "xxx" + "Database": "falkordb", + "host": "---" + "username": "---", + "password": "---", + "port": "---" }, "custom_prompt": "Please only extract entities containing sports related relationships and nothing else.", }, diff --git a/docs/open-source/graph_memory/overview.mdx b/docs/open-source/graph_memory/overview.mdx index fe933db8d4..395c8402e0 100644 --- a/docs/open-source/graph_memory/overview.mdx +++ b/docs/open-source/graph_memory/overview.mdx @@ -36,12 +36,11 @@ allowfullscreen ## Initialize Graph Memory To initialize Graph Memory you'll need to set up your configuration with graph store providers. -Currently, we support Neo4j as a graph store provider. You can setup [Neo4j](https://neo4j.com/) locally or use the hosted [Neo4j AuraDB](https://neo4j.com/product/auradb/). -Moreover, you also need to set the version to `v1.1` (*prior versions are not supported*). +Currently, we support FalkorDB and Neo4j as a graph store providers. You can set up [FalkorDB](https://www.falkordb.com/) or [Neo4j](https://neo4j.com/) locally or use the hosted [FalkorDB Cloud](https://app.falkordb.cloud/) or [Neo4j AuraDB](https://neo4j.com/product/auradb/). +Moreover, you also need to set the version to `v1.1` (*prior versions are not supported*). If you are using Neo4j locally, then you need to install [APOC plugins](https://neo4j.com/labs/apoc/4.1/installation/). - User can also customize the LLM for Graph Memory from the [Supported LLM list](https://docs.mem0.ai/components/llms/overview) with three levels of configuration: 1. **Main Configuration**: If `llm` is set in the main config, it will be used for all graph operations. @@ -57,11 +56,13 @@ from mem0 import Memory config = { "graph_store": { - "provider": "neo4j", + "provider": "falkordb", "config": { - "url": "neo4j+s://xxx", - "username": "neo4j", - "password": "xxx" + "Database": "falkordb", + "host": "---" + "username": "---", + "password": "---", + "port": "---" } }, "version": "v1.1" @@ -83,11 +84,13 @@ config = { } }, "graph_store": { - "provider": "neo4j", + "provider": "falkordb", "config": { - "url": "neo4j+s://xxx", - "username": "neo4j", - "password": "xxx" + "Database": "falkordb", + "host": "---" + "username": "---", + "password": "---", + "port": "---" }, "llm" : { "provider": "openai", diff --git a/docs/open-source/quickstart.mdx b/docs/open-source/quickstart.mdx index 45d030e97a..2b830d16a6 100644 --- a/docs/open-source/quickstart.mdx +++ b/docs/open-source/quickstart.mdx @@ -63,11 +63,13 @@ from mem0 import Memory config = { "graph_store": { - "provider": "neo4j", + "provider": "falkordb", "config": { - "url": "neo4j+s://---", - "username": "neo4j", - "password": "---" + "Database": "falkordb", + "host": "---" + "username": "---", + "password": "---", + "port": "---" } }, "version": "v1.1" diff --git a/mem0/graphs/configs.py b/mem0/graphs/configs.py index c14249adfd..be5a2bae8d 100644 --- a/mem0/graphs/configs.py +++ b/mem0/graphs/configs.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union from pydantic import BaseModel, Field, field_validator, model_validator @@ -20,11 +20,31 @@ def check_host_port_or_path(cls, values): if not url or not username or not password: raise ValueError("Please provide 'url', 'username' and 'password'.") return values + +class FalkorDBConfig(BaseModel): + host: Optional[str] = Field(None, description="Host address for the graph database") + username: Optional[str] = Field(None, description="Username for the graph database") + password: Optional[str] = Field(None, description="Password for the graph database") + port: Optional[int] = Field(None, description="Port for the graph database") + # Default database name is mandatory in langchain + database: str = "_default_" + + @model_validator(mode="before") + def check_host_port_or_path(cls, values): + host, port = ( + values.get("host"), + values.get("port"), + ) + if not host or not port: + raise ValueError( + "Please provide 'host' and 'port'." + ) + return values class GraphStoreConfig(BaseModel): - provider: str = Field(description="Provider of the data store (e.g., 'neo4j')", default="neo4j") - config: Neo4jConfig = Field(description="Configuration for the specific data store", default=None) + provider: str = Field(description="Provider of the data store (e.g., 'falkordb', 'neo4j')", default="falkordb") + config: Union[FalkorDBConfig, Neo4jConfig] = Field(description="Configuration for the specific data store", default=None) llm: Optional[LlmConfig] = Field(description="LLM configuration for querying the graph store", default=None) custom_prompt: Optional[str] = Field( description="Custom prompt to fetch entities from the given text", default=None @@ -35,5 +55,11 @@ def validate_config(cls, v, values): provider = values.data.get("provider") if provider == "neo4j": return Neo4jConfig(**v.model_dump()) + elif provider == "falkordb": + config = v.model_dump() + # In case the user try to use diffrent database name + config["database"] = "_default_" + + return FalkorDBConfig(**config) else: raise ValueError(f"Unsupported graph store provider: {provider}") diff --git a/mem0/graphs/utils.py b/mem0/graphs/utils.py index efc14db03d..db637fe6d6 100644 --- a/mem0/graphs/utils.py +++ b/mem0/graphs/utils.py @@ -53,6 +53,52 @@ Adhere strictly to these guidelines to ensure high-quality knowledge graph extraction.""" +FALKORDB_QUERY = """ +MATCH (n) +WHERE n.embedding IS NOT NULL AND n.user_id = $user_id +WITH n, + reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) / + (sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) * + sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))) AS similarity +WHERE similarity >= $threshold +MATCH (n)-[r]->(m) +RETURN n.name AS source, Id(n) AS source_id, type(r) AS relation, Id(r) AS relation_id, m.name AS destination, Id(m) AS destination_id, similarity +UNION +MATCH (n) +WHERE n.embedding IS NOT NULL AND n.user_id = $user_id +WITH n, + reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) / + (sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) * + sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))) AS similarity +WHERE similarity >= $threshold +MATCH (m)-[r]->(n) +RETURN m.name AS source, Id(m) AS source_id, type(r) AS relation, Id(r) AS relation_id, n.name AS destination, Id(n) AS destination_id, similarity +ORDER BY similarity DESC +""" + + +NEO4J_QUERY = """ +MATCH (n) +WHERE n.embedding IS NOT NULL AND n.user_id = $user_id +WITH n, + round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) / + (sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) * + sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity +WHERE similarity >= $threshold +MATCH (n)-[r]->(m) +RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relation, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id, similarity +UNION +MATCH (n) +WHERE n.embedding IS NOT NULL AND n.user_id = $user_id +WITH n, + round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) / + (sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) * + sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity +WHERE similarity >= $threshold +MATCH (m)-[r]->(n) +RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relation, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity +ORDER BY similarity DESC +""" def get_update_memory_prompt(existing_memories, memory, template): return template.format(existing_memories=existing_memories, memory=memory) diff --git a/mem0/memory/graph_memory.py b/mem0/memory/graph_memory.py index 55d644c901..a13a948626 100644 --- a/mem0/memory/graph_memory.py +++ b/mem0/memory/graph_memory.py @@ -1,10 +1,5 @@ import logging -try: - from langchain_community.graphs import Neo4jGraph -except ImportError: - raise ImportError("langchain_community is not installed. Please install it using pip install langchain-community") - try: from rank_bm25 import BM25Okapi except ImportError: @@ -22,8 +17,14 @@ UPDATE_MEMORY_STRUCT_TOOL_GRAPH, UPDATE_MEMORY_TOOL_GRAPH, ) -from mem0.graphs.utils import EXTRACT_ENTITIES_PROMPT, get_update_memory_messages -from mem0.utils.factory import EmbedderFactory, LlmFactory +from mem0.graphs.utils import ( + EXTRACT_ENTITIES_PROMPT, + FALKORDB_QUERY, + NEO4J_QUERY, + get_update_memory_messages, +) +from mem0.utils.factory import EmbedderFactory, LlmFactory, GraphFactory + logger = logging.getLogger(__name__) @@ -31,10 +32,11 @@ class MemoryGraph: def __init__(self, config): self.config = config - self.graph = Neo4jGraph( - self.config.graph_store.config.url, - self.config.graph_store.config.username, - self.config.graph_store.config.password, + self.graph = GraphFactory.create( + self.config.graph_store.provider, self.config.graph_store.config + ) + self.embedding_model = EmbedderFactory.create( + self.config.embedder.provider, self.config.embedder.config ) self.embedding_model = EmbedderFactory.create(self.config.embedder.provider, self.config.embedder.config) @@ -161,7 +163,7 @@ def add(self, data, filters): "user_id": filters["user_id"], } - _ = self.graph.query(cypher, params=params) + _ = self.graph_query(cypher, params=params) logger.info(f"Added {len(to_be_added)} new memories to the graph") @@ -201,36 +203,20 @@ def _search(self, query, filters, limit=100): for node in node_list: n_embedding = self.embedding_model.embed(node) - cypher_query = """ - MATCH (n) - WHERE n.embedding IS NOT NULL AND n.user_id = $user_id - WITH n, - round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) / - (sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) * - sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity - WHERE similarity >= $threshold - MATCH (n)-[r]->(m) - RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relation, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id, similarity - UNION - MATCH (n) - WHERE n.embedding IS NOT NULL AND n.user_id = $user_id - WITH n, - round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) / - (sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) * - sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity - WHERE similarity >= $threshold - MATCH (m)-[r]->(n) - RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relation, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity - ORDER BY similarity DESC - LIMIT $limit - """ + if self.config.graph_store.provider == "falkordb": + cypher_query = FALKORDB_QUERY + elif self.config.graph_store.provider == "neo4j": + cypher_query = NEO4J_QUERY + else: + raise ValueError("Unsupported graph database provider for querying") + params = { "n_embedding": n_embedding, "threshold": self.threshold, "user_id": filters["user_id"], "limit": limit, } - ans = self.graph.query(cypher_query, params=params) + ans = self.graph_query(cypher_query, params=params) result_relations.extend(ans) return result_relations @@ -255,7 +241,7 @@ def search(self, query, filters, limit=100): if not search_output: return [] - search_outputs_sequence = [[item["source"], item["relation"], item["destination"]] for item in search_output] + search_outputs_sequence = [[item[0], item[2], item[4]] for item in search_output] bm25 = BM25Okapi(search_outputs_sequence) tokenized_query = query.split(" ") @@ -275,7 +261,7 @@ def delete_all(self, filters): DETACH DELETE n """ params = {"user_id": filters["user_id"]} - self.graph.query(cypher, params=params) + self.graph_query(cypher, params=params) def get_all(self, filters, limit=100): """ @@ -296,17 +282,15 @@ def get_all(self, filters, limit=100): RETURN n.name AS source, type(r) AS relationship, m.name AS target LIMIT $limit """ - results = self.graph.query(query, params={"user_id": filters["user_id"], "limit": limit}) + results = self.graph_query(query, params={"user_id": filters["user_id"]}) final_results = [] for result in results: - final_results.append( - { - "source": result["source"], - "relationship": result["relationship"], - "target": result["target"], - } - ) + final_results.append({ + "source": result[0], + "relationship": result[1], + "target": result[2] + }) logger.info(f"Retrieved {len(final_results)} relationships") @@ -334,7 +318,7 @@ def _update_relationship(self, source, target, relationship, filters): MERGE (n1 {name: $source, user_id: $user_id}) MERGE (n2 {name: $target, user_id: $user_id}) """ - self.graph.query( + self.graph_query( check_and_create_query, params={"source": source, "target": target, "user_id": filters["user_id"]}, ) @@ -344,7 +328,7 @@ def _update_relationship(self, source, target, relationship, filters): MATCH (n1 {name: $source, user_id: $user_id})-[r]->(n2 {name: $target, user_id: $user_id}) DELETE r """ - self.graph.query( + self.graph_query( delete_query, params={"source": source, "target": target, "user_id": filters["user_id"]}, ) @@ -355,10 +339,34 @@ def _update_relationship(self, source, target, relationship, filters): CREATE (n1)-[r:{relationship}]->(n2) RETURN n1, r, n2 """ - result = self.graph.query( + result = self.graph_query( create_query, params={"source": source, "target": target, "user_id": filters["user_id"]}, ) if not result: raise Exception(f"Failed to update or create relationship between {source} and {target}") + + def graph_query(self, query, params): + """ + Execute a Cypher query on the graph database. + FalkorDB supported multi-graph usage, the graphs is switched based on the user_id. + + Args: + query (str): The Cypher query to execute. + params (dict): A dictionary containing params to be applied during the query. + + Returns: + list: A list of dictionaries containing the results of the query. + """ + if self.config.graph_store.provider == "falkordb": + # TODO: Use langchain to switch graphs after the multi-graph feature is released + self.graph._graph = self.graph._driver.select_graph(params["user_id"]) + + query_output = self.graph.query(query, params=params) + + if self.config.graph_store.provider == "neo4j": + query_output = [list(d.values()) for d in query_output] + + + return query_output \ No newline at end of file diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index 5e0defc307..280d1f0e39 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -77,3 +77,20 @@ def create(cls, provider_name, config): return vector_store_instance(**config) else: raise ValueError(f"Unsupported VectorStore provider: {provider_name}") + +class GraphFactory: + provider_to_class = { + "falkordb": "langchain_community.graphs.FalkorDBGraph", + "neo4j": "langchain_community.graphs.Neo4jGraph", + } + + @classmethod + def create(cls, provider_name, config): + class_type = cls.provider_to_class.get(provider_name) + if class_type: + if not isinstance(config, dict): + config = config.model_dump() + graph_instance = load_class(class_type) + return graph_instance(**config) + else: + raise ValueError(f"Unsupported graph provider: {provider_name}") \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 66b08af481..cdc7bf2b66 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -372,6 +372,19 @@ files = [ [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "falkordb" +version = "1.0.8" +description = "Python client for interacting with FalkorDB database" +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "falkordb-1.0.8.tar.gz", hash = "sha256:14a68ab9d684553caf8302602c18c8148c403a0d124a8a5f45de9ea43529b2c6"}, +] + +[package.dependencies] +redis = ">=5.0.1,<6.0.0" + [[package]] name = "frozenlist" version = "1.4.1" @@ -1635,6 +1648,24 @@ numpy = "*" [package.extras] dev = ["pytest"] +[[package]] +name = "redis" +version = "5.0.8" +description = "Python client for Redis database and key-value store" +optional = false +python-versions = ">=3.7" +files = [ + {file = "redis-5.0.8-py3-none-any.whl", hash = "sha256:56134ee08ea909106090934adc36f65c9bcbbaecea5b21ba704ba6fb561f8eb4"}, + {file = "redis-5.0.8.tar.gz", hash = "sha256:0c5b10d387568dfe0698c6fad6615750c24170e548ca2deac10c649d463e9870"}, +] + +[package.dependencies] +async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""} + +[package.extras] +hiredis = ["hiredis (>1.0.0)"] +ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"] + [[package]] name = "requests" version = "2.32.3" diff --git a/pyproject.toml b/pyproject.toml index 95a0505785..8f81f323f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ sqlalchemy = "^2.0.31" langchain-community = "^0.3.1" neo4j = "^5.23.1" rank-bm25 = "^0.2.2" +falkordb = "^1.0.8" [tool.poetry.group.test.dependencies] pytest = "^8.2.2"