Skip to content

Commit

Permalink
feat(sqllab): Format sql (apache#25344)
Browse files Browse the repository at this point in the history
  • Loading branch information
justinpark authored Nov 3, 2023
1 parent 915aaeb commit 3aab80a
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 4 deletions.
13 changes: 13 additions & 0 deletions superset-frontend/src/SqlLab/actions/sqlLab.js
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,19 @@ export function queryEditorSetSql(queryEditor, sql) {
return { type: QUERY_EDITOR_SET_SQL, queryEditor, sql };
}

export function formatQuery(queryEditor) {
return function (dispatch, getState) {
const { sql } = getUpToDateQuery(getState(), queryEditor);
return SupersetClient.post({
endpoint: `/api/v1/sqllab/format_sql/`,
body: JSON.stringify({ sql }),
headers: { 'Content-Type': 'application/json' },
}).then(({ json }) => {
dispatch(queryEditorSetSql(queryEditor, json.result));
});
};
}

export function queryEditorSetAndSaveSql(targetQueryEditor, sql) {
return function (dispatch, getState) {
const queryEditor = getUpToDateQuery(getState(), targetQueryEditor);
Expand Down
17 changes: 17 additions & 0 deletions superset-frontend/src/SqlLab/actions/sqlLab.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import fetchMock from 'fetch-mock';
import configureMockStore from 'redux-mock-store';
import thunk from 'redux-thunk';
import shortid from 'shortid';
import { waitFor } from '@testing-library/react';
import * as uiCore from '@superset-ui/core';
import * as actions from 'src/SqlLab/actions/sqlLab';
import { LOG_EVENT } from 'src/logger/actions';
Expand Down Expand Up @@ -127,6 +128,22 @@ describe('async actions', () => {
});
});

describe('formatQuery', () => {
const formatQueryEndpoint = 'glob:*/api/v1/sqllab/format_sql/';
const expectedSql = 'SELECT 1';
fetchMock.post(formatQueryEndpoint, { result: expectedSql });

test('posts to the correct url', async () => {
const store = mockStore(initialState);
store.dispatch(actions.formatQuery(query, queryId));
await waitFor(() =>
expect(fetchMock.calls(formatQueryEndpoint)).toHaveLength(1),
);
expect(store.getActions()[0].type).toBe(actions.QUERY_EDITOR_SET_SQL);
expect(store.getActions()[0].sql).toBe(expectedSql);
});
});

describe('fetchQueryResults', () => {
const makeRequest = () => {
const request = actions.fetchQueryResults(query);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,10 @@ const AceEditorWrapper = ({
};

const onChangeText = (text: string) => {
setSql(text);
onChange(text);
if (text !== sql) {
setSql(text);
onChange(text);
}
};

const { data: annotations } = useAnnotations({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ export enum KeyboardShortcut {
CMD_OPT_F = 'cmd+opt+f',
CTRL_F = 'ctrl+f',
CTRL_H = 'ctrl+h',
CTRL_SHIFT_F = 'ctrl+shift+f',
}

export const KEY_MAP = {
Expand All @@ -49,6 +50,7 @@ export const KEY_MAP = {
[KeyboardShortcut.CTRL_Q]: userOS === 'Windows' ? t('New tab') : undefined,
[KeyboardShortcut.CTRL_T]: userOS !== 'Windows' ? t('New tab') : undefined,
[KeyboardShortcut.CTRL_P]: t('Previous Line'),
[KeyboardShortcut.CTRL_SHIFT_F]: t('Format SQL'),
// default ace editor shortcuts
[KeyboardShortcut.CMD_F]: userOS === 'MacOS' ? t('Find') : undefined,
[KeyboardShortcut.CTRL_F]: userOS !== 'MacOS' ? t('Find') : undefined,
Expand Down
18 changes: 16 additions & 2 deletions superset-frontend/src/SqlLab/components/SqlEditor/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ import {
scheduleQuery,
setActiveSouthPaneTab,
updateSavedQuery,
formatQuery,
} from 'src/SqlLab/actions/sqlLab';
import {
STATE_TYPE_MAP,
Expand Down Expand Up @@ -305,6 +306,10 @@ const SqlEditor: React.FC<Props> = ({
[ctas, database, defaultQueryLimit, dispatch, queryEditor],
);

const formatCurrentQuery = useCallback(() => {
dispatch(formatQuery(queryEditor));
}, [dispatch, queryEditor]);

const stopQuery = useCallback(() => {
if (latestQuery && ['running', 'pending'].indexOf(latestQuery.state) >= 0) {
dispatch(postStopQuery(latestQuery));
Expand Down Expand Up @@ -384,8 +389,16 @@ const SqlEditor: React.FC<Props> = ({
}),
func: stopQuery,
},
{
name: 'formatQuery',
key: KeyboardShortcut.CTRL_SHIFT_F,
descr: KEY_MAP[KeyboardShortcut.CTRL_SHIFT_F],
func: () => {
formatCurrentQuery();
},
},
];
}, [dispatch, queryEditor.sql, startQuery, stopQuery]);
}, [dispatch, queryEditor.sql, startQuery, stopQuery, formatCurrentQuery]);

const hotkeys = useMemo(() => {
// Get all hotkeys including ace editor hotkeys
Expand Down Expand Up @@ -602,7 +615,7 @@ const SqlEditor: React.FC<Props> = ({
? t('Schedule the query periodically')
: t('You must run the query successfully first');
return (
<Menu css={{ width: theme.gridUnit * 44 }}>
<Menu css={{ width: theme.gridUnit * 50 }}>
<Menu.Item css={{ display: 'flex', justifyContent: 'space-between' }}>
{' '}
<span>{t('Autocomplete')}</span>{' '}
Expand All @@ -622,6 +635,7 @@ const SqlEditor: React.FC<Props> = ({
/>
</Menu.Item>
)}
<Menu.Item onClick={formatCurrentQuery}>{t('Format SQL')}</Menu.Item>
{!isEmpty(scheduledQueriesConf) && (
<Menu.Item>
<ScheduleQueryButton
Expand Down
50 changes: 50 additions & 0 deletions superset/sqllab/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from urllib import parse

import simplejson as json
import sqlparse
from flask import request, Response
from flask_appbuilder import permission_name
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface
from marshmallow import ValidationError
Expand All @@ -46,6 +48,7 @@
from superset.sqllab.schemas import (
EstimateQueryCostSchema,
ExecutePayloadSchema,
FormatQueryPayloadSchema,
QueryExecutionResponseSchema,
sql_lab_get_results_schema,
SQLLabBootstrapSchema,
Expand Down Expand Up @@ -78,6 +81,7 @@ class SqlLabRestApi(BaseSupersetApi):

estimate_model_schema = EstimateQueryCostSchema()
execute_model_schema = ExecutePayloadSchema()
format_model_schema = FormatQueryPayloadSchema()

apispec_parameter_schemas = {
"sql_lab_get_results_schema": sql_lab_get_results_schema,
Expand Down Expand Up @@ -185,6 +189,52 @@ def estimate_query_cost(self) -> Response:
result = command.run()
return self.response(200, result=result)

@expose("/format_sql/", methods=("POST",))
@statsd_metrics
@protect()
@permission_name("read")
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" f".format",
log_to_statsd=False,
)
def format_sql(self) -> FlaskResponse:
"""Format the SQL query.
---
post:
summary: Format SQL code
requestBody:
description: SQL query
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/FormatQueryPayloadSchema'
responses:
200:
description: Format SQL result
content:
application/json:
schema:
type: object
properties:
result:
type: string
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
403:
$ref: '#/components/responses/403'
500:
$ref: '#/components/responses/500'
"""
try:
model = self.format_model_schema.load(request.json)
result = sqlparse.format(model["sql"], reindent=True, keyword_case="upper")
return self.response(200, result=result)
except ValidationError as error:
return self.response_400(message=error.messages)

@expose("/export/<string:client_id>/")
@protect()
@statsd_metrics
Expand Down
4 changes: 4 additions & 0 deletions superset/sqllab/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ class EstimateQueryCostSchema(Schema):
)


class FormatQueryPayloadSchema(Schema):
sql = fields.String(required=True)


class ExecutePayloadSchema(Schema):
database_id = fields.Integer(required=True)
sql = fields.String(required=True)
Expand Down
13 changes: 13 additions & 0 deletions tests/integration_tests/sql_lab/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,19 @@ def test_estimate_valid_request(self):
self.assertDictEqual(resp_data, success_resp)
self.assertEqual(rv.status_code, 200)

def test_format_sql_request(self):
self.login()

data = {"sql": "select 1 from my_table"}
rv = self.client.post(
"/api/v1/sqllab/format_sql/",
json=data,
)
success_resp = {"result": "SELECT 1\nFROM my_table"}
resp_data = json.loads(rv.data.decode("utf-8"))
self.assertDictEqual(resp_data, success_resp)
self.assertEqual(rv.status_code, 200)

@mock.patch("superset.sqllab.commands.results.results_backend_use_msgpack", False)
def test_execute_required_params(self):
self.login()
Expand Down

0 comments on commit 3aab80a

Please sign in to comment.