diff --git a/.gitignore b/.gitignore index ad91734..40cb0e8 100644 --- a/.gitignore +++ b/.gitignore @@ -129,3 +129,4 @@ dmypy.json .pyre/ src/kg_chat/graph_output/knowledge_graph.html data/database/*.db +tests/input/database/*.db diff --git a/README.md b/README.md index c79b1af..e7e21cc 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,9 @@ LLM-based chatbot that queries and visualizes [`KGX`](https://github.com/biolink 3. Install the APOC plugin in Neo4j Desktop. 4. Update settings to match [`neo4j_db_settings.conf`](conf_files/neo4j_db_settings.conf). -### General Setup +### General Setup + +#### For Developers 1. Clone this repository. 2. Create a virtual environment and install dependencies: ```shell @@ -23,6 +25,16 @@ LLM-based chatbot that queries and visualizes [`KGX`](https://github.com/biolink ``` 3. Replace [`data/nodes.tsv`](data/nodes.tsv) and [`data/edges.tsv`](data/edges.tsv) with desired KGX files if needed. +### For using kg-chat as a dependency + +```shell +pip install kg-chat +``` +OR +```shell +poetry add kg-chat@latest +``` + ### Supported Backends - DuckDB [default] - Neo4j @@ -31,27 +43,28 @@ LLM-based chatbot that queries and visualizes [`KGX`](https://github.com/biolink 1. **Import KG**: Load nodes and edges into a database (default: duckdb). ```shell - poetry run kg import + poetry run kg import --data-dir data ``` -2. **Test Query**: Run a test query. +2. **Test Query**: Run a test query. + > NOTE: `--data-dir` is a required parameter for all commands. This is the path for the directory which contains the nodes.tsv and edges.tsv file. The filenames are expected to be exactly that. ```shell - poetry run kg test-query + poetry run kg test-query --data-dir data ``` 3. **QnA**: Ask questions about the data. ```shell - poetry run kg qna "how many nodes do we have here?" + poetry run kg qna "how many nodes do we have here?" --data-dir data ``` 4. **Chat**: Start an interactive chat session. ```shell - poetry run kg chat + poetry run kg chat --data-dir data ``` 5. **App**: Deploy a local web application. ```shell - poetry run kg app + poetry run kg app --data-dir data ``` ### Visualization diff --git a/docs/commands.rst b/docs/commands.rst index 832e196..ede4206 100644 --- a/docs/commands.rst +++ b/docs/commands.rst @@ -5,13 +5,13 @@ Commands .. code-block:: shell - poetry run kg import + poetry run kg import --data-dir data 2. ``test-query``: To test that the above worked, run a built-in test query: .. code-block:: shell - poetry run kg test-query --database neo4j + poetry run kg test-query --database neo4j --data-dir data This should return something like (as per KGX data in the repo): @@ -32,7 +32,7 @@ Commands .. code-block:: shell - poetry run kg qna "give me the sorted (descending) frequency count nodes with relationships. Give me label and id. I want this as a table " + poetry run kg qna "give me the sorted (descending) frequency count nodes with relationships. Give me label and id. I want this as a table " --data-dir data This should return @@ -126,7 +126,7 @@ Commands .. code-block:: shell - poetry run kg chat --database neo4j + poetry run kg chat --database neo4j --data-dir data Gives you the following: @@ -208,7 +208,7 @@ Commands .. code-block:: shell - kg-chat $ poetry run kg chat + kg-chat $ poetry run kg chat --data-dir data Ask me about your data! : show me 20 edges with subject prefix = CHEBI @@ -421,7 +421,7 @@ Commands .. code-block:: shell - poetry run kg app + poetry run kg app --data-dir data This will start the app on http://localhost:8050/ which can be accessed in the browser. diff --git a/docs/setup.rst b/docs/setup.rst index 2cb2183..bd8c090 100644 --- a/docs/setup.rst +++ b/docs/setup.rst @@ -26,6 +26,8 @@ Setup Update the memory heaps as per your preference. +For Developers +--------------- 6. Clone this repository locally 7. Create a virtual environment of your choice and ``pip install poetry`` in it. @@ -38,3 +40,24 @@ Setup poetry install 9. Replace the ``data/nodes.tsv`` and ``data/edges.tsv`` file in the project with corresponding files of choice that needs to be queried against. + + + +For Users +---------- +10. Install the package from PyPI + + .. code-block:: shell + + pip install kg-chat + + or + .. code-block:: shell + + pip install poetry + poetry add kg-chat@latest + +.. note:: + + * The KGX files should have the names `nodes.tsv` and `edges.tsv`. + * The data directory must be provided to run every command. The data directory contains the `nodes.tsv` and `edges.tsv` files. diff --git a/src/kg_chat/cli.py b/src/kg_chat/cli.py index 2f209ce..bb8d22f 100644 --- a/src/kg_chat/cli.py +++ b/src/kg_chat/cli.py @@ -1,13 +1,14 @@ """Command line interface for kg-chat.""" import logging +from pathlib import Path from pprint import pprint +from typing import Union import click from kg_chat import __version__ from kg_chat.app import create_app -from kg_chat.constants import DATA_DIR from kg_chat.implementations.duckdb_implementation import DuckDBImplementation from kg_chat.implementations.neo4j_implementation import Neo4jImplementation from kg_chat.main import KnowledgeGraphChat @@ -28,7 +29,7 @@ "--data-dir", type=click.Path(exists=True, file_okay=False, dir_okay=True), help="Directory containing the data.", - default=DATA_DIR, + required=True, ) @@ -56,30 +57,33 @@ def main(verbose: int, quiet: bool): @main.command("import") @database_options @data_dir_option -def import_kg(database: str = "duckdb", data_dir: str = DATA_DIR): +def import_kg(database: str = "duckdb", data_dir: str = None): """Run the kg-chat's demo command.""" + if not data_dir: + raise ValueError("Data directory is required. This typically contains the KGX tsv files.") if database == "neo4j": - impl = Neo4jImplementation() - impl.load_kg(data_dir=data_dir) + impl = Neo4jImplementation(data_dir=data_dir) + impl.load_kg() elif database == "duckdb": - impl = DuckDBImplementation() - impl.load_kg(data_dir=data_dir) + impl = DuckDBImplementation(data_dir=data_dir) + impl.load_kg() else: raise ValueError(f"Database {database} not supported.") @main.command() +@data_dir_option @database_options -def test_query(database: str = "duckdb"): +def test_query(data_dir: Union[str, Path], database: str = "duckdb"): """Run the kg-chat's chat command.""" if database == "neo4j": - impl = Neo4jImplementation() + impl = Neo4jImplementation(data_dir=data_dir) query = "MATCH (n) RETURN n LIMIT 10" result = impl.execute_query(query) for record in result: print(record) elif database == "duckdb": - impl = DuckDBImplementation() + impl = DuckDBImplementation(data_dir=data_dir) query = "SELECT * FROM nodes LIMIT 10" result = impl.execute_query(query) for record in result: @@ -89,14 +93,15 @@ def test_query(database: str = "duckdb"): @main.command() +@data_dir_option @database_options -def show_schema(database: str = "duckdb"): +def show_schema(data_dir: Union[str, Path], database: str = "duckdb"): """Run the kg-chat's chat command.""" if database == "neo4j": - impl = Neo4jImplementation() + impl = Neo4jImplementation(data_dir=data_dir) impl.show_schema() elif database == "duckdb": - impl = DuckDBImplementation() + impl = DuckDBImplementation(data_dir=data_dir) impl.show_schema() else: raise ValueError(f"Database {database} not supported.") @@ -105,14 +110,15 @@ def show_schema(database: str = "duckdb"): @main.command() @database_options @click.argument("query", type=str, required=True) -def qna(query: str, database: str = "duckdb"): +@data_dir_option +def qna(query: str, data_dir: Union[str, Path], database: str = "duckdb"): """Run the kg-chat's chat command.""" if database == "neo4j": - impl = Neo4jImplementation() + impl = Neo4jImplementation(data_dir=data_dir) response = impl.get_human_response(query, impl) pprint(response) elif database == "duckdb": - impl = DuckDBImplementation() + impl = DuckDBImplementation(data_dir=data_dir) response = impl.get_human_response(query) pprint(response) else: @@ -120,15 +126,16 @@ def qna(query: str, database: str = "duckdb"): @main.command("chat") +@data_dir_option @database_options -def run_chat(database: str = "duckdb"): +def run_chat(data_dir: Union[str, Path], database: str = "duckdb"): """Run the kg-chat's chat command.""" if database == "neo4j": - impl = Neo4jImplementation() + impl = Neo4jImplementation(data_dir=data_dir) kgc = KnowledgeGraphChat(impl) kgc.chat() elif database == "duckdb": - impl = DuckDBImplementation() + impl = DuckDBImplementation(data_dir=data_dir) kgc = KnowledgeGraphChat(impl) kgc.chat() else: @@ -137,17 +144,19 @@ def run_chat(database: str = "duckdb"): @main.command("app") @click.option("--debug", is_flag=True, help="Run the app in debug mode.") +@data_dir_option @database_options def run_app( + data_dir: Union[str, Path], debug: bool = False, database: str = "duckdb", ): """Run the kg-chat's chat command.""" if database == "neo4j": - impl = Neo4jImplementation() + impl = Neo4jImplementation(data_dir=data_dir) kgc = KnowledgeGraphChat(impl) elif database == "duckdb": - impl = DuckDBImplementation() + impl = DuckDBImplementation(data_dir=data_dir) kgc = KnowledgeGraphChat(impl) else: raise ValueError(f"Database {database} not supported.") diff --git a/src/kg_chat/constants.py b/src/kg_chat/constants.py index 090ef56..50b329c 100644 --- a/src/kg_chat/constants.py +++ b/src/kg_chat/constants.py @@ -5,7 +5,6 @@ PWD = Path(__file__).parent.resolve() PROJ_DIR = PWD.parents[1] -DATA_DIR = PROJ_DIR / "data" GRAPH_OUTPUT_DIR = PWD / "graph_output" ASSETS_DIR = PWD / "assets" TEST_DIR = PROJ_DIR / "tests" @@ -21,6 +20,3 @@ NEO4J_URI = "bolt://localhost:7687" NEO4J_USERNAME = "neo4j" NEO4J_PASSWORD = "password" - - -DATABASE_DIR = DATA_DIR / "database" diff --git a/src/kg_chat/implementations/duckdb_implementation.py b/src/kg_chat/implementations/duckdb_implementation.py index f33a1d2..4516ddb 100644 --- a/src/kg_chat/implementations/duckdb_implementation.py +++ b/src/kg_chat/implementations/duckdb_implementation.py @@ -13,8 +13,6 @@ from sqlalchemy import create_engine from kg_chat.constants import ( - DATA_DIR, - DATABASE_DIR, OPEN_AI_MODEL, OPENAI_KEY, ) @@ -25,10 +23,15 @@ class DuckDBImplementation(DatabaseInterface): """Implementation of the DatabaseInterface for DuckDB.""" - def __init__(self): + def __init__(self, data_dir: Union[Path, str]): """Initialize the DuckDB database and the Langchain components.""" + if not data_dir: + raise ValueError("Data directory is required. This typically contains the KGX tsv files.") self.safe_mode = True - self.database_path = DATABASE_DIR / "kg_chat.db" + self.data_dir = Path(data_dir) + self.database_path = self.data_dir / "database/kg_chat.db" + if not self.database_path.exists(): + self.database_path.parent.mkdir(parents=True, exist_ok=True) self.conn = duckdb.connect(database=str(self.database_path)) self.llm = ChatOpenAI(model=OPEN_AI_MODEL, temperature=0, api_key=OPENAI_KEY) self.engine = create_engine(f"duckdb:///{self.database_path}") @@ -113,7 +116,7 @@ def execute_query_using_langchain(self, prompt: str): result = self.agent.invoke(prompt) return result["output"] - def load_kg(self, data_dir: Union[Path, str] = DATA_DIR): + def load_kg(self): """Load the Knowledge Graph into the database.""" def _load_kg(): @@ -189,9 +192,9 @@ def _load_kg(): return self.execute_unsafe_operation(_load_kg) - def _import_nodes(self, data_dir: Union[Path, str] = DATA_DIR): + def _import_nodes(self): columns_of_interest = ["id", "category", "name"] - nodes_filepath = Path(data_dir) / "nodes.tsv" + nodes_filepath = Path(self.data_dir) / "nodes.tsv" with open(nodes_filepath, "r") as nodes_file: header_line = nodes_file.readline().strip().split("\t") @@ -211,9 +214,9 @@ def _import_nodes(self, data_dir: Union[Path, str] = DATA_DIR): # Load data from temporary file into DuckDB self.conn.execute(f"COPY nodes FROM '{temp_nodes_file.name}' (DELIMITER '\t', HEADER)") - def _import_edges(self, data_dir: Union[Path, str] = DATA_DIR): + def _import_edges(self): edge_column_of_interest = ["subject", "predicate", "object"] - edges_filepath = Path(data_dir) / "edges.tsv" + edges_filepath = Path(self.data_dir) / "edges.tsv" with open(edges_filepath, "r") as edges_file: header_line = edges_file.readline().strip().split("\t") column_indexes = {col: idx for idx, col in enumerate(header_line) if col in edge_column_of_interest} diff --git a/src/kg_chat/implementations/neo4j_implementation.py b/src/kg_chat/implementations/neo4j_implementation.py index 056ebd0..2216196 100644 --- a/src/kg_chat/implementations/neo4j_implementation.py +++ b/src/kg_chat/implementations/neo4j_implementation.py @@ -12,7 +12,6 @@ from neo4j import GraphDatabase from kg_chat.constants import ( - DATA_DIR, DATALOAD_BATCH_SIZE, NEO4J_PASSWORD, NEO4J_URI, @@ -27,11 +26,20 @@ class Neo4jImplementation(DatabaseInterface): """Implementation of the DatabaseInterface for Neo4j.""" - def __init__(self, uri: str = NEO4J_URI, username: str = NEO4J_USERNAME, password: str = NEO4J_PASSWORD): + def __init__( + self, + data_dir: Union[str, Path], + uri: str = NEO4J_URI, + username: str = NEO4J_USERNAME, + password: str = NEO4J_PASSWORD, + ): """Initialize the Neo4j database and the Langchain components.""" + if not data_dir: + raise ValueError("Data directory is required. This typically contains the KGX tsv files.") self.driver = GraphDatabase.driver(uri, auth=(username, password)) self.graph = Neo4jGraph(url=uri, username=username, password=password) self.llm = ChatOpenAI(model=OPEN_AI_MODEL, temperature=0, api_key=OPENAI_KEY) + self.data_dir = Path(data_dir) self.chain = GraphCypherQAChain.from_llm( graph=self.graph, @@ -178,10 +186,10 @@ def show_schema(self): result = session.read_transaction(lambda tx: list(tx.run("CALL db.schema.visualization()"))) pprint(result) - def load_kg(self, data_dir: Union[str, Path] = DATA_DIR, block_size: int = DATALOAD_BATCH_SIZE): + def load_kg(self, block_size: int = DATALOAD_BATCH_SIZE): """Load the Knowledge Graph into the Neo4j database.""" - nodes_filepath = data_dir / "nodes.tsv" - edges_filepath = data_dir / "edges.tsv" + nodes_filepath = self.data_dir / "nodes.tsv" + edges_filepath = self.data_dir / "edges.tsv" def _load_kg(): # Clear the existing database diff --git a/tests/test_duckdb_impl.py b/tests/test_duckdb_impl.py index 72a68a0..65f86d5 100644 --- a/tests/test_duckdb_impl.py +++ b/tests/test_duckdb_impl.py @@ -3,19 +3,20 @@ from unittest.mock import call import pytest +from kg_chat.constants import TESTS_INPUT_DIR from kg_chat.implementations import DuckDBImplementation @pytest.fixture def db_impl(): """Fixture to initialize DuckDBImplementation.""" - return DuckDBImplementation() + return DuckDBImplementation(data_dir=TESTS_INPUT_DIR) # TODO # def test_init(mocker): # mock_connect = mocker.patch('duckdb.connect', return_value=MagicMock()) -# db_impl = DuckDBImplementation() +# db_impl = DuckDBImplementation(data_dir=TESTS_INPUT_DIR) # assert db_impl.safe_mode is True # assert db_impl.conn is not None # assert db_impl.llm is not None diff --git a/tests/test_neo4j_impl.py b/tests/test_neo4j_impl.py index 27fa832..0f10bcf 100644 --- a/tests/test_neo4j_impl.py +++ b/tests/test_neo4j_impl.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch import pytest +from kg_chat.constants import TESTS_INPUT_DIR from kg_chat.implementations import Neo4jImplementation @@ -19,7 +20,7 @@ def neo4j_impl(mock_from_llm, mock_chat_openai, mock_neo4j_graph, mock_driver): mock_chat_openai.return_value = MagicMock() mock_from_llm.return_value = MagicMock() - return Neo4jImplementation() + return Neo4jImplementation(data_dir=TESTS_INPUT_DIR) def test_toggle_safe_mode(neo4j_impl): @@ -43,7 +44,7 @@ def test_execute_query(mock_driver): mock_transaction = MagicMock() mock_driver.return_value.session.return_value.__enter__.return_value = mock_session mock_session.read_transaction.return_value = mock_transaction - neo4j_impl = Neo4jImplementation() + neo4j_impl = Neo4jImplementation(data_dir=TESTS_INPUT_DIR) query = "MATCH (n) RETURN n" result = neo4j_impl.execute_query(query) @@ -75,7 +76,7 @@ def test_clear_database(mock_driver): """Test clearing the database.""" mock_session = MagicMock() mock_driver.return_value.session.return_value.__enter__.return_value = mock_session - neo4j_impl = Neo4jImplementation() + neo4j_impl = Neo4jImplementation(data_dir=TESTS_INPUT_DIR) neo4j_impl.clear_database() mock_session.write_transaction.assert_called_once() @@ -85,7 +86,7 @@ def test_ensure_index(mock_driver): """Test ensuring that the index on :Node(id) exists.""" mock_session = MagicMock() mock_driver.return_value.session.return_value.__enter__.return_value = mock_session - neo4j_impl = Neo4jImplementation() + neo4j_impl = Neo4jImplementation(data_dir=TESTS_INPUT_DIR) neo4j_impl.ensure_index() mock_session.write_transaction.assert_called_once() @@ -115,7 +116,7 @@ def test_create_nodes(mock_driver): mock_session = MagicMock() mock_driver.return_value.session.return_value.__enter__.return_value = mock_session nodes = [{"id": "1", "category": "Person", "label": "John"}] - neo4j_impl = Neo4jImplementation() + neo4j_impl = Neo4jImplementation(data_dir=TESTS_INPUT_DIR) neo4j_impl.create_nodes(nodes) mock_session.write_transaction.assert_called_once() @@ -126,7 +127,7 @@ def test_create_edges(mock_driver): mock_session = MagicMock() mock_driver.return_value.session.return_value.__enter__.return_value = mock_session edges = [{"subject": "1", "predicate": "KNOWS", "object": "2"}] - neo4j_impl = Neo4jImplementation() + neo4j_impl = Neo4jImplementation(data_dir=TESTS_INPUT_DIR) neo4j_impl.create_edges(edges) mock_session.write_transaction.assert_called_once() @@ -138,7 +139,7 @@ def test_show_schema(mock_driver): mock_driver.return_value.session.return_value.__enter__.return_value = mock_session mock_transaction = MagicMock() mock_session.read_transaction.return_value = mock_transaction - neo4j_impl = Neo4jImplementation() + neo4j_impl = Neo4jImplementation(data_dir=TESTS_INPUT_DIR) neo4j_impl.show_schema() mock_session.read_transaction.assert_called_once()