From fb95086cef8af70c0dfda150cf3dcadda3acc642 Mon Sep 17 00:00:00 2001 From: "J.C. Zhong" Date: Thu, 22 Jun 2023 23:19:41 +0000 Subject: [PATCH 1/3] feat: add LLM powered text2sql support --- package.json | 2 +- querybook/config/querybook_public_config.yaml | 2 +- querybook/server/datasources/ai_assistant.py | 15 ++ .../server/lib/ai_assistant/ai_assistant.py | 19 ++ .../assistants/openai_assistant.py | 74 +++++- .../lib/ai_assistant/base_ai_assistant.py | 46 +++- .../components/AIAssistant/AIModeSelector.tsx | 41 +++ .../AIAssistant/QueryGenerationButton.tsx | 59 +++++ .../AIAssistant/QueryGenerationModal.scss | 60 +++++ .../AIAssistant/QueryGenerationModal.tsx | 233 ++++++++++++++++++ .../components/AIAssistant/TableSelector.tsx | 113 +++++++++ .../webapp/components/DataDoc/DataDoc.scss | 6 +- .../DataDocQueryCell/DataDocQueryCell.scss | 6 + .../DataDocQueryCell/DataDocQueryCell.tsx | 13 + .../TranspileQueryModal/QueryComparison.tsx | 46 ++-- querybook/webapp/stylesheets/_utilities.scss | 4 + 16 files changed, 711 insertions(+), 28 deletions(-) create mode 100644 querybook/webapp/components/AIAssistant/AIModeSelector.tsx create mode 100644 querybook/webapp/components/AIAssistant/QueryGenerationButton.tsx create mode 100644 querybook/webapp/components/AIAssistant/QueryGenerationModal.scss create mode 100644 querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx create mode 100644 querybook/webapp/components/AIAssistant/TableSelector.tsx diff --git a/package.json b/package.json index 3e09b5c57..20842d0f8 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "querybook", - "version": "3.24.0", + "version": "3.25.0", "description": "A Big Data Webapp", "private": true, "scripts": { diff --git a/querybook/config/querybook_public_config.yaml b/querybook/config/querybook_public_config.yaml index 2e7aa1383..466de81fd 100644 --- a/querybook/config/querybook_public_config.yaml +++ b/querybook/config/querybook_public_config.yaml @@ -6,7 +6,7 @@ ai_assistant: enabled: true query_generation: - enabled: false + enabled: true query_auto_fix: enabled: true diff --git a/querybook/server/datasources/ai_assistant.py b/querybook/server/datasources/ai_assistant.py index 2b85882d0..7614ef977 100644 --- a/querybook/server/datasources/ai_assistant.py +++ b/querybook/server/datasources/ai_assistant.py @@ -22,3 +22,18 @@ def query_auto_fix(query_execution_id): ) return Response(res_stream, mimetype="text/event-stream") + + +@register("/ai/generate_query/", custom_response=True) +def generate_sql_query( + query_engine_id: int, tables: list[str], question: str, data_cell_id: int = None +): + res_stream = ai_assistant.generate_sql_query( + query_engine_id=query_engine_id, + tables=tables, + question=question, + data_cell_id=data_cell_id, + user_id=current_user.id, + ) + + return Response(res_stream, mimetype="text/event-stream") diff --git a/querybook/server/lib/ai_assistant/ai_assistant.py b/querybook/server/lib/ai_assistant/ai_assistant.py index 7a74f431e..d7c882baf 100644 --- a/querybook/server/lib/ai_assistant/ai_assistant.py +++ b/querybook/server/lib/ai_assistant/ai_assistant.py @@ -31,3 +31,22 @@ def query_auto_fix(self, query_execution_id, user_id=None): "user_id": user_id, }, ) + + def generate_sql_query( + self, + query_engine_id: int, + tables: list[str], + question: str, + data_cell_id: int = None, + user_id=None, + ): + return self._get_streaming_result( + self._assisant.generate_sql_query, + { + "query_engine_id": query_engine_id, + "tables": tables, + "question": question, + "data_cell_id": data_cell_id, + "user_id": user_id, + }, + ) diff --git a/querybook/server/lib/ai_assistant/assistants/openai_assistant.py b/querybook/server/lib/ai_assistant/assistants/openai_assistant.py index e04395159..3fdb1fb3c 100644 --- a/querybook/server/lib/ai_assistant/assistants/openai_assistant.py +++ b/querybook/server/lib/ai_assistant/assistants/openai_assistant.py @@ -66,6 +66,11 @@ def query_auto_fix_prompt_template(self) -> str: "{error}\n\n" "===Table schemas\n" "{table_schemas}\n\n" + "===Response format\n" + "<@key-1@>\n" + "value-1\n\n" + "<@key-2@>\n" + "value-2\n\n" "===Response restrictions\n" "1. Only include SQL queries in the fixed_query section, no additional comments or information.\n" "2. If there isn't enough information or context to address the query error, you may leave the fixed_query section blank or provide a general suggestion instead.\n" @@ -83,8 +88,48 @@ def query_auto_fix_prompt_template(self) -> str: [system_message_prompt, human_message_prompt] ) - def generate_sql_query(self): - pass + @property + def generate_sql_query_prompt_template(self) -> str: + system_message_prompt = SystemMessage( + content=( + "You are a SQL expert that can help generating SQL query.\n\n" + "Please follow the format below for your response:\n" + "<@key-1@>\n" + "value-1\n\n" + "<@key-2@>\n" + "value-2\n\n" + ) + ) + human_template = ( + "Please help to generate a new SQL query or edit the original query for below question based ONLY on the given context. \n\n" + "===SQL Dialect\n" + "{dialect}\n\n" + "===Tables\n" + "{table_schemas}\n\n" + "===Original Query\n" + "{original_query}\n\n" + "===Question\n" + "{question}\n\n" + "===Response format\n" + "<@key-1@>\n" + "value-1\n\n" + "<@key-2@>\n" + "value-2\n\n" + "===Response Restrictions\n" + "1. If there is enough information and context to generate/edit the query, please respond only with the new query without any explanation.\n" + "2. If there isn't enough information or context to generate/edit the query, provide an explanation for the missing context.\n" + "===Example Response:\n" + "Example 1: Insufficient Context\n" + "<@explanation@>\n" + "An explanation of the missing context is provided here.\n\n" + "Example 2: Query Generation Possible\n" + "<@query@>\n" + "Generated SQL query based on provided context is provided here.\n\n" + ) + human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) + return ChatPromptTemplate.from_messages( + [system_message_prompt, human_message_prompt] + ) def _generate_title_from_query( self, query, stream=True, callback_handler=None, user_id=None @@ -122,3 +167,28 @@ def _query_auto_fix( ) ai_message = chat(messages) return ai_message.content + + def _generate_sql_query( + self, + language: str, + table_schemas: str, + question: str, + original_query: str, + stream, + callback_handler, + user_id=None, + ): + """Generate SQL query using OpenAI's chat model.""" + messages = self.generate_sql_query_prompt_template.format_prompt( + dialect=language, + question=question, + table_schemas=table_schemas, + original_query=original_query, + ).to_messages() + chat = ChatOpenAI( + **self._config, + streaming=stream, + callback_manager=CallbackManager([callback_handler]), + ) + ai_message = chat(messages) + return ai_message.content diff --git a/querybook/server/lib/ai_assistant/base_ai_assistant.py b/querybook/server/lib/ai_assistant/base_ai_assistant.py index ccc623f53..cbaa783c8 100644 --- a/querybook/server/lib/ai_assistant/base_ai_assistant.py +++ b/querybook/server/lib/ai_assistant/base_ai_assistant.py @@ -10,6 +10,8 @@ from lib.logger import get_logger from logic import query_execution as qe_logic from lib.query_analysis.lineage import process_query +from logic import admin as admin_logic +from logic import datadoc as datadoc_logic from logic import metastore as m_logic from models.query_execution import QueryExecution @@ -152,9 +154,49 @@ def _get_query_execution_error(self, query_execution: QueryExecution) -> str: return error[:1000] - @abstractmethod + @catch_error + @with_session def generate_sql_query( - self, metastore_id: int, query_engine_id: int, question: str, tables: list[str] + self, + query_engine_id: int, + tables: list[str], + question: str, + data_cell_id: int = None, + stream=True, + callback_handler: ChainStreamHandler = None, + user_id=None, + session=None, + ): + query_engine = admin_logic.get_query_engine_by_id( + query_engine_id, session=session + ) + table_schemas = self._generate_table_schema_prompt( + metastore_id=query_engine.metastore_id, table_names=tables, session=session + ) + + data_cell = datadoc_logic.get_data_cell_by_id(data_cell_id, session=session) + original_query = data_cell.context if data_cell else None + + return self._generate_sql_query( + language=query_engine.language, + table_schemas=table_schemas, + question=question, + original_query=original_query, + stream=stream, + callback_handler=callback_handler, + user_id=user_id, + ) + + @abstractmethod + def _generate_sql_query( + self, + language: str, + table_schemas: str, + question: str, + original_query: str = None, + stream=True, + callback_handler: ChainStreamHandler = None, + user_id=None, ): raise NotImplementedError() diff --git a/querybook/webapp/components/AIAssistant/AIModeSelector.tsx b/querybook/webapp/components/AIAssistant/AIModeSelector.tsx new file mode 100644 index 000000000..05294f9e6 --- /dev/null +++ b/querybook/webapp/components/AIAssistant/AIModeSelector.tsx @@ -0,0 +1,41 @@ +import React from 'react'; + +import { Dropdown } from 'ui/Dropdown/Dropdown'; +import { ListMenu } from 'ui/Menu/ListMenu'; + +export enum AIMode { + GENERATE = 'GENERATE', + EDIT = 'EDIT', +} + +interface IAIModeSelectorProps { + aiMode: AIMode; + aiModes: AIMode[]; + onModeSelect: (mode: AIMode) => any; +} + +export const AIModeSelector: React.FC = ({ + aiMode, + aiModes, + onModeSelect, +}) => { + const engineItems = aiModes.map((mode) => ({ + name: {mode}, + onClick: onModeSelect.bind(null, mode), + checked: aiMode === mode, + })); + + return ( +
{aiMode}
} + layout={['bottom', 'left']} + className="engine-selector-dropdown" + > + {engineItems.length > 1 && ( +
+ +
+ )} +
+ ); +}; diff --git a/querybook/webapp/components/AIAssistant/QueryGenerationButton.tsx b/querybook/webapp/components/AIAssistant/QueryGenerationButton.tsx new file mode 100644 index 000000000..614624de1 --- /dev/null +++ b/querybook/webapp/components/AIAssistant/QueryGenerationButton.tsx @@ -0,0 +1,59 @@ +import React, { useState } from 'react'; + +import PublicConfig from 'config/querybook_public_config.yaml'; +import { IQueryEngine } from 'const/queryEngine'; +import { IconButton } from 'ui/Button/IconButton'; + +import { QueryGenerationModal } from './QueryGenerationModal'; + +const AIAssistantConfig = PublicConfig.ai_assistant; + +interface IProps { + dataCellId: number; + query: string; + engineId: number; + queryEngines: IQueryEngine[]; + queryEngineById: Record; + onUpdateQuery?: (query: string) => void; + onUpdateEngineId: (engineId: number) => void; +} + +export const QueryGenerationButton = ({ + dataCellId, + query = '', + engineId, + queryEngines, + queryEngineById, + onUpdateQuery, + onUpdateEngineId, +}: IProps) => { + const [show, setShow] = useState(false); + + return ( + <> + {AIAssistantConfig.enabled && + AIAssistantConfig.query_generation.enabled && ( + setShow(true)} + /> + )} + {show && ( + setShow(false)} + /> + )} + + ); +}; diff --git a/querybook/webapp/components/AIAssistant/QueryGenerationModal.scss b/querybook/webapp/components/AIAssistant/QueryGenerationModal.scss new file mode 100644 index 000000000..f1392d1bf --- /dev/null +++ b/querybook/webapp/components/AIAssistant/QueryGenerationModal.scss @@ -0,0 +1,60 @@ +.QueryGenerationModal { + .title { + font-size: var(--text-size); + font-weight: var(--bold-font); + margin-bottom: var(--margin); + } + + .Modal-box { + .Modal-content { + min-height: 40vh; + } + + .Modal-bottom { + margin-top: 0 !important; + } + } + .question-bar { + flex: 1; + display: flex; + flex-direction: row; + align-items: center; + background-color: var(--bg-light); + border-radius: var(--border-radius-sm); + margin-top: 24px; + + &:hover { + background-color: var(--bg-hover); + } + + .stars-icon { + position: relative; + color: var(--icon); + padding: 8px; + + display: flex; + align-items: center; + } + + .ai-mode { + color: var(--color-pink-dark); + } + + .DebouncedInput { + flex: 1; + + input { + line-height: 1.8rem; + border-radius: 0; + background-color: transparent; + } + } + } + + .action-buttons { + display: flex; + flex-direction: row; + justify-content: flex-end; + margin-top: var(--margin); + } +} diff --git a/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx b/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx new file mode 100644 index 000000000..2a8ab4e4f --- /dev/null +++ b/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx @@ -0,0 +1,233 @@ +import React, { useEffect, useState } from 'react'; + +import { QueryEngineSelector } from 'components/QueryRunButton/QueryRunButton'; +import { QueryComparison } from 'components/TranspileQueryModal/QueryComparison'; +import { IQueryEngine } from 'const/queryEngine'; +import { StreamStatus, useStream } from 'hooks/useStream'; +import { TableToken } from 'lib/sql-helper/sql-lexer'; +import { matchKeyPress } from 'lib/utils/keyboard'; +import { analyzeCode } from 'lib/web-worker'; +import { Button } from 'ui/Button/Button'; +import { DebouncedInput } from 'ui/DebouncedInput/DebouncedInput'; +import { Icon } from 'ui/Icon/Icon'; +import { Message } from 'ui/Message/Message'; +import { Modal } from 'ui/Modal/Modal'; +import { StyledText } from 'ui/StyledText/StyledText'; + +import { AIMode, AIModeSelector } from './AIModeSelector'; +import { TableSelector } from './TableSelector'; + +import './QueryGenerationModal.scss'; + +interface IProps { + dataCellId: number; + query: string; + engineId: number; + queryEngines: IQueryEngine[]; + queryEngineById: Record; + onUpdateQuery?: (query: string) => void; + onUpdateEngineId: (engineId: number) => void; + onHide: () => void; +} + +const useTablesInQuery = (query, language) => { + const [tables, setTables] = useState([]); + + useEffect(() => { + analyzeCode(query, 'autocomplete', language).then((codeAnalysis) => { + const tableReferences: TableToken[] = [].concat.apply( + [], + Object.values(codeAnalysis?.lineage.references ?? {}) + ); + setTables( + tableReferences.map(({ schema, name }) => `${schema}.${name}`) + ); + }); + }, [query, language]); + + return tables; +}; + +export const QueryGenerationModal = ({ + dataCellId, + query = '', + engineId, + queryEngines, + queryEngineById, + onUpdateQuery, + onUpdateEngineId, + onHide, +}: IProps) => { + const inputRef = React.useRef(null); + const tablesInQuery = useTablesInQuery( + query, + queryEngineById[engineId]?.language + ); + const [question, setQuestion] = useState(''); + const [tables, setTables] = useState(tablesInQuery); + const [aiMode, setAIMode] = useState( + !!query ? AIMode.EDIT : AIMode.GENERATE + ); + + useEffect(() => { + setTables([...new Set([...tablesInQuery, ...tables])]); + }, [tablesInQuery]); + + const { streamStatus, startStream, streamData } = useStream( + '/ds/ai/generate_query/', + { + query_engine_id: engineId, + tables: tables, + question: question, + data_cell_id: aiMode === AIMode.EDIT ? dataCellId : undefined, + } + ); + + const { explanation, query: newQuery } = streamData; + + const onKeyDown = (event: React.KeyboardEvent) => { + if (matchKeyPress(event, 'Enter')) { + startStream(); + inputRef.current.blur(); + } + }; + + const questionBarDOM = ( +
+ + + +
+ +
+ +
+ ); + + const bottomDOM = newQuery && ( +
+
+ ); + + return ( + +
+ +
+ + Please select query engine and table(s) to get started + +
+ +
+ +
+
+
+ + {tables.length > 0 && ( + <> + {questionBarDOM} + {explanation && ( +
{explanation}
+ )} + + {(query || newQuery) && ( +
+ +
+ )} + + )} +
+
+ ); +}; diff --git a/querybook/webapp/components/AIAssistant/TableSelector.tsx b/querybook/webapp/components/AIAssistant/TableSelector.tsx new file mode 100644 index 000000000..25e6ca1c8 --- /dev/null +++ b/querybook/webapp/components/AIAssistant/TableSelector.tsx @@ -0,0 +1,113 @@ +import React, { useCallback, useState } from 'react'; +import AsyncSelect, { Props as AsyncProps } from 'react-select/async'; + +import { + asyncReactSelectStyles, + makeReactSelectStyle, +} from 'lib/utils/react-select'; +import { SearchTableResource } from 'resource/search'; +import { overlayRoot } from 'ui/Overlay/Overlay'; +import { HoverIconTag } from 'ui/Tag/HoverIconTag'; + +interface ITableSelectProps { + metastoreId: number; + tableNames: string[]; + onTableNamesChange: (tableNames: string[]) => void; + usePortalMenu?: boolean; + + selectProps?: Partial>; + + // remove the selected table name after select + clearAfterSelect?: boolean; +} + +export const TableSelector: React.FunctionComponent = ({ + metastoreId, + tableNames, + onTableNamesChange, + usePortalMenu = true, + selectProps = {}, + clearAfterSelect = false, +}) => { + const [searchText, setSearchText] = useState(''); + const asyncSelectProps: Partial> = {}; + const tableReactSelectStyle = React.useMemo( + () => makeReactSelectStyle(usePortalMenu, asyncReactSelectStyles), + [usePortalMenu] + ); + if (usePortalMenu) { + asyncSelectProps.menuPortalTarget = overlayRoot; + } + if (clearAfterSelect) { + asyncSelectProps.value = null; + } + + const loadOptions = useCallback( + async (tableName: string) => { + const { data } = await SearchTableResource.searchConcise({ + metastore_id: metastoreId, + keywords: tableName, + }); + const filteredTableNames = data.results.filter( + (result) => + tableNames.indexOf(`${result.schema}.${result.name}`) === -1 + ); + const tableNameOptions = filteredTableNames.map( + ({ id, schema, name }) => ({ + value: id, + label: `${schema}.${name}`, + }) + ); + return tableNameOptions; + }, + [metastoreId, tableNames] + ); + + return ( +
+ { + const newTableName = option?.label ?? null; + if (newTableName == null) { + onTableNamesChange([]); + return; + } + const newTableNames = tableNames.concat(newTableName); + onTableNamesChange(newTableNames); + }} + loadOptions={loadOptions} + defaultOptions={[]} + inputValue={searchText} + onInputChange={(text) => setSearchText(text)} + noOptionsMessage={() => (searchText ? 'No table found.' : null)} + {...asyncSelectProps} + {...selectProps} + /> + {tableNames.length ? ( +
+ {tableNames.map((tableName) => ( +
+ { + const newTableNames = tableNames.filter( + (name) => name !== tableName + ); + onTableNamesChange(newTableNames); + }} + tooltip={tableName} + tooltipPos="right" + mini + highlighted + light + /> +
+ ))} +
+ ) : null} +
+ ); +}; diff --git a/querybook/webapp/components/DataDoc/DataDoc.scss b/querybook/webapp/components/DataDoc/DataDoc.scss index 6d7fe6fac..08b5521ac 100644 --- a/querybook/webapp/components/DataDoc/DataDoc.scss +++ b/querybook/webapp/components/DataDoc/DataDoc.scss @@ -147,7 +147,8 @@ .additional-dropdown-button, .add-snippet-wrapper, .query-editor-float-buttons-wrapper, - .chart-cell-controls { + .chart-cell-controls, + .QueryGenerationButton { opacity: 0; transition: opacity 0.2s ease-out; } @@ -168,7 +169,8 @@ .additional-dropdown-button, .add-snippet-wrapper, .query-editor-float-buttons-wrapper, - .chart-cell-controls { + .chart-cell-controls, + .QueryGenerationButton { opacity: 1; } } diff --git a/querybook/webapp/components/DataDocQueryCell/DataDocQueryCell.scss b/querybook/webapp/components/DataDocQueryCell/DataDocQueryCell.scss index 7d61a6c17..9b1e9eda0 100644 --- a/querybook/webapp/components/DataDocQueryCell/DataDocQueryCell.scss +++ b/querybook/webapp/components/DataDocQueryCell/DataDocQueryCell.scss @@ -130,5 +130,11 @@ top: 10px; right: 42px; } + + .QueryGenerationButton { + position: absolute; + left: -30px; + padding: 4px; + } } } diff --git a/querybook/webapp/components/DataDocQueryCell/DataDocQueryCell.tsx b/querybook/webapp/components/DataDocQueryCell/DataDocQueryCell.tsx index 4e4535e87..7e9586ee8 100644 --- a/querybook/webapp/components/DataDocQueryCell/DataDocQueryCell.tsx +++ b/querybook/webapp/components/DataDocQueryCell/DataDocQueryCell.tsx @@ -8,6 +8,7 @@ import React from 'react'; import toast from 'react-hot-toast'; import { connect } from 'react-redux'; +import { QueryGenerationButton } from 'components/AIAssistant/QueryGenerationButton'; import { DataDocQueryExecutions } from 'components/DataDocQueryExecutions/DataDocQueryExecutions'; import { QueryCellTitle } from 'components/QueryCellTitle/QueryCellTitle'; import { runQuery, transformQuery } from 'components/QueryComposer/RunQuery'; @@ -728,6 +729,18 @@ class DataDocQueryCellComponent extends React.PureComponent { const editorDOM = !queryCollapsed && (
+ = ({ fromQuery, toQuery, fromQueryTitle, toQueryTitle, disableHighlight, + hideEmptyQuery, }) => { const [addedRanges, removedRanges] = useMemo(() => { - if (disableHighlight) { + if (disableHighlight || (hideEmptyQuery && (!fromQuery || !toQuery))) { return [[], []]; } @@ -53,28 +55,32 @@ export const QueryComparison: React.FC<{ } } return [added, removed]; - }, [fromQuery, toQuery, disableHighlight]); + }, [fromQuery, toQuery, disableHighlight, hideEmptyQuery]); return (
-
- {fromQueryTitle && {fromQueryTitle}} - -
-
- {toQueryTitle && {toQueryTitle}} - -
+ {!(hideEmptyQuery && !fromQuery) && ( +
+ {fromQueryTitle && {fromQueryTitle}} + +
+ )} + {!(hideEmptyQuery && !toQuery) && ( +
+ {toQueryTitle && {toQueryTitle}} + +
+ )}
); }; diff --git a/querybook/webapp/stylesheets/_utilities.scss b/querybook/webapp/stylesheets/_utilities.scss index de26149a4..9c0c1a943 100644 --- a/querybook/webapp/stylesheets/_utilities.scss +++ b/querybook/webapp/stylesheets/_utilities.scss @@ -58,6 +58,10 @@ justify-content: flex-start; } +.flex-wrap { + flex-wrap: wrap; +} + .horizontal-space-between { display: flex; flex-direction: row; From 5e7a822ab5d1884cebaabfa3c6a2fca2ddc0d640 Mon Sep 17 00:00:00 2001 From: "J.C. Zhong" Date: Fri, 23 Jun 2023 01:56:35 +0000 Subject: [PATCH 2/3] comments --- querybook/server/datasources/ai_assistant.py | 5 +- .../server/lib/ai_assistant/ai_assistant.py | 4 +- .../lib/ai_assistant/base_ai_assistant.py | 7 +-- .../AIAssistant/QueryGenerationModal.tsx | 47 ++++++++++--------- ...Selector.tsx => TextToSQLModeSelector.tsx} | 22 ++++----- 5 files changed, 44 insertions(+), 41 deletions(-) rename querybook/webapp/components/AIAssistant/{AIModeSelector.tsx => TextToSQLModeSelector.tsx} (62%) diff --git a/querybook/server/datasources/ai_assistant.py b/querybook/server/datasources/ai_assistant.py index 7614ef977..12db5d875 100644 --- a/querybook/server/datasources/ai_assistant.py +++ b/querybook/server/datasources/ai_assistant.py @@ -3,6 +3,7 @@ from app.datasource import register from lib.ai_assistant import ai_assistant +from logic import datadoc as datadoc_logic @register("/ai/query_title/", custom_response=True) @@ -28,11 +29,13 @@ def query_auto_fix(query_execution_id): def generate_sql_query( query_engine_id: int, tables: list[str], question: str, data_cell_id: int = None ): + data_cell = datadoc_logic.get_data_cell_by_id(data_cell_id) + original_query = data_cell.context if data_cell else None res_stream = ai_assistant.generate_sql_query( query_engine_id=query_engine_id, tables=tables, question=question, - data_cell_id=data_cell_id, + original_query=original_query, user_id=current_user.id, ) diff --git a/querybook/server/lib/ai_assistant/ai_assistant.py b/querybook/server/lib/ai_assistant/ai_assistant.py index d7c882baf..6dc633161 100644 --- a/querybook/server/lib/ai_assistant/ai_assistant.py +++ b/querybook/server/lib/ai_assistant/ai_assistant.py @@ -37,7 +37,7 @@ def generate_sql_query( query_engine_id: int, tables: list[str], question: str, - data_cell_id: int = None, + original_query: str = None, user_id=None, ): return self._get_streaming_result( @@ -46,7 +46,7 @@ def generate_sql_query( "query_engine_id": query_engine_id, "tables": tables, "question": question, - "data_cell_id": data_cell_id, + "original_query": original_query, "user_id": user_id, }, ) diff --git a/querybook/server/lib/ai_assistant/base_ai_assistant.py b/querybook/server/lib/ai_assistant/base_ai_assistant.py index cbaa783c8..f8d883ebc 100644 --- a/querybook/server/lib/ai_assistant/base_ai_assistant.py +++ b/querybook/server/lib/ai_assistant/base_ai_assistant.py @@ -11,7 +11,6 @@ from logic import query_execution as qe_logic from lib.query_analysis.lineage import process_query from logic import admin as admin_logic -from logic import datadoc as datadoc_logic from logic import metastore as m_logic from models.query_execution import QueryExecution @@ -161,7 +160,7 @@ def generate_sql_query( query_engine_id: int, tables: list[str], question: str, - data_cell_id: int = None, + original_query: str = None, stream=True, callback_handler: ChainStreamHandler = None, user_id=None, @@ -173,10 +172,6 @@ def generate_sql_query( table_schemas = self._generate_table_schema_prompt( metastore_id=query_engine.metastore_id, table_names=tables, session=session ) - - data_cell = datadoc_logic.get_data_cell_by_id(data_cell_id, session=session) - original_query = data_cell.context if data_cell else None - return self._generate_sql_query( language=query_engine.language, table_schemas=table_schemas, diff --git a/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx b/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx index 2a8ab4e4f..e52796aec 100644 --- a/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx +++ b/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx @@ -1,4 +1,5 @@ -import React, { useEffect, useState } from 'react'; +import { uniq } from 'lodash'; +import React, { useCallback, useEffect, useState } from 'react'; import { QueryEngineSelector } from 'components/QueryRunButton/QueryRunButton'; import { QueryComparison } from 'components/TranspileQueryModal/QueryComparison'; @@ -14,8 +15,8 @@ import { Message } from 'ui/Message/Message'; import { Modal } from 'ui/Modal/Modal'; import { StyledText } from 'ui/StyledText/StyledText'; -import { AIMode, AIModeSelector } from './AIModeSelector'; import { TableSelector } from './TableSelector'; +import { TextToSQLMode, TextToSQLModeSelector } from './TextToSQLModeSelector'; import './QueryGenerationModal.scss'; @@ -66,11 +67,11 @@ export const QueryGenerationModal = ({ const [question, setQuestion] = useState(''); const [tables, setTables] = useState(tablesInQuery); const [aiMode, setAIMode] = useState( - !!query ? AIMode.EDIT : AIMode.GENERATE + !!query ? TextToSQLMode.EDIT : TextToSQLMode.GENERATE ); useEffect(() => { - setTables([...new Set([...tablesInQuery, ...tables])]); + setTables(uniq([...tablesInQuery, ...tables])); }, [tablesInQuery]); const { streamStatus, startStream, streamData } = useStream( @@ -79,18 +80,22 @@ export const QueryGenerationModal = ({ query_engine_id: engineId, tables: tables, question: question, - data_cell_id: aiMode === AIMode.EDIT ? dataCellId : undefined, + data_cell_id: + aiMode === TextToSQLMode.EDIT ? dataCellId : undefined, } ); const { explanation, query: newQuery } = streamData; - const onKeyDown = (event: React.KeyboardEvent) => { - if (matchKeyPress(event, 'Enter')) { - startStream(); - inputRef.current.blur(); - } - }; + const onKeyDown = useCallback( + (event: React.KeyboardEvent) => { + if (matchKeyPress(event, 'Enter')) { + startStream(); + inputRef.current.blur(); + } + }, + [startStream] + ); const questionBarDOM = (
@@ -105,12 +110,12 @@ export const QueryGenerationModal = ({ />
- @@ -123,7 +128,7 @@ export const QueryGenerationModal = ({ transparent={false} inputProps={{ placeholder: - aiMode === AIMode.GENERATE + aiMode === TextToSQLMode.GENERATE ? 'Ask AI to generate a new query' : 'Ask AI to edit the query', type: 'text', @@ -135,7 +140,7 @@ export const QueryGenerationModal = ({
); - const bottomDOM = newQuery && ( + const bottomDOM = newQuery && streamStatus === StreamStatus.FINISHED && (
); @@ -213,7 +216,9 @@ export const QueryGenerationModal = ({
any; +interface IProps { + selectedMode: TextToSQLMode; + modes: TextToSQLMode[]; + onModeSelect: (mode: TextToSQLMode) => any; } -export const AIModeSelector: React.FC = ({ - aiMode, - aiModes, +export const TextToSQLModeSelector: React.FC = ({ + selectedMode, + modes, onModeSelect, }) => { - const engineItems = aiModes.map((mode) => ({ + const engineItems = modes.map((mode) => ({ name: {mode}, onClick: onModeSelect.bind(null, mode), - checked: aiMode === mode, + checked: selectedMode === mode, })); return (
{aiMode}
} + customButtonRenderer={() =>
{selectedMode}
} layout={['bottom', 'left']} className="engine-selector-dropdown" > From 5a6b291eaa4e9aa0eb46452a84391096126273cf Mon Sep 17 00:00:00 2001 From: "J.C. Zhong" Date: Fri, 23 Jun 2023 03:28:39 +0000 Subject: [PATCH 3/3] comments --- .../AIAssistant/QueryGenerationModal.scss | 2 +- .../AIAssistant/QueryGenerationModal.tsx | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/querybook/webapp/components/AIAssistant/QueryGenerationModal.scss b/querybook/webapp/components/AIAssistant/QueryGenerationModal.scss index f1392d1bf..d948453cc 100644 --- a/querybook/webapp/components/AIAssistant/QueryGenerationModal.scss +++ b/querybook/webapp/components/AIAssistant/QueryGenerationModal.scss @@ -36,7 +36,7 @@ align-items: center; } - .ai-mode { + .text2sql-mode { color: var(--color-pink-dark); } diff --git a/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx b/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx index e52796aec..b59a58ade 100644 --- a/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx +++ b/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx @@ -66,7 +66,7 @@ export const QueryGenerationModal = ({ ); const [question, setQuestion] = useState(''); const [tables, setTables] = useState(tablesInQuery); - const [aiMode, setAIMode] = useState( + const [textToSQLMode, setTextToSQLMode] = useState( !!query ? TextToSQLMode.EDIT : TextToSQLMode.GENERATE ); @@ -81,7 +81,7 @@ export const QueryGenerationModal = ({ tables: tables, question: question, data_cell_id: - aiMode === TextToSQLMode.EDIT ? dataCellId : undefined, + textToSQLMode === TextToSQLMode.EDIT ? dataCellId : undefined, } ); @@ -109,15 +109,15 @@ export const QueryGenerationModal = ({ size={18} /> -
+