Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add LLM powered text2sql support #1276

Merged
merged 3 commits into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "querybook",
"version": "3.24.0",
"version": "3.25.0",
"description": "A Big Data Webapp",
"private": true,
"scripts": {
Expand Down
2 changes: 1 addition & 1 deletion querybook/config/querybook_public_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ ai_assistant:
enabled: true

query_generation:
enabled: false
enabled: true

query_auto_fix:
enabled: true
15 changes: 15 additions & 0 deletions querybook/server/datasources/ai_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,18 @@ def query_auto_fix(query_execution_id):
)

return Response(res_stream, mimetype="text/event-stream")


@register("/ai/generate_query/", custom_response=True)
def generate_sql_query(
query_engine_id: int, tables: list[str], question: str, data_cell_id: int = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of data cell, use original_query so it is not limited to data doc?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is I realized that for GET http request, it has length limit on the query string. So it may not be safe to pass along the whole query.

Can reconsider the EventSource solution in future and change it back to use POST method. For this specific case, I can move the get query from the data cell into this api endpoint and make ai_assistant to accept the original query.

):
res_stream = ai_assistant.generate_sql_query(
query_engine_id=query_engine_id,
tables=tables,
question=question,
data_cell_id=data_cell_id,
user_id=current_user.id,
)

return Response(res_stream, mimetype="text/event-stream")
19 changes: 19 additions & 0 deletions querybook/server/lib/ai_assistant/ai_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,22 @@ def query_auto_fix(self, query_execution_id, user_id=None):
"user_id": user_id,
},
)

def generate_sql_query(
self,
query_engine_id: int,
tables: list[str],
question: str,
data_cell_id: int = None,
user_id=None,
):
return self._get_streaming_result(
self._assisant.generate_sql_query,
{
"query_engine_id": query_engine_id,
"tables": tables,
"question": question,
"data_cell_id": data_cell_id,
"user_id": user_id,
},
)
74 changes: 72 additions & 2 deletions querybook/server/lib/ai_assistant/assistants/openai_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ def query_auto_fix_prompt_template(self) -> str:
"{error}\n\n"
"===Table schemas\n"
"{table_schemas}\n\n"
"===Response format\n"
"<@key-1@>\n"
"value-1\n\n"
"<@key-2@>\n"
"value-2\n\n"
"===Response restrictions\n"
"1. Only include SQL queries in the fixed_query section, no additional comments or information.\n"
"2. If there isn't enough information or context to address the query error, you may leave the fixed_query section blank or provide a general suggestion instead.\n"
Expand All @@ -83,8 +88,48 @@ def query_auto_fix_prompt_template(self) -> str:
[system_message_prompt, human_message_prompt]
)

def generate_sql_query(self):
pass
@property
def generate_sql_query_prompt_template(self) -> str:
system_message_prompt = SystemMessage(
content=(
"You are a SQL expert that can help generating SQL query.\n\n"
"Please follow the format below for your response:\n"
"<@key-1@>\n"
"value-1\n\n"
"<@key-2@>\n"
"value-2\n\n"
)
)
human_template = (
"Please help to generate a new SQL query or edit the original query for below question based ONLY on the given context. \n\n"
"===SQL Dialect\n"
"{dialect}\n\n"
"===Tables\n"
"{table_schemas}\n\n"
"===Original Query\n"
"{original_query}\n\n"
"===Question\n"
"{question}\n\n"
"===Response format\n"
"<@key-1@>\n"
"value-1\n\n"
"<@key-2@>\n"
"value-2\n\n"
"===Response Restrictions\n"
"1. If there is enough information and context to generate/edit the query, please respond only with the new query without any explanation.\n"
"2. If there isn't enough information or context to generate/edit the query, provide an explanation for the missing context.\n"
"===Example Response:\n"
"Example 1: Insufficient Context\n"
"<@explanation@>\n"
"An explanation of the missing context is provided here.\n\n"
"Example 2: Query Generation Possible\n"
"<@query@>\n"
"Generated SQL query based on provided context is provided here.\n\n"
)
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
return ChatPromptTemplate.from_messages(
[system_message_prompt, human_message_prompt]
)

def _generate_title_from_query(
self, query, stream=True, callback_handler=None, user_id=None
Expand Down Expand Up @@ -122,3 +167,28 @@ def _query_auto_fix(
)
ai_message = chat(messages)
return ai_message.content

def _generate_sql_query(
self,
language: str,
table_schemas: str,
question: str,
original_query: str,
stream,
callback_handler,
user_id=None,
):
"""Generate SQL query using OpenAI's chat model."""
messages = self.generate_sql_query_prompt_template.format_prompt(
dialect=language,
question=question,
table_schemas=table_schemas,
original_query=original_query,
).to_messages()
chat = ChatOpenAI(
**self._config,
streaming=stream,
callback_manager=CallbackManager([callback_handler]),
)
ai_message = chat(messages)
return ai_message.content
46 changes: 44 additions & 2 deletions querybook/server/lib/ai_assistant/base_ai_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from lib.logger import get_logger
from logic import query_execution as qe_logic
from lib.query_analysis.lineage import process_query
from logic import admin as admin_logic
from logic import datadoc as datadoc_logic
from logic import metastore as m_logic
from models.query_execution import QueryExecution

Expand Down Expand Up @@ -152,9 +154,49 @@ def _get_query_execution_error(self, query_execution: QueryExecution) -> str:

return error[:1000]

@abstractmethod
@catch_error
@with_session
def generate_sql_query(
self, metastore_id: int, query_engine_id: int, question: str, tables: list[str]
self,
query_engine_id: int,
tables: list[str],
question: str,
data_cell_id: int = None,
stream=True,
callback_handler: ChainStreamHandler = None,
user_id=None,
session=None,
):
query_engine = admin_logic.get_query_engine_by_id(
query_engine_id, session=session
)
table_schemas = self._generate_table_schema_prompt(
metastore_id=query_engine.metastore_id, table_names=tables, session=session
)

data_cell = datadoc_logic.get_data_cell_by_id(data_cell_id, session=session)
original_query = data_cell.context if data_cell else None

return self._generate_sql_query(
language=query_engine.language,
table_schemas=table_schemas,
question=question,
original_query=original_query,
stream=stream,
callback_handler=callback_handler,
user_id=user_id,
)

@abstractmethod
def _generate_sql_query(
self,
language: str,
table_schemas: str,
question: str,
original_query: str = None,
stream=True,
callback_handler: ChainStreamHandler = None,
user_id=None,
):
raise NotImplementedError()

Expand Down
41 changes: 41 additions & 0 deletions querybook/webapp/components/AIAssistant/AIModeSelector.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import React from 'react';

import { Dropdown } from 'ui/Dropdown/Dropdown';
import { ListMenu } from 'ui/Menu/ListMenu';

export enum AIMode {
GENERATE = 'GENERATE',
EDIT = 'EDIT',
}

interface IAIModeSelectorProps {
aiMode: AIMode;
aiModes: AIMode[];
onModeSelect: (mode: AIMode) => any;
}

export const AIModeSelector: React.FC<IAIModeSelectorProps> = ({
jczhong84 marked this conversation as resolved.
Show resolved Hide resolved
aiMode,
aiModes,
onModeSelect,
}) => {
const engineItems = aiModes.map((mode) => ({
name: <span>{mode}</span>,
onClick: onModeSelect.bind(null, mode),
checked: aiMode === mode,
}));

return (
<Dropdown
customButtonRenderer={() => <div>{aiMode}</div>}
layout={['bottom', 'left']}
className="engine-selector-dropdown"
>
{engineItems.length > 1 && (
<div className="engine-selector-wrapper">
<ListMenu items={engineItems} type="select" />
</div>
)}
</Dropdown>
);
};
59 changes: 59 additions & 0 deletions querybook/webapp/components/AIAssistant/QueryGenerationButton.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import React, { useState } from 'react';

import PublicConfig from 'config/querybook_public_config.yaml';
import { IQueryEngine } from 'const/queryEngine';
import { IconButton } from 'ui/Button/IconButton';

import { QueryGenerationModal } from './QueryGenerationModal';

const AIAssistantConfig = PublicConfig.ai_assistant;

interface IProps {
dataCellId: number;
query: string;
engineId: number;
queryEngines: IQueryEngine[];
queryEngineById: Record<number, IQueryEngine>;
onUpdateQuery?: (query: string) => void;
onUpdateEngineId: (engineId: number) => void;
}

export const QueryGenerationButton = ({
dataCellId,
query = '',
engineId,
queryEngines,
queryEngineById,
onUpdateQuery,
onUpdateEngineId,
}: IProps) => {
const [show, setShow] = useState(false);

return (
<>
{AIAssistantConfig.enabled &&
AIAssistantConfig.query_generation.enabled && (
<IconButton
className="QueryGenerationButton"
icon="Stars"
size={18}
tooltip="AI: generate/edit query"
color={!query ? 'accent' : undefined}
onClick={() => setShow(true)}
/>
)}
{show && (
<QueryGenerationModal
dataCellId={dataCellId}
query={query}
engineId={engineId}
queryEngines={queryEngines}
queryEngineById={queryEngineById}
onUpdateQuery={onUpdateQuery}
onUpdateEngineId={onUpdateEngineId}
onHide={() => setShow(false)}
/>
)}
</>
);
};
60 changes: 60 additions & 0 deletions querybook/webapp/components/AIAssistant/QueryGenerationModal.scss
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
.QueryGenerationModal {
.title {
font-size: var(--text-size);
font-weight: var(--bold-font);
margin-bottom: var(--margin);
}

.Modal-box {
.Modal-content {
min-height: 40vh;
}

.Modal-bottom {
margin-top: 0 !important;
}
}
.question-bar {
flex: 1;
display: flex;
flex-direction: row;
align-items: center;
background-color: var(--bg-light);
border-radius: var(--border-radius-sm);
margin-top: 24px;

&:hover {
background-color: var(--bg-hover);
}

.stars-icon {
position: relative;
color: var(--icon);
padding: 8px;

display: flex;
align-items: center;
}

.ai-mode {
color: var(--color-pink-dark);
}

.DebouncedInput {
flex: 1;

input {
line-height: 1.8rem;
border-radius: 0;
background-color: transparent;
}
}
}

.action-buttons {
display: flex;
flex-direction: row;
justify-content: flex-end;
margin-top: var(--margin);
}
}
Loading