Skip to content

Commit

Permalink
Merge pull request #205 from aws-samples/204-support-clickhouse-as-da…
Browse files Browse the repository at this point in the history
…ta-source

support clickhouse db as datasource
  • Loading branch information
wzt1001 authored Jul 22, 2024
2 parents bd9b8bf + f752da4 commit fdb4b7f
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 40 deletions.
1 change: 1 addition & 0 deletions application/api/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class ErrorEnum(Enum):
NOT_SUPPORTED = {1001: "Your query statement is currently not supported by the system"}
INVAILD_BEDROCK_MODEL_ID = {1002: f"Invalid bedrock model id.Vaild ids:{BEDROCK_MODEL_IDS}"}
INVAILD_SESSION_ID = {1003: f"Invalid session id."}
PROFILE_NOT_FOUND = {1004: "Profile name not found."}
UNKNOWN_ERROR = {9999: "Unknown error."}

def get_code(self):
Expand Down
2 changes: 2 additions & 0 deletions application/api/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ def ask(question: Question) -> Answer:
log_info = ""

all_profiles = ProfileManagement.get_all_profiles_with_info()
if selected_profile not in all_profiles:
raise BizException(ErrorEnum.PROFILE_NOT_FOUND)
database_profile = all_profiles[selected_profile]

current_nlq_chain = NLQChain(selected_profile)
Expand Down
52 changes: 15 additions & 37 deletions application/nlq/data_access/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ class RelationDatabase():
'mysql': 'mysql+pymysql',
'postgresql': 'postgresql+psycopg2',
'redshift': 'postgresql+psycopg2',
'starrocks': 'starrocks'
'starrocks': 'starrocks',
'clickhouse': 'clickhouse',
# Add more mappings here for other databases
}

Expand Down Expand Up @@ -42,43 +43,20 @@ def test_connection(cls, db_type, user, password, host, port, db_name) -> bool:

@classmethod
def get_all_schema_names_by_connection(cls, connection: ConnectConfigEntity):
schemas = []
if connection.db_type == 'postgresql':
db_url = cls.get_db_url(connection.db_type, connection.db_user, connection.db_pwd, connection.db_host,
connection.db_port, connection.db_name)
engine = db.create_engine(db_url)
# with engine.connect() as conn:
# query = text("""
# SELECT nspname AS schema_name
# FROM pg_catalog.pg_namespace
# WHERE nspname !~ '^pg_' AND nspname <> 'information_schema' AND nspname <> 'public'
# AND has_schema_privilege(nspname, 'USAGE');
# """)
#
# # Executing the query
# result = conn.execute(query)
# schemas = [row['schema_name'] for row in result.mappings()]
# print(schemas)
inspector = sqlalchemy.inspect(engine)
schemas = inspector.get_schema_names()
elif connection.db_type == 'redshift':
db_url = cls.get_db_url(connection.db_type, connection.db_user, connection.db_pwd, connection.db_host,
connection.db_port, connection.db_name)
engine = db.create_engine(db_url)
inspector = inspect(engine)
db_type = connection.db_type
db_url = cls.get_db_url(db_type, connection.db_user, connection.db_pwd, connection.db_host, connection.db_port,
connection.db_name)
engine = db.create_engine(db_url)
inspector = inspect(engine)

if db_type == 'postgresql':
schemas = [schema for schema in inspector.get_schema_names() if
schema not in ('pg_catalog', 'information_schema', 'public')]
elif db_type in ('redshift', 'mysql', 'starrocks', 'clickhouse'):
schemas = inspector.get_schema_names()
elif connection.db_type == 'mysql':
db_url = cls.get_db_url(connection.db_type, connection.db_user, connection.db_pwd, connection.db_host,
connection.db_port, connection.db_name)
engine = db.create_engine(db_url)
database_connect = sqlalchemy.inspect(engine)
schemas = database_connect.get_schema_names()
elif connection.db_type == 'starrocks':
db_url = cls.get_db_url(connection.db_type, connection.db_user, connection.db_pwd, connection.db_host,
connection.db_port, connection.db_name)
engine = db.create_engine(db_url)
database_connect = sqlalchemy.inspect(engine)
schemas = database_connect.get_schema_names()
else:
raise ValueError("Unsupported database type")

return schemas

@classmethod
Expand Down
1 change: 1 addition & 0 deletions application/pages/2_🪙_Data_Connection_Management.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
'postgresql': 'PostgreSQL',
'redshift': 'Redshift',
'starrocks': 'StarRocks',
'clickhouse': 'Clickhouse',
}


Expand Down
3 changes: 2 additions & 1 deletion application/requirements-api.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ langchain-core~=0.1.30
sqlparse~=0.4.2
pandas==2.0.3
openpyxl
starrocks==1.0.6
starrocks==1.0.6
clickhouse-sqlalchemy==0.2.6
3 changes: 2 additions & 1 deletion application/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ sqlparse~=0.4.2
debugpy
pandas==2.0.3
openpyxl
starrocks==1.0.6
starrocks==1.0.6
clickhouse-sqlalchemy==0.2.6
5 changes: 5 additions & 0 deletions application/utils/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
question.When generating SQL, do not add double quotes or single quotes around table names. Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per StarRocks SQL.
Never query for all columns from a table.""".format(top_k=TOP_K)

CLICKHOUSE_DIALECT_PROMPT_CLAUDE3="""
You are a data analysis expert and proficient in Clickhouse. Given an input question, first create a syntactically correct Clickhouse query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per ClickHouse. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use today() function to get the current date, if the question involves "today". Pay attention to adapted to the table field type. Please follow the clickhouse syntax or function case specifications.If the field alias contains Chinese characters, please use double quotes to Wrap it.""".format(top_k=TOP_K)

AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3 = """You are a Amazon Redshift expert. Given an input question, first create a syntactically correct Redshift query to run, then look at the results of the query and return the answer to the input
question.When generating SQL, do not add double quotes or single quotes around table names. Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL.
Expand Down
4 changes: 3 additions & 1 deletion application/utils/prompts/generate_prompt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from utils.prompt import POSTGRES_DIALECT_PROMPT_CLAUDE3, MYSQL_DIALECT_PROMPT_CLAUDE3, \
DEFAULT_DIALECT_PROMPT, AGENT_COT_EXAMPLE, AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3, STARROCKS_DIALECT_PROMPT_CLAUDE3
DEFAULT_DIALECT_PROMPT, AGENT_COT_EXAMPLE, AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3, STARROCKS_DIALECT_PROMPT_CLAUDE3, CLICKHOUSE_DIALECT_PROMPT_CLAUDE3
from utils.prompts import guidance_prompt
from utils.prompts import table_prompt
import logging
Expand Down Expand Up @@ -1909,6 +1909,8 @@ def generate_llm_prompt(ddl, hints, prompt_map, search_box, sql_examples=None, n
dialect_prompt = AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3
elif dialect == 'starrocks':
dialect_prompt = STARROCKS_DIALECT_PROMPT_CLAUDE3
elif dialect == 'clickhouse':
dialect_prompt = CLICKHOUSE_DIALECT_PROMPT_CLAUDE3
else:
dialect_prompt = DEFAULT_DIALECT_PROMPT

Expand Down

0 comments on commit fdb4b7f

Please sign in to comment.