Skip to content

Commit

Permalink
Merge pull request #13 from Knowledge-Graph-Hub/rag
Browse files Browse the repository at this point in the history
Enabling RAG
  • Loading branch information
hrshdhgd authored Sep 13, 2024
2 parents f6753d3 + 0358183 commit 33cea84
Show file tree
Hide file tree
Showing 14 changed files with 2,703 additions and 837 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,4 @@ dmypy.json
data/database/*.db
tests/input/database/*.db
knowledge_graph.html
vector_store/chroma.sqlite3
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ LLM-based chatbot that queries and visualizes [`KGX`](https://github.com/biolink
| OpenAI | - `gpt-4o-2024-08-06`<br>- `gpt-4o-mini`<br>- `gpt-4o-mini-2024-07-18`<br>- `gpt-4o-2024-05-13`<br>- `gpt-4o`<br>- `gpt-4-turbo-2024-04-09`<br>- `gpt-4-turbo`<br>- `gpt-4-turbo-preview` |
| Anthropic | - `claude-3-5-sonnet-20240620`<br>- `claude-3-opus-20240229`<br>- `claude-3-sonnet-20240229`<br>- `claude-3-haiku-20240307` |
| Ollama | - `llama3.1` |
| [LBNL-hosted models via CBORG](https://cborg.lbl.gov) | - `lbl/llama-3` (actually 3.1 (405b))<br>- `openai/gpt-4o-mini`<br>- `anthropic/claude-haiku`<br>- `anthropic/claude-sonnet`<br>- `anthropic/claude-opus` |
| [LBNL-hosted models via CBORG](https://cborg.lbl.gov) | - `lbl/cborg-chat:latest`<br>- `lbl/cborg-chat-nano:latest`<br>- `lbl/cborg-coder:latest`<br>- `openai/chatgpt:latest`<br>- `anthropic/claude:latest`<br>- `google/gemini:latest` |


## **:warning:**
Expand Down
2,950 changes: 2,211 additions & 739 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ langchain-anthropic = "^0.1.20"
neo4j = "^5.22.0"
pyvis = "^0.3.2"
dash = "^2.17.1"
duckdb = "^1.0.0"
duckdb = "1.0.0"
duckdb-engine = "^0.13.0"
langchain-community = "^0.2.10"
langchain-chroma = "^0.1.3"


[tool.poetry.group.dev.dependencies]
Expand Down
24 changes: 19 additions & 5 deletions src/kg_chat/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Command line interface for kg-chat."""

import logging
import shutil
from pathlib import Path
from pprint import pprint
from typing import Union
Expand All @@ -9,7 +10,7 @@

from kg_chat import __version__
from kg_chat.app import create_app
from kg_chat.constants import OPEN_AI_MODEL
from kg_chat.constants import OPEN_AI_MODEL, VECTOR_DB_PATH, VECTOR_STORE
from kg_chat.main import KnowledgeGraphChat
from kg_chat.utils import (
get_anthropic_models,
Expand Down Expand Up @@ -42,6 +43,13 @@
help="Directory containing the data.",
required=True,
)
docs_option = click.option(
"--docs",
type=click.Path(exists=True, file_okay=True, dir_okay=True),
help="Path to a document or directory of only documents.",
required=False,
default=None,
)
llm_provider_option = click.option(
"--llm-provider",
type=click.Choice(ALL_AVAILABLE_PROVIDERS, case_sensitive=False),
Expand Down Expand Up @@ -94,13 +102,19 @@ def list_models():
@main.command("import")
@database_options
@data_dir_option
@docs_option
@llm_provider_option
def import_kg(database: str = "duckdb", data_dir: str = None, llm_provider: str = "openai"):
def import_kg(database: str = "duckdb", data_dir: str = None, docs: str = None, llm_provider: str = "openai"):
"""Run the kg-chat's import command."""
if not data_dir:
raise ValueError("Data directory is required. This typically contains the KGX tsv files.")
if docs and VECTOR_DB_PATH.exists():
for item in VECTOR_STORE.iterdir():
if item.is_dir():
shutil.rmtree(item)
else:
item.unlink()

config = get_llm_config(llm_provider)
impl = get_database_impl(database, data_dir=data_dir, llm_config=config)
impl = get_database_impl(database, data_dir=data_dir, doc_dir_or_file=docs, llm_config=config)
impl.load_kg()


Expand Down
6 changes: 5 additions & 1 deletion src/kg_chat/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ class LLMConfig(BaseModel):
class OpenAIConfig(LLMConfig):
"""Configuration for OpenAI LLM model."""

pass
def __init__(self, **data):
"""Initialize the OpenAI LLM configuration."""
super().__init__(**data)
if self.model.startswith("o1"):
self.temperature = 1.0


class OllamaConfig(LLMConfig):
Expand Down
6 changes: 5 additions & 1 deletion src/kg_chat/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@
OPEN_AI_MODEL = "gpt-4o-mini"
ANTHROPIC_MODEL = "claude-3-5-sonnet-20240620"
OLLAMA_MODEL = "llama3.1" #! not all models support tools (tool calling)
CBORG_MODEL = "lbl/llama-3"
CBORG_MODEL = "anthropic/claude:latest"

DATALOAD_BATCH_SIZE = 5000 # Adjust the batch size as needed

# Set environment variables for Neo4j connection
NEO4J_URI = "bolt://localhost:7687"
NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = "password"


VECTOR_STORE = PROJ_DIR / "vector_store"
VECTOR_DB_PATH = VECTOR_STORE / "chroma.sqlite3"
91 changes: 70 additions & 21 deletions src/kg_chat/implementations/duckdb_implementation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Implementation of the DatabaseInterface for DuckDB."""

import logging
import tempfile
import time
from pathlib import Path
Expand All @@ -8,20 +9,29 @@

import duckdb
from langchain.agents.agent import AgentExecutor, AgentType
from langchain.tools.retriever import create_retriever_tool
from langchain_community.agent_toolkits import SQLDatabaseToolkit, create_sql_agent
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_ollama import ChatOllama
from sqlalchemy import Engine, create_engine

from kg_chat.config.llm_config import LLMConfig
from kg_chat.constants import VECTOR_DB_PATH
from kg_chat.interface.database_interface import DatabaseInterface
from kg_chat.utils import get_agent_prompt_template, llm_factory, structure_query
from kg_chat.utils import (
create_vectorstore,
get_exisiting_vectorstore,
llm_factory,
structure_query,
)

logger = logging.getLogger(__name__)


class DuckDBImplementation(DatabaseInterface):
"""Implementation of the DatabaseInterface for DuckDB."""

def __init__(self, data_dir: Union[Path, str], llm_config: LLMConfig):
def __init__(self, data_dir: Union[Path, str], llm_config: LLMConfig, doc_dir_or_file: Union[Path, str] = None):
"""Initialize the DuckDB database and the Langchain components."""
if not data_dir:
raise ValueError("Data directory is required. This typically contains the KGX tsv files.")
Expand All @@ -36,6 +46,24 @@ def __init__(self, data_dir: Union[Path, str], llm_config: LLMConfig):
self.engine: Engine = create_engine(f"duckdb:///{self.database_path}")
self.db = SQLDatabase(self.engine, view_support=True)
self.toolkit = SQLDatabaseToolkit(db=self.db, llm=self.llm)
self.tools = self.toolkit.get_tools()
if VECTOR_DB_PATH.exists() and get_exisiting_vectorstore():
vectorstore = get_exisiting_vectorstore()
retriever = vectorstore.as_retriever(search_kwargs={"k": 1})
rag_tool = create_retriever_tool(retriever, "VectorStoreRetriever", "Vector Store Retriever")

self.tools.append(rag_tool)
elif doc_dir_or_file:
vectorstore = create_vectorstore(doc_dir_or_file=doc_dir_or_file)
retriever = vectorstore.as_retriever(search_kwargs={"k": 1})
rag_tool = create_retriever_tool(retriever, "VectorStoreRetriever", "Vector Store Retriever")

self.tools.append(rag_tool)
else:
logger.info("No vectorstore found or documents provided. Skipping RAG tool creation.")

self.tool_names = [tool.name for tool in self.tools]

self.agent: AgentExecutor = create_sql_agent(
llm=self.llm,
verbose=True,
Expand All @@ -45,7 +73,6 @@ def __init__(self, data_dir: Union[Path, str], llm_config: LLMConfig):
return_intermediate_steps=True,
handle_parsing_errors=True,
),
extra_tools=self.toolkit.get_tools(),
)

def toggle_safe_mode(self, enabled: bool):
Expand Down Expand Up @@ -84,7 +111,13 @@ def _clear():

def get_human_response(self, prompt: str):
"""Get a human response from the database."""
response = self.agent.invoke(prompt)
response = self.agent.invoke(
{
"input": prompt,
"tools": self.tools,
"tool_names": self.tool_names,
}
)
return response["output"]

def get_structured_response(self, prompt: str):
Expand All @@ -93,16 +126,25 @@ def get_structured_response(self, prompt: str):
if "show me" in prompt.lower():
self.llm.format = "json"

structured_query = get_agent_prompt_template().format(
input=prompt,
tools=self.toolkit.get_tools(),
tool_names=[tool.name for tool in self.toolkit.get_tools()],
agent_scratchpad=None,
)
else:
structured_query = {"input": structure_query(prompt)}

response = self.agent.invoke(structured_query)
# tool_names = [tool.name for tool in self.toolkit.get_tools()] + ["kg_retriever"]

# structured_query = get_sql_agent_prompt_template().format(
# input=prompt,
# tools=self.tools,
# tool_names=tool_names,
# agent_scratchpad=None,
# )
# else:
# structured_query = {"input": structure_query(prompt)}

# response = self.agent.invoke(structured_query)
response = self.agent.invoke(
{
"input": structure_query(prompt),
"tools": self.tools,
"tool_names": self.tool_names,
}
)
return response["output"]

def create_edges(self, edges):
Expand Down Expand Up @@ -134,8 +176,15 @@ def show_schema(self):

def execute_query_using_langchain(self, prompt: str):
"""Execute a query against the database using Langchain."""
result = self.agent.invoke(prompt)
return result["output"]
# response = self.agent.invoke(prompt)
response = self.agent.invoke(
{
"input": prompt,
"tools": self.tools,
"tool_names": self.tool_names,
}
)
return response["output"]

def load_kg(self):
"""Load the Knowledge Graph into the database."""
Expand Down Expand Up @@ -245,15 +294,15 @@ def _import_edges(self):
with open(edges_filepath, "r") as edges_file:
header_line = edges_file.readline().strip().split("\t")
column_indexes = {col: idx for idx, col in enumerate(header_line) if col in edge_column_of_interest}

with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_edges_file:
temp_edges_file.write("\t".join(edge_column_of_interest) + "\n")
for line in edges_file:
columns = line.strip().split("\t")
subject = columns[column_indexes["subject"]]
predicate = columns[column_indexes["predicate"]]
object = columns[column_indexes["object"]]
temp_edges_file.write(f"{subject}\t{predicate}\t{object}\n")
if len(columns) == (max(column_indexes.values()) + 1):
subject = columns[column_indexes["subject"]]
predicate = columns[column_indexes["predicate"]]
object = columns[column_indexes["object"]]
temp_edges_file.write(f"{subject}\t{predicate}\t{object}\n")
temp_edges_file.flush()

# Load data from temporary file into DuckDB
Expand Down
Loading

0 comments on commit 33cea84

Please sign in to comment.