diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 89395fb3d197e..388b5783a1752 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -29,7 +29,7 @@ If applicable, add screenshots to help explain your problem. (please complete the following information): -- superset version: [e.g. `v0.29`, `master`, `commit`] +- superset version: `superset version` - python version: `python --version` - node.js version: `node -v` - npm version: `npm -v` diff --git a/superset/assets/spec/javascripts/sqllab/QueryAutoRefresh_spec.jsx b/superset/assets/spec/javascripts/sqllab/QueryAutoRefresh_spec.jsx new file mode 100644 index 0000000000000..527ecd680980b --- /dev/null +++ b/superset/assets/spec/javascripts/sqllab/QueryAutoRefresh_spec.jsx @@ -0,0 +1,72 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import React from 'react'; +import { shallow } from 'enzyme'; +import sinon from 'sinon'; +import thunk from 'redux-thunk'; +import configureStore from 'redux-mock-store'; + +import QueryAutoRefresh from '../../../src/SqlLab/components/QueryAutoRefresh'; +import { initialState, runningQuery } from './fixtures'; + +describe('QueryAutoRefresh', () => { + const middlewares = [thunk]; + const mockStore = configureStore(middlewares); + const sqlLab = { + ...initialState.sqlLab, + queries: { + ryhMUZCGb: runningQuery, + }, + }; + const state = { + ...initialState, + sqlLab, + + }; + const store = mockStore(state); + + const getWrapper = () => ( + shallow(, { + context: { store }, + }).dive()); + + let wrapper; + + it('shouldCheckForQueries', () => { + wrapper = getWrapper(); + expect(wrapper.instance().shouldCheckForQueries()).toBe(true); + }); + + it('setUserOffline', () => { + wrapper = getWrapper(); + const spy = sinon.spy(wrapper.instance().props.actions, 'setUserOffline'); + + // state not changed + wrapper.setState({ + offline: false, + }); + expect(spy.called).toBe(false); + + // state is changed + wrapper.setState({ + offline: true, + }); + expect(spy.callCount).toBe(1); + }); +}); diff --git a/superset/assets/spec/javascripts/sqllab/SqlEditor_spec.jsx b/superset/assets/spec/javascripts/sqllab/SqlEditor_spec.jsx index 046b2e69cf89f..17bb4d83be2ad 100644 --- a/superset/assets/spec/javascripts/sqllab/SqlEditor_spec.jsx +++ b/superset/assets/spec/javascripts/sqllab/SqlEditor_spec.jsx @@ -20,10 +20,19 @@ import React from 'react'; import { shallow } from 'enzyme'; import { defaultQueryEditor, initialState, queries, table } from './fixtures'; +import { + SQL_EDITOR_GUTTER_HEIGHT, + SQL_EDITOR_GUTTER_MARGIN, + SQL_TOOLBAR_HEIGHT, +} from '../../../src/SqlLab/constants'; +import AceEditorWrapper from '../../../src/SqlLab/components/AceEditorWrapper'; import LimitControl from '../../../src/SqlLab/components/LimitControl'; +import SouthPane from '../../../src/SqlLab/components/SouthPane'; import SqlEditor from '../../../src/SqlLab/components/SqlEditor'; import SqlEditorLeftBar from '../../../src/SqlLab/components/SqlEditorLeftBar'; +const MOCKED_SQL_EDITOR_HEIGHT = 500; + describe('SqlEditor', () => { const mockedProps = { actions: {}, @@ -40,7 +49,7 @@ describe('SqlEditor', () => { }; beforeAll(() => { - jest.spyOn(SqlEditor.prototype, 'getSqlEditorHeight').mockImplementation(() => 500); + jest.spyOn(SqlEditor.prototype, 'getSqlEditorHeight').mockImplementation(() => MOCKED_SQL_EDITOR_HEIGHT); }); it('is valid', () => { @@ -52,6 +61,33 @@ describe('SqlEditor', () => { const wrapper = shallow(); expect(wrapper.find(SqlEditorLeftBar)).toHaveLength(1); }); + it('render an AceEditorWrapper', () => { + const wrapper = shallow(); + expect(wrapper.find(AceEditorWrapper)).toHaveLength(1); + }); + it('render an SouthPane', () => { + const wrapper = shallow(); + expect(wrapper.find(SouthPane)).toHaveLength(1); + }); + it('does not overflow the editor window', () => { + const wrapper = shallow(); + const totalSize = parseFloat(wrapper.find(AceEditorWrapper).props().height) + + wrapper.find(SouthPane).props().height + + SQL_TOOLBAR_HEIGHT + + (SQL_EDITOR_GUTTER_MARGIN * 2) + + SQL_EDITOR_GUTTER_HEIGHT; + expect(totalSize).toEqual(MOCKED_SQL_EDITOR_HEIGHT); + }); + it('does not overflow the editor window after resizing', () => { + const wrapper = shallow(); + wrapper.setState({ height: 450 }); + const totalSize = parseFloat(wrapper.find(AceEditorWrapper).props().height) + + wrapper.find(SouthPane).props().height + + SQL_TOOLBAR_HEIGHT + + (SQL_EDITOR_GUTTER_MARGIN * 2) + + SQL_EDITOR_GUTTER_HEIGHT; + expect(totalSize).toEqual(450); + }); it('render a LimitControl with default limit', () => { const defaultQueryLimit = 101; const updatedProps = { ...mockedProps, defaultQueryLimit }; diff --git a/superset/assets/spec/javascripts/sqllab/fixtures.js b/superset/assets/spec/javascripts/sqllab/fixtures.js index f43f43f550a68..99e740c338237 100644 --- a/superset/assets/spec/javascripts/sqllab/fixtures.js +++ b/superset/assets/spec/javascripts/sqllab/fixtures.js @@ -223,6 +223,20 @@ export const queries = [ type: 'STRING', }, ], + selected_columns: [ + { + is_date: true, + is_dim: false, + name: 'ds', + type: 'STRING', + }, + { + is_date: false, + is_dim: true, + name: 'gender', + type: 'STRING', + }, + ], data: [{ col1: 0, col2: 1 }, { col1: 2, col2: 3 }], }, }, @@ -264,7 +278,7 @@ export const queryWithBadColumns = { ...queries[0], results: { data: queries[0].results.data, - columns: [ + selected_columns: [ { is_date: true, is_dim: false, @@ -366,11 +380,13 @@ export const runningQuery = { id: 'ryhMUZCGb', progress: 90, state: 'running', + startDttm: Date.now() - 500, }; export const cachedQuery = Object.assign({}, queries[0], { cached: true }); export const initialState = { sqlLab: { + offline: false, alerts: [], queries: {}, databases: {}, diff --git a/superset/assets/src/SqlLab/components/ExploreResultsButton.jsx b/superset/assets/src/SqlLab/components/ExploreResultsButton.jsx index 2394c6713a942..5ac1e2570df1a 100644 --- a/superset/assets/src/SqlLab/components/ExploreResultsButton.jsx +++ b/superset/assets/src/SqlLab/components/ExploreResultsButton.jsx @@ -85,8 +85,8 @@ class ExploreResultsButton extends React.PureComponent { } getColumns() { const props = this.props; - if (props.query && props.query.results && props.query.results.columns) { - return props.query.results.columns; + if (props.query && props.query.results && props.query.results.selected_columns) { + return props.query.results.selected_columns; } return []; } @@ -97,7 +97,7 @@ class ExploreResultsButton extends React.PureComponent { const re1 = /^[A-Za-z_]\w*$/; // starts with char or _, then only alphanum const re2 = /__\d+$/; // does not finish with __ and then a number which screams dup col name - return this.props.query.results.columns.map(col => col.name) + return this.props.query.results.selected_columns.map(col => col.name) .filter(col => !re1.test(col) || re2.test(col)); } datasourceName() { diff --git a/superset/assets/src/SqlLab/components/QueryAutoRefresh.jsx b/superset/assets/src/SqlLab/components/QueryAutoRefresh.jsx index 6704d393fe754..3fcab31a64de0 100644 --- a/superset/assets/src/SqlLab/components/QueryAutoRefresh.jsx +++ b/superset/assets/src/SqlLab/components/QueryAutoRefresh.jsx @@ -30,9 +30,20 @@ const MAX_QUERY_AGE_TO_POLL = 21600000; const QUERY_TIMEOUT_LIMIT = 10000; class QueryAutoRefresh extends React.PureComponent { + constructor(props) { + super(props); + this.state = { + offline: props.offline, + }; + } componentWillMount() { this.startTimer(); } + componentDidUpdate(prevProps) { + if (prevProps.offline !== this.state.offline) { + this.props.actions.setUserOffline(this.state.offline); + } + } componentWillUnmount() { this.stopTimer(); } @@ -70,12 +81,12 @@ class QueryAutoRefresh extends React.PureComponent { if (Object.keys(json).length > 0) { this.props.actions.refreshQueries(json); } - this.props.actions.setUserOffline(false); - }).catch(() => { - this.props.actions.setUserOffline(true); - }); + this.setState({ offline: false }); + }).catch(() => { + this.setState({ offline: true }); + }); } else { - this.props.actions.setUserOffline(false); + this.setState({ offline: false }); } } render() { @@ -83,6 +94,7 @@ class QueryAutoRefresh extends React.PureComponent { } } QueryAutoRefresh.propTypes = { + offline: PropTypes.bool.isRequired, queries: PropTypes.object.isRequired, actions: PropTypes.object.isRequired, queriesLastUpdate: PropTypes.number.isRequired, @@ -90,6 +102,7 @@ QueryAutoRefresh.propTypes = { function mapStateToProps({ sqlLab }) { return { + offline: sqlLab.offline, queries: sqlLab.queries, queriesLastUpdate: sqlLab.queriesLastUpdate, }; diff --git a/superset/assets/src/SqlLab/components/ResultSet.jsx b/superset/assets/src/SqlLab/components/ResultSet.jsx index 443eb16b29998..7ac04b3292c51 100644 --- a/superset/assets/src/SqlLab/components/ResultSet.jsx +++ b/superset/assets/src/SqlLab/components/ResultSet.jsx @@ -207,6 +207,9 @@ export default class ResultSet extends React.PureComponent { data = results.data; } if (data && data.length > 0) { + const expandedColumns = results.expanded_columns + ? results.expanded_columns.map(col => col.name) + : []; return ( {this.renderControls.bind(this)()} @@ -216,6 +219,7 @@ export default class ResultSet extends React.PureComponent { orderedColumnKeys={results.columns.map(col => col.name)} height={height} filterText={this.state.searchText} + expandedColumns={expandedColumns} /> ); diff --git a/superset/assets/src/SqlLab/components/SqlEditor.jsx b/superset/assets/src/SqlLab/components/SqlEditor.jsx index a4aabb73fd49a..84e51064595dd 100644 --- a/superset/assets/src/SqlLab/components/SqlEditor.jsx +++ b/superset/assets/src/SqlLab/components/SqlEditor.jsx @@ -31,6 +31,7 @@ import { import Split from 'react-split'; import { t } from '@superset-ui/translation'; import debounce from 'lodash/debounce'; +import throttle from 'lodash/throttle'; import Button from '../../components/Button'; import LimitControl from './LimitControl'; @@ -43,17 +44,20 @@ import Timer from '../../components/Timer'; import Hotkeys from '../../components/Hotkeys'; import SqlEditorLeftBar from './SqlEditorLeftBar'; import AceEditorWrapper from './AceEditorWrapper'; -import { STATE_BSSTYLE_MAP } from '../constants'; +import { + STATE_BSSTYLE_MAP, + SQL_EDITOR_GUTTER_HEIGHT, + SQL_EDITOR_GUTTER_MARGIN, + SQL_TOOLBAR_HEIGHT, +} from '../constants'; import RunQueryActionButton from './RunQueryActionButton'; import { FeatureFlag, isFeatureEnabled } from '../../featureFlags'; const SQL_EDITOR_PADDING = 10; -const SQL_TOOLBAR_HEIGHT = 51; -const GUTTER_HEIGHT = 5; -const GUTTER_MARGIN = 3; const INITIAL_NORTH_PERCENT = 30; const INITIAL_SOUTH_PERCENT = 70; const VALIDATION_DEBOUNCE_MS = 600; +const WINDOW_RESIZE_THROTTLE_MS = 100; const propTypes = { actions: PropTypes.object.isRequired, @@ -83,6 +87,8 @@ class SqlEditor extends React.PureComponent { this.state = { autorun: props.queryEditor.autorun, ctas: '', + northPercent: INITIAL_NORTH_PERCENT, + southPercent: INITIAL_SOUTH_PERCENT, sql: props.queryEditor.sql, }; this.sqlEditorRef = React.createRef(); @@ -103,6 +109,10 @@ class SqlEditor extends React.PureComponent { this.requestValidation.bind(this), VALIDATION_DEBOUNCE_MS, ); + this.handleWindowResize = throttle( + this.handleWindowResize.bind(this), + WINDOW_RESIZE_THROTTLE_MS, + ); } componentWillMount() { if (this.state.autorun) { @@ -116,6 +126,11 @@ class SqlEditor extends React.PureComponent { // the south pane so it gets rendered properly // eslint-disable-next-line react/no-did-mount-set-state this.setState({ height: this.getSqlEditorHeight() }); + + window.addEventListener('resize', this.handleWindowResize); + } + componentWillUnmount() { + window.removeEventListener('resize', this.handleWindowResize); } onResizeStart() { // Set the heights on the ace editor and the ace content area after drag starts @@ -124,8 +139,7 @@ class SqlEditor extends React.PureComponent { document.getElementsByClassName('ace_content')[0].style.height = '100%'; } onResizeEnd([northPercent, southPercent]) { - this.setState(this.getAceEditorAndSouthPaneHeights( - this.state.height, northPercent, southPercent)); + this.setState({ northPercent, southPercent }); if (this.northPaneRef.current && this.northPaneRef.current.clientHeight) { this.props.actions.persistEditorHeight(this.props.queryEditor, @@ -149,9 +163,11 @@ class SqlEditor extends React.PureComponent { // given the height of the sql editor, north pane percent and south pane percent. getAceEditorAndSouthPaneHeights(height, northPercent, southPercent) { return { - aceEditorHeight: height * northPercent / 100 - (GUTTER_HEIGHT / 2 + GUTTER_MARGIN) + aceEditorHeight: height * northPercent / 100 + - (SQL_EDITOR_GUTTER_HEIGHT / 2 + SQL_EDITOR_GUTTER_MARGIN) - SQL_TOOLBAR_HEIGHT, - southPaneHeight: height * southPercent / 100 - (GUTTER_HEIGHT / 2 + GUTTER_MARGIN), + southPaneHeight: height * southPercent / 100 + - (SQL_EDITOR_GUTTER_HEIGHT / 2 + SQL_EDITOR_GUTTER_MARGIN), }; } getHotkeyConfig() { @@ -194,9 +210,12 @@ class SqlEditor extends React.PureComponent { setQueryLimit(queryLimit) { this.props.actions.queryEditorSetQueryLimit(this.props.queryEditor, queryLimit); } + handleWindowResize() { + this.setState({ height: this.getSqlEditorHeight() }); + } elementStyle(dimension, elementSize, gutterSize) { return { - [dimension]: `calc(${elementSize}% - ${gutterSize + GUTTER_MARGIN}px)`, + [dimension]: `calc(${elementSize}% - ${gutterSize + SQL_EDITOR_GUTTER_MARGIN}px)`, }; } requestValidation() { @@ -257,15 +276,18 @@ class SqlEditor extends React.PureComponent { queryPane() { const hotkeys = this.getHotkeyConfig(); const { aceEditorHeight, southPaneHeight } = this.getAceEditorAndSouthPaneHeights( - this.state.height, INITIAL_NORTH_PERCENT, INITIAL_SOUTH_PERCENT); + this.state.height, + this.state.northPercent, + this.state.southPercent, + ); return ( @@ -277,7 +299,7 @@ class SqlEditor extends React.PureComponent { queryEditor={this.props.queryEditor} sql={this.props.queryEditor.sql} tables={this.props.tables} - height={`${this.state.aceEditorHeight || aceEditorHeight}px`} + height={`${aceEditorHeight}px`} hotkeys={hotkeys} /> {this.renderEditorBottomBar(hotkeys)} @@ -286,7 +308,7 @@ class SqlEditor extends React.PureComponent { editorQueries={this.props.editorQueries} dataPreviewQueries={this.props.dataPreviewQueries} actions={this.props.actions} - height={this.state.southPaneHeight || southPaneHeight} + height={southPaneHeight} /> ); diff --git a/superset/assets/src/SqlLab/components/SqlEditorLeftBar.jsx b/superset/assets/src/SqlLab/components/SqlEditorLeftBar.jsx index 43ea4873a0f73..f389641117467 100644 --- a/superset/assets/src/SqlLab/components/SqlEditorLeftBar.jsx +++ b/superset/assets/src/SqlLab/components/SqlEditorLeftBar.jsx @@ -33,9 +33,10 @@ const propTypes = { }; const defaultProps = { - tables: [], actions: {}, + height: 500, offline: false, + tables: [], }; export default class SqlEditorLeftBar extends React.PureComponent { diff --git a/superset/assets/src/SqlLab/constants.js b/superset/assets/src/SqlLab/constants.js index 3bf8ce0bf3fba..4dd2118d7901e 100644 --- a/superset/assets/src/SqlLab/constants.js +++ b/superset/assets/src/SqlLab/constants.js @@ -43,3 +43,8 @@ export const TIME_OPTIONS = [ '90 days ago', '1 year ago', ]; + +// SqlEditor layout constants +export const SQL_EDITOR_GUTTER_HEIGHT = 5; +export const SQL_EDITOR_GUTTER_MARGIN = 3; +export const SQL_TOOLBAR_HEIGHT = 51; diff --git a/superset/assets/src/SqlLab/main.less b/superset/assets/src/SqlLab/main.less index 822e7d84ce4dc..c7be8fbc20c6a 100644 --- a/superset/assets/src/SqlLab/main.less +++ b/superset/assets/src/SqlLab/main.less @@ -238,6 +238,7 @@ div.Workspace { .schemaPane { flex: 0 0 300px; + max-width: 300px; transition: all .3s ease-in-out; } diff --git a/superset/assets/src/components/FilterableTable/FilterableTable.jsx b/superset/assets/src/components/FilterableTable/FilterableTable.jsx index 702f2b96493f8..e4f21a514d5e6 100644 --- a/superset/assets/src/components/FilterableTable/FilterableTable.jsx +++ b/superset/assets/src/components/FilterableTable/FilterableTable.jsx @@ -44,6 +44,7 @@ const propTypes = { overscanRowCount: PropTypes.number, rowHeight: PropTypes.number, striped: PropTypes.bool, + expandedColumns: PropTypes.array, }; const defaultProps = { @@ -52,6 +53,7 @@ const defaultProps = { overscanRowCount: 10, rowHeight: 32, striped: true, + expandedColumns: [], }; export default class FilterableTable extends PureComponent { @@ -141,7 +143,15 @@ export default class FilterableTable extends PureComponent { return (
- {label} + -1 + ? 'header-style-disabled' + : '' + } + > + {label} + {sortBy === dataKey && } diff --git a/superset/assets/src/components/FilterableTable/FilterableTableStyles.css b/superset/assets/src/components/FilterableTable/FilterableTableStyles.css index 7a0d3ba0ea7d3..f24df737e9bc2 100644 --- a/superset/assets/src/components/FilterableTable/FilterableTableStyles.css +++ b/superset/assets/src/components/FilterableTable/FilterableTableStyles.css @@ -76,4 +76,7 @@ overflow: hidden; text-overflow: ellipsis; white-space: nowrap; -} \ No newline at end of file +} +.header-style-disabled { + color: #aaa; +} diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 47abbf2a091bd..553c0b9ddcfd7 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -51,9 +51,16 @@ def __init__( is_prequery: bool = False, columns: List[str] = None, orderby: List[List] = None, + relative_start: str = app.config.get('DEFAULT_RELATIVE_START_TIME', 'today'), + relative_end: str = app.config.get('DEFAULT_RELATIVE_END_TIME', 'today'), ): self.granularity = granularity - self.from_dttm, self.to_dttm = utils.get_since_until(time_range, time_shift) + self.from_dttm, self.to_dttm = utils.get_since_until( + relative_start=relative_start, + relative_end=relative_end, + time_range=time_range, + time_shift=time_shift, + ) self.is_timeseries = is_timeseries self.time_range = time_range self.time_shift = utils.parse_human_timedelta(time_shift) diff --git a/superset/config.py b/superset/config.py index a3e27635b6324..5e35c1def6122 100644 --- a/superset/config.py +++ b/superset/config.py @@ -599,8 +599,13 @@ class CeleryConfig(object): DOCUMENTATION_URL = None # What is the Last N days relative in the time selector to: -# 'today' means it is midnight (00:00:00) of today in the local timezone +# 'today' means it is midnight (00:00:00) in the local timezone # 'now' means it is relative to the query issue time +# If both start and end time is set to now, this will make the time +# filter a moving window. By only setting the end time to now, +# start time will be set to midnight, while end will be relative to +# the query issue time. +DEFAULT_RELATIVE_START_TIME = 'today' DEFAULT_RELATIVE_END_TIME = 'today' # Is epoch_s/epoch_ms datetime format supposed to be considered since UTC ? diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index f178db458bbf1..de9f4d1fd397f 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -18,6 +18,7 @@ from collections import namedtuple, OrderedDict from datetime import datetime import logging +from typing import Optional, Union from flask import escape, Markup from flask_appbuilder import Model @@ -32,11 +33,12 @@ from sqlalchemy.orm import backref, relationship from sqlalchemy.schema import UniqueConstraint from sqlalchemy.sql import column, literal_column, table, text -from sqlalchemy.sql.expression import TextAsFrom +from sqlalchemy.sql.expression import Label, TextAsFrom import sqlparse from superset import app, db, security_manager from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric +from superset.db_engine_specs import TimestampExpression from superset.jinja_context import get_template_processor from superset.models.annotations import Annotation from superset.models.core import Database @@ -140,8 +142,14 @@ def get_time_filter(self, start_dttm, end_dttm): l.append(col <= text(self.dttm_sql_literal(end_dttm, is_epoch_in_utc))) return and_(*l) - def get_timestamp_expression(self, time_grain): - """Getting the time component of the query""" + def get_timestamp_expression(self, time_grain: Optional[str]) \ + -> Union[TimestampExpression, Label]: + """ + Return a SQLAlchemy Core element representation of self to be used in a query. + + :param time_grain: Optional time grain, e.g. P1Y + :return: A TimeExpression object wrapped in a Label if supported by db + """ label = utils.DTTM_ALIAS db = self.table.database @@ -150,16 +158,12 @@ def get_timestamp_expression(self, time_grain): if not self.expression and not time_grain and not is_epoch: sqla_col = column(self.column_name, type_=DateTime) return self.table.make_sqla_column_compatible(sqla_col, label) - grain = None - if time_grain: - grain = db.grains_dict().get(time_grain) - if not grain: - raise NotImplementedError( - f'No grain spec for {time_grain} for database {db.database_name}') - col = db.db_engine_spec.get_timestamp_column(self.expression, self.column_name) - expr = db.db_engine_spec.get_time_expr(col, pdf, time_grain, grain) - sqla_col = literal_column(expr, type_=DateTime) - return self.table.make_sqla_column_compatible(sqla_col, label) + if self.expression: + col = literal_column(self.expression) + else: + col = column(self.column_name) + time_expr = db.db_engine_spec.get_timestamp_expr(col, pdf, time_grain) + return self.table.make_sqla_column_compatible(time_expr, label) @classmethod def import_obj(cls, i_column): diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 04efef78b8f37..89e677b0136ea 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -28,7 +28,7 @@ The general idea is to use static classes and an inheritance scheme. """ -from collections import namedtuple +from collections import namedtuple, OrderedDict import hashlib import inspect import logging @@ -36,19 +36,20 @@ import re import textwrap import time -from typing import List, Tuple +from typing import Dict, List, Optional, Set, Tuple from urllib import parse from flask import g from flask_babel import lazy_gettext as _ import pandas import sqlalchemy as sqla -from sqlalchemy import Column, select, types +from sqlalchemy import Column, DateTime, select, types from sqlalchemy.engine import create_engine from sqlalchemy.engine.base import Engine from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.result import RowProxy from sqlalchemy.engine.url import make_url +from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql import quoted_name, text from sqlalchemy.sql.expression import ColumnClause from sqlalchemy.sql.expression import TextAsFrom @@ -90,6 +91,24 @@ } +class TimestampExpression(ColumnClause): + def __init__(self, expr: str, col: ColumnClause, **kwargs): + """Sqlalchemy class that can be can be used to render native column elements + respeting engine-specific quoting rules as part of a string-based expression. + + :param expr: Sql expression with '{col}' denoting the locations where the col + object will be rendered. + :param col: the target column + """ + super().__init__(expr, **kwargs) + self.col = col + + +@compiles(TimestampExpression) +def compile_timegrain_expression(element: TimestampExpression, compiler, **kw): + return element.name.replace('{col}', compiler.process(element.col, **kw)) + + def _create_time_grains_tuple(time_grains, time_grain_functions, blacklist): ret_list = [] blacklist = blacklist if blacklist else [] @@ -112,7 +131,7 @@ class BaseEngineSpec(object): """Abstract class for database engine specific configurations""" engine = 'base' # str as defined in sqlalchemy.engine.engine - time_grain_functions: dict = {} + time_grain_functions: Dict[Optional[str], str] = {} time_groupby_inline = False limit_method = LimitMethod.FORCE_LIMIT time_secondary_columns = False @@ -125,16 +144,31 @@ class BaseEngineSpec(object): try_remove_schema_from_table_name = True @classmethod - def get_time_expr(cls, expr, pdf, time_grain, grain): + def get_timestamp_expr(cls, col: ColumnClause, pdf: Optional[str], + time_grain: Optional[str]) -> TimestampExpression: + """ + Construct a TimeExpression to be used in a SQLAlchemy query. + + :param col: Target column for the TimeExpression + :param pdf: date format (seconds or milliseconds) + :param time_grain: time grain, e.g. P1Y for 1 year + :return: TimestampExpression object + """ + if time_grain: + time_expr = cls.time_grain_functions.get(time_grain) + if not time_expr: + raise NotImplementedError( + f'No grain spec for {time_grain} for database {cls.engine}') + else: + time_expr = '{col}' + # if epoch, translate to DATE using db specific conf if pdf == 'epoch_s': - expr = cls.epoch_to_dttm().format(col=expr) + time_expr = time_expr.replace('{col}', cls.epoch_to_dttm()) elif pdf == 'epoch_ms': - expr = cls.epoch_ms_to_dttm().format(col=expr) + time_expr = time_expr.replace('{col}', cls.epoch_ms_to_dttm()) - if grain: - expr = grain.function.format(col=expr) - return expr + return TimestampExpression(time_expr, col, type_=DateTime) @classmethod def get_time_grains(cls): @@ -160,6 +194,12 @@ def fetch_data(cls, cursor, limit): return cursor.fetchmany(limit) return cursor.fetchall() + @classmethod + def expand_data(cls, + columns: List[dict], + data: List[dict]) -> Tuple[List[dict], List[dict], List[dict]]: + return columns, data, [] + @classmethod def alter_new_orm_column(cls, orm_col): """Allow altering default column attributes when first detected/added @@ -489,13 +529,6 @@ def truncate_label(cls, label): label = label[:cls.max_column_name_length] return label - @staticmethod - def get_timestamp_column(expression, column_name): - """Return the expression if defined, otherwise return column_name. Some - engines require forcing quotes around column name, in which case this method - can be overridden.""" - return expression or column_name - class PostgresBaseEngineSpec(BaseEngineSpec): """ Abstract class for Postgres 'like' databases """ @@ -543,16 +576,6 @@ def get_table_names(cls, inspector, schema): tables.extend(inspector.get_foreign_table_names(schema)) return sorted(tables) - @staticmethod - def get_timestamp_column(expression, column_name): - """Postgres is unable to identify mixed case column names unless they - are quoted.""" - if expression: - return expression - elif column_name.lower() != column_name: - return f'"{column_name}"' - return column_name - class SnowflakeEngineSpec(PostgresBaseEngineSpec): engine = 'snowflake' @@ -794,7 +817,7 @@ class MySQLEngineSpec(BaseEngineSpec): 'INTERVAL DAYOFWEEK(DATE_SUB({col}, INTERVAL 1 DAY)) - 1 DAY))', } - type_code_map: dict = {} # loaded from get_datatype only if needed + type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed @classmethod def convert_dttm(cls, target_type, dttm): @@ -874,20 +897,16 @@ def get_view_names(cls, inspector, schema): return [] @classmethod - def _create_column_info(cls, column: RowProxy, name: str, data_type: str) -> dict: + def _create_column_info(cls, name: str, data_type: str) -> dict: """ Create column info object - :param column: column object :param name: column name :param data_type: column data type :return: column info object """ return { 'name': name, - 'type': data_type, - # newer Presto no longer includes this column - 'nullable': getattr(column, 'Null', True), - 'default': None, + 'type': f'{data_type}', } @classmethod @@ -926,13 +945,20 @@ def _split_data_type(cls, data_type: str, delimiter: str) -> List[str]: r'{}(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)'.format(delimiter), data_type) @classmethod - def _parse_structural_column(cls, column: RowProxy, result: List[dict]) -> None: + def _parse_structural_column(cls, + parent_column_name: str, + parent_data_type: str, + result: List[dict]) -> None: """ Parse a row or array column - :param column: column :param result: list tracking the results """ - full_data_type = '{} {}'.format(column.Column, column.Type) + formatted_parent_column_name = parent_column_name + # Quote the column name if there is a space + if ' ' in parent_column_name: + formatted_parent_column_name = f'"{parent_column_name}"' + full_data_type = f'{formatted_parent_column_name} {parent_data_type}' + original_result_len = len(result) # split on open parenthesis ( to get the structural # data type and its component types data_types = cls._split_data_type(full_data_type, r'\(') @@ -947,8 +973,9 @@ def _parse_structural_column(cls, column: RowProxy, result: List[dict]) -> None: stack.pop() elif cls._has_nested_data_types(inner_type): # split on comma , to get individual data types - single_fields = cls._split_data_type(inner_type, ', ') + single_fields = cls._split_data_type(inner_type, ',') for single_field in single_fields: + single_field = single_field.strip() # If component type starts with a comma, the first single field # will be an empty string. Disregard this empty string. if not single_field: @@ -961,13 +988,13 @@ def _parse_structural_column(cls, column: RowProxy, result: List[dict]) -> None: stack.append((field_info[0], field_info[1])) full_parent_path = cls._get_full_name(stack) result.append(cls._create_column_info( - column, full_parent_path, + full_parent_path, presto_type_map[field_info[1]]())) else: # otherwise this field is a basic data type full_parent_path = cls._get_full_name(stack) column_name = '{}.{}'.format(full_parent_path, field_info[0]) result.append(cls._create_column_info( - column, column_name, presto_type_map[field_info[1]]())) + column_name, presto_type_map[field_info[1]]())) # If the component type ends with a structural data type, do not pop # the stack. We have run across a structural data type within the # overall structural data type. Otherwise, we have completely parsed @@ -983,6 +1010,11 @@ def _parse_structural_column(cls, column: RowProxy, result: List[dict]) -> None: # Because it is an array of a basic data type. We have finished # parsing the structural data type and can move on. stack.pop() + # Unquote the column name if necessary + if formatted_parent_column_name != parent_column_name: + for index in range(original_result_len, len(result)): + result[index]['name'] = result[index]['name'].replace( + formatted_parent_column_name, parent_column_name) @classmethod def _show_columns( @@ -1019,7 +1051,11 @@ def get_columns( try: # parse column if it is a row or array if 'array' in column.Type or 'row' in column.Type: - cls._parse_structural_column(column, result) + structural_column_index = len(result) + cls._parse_structural_column(column.Column, column.Type, result) + result[structural_column_index]['nullable'] = getattr( + column, 'Null', True) + result[structural_column_index]['default'] = None continue else: # otherwise column is a basic data type column_type = presto_type_map[column.Type]() @@ -1027,7 +1063,10 @@ def get_columns( logging.info('Did not recognize type {} of column {}'.format( column.Type, column.Column)) column_type = types.NullType - result.append(cls._create_column_info(column, column.Column, column_type)) + column_info = cls._create_column_info(column.Column, column_type) + column_info['nullable'] = getattr(column, 'Null', True) + column_info['default'] = None + result.append(column_info) return result @classmethod @@ -1072,18 +1111,12 @@ def _get_fields(cls, cols: List[dict]) -> List[ColumnClause]: return column_clauses @classmethod - def _filter_presto_cols(cls, cols: List[dict]) -> List[dict]: + def _filter_out_array_nested_cols( + cls, cols: List[dict]) -> Tuple[List[dict], List[dict]]: """ - We want to filter out columns that correspond to array content because expanding - arrays would require us to use unnest and join. This can lead to a large, - complicated, and slow query. - - Example: select array_content - from TABLE - cross join UNNEST(array_column) as t(array_content); - - We know which columns to skip because cols is a list provided to us in a specific - order where a structural column is positioned right before its content. + Filter out columns that correspond to array content. We know which columns to + skip because cols is a list provided to us in a specific order where a structural + column is positioned right before its content. Example: Column Name: ColA, Column Data Type: array(row(nest_obj int)) cols = [ ..., ColA, ColA.nest_obj, ... ] @@ -1091,23 +1124,26 @@ def _filter_presto_cols(cls, cols: List[dict]) -> List[dict]: When we run across an array, check if subsequent column names start with the array name and skip them. :param cols: columns - :return: filtered list of columns + :return: filtered list of columns and list of array columns and its nested fields """ filtered_cols = [] - curr_array_col_name = '' + array_cols = [] + curr_array_col_name = None for col in cols: # col corresponds to an array's content and should be skipped if curr_array_col_name and col['name'].startswith(curr_array_col_name): + array_cols.append(col) continue # col is an array so we need to check if subsequent # columns correspond to the array's contents elif str(col['type']) == 'ARRAY': curr_array_col_name = col['name'] + array_cols.append(col) filtered_cols.append(col) else: - curr_array_col_name = '' + curr_array_col_name = None filtered_cols.append(col) - return filtered_cols + return filtered_cols, array_cols @classmethod def select_star(cls, my_db, table_name: str, engine: Engine, schema: str = None, @@ -1120,7 +1156,9 @@ def select_star(cls, my_db, table_name: str, engine: Engine, schema: str = None, """ presto_cols = cols if show_cols: - presto_cols = cls._filter_presto_cols(cols) + dot_regex = r'\.(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)' + presto_cols = [ + col for col in presto_cols if not re.search(dot_regex, col['name'])] return super(PrestoEngineSpec, cls).select_star( my_db, table_name, engine, schema, limit, show_cols, indent, latest_partition, presto_cols, @@ -1168,6 +1206,373 @@ def get_all_datasource_names(cls, db, datasource_type: str) \ schema=row['table_schema'], table=row['table_name'])) return datasource_names + @classmethod + def _build_column_hierarchy(cls, + columns: List[dict], + parent_column_types: List[str], + column_hierarchy: dict) -> None: + """ + Build a graph where the root node represents a column whose data type is in + parent_column_types. A node's children represent that column's nested fields + :param columns: list of columns + :param parent_column_types: list of data types that decide what columns can + be root nodes + :param column_hierarchy: dictionary representing the graph + """ + if len(columns) == 0: + return + root = columns.pop(0) + root_info = {'type': root['type'], 'children': []} + column_hierarchy[root['name']] = root_info + while columns: + column = columns[0] + # If the column name does not start with the root's name, + # then this column is not a nested field + if not column['name'].startswith(f"{root['name']}."): + break + # If the column's data type is one of the parent types, + # then this column may have nested fields + if str(column['type']) in parent_column_types: + cls._build_column_hierarchy(columns, parent_column_types, + column_hierarchy) + root_info['children'].append(column['name']) + continue + else: # The column is a nested field + root_info['children'].append(column['name']) + columns.pop(0) + + @classmethod + def _create_row_and_array_hierarchy( + cls, selected_columns: List[dict]) -> Tuple[dict, dict, List[dict]]: + """ + Build graphs where the root node represents a row or array and its children + are that column's nested fields + :param selected_columns: columns selected in a query + :return: graph representing a row, graph representing an array, and a list + of all the nested fields + """ + row_column_hierarchy: OrderedDict = OrderedDict() + array_column_hierarchy: OrderedDict = OrderedDict() + expanded_columns: List[dict] = [] + for column in selected_columns: + if column['type'].startswith('ROW'): + parsed_row_columns: List[dict] = [] + cls._parse_structural_column(column['name'], + column['type'].lower(), + parsed_row_columns) + expanded_columns = expanded_columns + parsed_row_columns[1:] + filtered_row_columns, array_columns = cls._filter_out_array_nested_cols( + parsed_row_columns) + cls._build_column_hierarchy(filtered_row_columns, + ['ROW'], + row_column_hierarchy) + cls._build_column_hierarchy(array_columns, + ['ROW', 'ARRAY'], + array_column_hierarchy) + elif column['type'].startswith('ARRAY'): + parsed_array_columns: List[dict] = [] + cls._parse_structural_column(column['name'], + column['type'].lower(), + parsed_array_columns) + expanded_columns = expanded_columns + parsed_array_columns[1:] + cls._build_column_hierarchy(parsed_array_columns, + ['ROW', 'ARRAY'], + array_column_hierarchy) + return row_column_hierarchy, array_column_hierarchy, expanded_columns + + @classmethod + def _create_empty_row_of_data(cls, columns: List[dict]) -> dict: + """ + Create an empty row of data + :param columns: list of columns + :return: dictionary representing an empty row of data + """ + return {column['name']: '' for column in columns} + + @classmethod + def _expand_row_data(cls, datum: dict, column: str, column_hierarchy: dict) -> None: + """ + Separate out nested fields and its value in a row of data + :param datum: row of data + :param column: row column name + :param column_hierarchy: dictionary tracking structural columns and its + nested fields + """ + if column in datum: + row_data = datum[column] + row_children = column_hierarchy[column]['children'] + if row_data and len(row_data) != len(row_children): + raise Exception('The number of data values and number of nested' + 'fields are not equal') + elif row_data: + for index, data_value in enumerate(row_data): + datum[row_children[index]] = data_value + else: + for row_child in row_children: + datum[row_child] = '' + + @classmethod + def _split_array_columns_by_process_state( + cls, array_columns: List[str], + array_column_hierarchy: dict, + datum: dict) -> Tuple[List[str], Set[str]]: + """ + Take a list of array columns and split them according to whether or not we are + ready to process them from a data set + :param array_columns: list of array columns + :param array_column_hierarchy: graph representing array columns + :param datum: row of data + :return: list of array columns ready to be processed and set of array columns + not ready to be processed + """ + array_columns_to_process = [] + unprocessed_array_columns = set() + child_array = None + for array_column in array_columns: + if array_column in datum: + array_columns_to_process.append(array_column) + elif str(array_column_hierarchy[array_column]['type']) == 'ARRAY': + child_array = array_column + unprocessed_array_columns.add(child_array) + elif child_array and array_column.startswith(child_array): + unprocessed_array_columns.add(array_column) + return array_columns_to_process, unprocessed_array_columns + + @classmethod + def _convert_data_list_to_array_data_dict( + cls, data: List[dict], array_columns_to_process: List[str]) -> dict: + """ + Pull out array data from rows of data into a dictionary where the key represents + the index in the data list and the value is the array data values + Example: + data = [ + {'ColumnA': [1, 2], 'ColumnB': 3}, + {'ColumnA': [11, 22], 'ColumnB': 3} + ] + data dictionary = { + 0: [{'ColumnA': [1, 2]], + 1: [{'ColumnA': [11, 22]] + } + :param data: rows of data + :param array_columns_to_process: array columns we want to pull out + :return: data dictionary + """ + array_data_dict = {} + for data_index, datum in enumerate(data): + all_array_datum = {} + for array_column in array_columns_to_process: + all_array_datum[array_column] = datum[array_column] + array_data_dict[data_index] = [all_array_datum] + return array_data_dict + + @classmethod + def _process_array_data(cls, + data: List[dict], + all_columns: List[dict], + array_column_hierarchy: dict) -> dict: + """ + Pull out array data that is ready to be processed into a dictionary. + The key refers to the index in the original data set. The value is + a list of data values. Initially this list will contain just one value, + the row of data that corresponds to the index in the original data set. + As we process arrays, we will pull out array values into separate rows + and append them to the list of data values. + Example: + Original data set = [ + {'ColumnA': [1, 2], 'ColumnB': [3]}, + {'ColumnA': [11, 22], 'ColumnB': [33]} + ] + all_array_data (intially) = { + 0: [{'ColumnA': [1, 2], 'ColumnB': [3}], + 1: [{'ColumnA': [11, 22], 'ColumnB': [33]}] + } + all_array_data (after processing) = { + 0: [ + {'ColumnA': 1, 'ColumnB': 3}, + {'ColumnA': 2, 'ColumnB': ''}, + ], + 1: [ + {'ColumnA': 11, 'ColumnB': 33}, + {'ColumnA': 22, 'ColumnB': ''}, + ], + } + :param data: rows of data + :param all_columns: list of columns + :param array_column_hierarchy: graph representing array columns + :return: dictionary representing processed array data + """ + array_columns = list(array_column_hierarchy.keys()) + # Determine what columns are ready to be processed. This is necessary for + # array columns that contain rows with nested arrays. We first process + # the outer arrays before processing inner arrays. + array_columns_to_process, \ + unprocessed_array_columns = cls._split_array_columns_by_process_state( + array_columns, array_column_hierarchy, data[0]) + + # Pull out array data that is ready to be processed into a dictionary. + all_array_data = cls._convert_data_list_to_array_data_dict( + data, array_columns_to_process) + + for original_data_index, expanded_array_data in all_array_data.items(): + for array_column in array_columns: + if array_column in unprocessed_array_columns: + continue + # Expand array values that are rows + if str(array_column_hierarchy[array_column]['type']) == 'ROW': + for array_value in expanded_array_data: + cls._expand_row_data(array_value, + array_column, + array_column_hierarchy) + continue + array_data = expanded_array_data[0][array_column] + array_children = array_column_hierarchy[array_column] + # This is an empty array of primitive data type + if not array_data and not array_children['children']: + continue + # Pull out complex array values into its own row of data + elif array_data and array_children['children']: + for array_index, data_value in enumerate(array_data): + if array_index >= len(expanded_array_data): + empty_data = cls._create_empty_row_of_data(all_columns) + expanded_array_data.append(empty_data) + for index, datum_value in enumerate(data_value): + array_child = array_children['children'][index] + expanded_array_data[array_index][array_child] = datum_value + # Pull out primitive array values into its own row of data + elif array_data: + for array_index, data_value in enumerate(array_data): + if array_index >= len(expanded_array_data): + empty_data = cls._create_empty_row_of_data(all_columns) + expanded_array_data.append(empty_data) + expanded_array_data[array_index][array_column] = data_value + # This is an empty array with nested fields + else: + for index, array_child in enumerate(array_children['children']): + for array_value in expanded_array_data: + array_value[array_child] = '' + return all_array_data + + @classmethod + def _consolidate_array_data_into_data(cls, + data: List[dict], + array_data: dict) -> None: + """ + Consolidate data given a list representing rows of data and a dictionary + representing expanded array data + Example: + Original data set = [ + {'ColumnA': [1, 2], 'ColumnB': [3]}, + {'ColumnA': [11, 22], 'ColumnB': [33]} + ] + array_data = { + 0: [ + {'ColumnA': 1, 'ColumnB': 3}, + {'ColumnA': 2, 'ColumnB': ''}, + ], + 1: [ + {'ColumnA': 11, 'ColumnB': 33}, + {'ColumnA': 22, 'ColumnB': ''}, + ], + } + Final data set = [ + {'ColumnA': 1, 'ColumnB': 3}, + {'ColumnA': 2, 'ColumnB': ''}, + {'ColumnA': 11, 'ColumnB': 33}, + {'ColumnA': 22, 'ColumnB': ''}, + ] + :param data: list representing rows of data + :param array_data: dictionary representing expanded array data + :return: list where data and array_data are combined + """ + data_index = 0 + original_data_index = 0 + while data_index < len(data): + data[data_index].update(array_data[original_data_index][0]) + array_data[original_data_index].pop(0) + data[data_index + 1:data_index + 1] = array_data[original_data_index] + data_index = data_index + len(array_data[original_data_index]) + 1 + original_data_index = original_data_index + 1 + + @classmethod + def _remove_processed_array_columns(cls, + unprocessed_array_columns: Set[str], + array_column_hierarchy: dict) -> None: + """ + Remove keys representing array columns that have already been processed + :param unprocessed_array_columns: list of unprocessed array columns + :param array_column_hierarchy: graph representing array columns + """ + array_columns = list(array_column_hierarchy.keys()) + for array_column in array_columns: + if array_column in unprocessed_array_columns: + continue + else: + del array_column_hierarchy[array_column] + + @classmethod + def expand_data(cls, + columns: List[dict], + data: List[dict]) -> Tuple[List[dict], List[dict], List[dict]]: + """ + We do not immediately display rows and arrays clearly in the data grid. This + method separates out nested fields and data values to help clearly display + structural columns. + + Example: ColumnA is a row(nested_obj varchar) and ColumnB is an array(int) + Original data set = [ + {'ColumnA': ['a1'], 'ColumnB': [1, 2]}, + {'ColumnA': ['a2'], 'ColumnB': [3, 4]}, + ] + Expanded data set = [ + {'ColumnA': ['a1'], 'ColumnA.nested_obj': 'a1', 'ColumnB': 1}, + {'ColumnA': '', 'ColumnA.nested_obj': '', 'ColumnB': 2}, + {'ColumnA': ['a2'], 'ColumnA.nested_obj': 'a2', 'ColumnB': 3}, + {'ColumnA': '', 'ColumnA.nested_obj': '', 'ColumnB': 4}, + ] + :param columns: columns selected in the query + :param data: original data set + :return: list of all columns(selected columns and their nested fields), + expanded data set, listed of nested fields + """ + all_columns: List[dict] = [] + # Get the list of all columns (selected fields and their nested fields) + for column in columns: + if column['type'].startswith('ARRAY') or column['type'].startswith('ROW'): + cls._parse_structural_column(column['name'], + column['type'].lower(), + all_columns) + else: + all_columns.append(column) + + # Build graphs where the root node is a row or array and its children are that + # column's nested fields + row_column_hierarchy,\ + array_column_hierarchy,\ + expanded_columns = cls._create_row_and_array_hierarchy(columns) + + # Pull out a row's nested fields and their values into separate columns + ordered_row_columns = row_column_hierarchy.keys() + for datum in data: + for row_column in ordered_row_columns: + cls._expand_row_data(datum, row_column, row_column_hierarchy) + + while array_column_hierarchy: + array_columns = list(array_column_hierarchy.keys()) + # Determine what columns are ready to be processed. + array_columns_to_process,\ + unprocessed_array_columns = cls._split_array_columns_by_process_state( + array_columns, array_column_hierarchy, data[0]) + all_array_data = cls._process_array_data(data, + all_columns, + array_column_hierarchy) + # Consolidate the original data set and the expanded array data + cls._consolidate_array_data_into_data(data, all_array_data) + # Remove processed array columns from the graph + cls._remove_processed_array_columns(unprocessed_array_columns, + array_column_hierarchy) + + return all_columns, data, expanded_columns + @classmethod def extra_table_metadata(cls, database, table_name, schema_name): indexes = database.get_indexes(table_name, schema_name) @@ -1812,20 +2217,21 @@ class PinotEngineSpec(BaseEngineSpec): inner_joins = False supports_column_aliases = False - _time_grain_to_datetimeconvert = { + # Pinot does its own conversion below + time_grain_functions: Dict[Optional[str], str] = { 'PT1S': '1:SECONDS', 'PT1M': '1:MINUTES', 'PT1H': '1:HOURS', 'P1D': '1:DAYS', - 'P1Y': '1:YEARS', + 'P1W': '1:WEEKS', 'P1M': '1:MONTHS', + 'P0.25Y': '3:MONTHS', + 'P1Y': '1:YEARS', } - # Pinot does its own conversion below - time_grain_functions = {k: None for k in _time_grain_to_datetimeconvert.keys()} - @classmethod - def get_time_expr(cls, expr, pdf, time_grain, grain): + def get_timestamp_expr(cls, col: ColumnClause, pdf: Optional[str], + time_grain: Optional[str]) -> TimestampExpression: is_epoch = pdf in ('epoch_s', 'epoch_ms') if not is_epoch: raise NotImplementedError('Pinot currently only supports epochs') @@ -1834,11 +2240,12 @@ def get_time_expr(cls, expr, pdf, time_grain, grain): # We are not really converting any time units, just bucketing them. seconds_or_ms = 'MILLISECONDS' if pdf == 'epoch_ms' else 'SECONDS' tf = f'1:{seconds_or_ms}:EPOCH' - granularity = cls._time_grain_to_datetimeconvert.get(time_grain) + granularity = cls.time_grain_functions.get(time_grain) if not granularity: raise NotImplementedError('No pinot grain spec for ' + str(time_grain)) # In pinot the output is a string since there is no timestamp column like pg - return f'DATETIMECONVERT({expr}, "{tf}", "{tf}", "{granularity}")' + time_expr = f'DATETIMECONVERT({{col}}, "{tf}", "{tf}", "{granularity}")' + return TimestampExpression(time_expr, col) @classmethod def make_select_compatible(cls, groupby_exprs, select_exprs): diff --git a/superset/models/core.py b/superset/models/core.py index 047a3ddb11b11..b379af7caba2b 100644 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -1029,21 +1029,13 @@ def grains(self): """Defines time granularity database-specific expressions. The idea here is to make it easy for users to change the time grain - form a datetime (maybe the source grain is arbitrary timestamps, daily + from a datetime (maybe the source grain is arbitrary timestamps, daily or 5 minutes increments) to another, "truncated" datetime. Since each database has slightly different but similar datetime functions, this allows a mapping between database engines and actual functions. """ return self.db_engine_spec.get_time_grains() - def grains_dict(self): - """Allowing to lookup grain by either label or duration - - For backward compatibility""" - d = {grain.duration: grain for grain in self.grains()} - d.update({grain.label: grain for grain in self.grains()}) - return d - def get_extra(self): extra = {} if self.extra: diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 6a5ee8ebc53ae..86e171ba44d77 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -279,10 +279,17 @@ def execute_sql_statements( latest_partition=False) query.end_time = now_as_float() + selected_columns = cdf.columns or [] + data = cdf.data or [] + all_columns, data, expanded_columns = db_engine_spec.expand_data( + selected_columns, data) + payload.update({ 'status': query.status, - 'data': cdf.data if cdf.data else [], - 'columns': cdf.columns if cdf.columns else [], + 'data': data, + 'columns': all_columns, + 'selected_columns': selected_columns, + 'expanded_columns': expanded_columns, 'query': query.to_dict(), }) diff --git a/superset/utils/core.py b/superset/utils/core.py index 2defa70dd179e..df2550040d947 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -237,14 +237,14 @@ def parse_human_datetime(s): # when time is not extracted, we 'reset to midnight' if parsed_flags & 2 == 0: parsed_dttm = parsed_dttm.replace(hour=0, minute=0, second=0) - dttm = dttm_from_timtuple(parsed_dttm.utctimetuple()) + dttm = dttm_from_timetuple(parsed_dttm.utctimetuple()) except Exception as e: logging.exception(e) raise ValueError("Couldn't parse date string [{}]".format(s)) return dttm -def dttm_from_timtuple(d: struct_time) -> datetime: +def dttm_from_timetuple(d: struct_time) -> datetime: return datetime( d.tm_year, d.tm_mon, d.tm_mday, d.tm_hour, d.tm_min, d.tm_sec) @@ -306,7 +306,7 @@ def parse_human_timedelta(s: str): True """ cal = parsedatetime.Calendar() - dttm = dttm_from_timtuple(datetime.now().timetuple()) + dttm = dttm_from_timetuple(datetime.now().timetuple()) d = cal.parse(s or '', dttm)[0] d = datetime(d.tm_year, d.tm_mon, d.tm_mday, d.tm_hour, d.tm_min, d.tm_sec) return d - dttm @@ -939,6 +939,7 @@ def get_since_until(time_range: Optional[str] = None, since: Optional[str] = None, until: Optional[str] = None, time_shift: Optional[str] = None, + relative_start: Optional[str] = None, relative_end: Optional[str] = None) -> Tuple[datetime, datetime]: """Return `since` and `until` date time tuple from string representations of time_range, since, until and time_shift. @@ -965,13 +966,14 @@ def get_since_until(time_range: Optional[str] = None, """ separator = ' : ' + relative_start = parse_human_datetime(relative_start if relative_start else 'today') relative_end = parse_human_datetime(relative_end if relative_end else 'today') common_time_frames = { - 'Last day': (relative_end - relativedelta(days=1), relative_end), # noqa: T400 - 'Last week': (relative_end - relativedelta(weeks=1), relative_end), # noqa: T400 - 'Last month': (relative_end - relativedelta(months=1), relative_end), # noqa: E501, T400 - 'Last quarter': (relative_end - relativedelta(months=3), relative_end), # noqa: E501, T400 - 'Last year': (relative_end - relativedelta(years=1), relative_end), # noqa: T400 + 'Last day': (relative_start - relativedelta(days=1), relative_end), # noqa: T400 + 'Last week': (relative_start - relativedelta(weeks=1), relative_end), # noqa: E501, T400 + 'Last month': (relative_start - relativedelta(months=1), relative_end), # noqa: E501, T400 + 'Last quarter': (relative_start - relativedelta(months=3), relative_end), # noqa: E501, T400 + 'Last year': (relative_start - relativedelta(years=1), relative_end), # noqa: E501, T400 } if time_range: @@ -988,10 +990,10 @@ def get_since_until(time_range: Optional[str] = None, else: rel, num, grain = time_range.split() if rel == 'Last': - since = relative_end - relativedelta(**{grain: int(num)}) # noqa: T400 + since = relative_start - relativedelta(**{grain: int(num)}) # noqa: T400 until = relative_end else: # rel == 'Next' - since = relative_end + since = relative_start until = relative_end + relativedelta(**{grain: int(num)}) # noqa: T400 else: since = since or '' diff --git a/superset/viz.py b/superset/viz.py index c193daa29e4a6..a62bf2e03365f 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -61,6 +61,7 @@ config = app.config stats_logger = config.get('STATS_LOGGER') +relative_start = config.get('DEFAULT_RELATIVE_START_TIME', 'today') relative_end = config.get('DEFAULT_RELATIVE_END_TIME', 'today') METRIC_KEYS = [ @@ -276,7 +277,8 @@ def query_obj(self): # default order direction order_desc = form_data.get('order_desc', True) - since, until = utils.get_since_until(relative_end=relative_end, + since, until = utils.get_since_until(relative_start=relative_start, + relative_end=relative_end, time_range=form_data.get('time_range'), since=form_data.get('since'), until=form_data.get('until')) @@ -802,7 +804,8 @@ def get_data(self, df): values[str(v / 10**9)] = obj.get(metric) data[metric] = values - start, end = utils.get_since_until(relative_end=relative_end, + start, end = utils.get_since_until(relative_start=relative_start, + relative_end=relative_end, time_range=form_data.get('time_range'), since=form_data.get('since'), until=form_data.get('until')) diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py index 0372366a2a48e..44919143d8c02 100644 --- a/tests/db_engine_specs_test.py +++ b/tests/db_engine_specs_test.py @@ -17,15 +17,16 @@ import inspect from unittest import mock -from sqlalchemy import column, select, table -from sqlalchemy.dialects.mssql import pymssql +from sqlalchemy import column, literal_column, select, table +from sqlalchemy.dialects import mssql, oracle, postgresql from sqlalchemy.engine.result import RowProxy from sqlalchemy.types import String, UnicodeText from superset import db_engine_specs from superset.db_engine_specs import ( BaseEngineSpec, BQEngineSpec, HiveEngineSpec, MssqlEngineSpec, - MySQLEngineSpec, OracleEngineSpec, PrestoEngineSpec, + MySQLEngineSpec, OracleEngineSpec, PinotEngineSpec, PostgresEngineSpec, + PrestoEngineSpec, ) from superset.models.core import Database from .base_tests import SupersetTestCase @@ -349,7 +350,14 @@ def test_presto_get_simple_row_column(self): ('column_name.nested_obj', 'FLOAT')] self.verify_presto_column(presto_column, expected_results) - def test_presto_get_simple_row_column_with_tricky_name(self): + def test_presto_get_simple_row_column_with_name_containing_whitespace(self): + presto_column = ('column name', 'row(nested_obj double)', '') + expected_results = [ + ('column name', 'ROW'), + ('column name.nested_obj', 'FLOAT')] + self.verify_presto_column(presto_column, expected_results) + + def test_presto_get_simple_row_column_with_tricky_nested_field_name(self): presto_column = ('column_name', 'row("Field Name(Tricky, Name)" double)', '') expected_results = [ ('column_name', 'ROW'), @@ -398,13 +406,286 @@ def test_presto_get_fields(self): self.assertEqual(actual_result.element.name, expected_result['name']) self.assertEqual(actual_result.name, expected_result['label']) - def test_presto_filter_presto_cols(self): + def test_presto_filter_out_array_nested_cols(self): cols = [ {'name': 'column', 'type': 'ARRAY'}, {'name': 'column.nested_obj', 'type': 'FLOAT'}] - actual_results = PrestoEngineSpec._filter_presto_cols(cols) - expected_results = [cols[0]] - self.assertEqual(actual_results, expected_results) + actual_filtered_cols,\ + actual_array_cols = PrestoEngineSpec._filter_out_array_nested_cols(cols) + expected_filtered_cols = [{'name': 'column', 'type': 'ARRAY'}] + self.assertEqual(actual_filtered_cols, expected_filtered_cols) + self.assertEqual(actual_array_cols, cols) + + def test_presto_create_row_and_array_hierarchy(self): + cols = [ + {'name': 'row_column', + 'type': 'ROW(NESTED_OBJ1 VARCHAR, NESTED_ROW ROW(NESTED_OBJ2 VARCHAR)'}, + {'name': 'array_column', + 'type': 'ARRAY(ROW(NESTED_ARRAY ARRAY(ROW(NESTED_OBJ VARCHAR))))'}] + actual_row_col_hierarchy,\ + actual_array_col_hierarchy,\ + actual_expanded_cols = PrestoEngineSpec._create_row_and_array_hierarchy(cols) + expected_row_col_hierarchy = { + 'row_column': { + 'type': 'ROW', + 'children': ['row_column.nested_obj1', 'row_column.nested_row'], + }, + 'row_column.nested_row': { + 'type': 'ROW', + 'children': ['row_column.nested_row.nested_obj2']}, + } + expected_array_col_hierarchy = { + 'array_column': { + 'type': 'ARRAY', + 'children': ['array_column.nested_array'], + }, + 'array_column.nested_array': { + 'type': 'ARRAY', + 'children': ['array_column.nested_array.nested_obj']}, + } + expected_expanded_cols = [ + {'name': 'row_column.nested_obj1', 'type': 'VARCHAR'}, + {'name': 'row_column.nested_row', 'type': 'ROW'}, + {'name': 'row_column.nested_row.nested_obj2', 'type': 'VARCHAR'}, + {'name': 'array_column.nested_array', 'type': 'ARRAY'}, + {'name': 'array_column.nested_array.nested_obj', 'type': 'VARCHAR'}] + self.assertEqual(actual_row_col_hierarchy, expected_row_col_hierarchy) + self.assertEqual(actual_array_col_hierarchy, expected_array_col_hierarchy) + self.assertEqual(actual_expanded_cols, expected_expanded_cols) + + def test_presto_expand_row_data(self): + datum = {'row_col': [1, 'a']} + row_column = 'row_col' + row_col_hierarchy = { + 'row_col': { + 'type': 'ROW', + 'children': ['row_col.nested_int', 'row_col.nested_str'], + }, + } + PrestoEngineSpec._expand_row_data(datum, row_column, row_col_hierarchy) + expected_datum = { + 'row_col': [1, 'a'], 'row_col.nested_int': 1, 'row_col.nested_str': 'a', + } + self.assertEqual(datum, expected_datum) + + def test_split_array_columns_by_process_state(self): + array_cols = ['array_column', 'array_column.nested_array'] + array_col_hierarchy = { + 'array_column': { + 'type': 'ARRAY', + 'children': ['array_column.nested_array'], + }, + 'array_column.nested_array': { + 'type': 'ARRAY', + 'children': ['array_column.nested_array.nested_obj']}, + } + datum = {'array_column': [[[1], [2]]]} + actual_array_cols_to_process, actual_unprocessed_array_cols = \ + PrestoEngineSpec._split_array_columns_by_process_state( + array_cols, array_col_hierarchy, datum) + expected_array_cols_to_process = ['array_column'] + expected_unprocessed_array_cols = {'array_column.nested_array'} + self.assertEqual(actual_array_cols_to_process, expected_array_cols_to_process) + self.assertEqual(actual_unprocessed_array_cols, expected_unprocessed_array_cols) + + def test_presto_convert_data_list_to_array_data_dict(self): + data = [ + {'array_column': [1, 2], 'int_column': 3}, + {'array_column': [11, 22], 'int_column': 33}, + ] + array_columns_to_process = ['array_column'] + actual_array_data_dict = PrestoEngineSpec._convert_data_list_to_array_data_dict( + data, array_columns_to_process) + expected_array_data_dict = { + 0: [{'array_column': [1, 2]}], + 1: [{'array_column': [11, 22]}]} + self.assertEqual(actual_array_data_dict, expected_array_data_dict) + + def test_presto_process_array_data(self): + data = [ + {'array_column': [[1], [2]], 'int_column': 3}, + {'array_column': [[11], [22]], 'int_column': 33}, + ] + all_columns = [ + {'name': 'array_column', 'type': 'ARRAY'}, + {'name': 'array_column.nested_row', 'type': 'BIGINT'}, + {'name': 'int_column', 'type': 'BIGINT'}, + ] + array_column_hierarchy = { + 'array_column': { + 'type': 'ARRAY', + 'children': ['array_column.nested_row'], + }, + } + actual_array_data = PrestoEngineSpec._process_array_data( + data, all_columns, array_column_hierarchy) + expected_array_data = { + 0: [ + {'array_column': [[1], [2]], 'array_column.nested_row': 1}, + {'array_column': '', 'array_column.nested_row': 2, 'int_column': ''}, + ], + 1: [ + {'array_column': [[11], [22]], 'array_column.nested_row': 11}, + {'array_column': '', 'array_column.nested_row': 22, 'int_column': ''}, + ], + } + self.assertEqual(actual_array_data, expected_array_data) + + def test_presto_consolidate_array_data_into_data(self): + data = [ + {'arr_col': [[1], [2]], 'int_col': 3}, + {'arr_col': [[11], [22]], 'int_col': 33}, + ] + array_data = { + 0: [ + {'arr_col': [[1], [2]], 'arr_col.nested_row': 1}, + {'arr_col': '', 'arr_col.nested_row': 2, 'int_col': ''}, + ], + 1: [ + {'arr_col': [[11], [22]], 'arr_col.nested_row': 11}, + {'arr_col': '', 'arr_col.nested_row': 22, 'int_col': ''}, + ], + } + PrestoEngineSpec._consolidate_array_data_into_data(data, array_data) + expected_data = [ + {'arr_col': [[1], [2]], 'arr_col.nested_row': 1, 'int_col': 3}, + {'arr_col': '', 'arr_col.nested_row': 2, 'int_col': ''}, + {'arr_col': [[11], [22]], 'arr_col.nested_row': 11, 'int_col': 33}, + {'arr_col': '', 'arr_col.nested_row': 22, 'int_col': ''}, + ] + self.assertEqual(data, expected_data) + + def test_presto_remove_processed_array_columns(self): + array_col_hierarchy = { + 'array_column': { + 'type': 'ARRAY', + 'children': ['array_column.nested_array'], + }, + 'array_column.nested_array': { + 'type': 'ARRAY', + 'children': ['array_column.nested_array.nested_obj']}, + } + unprocessed_array_cols = {'array_column.nested_array'} + PrestoEngineSpec._remove_processed_array_columns( + unprocessed_array_cols, array_col_hierarchy) + expected_array_col_hierarchy = { + 'array_column.nested_array': { + 'type': 'ARRAY', + 'children': ['array_column.nested_array.nested_obj']}, + } + self.assertEqual(array_col_hierarchy, expected_array_col_hierarchy) + + def test_presto_expand_data_with_simple_structural_columns(self): + cols = [ + {'name': 'row_column', 'type': 'ROW(NESTED_OBJ VARCHAR)'}, + {'name': 'array_column', 'type': 'ARRAY(BIGINT)'}] + data = [ + {'row_column': ['a'], 'array_column': [1, 2, 3]}, + {'row_column': ['b'], 'array_column': [4, 5, 6]}] + actual_cols, actual_data, actual_expanded_cols = PrestoEngineSpec.expand_data( + cols, data) + expected_cols = [ + {'name': 'row_column', 'type': 'ROW'}, + {'name': 'row_column.nested_obj', 'type': 'VARCHAR'}, + {'name': 'array_column', 'type': 'ARRAY'}] + expected_data = [ + {'row_column': ['a'], 'row_column.nested_obj': 'a', 'array_column': 1}, + {'row_column': '', 'row_column.nested_obj': '', 'array_column': 2}, + {'row_column': '', 'row_column.nested_obj': '', 'array_column': 3}, + {'row_column': ['b'], 'row_column.nested_obj': 'b', 'array_column': 4}, + {'row_column': '', 'row_column.nested_obj': '', 'array_column': 5}, + {'row_column': '', 'row_column.nested_obj': '', 'array_column': 6}] + expected_expanded_cols = [ + {'name': 'row_column.nested_obj', 'type': 'VARCHAR'}] + self.assertEqual(actual_cols, expected_cols) + self.assertEqual(actual_data, expected_data) + self.assertEqual(actual_expanded_cols, expected_expanded_cols) + + def test_presto_expand_data_with_complex_row_columns(self): + cols = [ + {'name': 'row_column', + 'type': 'ROW(NESTED_OBJ1 VARCHAR, NESTED_ROW ROW(NESTED_OBJ2 VARCHAR)'}] + data = [ + {'row_column': ['a1', ['a2']]}, + {'row_column': ['b1', ['b2']]}] + actual_cols, actual_data, actual_expanded_cols = PrestoEngineSpec.expand_data( + cols, data) + expected_cols = [ + {'name': 'row_column', 'type': 'ROW'}, + {'name': 'row_column.nested_obj1', 'type': 'VARCHAR'}, + {'name': 'row_column.nested_row', 'type': 'ROW'}, + {'name': 'row_column.nested_row.nested_obj2', 'type': 'VARCHAR'}] + expected_data = [ + {'row_column': ['a1', ['a2']], + 'row_column.nested_obj1': 'a1', + 'row_column.nested_row': ['a2'], + 'row_column.nested_row.nested_obj2': 'a2'}, + {'row_column': ['b1', ['b2']], + 'row_column.nested_obj1': 'b1', + 'row_column.nested_row': ['b2'], + 'row_column.nested_row.nested_obj2': 'b2'}] + expected_expanded_cols = [ + {'name': 'row_column.nested_obj1', 'type': 'VARCHAR'}, + {'name': 'row_column.nested_row', 'type': 'ROW'}, + {'name': 'row_column.nested_row.nested_obj2', 'type': 'VARCHAR'}] + self.assertEqual(actual_cols, expected_cols) + self.assertEqual(actual_data, expected_data) + self.assertEqual(actual_expanded_cols, expected_expanded_cols) + + def test_presto_expand_data_with_complex_array_columns(self): + cols = [ + {'name': 'int_column', 'type': 'BIGINT'}, + {'name': 'array_column', + 'type': 'ARRAY(ROW(NESTED_ARRAY ARRAY(ROW(NESTED_OBJ VARCHAR))))'}] + data = [ + {'int_column': 1, 'array_column': [[[['a'], ['b']]], [[['c'], ['d']]]]}, + {'int_column': 2, 'array_column': [[[['e'], ['f']]], [[['g'], ['h']]]]}] + actual_cols, actual_data, actual_expanded_cols = PrestoEngineSpec.expand_data( + cols, data) + expected_cols = [ + {'name': 'int_column', 'type': 'BIGINT'}, + {'name': 'array_column', 'type': 'ARRAY'}, + {'name': 'array_column.nested_array', 'type': 'ARRAY'}, + {'name': 'array_column.nested_array.nested_obj', 'type': 'VARCHAR'}] + expected_data = [ + {'int_column': 1, + 'array_column': [[[['a'], ['b']]], [[['c'], ['d']]]], + 'array_column.nested_array': [['a'], ['b']], + 'array_column.nested_array.nested_obj': 'a'}, + {'int_column': '', + 'array_column': '', + 'array_column.nested_array': '', + 'array_column.nested_array.nested_obj': 'b'}, + {'int_column': '', + 'array_column': '', + 'array_column.nested_array': [['c'], ['d']], + 'array_column.nested_array.nested_obj': 'c'}, + {'int_column': '', + 'array_column': '', + 'array_column.nested_array': '', + 'array_column.nested_array.nested_obj': 'd'}, + {'int_column': 2, + 'array_column': [[[['e'], ['f']]], [[['g'], ['h']]]], + 'array_column.nested_array': [['e'], ['f']], + 'array_column.nested_array.nested_obj': 'e'}, + {'int_column': '', + 'array_column': '', + 'array_column.nested_array': '', + 'array_column.nested_array.nested_obj': 'f'}, + {'int_column': '', + 'array_column': '', + 'array_column.nested_array': [['g'], ['h']], + 'array_column.nested_array.nested_obj': 'g'}, + {'int_column': '', + 'array_column': '', + 'array_column.nested_array': '', + 'array_column.nested_array.nested_obj': 'h'}] + expected_expanded_cols = [ + {'name': 'array_column.nested_array', 'type': 'ARRAY'}, + {'name': 'array_column.nested_array.nested_obj', 'type': 'VARCHAR'}] + self.assertEqual(actual_cols, expected_cols) + self.assertEqual(actual_data, expected_data) + self.assertEqual(actual_expanded_cols, expected_expanded_cols) def test_hive_get_view_names_return_empty_list(self): self.assertEquals([], HiveEngineSpec.get_view_names(mock.ANY, mock.ANY)) @@ -451,7 +732,7 @@ def assert_type(type_string, type_expected): assert_type('NTEXT', UnicodeText) def test_mssql_where_clause_n_prefix(self): - dialect = pymssql.dialect() + dialect = mssql.dialect() spec = MssqlEngineSpec str_col = column('col', type_=spec.get_sqla_column_type('VARCHAR(10)')) unicode_col = column('unicode_col', type_=spec.get_sqla_column_type('NTEXT')) @@ -462,7 +743,9 @@ def test_mssql_where_clause_n_prefix(self): where(unicode_col == 'abc') query = str(sel.compile(dialect=dialect, compile_kwargs={'literal_binds': True})) - query_expected = "SELECT col, unicode_col \nFROM tbl \nWHERE col = 'abc' AND unicode_col = N'abc'" # noqa + query_expected = 'SELECT col, unicode_col \n' \ + 'FROM tbl \n' \ + "WHERE col = 'abc' AND unicode_col = N'abc'" self.assertEqual(query, query_expected) def test_get_table_names(self): @@ -483,3 +766,51 @@ def test_get_table_names(self): pg_result = db_engine_specs.PostgresEngineSpec.get_table_names( schema='schema', inspector=inspector) self.assertListEqual(pg_result_expected, pg_result) + + def test_pg_time_expression_literal_no_grain(self): + col = literal_column('COALESCE(a, b)') + expr = PostgresEngineSpec.get_timestamp_expr(col, None, None) + result = str(expr.compile(dialect=postgresql.dialect())) + self.assertEqual(result, 'COALESCE(a, b)') + + def test_pg_time_expression_literal_1y_grain(self): + col = literal_column('COALESCE(a, b)') + expr = PostgresEngineSpec.get_timestamp_expr(col, None, 'P1Y') + result = str(expr.compile(dialect=postgresql.dialect())) + self.assertEqual(result, "DATE_TRUNC('year', COALESCE(a, b))") + + def test_pg_time_expression_lower_column_no_grain(self): + col = column('lower_case') + expr = PostgresEngineSpec.get_timestamp_expr(col, None, None) + result = str(expr.compile(dialect=postgresql.dialect())) + self.assertEqual(result, 'lower_case') + + def test_pg_time_expression_lower_case_column_sec_1y_grain(self): + col = column('lower_case') + expr = PostgresEngineSpec.get_timestamp_expr(col, 'epoch_s', 'P1Y') + result = str(expr.compile(dialect=postgresql.dialect())) + self.assertEqual(result, "DATE_TRUNC('year', (timestamp 'epoch' + lower_case * interval '1 second'))") # noqa + + def test_pg_time_expression_mixed_case_column_1y_grain(self): + col = column('MixedCase') + expr = PostgresEngineSpec.get_timestamp_expr(col, None, 'P1Y') + result = str(expr.compile(dialect=postgresql.dialect())) + self.assertEqual(result, "DATE_TRUNC('year', \"MixedCase\")") + + def test_mssql_time_expression_mixed_case_column_1y_grain(self): + col = column('MixedCase') + expr = MssqlEngineSpec.get_timestamp_expr(col, None, 'P1Y') + result = str(expr.compile(dialect=mssql.dialect())) + self.assertEqual(result, 'DATEADD(year, DATEDIFF(year, 0, [MixedCase]), 0)') + + def test_oracle_time_expression_reserved_keyword_1m_grain(self): + col = column('decimal') + expr = OracleEngineSpec.get_timestamp_expr(col, None, 'P1M') + result = str(expr.compile(dialect=oracle.dialect())) + self.assertEqual(result, "TRUNC(CAST(\"decimal\" as DATE), 'MONTH')") + + def test_pinot_time_expression_sec_1m_grain(self): + col = column('tstamp') + expr = PinotEngineSpec.get_timestamp_expr(col, 'epoch_s', 'P1M') + result = str(expr.compile()) + self.assertEqual(result, 'DATETIMECONVERT(tstamp, "1:SECONDS:EPOCH", "1:SECONDS:EPOCH", "1:MONTHS")') # noqa diff --git a/tests/model_tests.py b/tests/model_tests.py index 0fe03de932a28..53e53cc5f6516 100644 --- a/tests/model_tests.py +++ b/tests/model_tests.py @@ -109,47 +109,6 @@ def test_select_star(self): LIMIT 100""") assert sql.startswith(expected) - def test_grains_dict(self): - uri = 'mysql://root@localhost' - database = Database(sqlalchemy_uri=uri) - d = database.grains_dict() - self.assertEquals(d.get('day').function, 'DATE({col})') - self.assertEquals(d.get('P1D').function, 'DATE({col})') - self.assertEquals(d.get('Time Column').function, '{col}') - - def test_postgres_expression_time_grain(self): - uri = 'postgresql+psycopg2://uid:pwd@localhost:5432/superset' - database = Database(sqlalchemy_uri=uri) - pdf, time_grain = '', 'P1D' - expression, column_name = 'COALESCE(lowercase_col, "MixedCaseCol")', '' - grain = database.grains_dict().get(time_grain) - col = database.db_engine_spec.get_timestamp_column(expression, column_name) - grain_expr = database.db_engine_spec.get_time_expr(col, pdf, time_grain, grain) - grain_expr_expected = grain.function.replace('{col}', expression) - self.assertEqual(grain_expr, grain_expr_expected) - - def test_postgres_lowercase_col_time_grain(self): - uri = 'postgresql+psycopg2://uid:pwd@localhost:5432/superset' - database = Database(sqlalchemy_uri=uri) - pdf, time_grain = '', 'P1D' - expression, column_name = '', 'lowercase_col' - grain = database.grains_dict().get(time_grain) - col = database.db_engine_spec.get_timestamp_column(expression, column_name) - grain_expr = database.db_engine_spec.get_time_expr(col, pdf, time_grain, grain) - grain_expr_expected = grain.function.replace('{col}', column_name) - self.assertEqual(grain_expr, grain_expr_expected) - - def test_postgres_mixedcase_col_time_grain(self): - uri = 'postgresql+psycopg2://uid:pwd@localhost:5432/superset' - database = Database(sqlalchemy_uri=uri) - pdf, time_grain = '', 'P1D' - expression, column_name = '', 'MixedCaseCol' - grain = database.grains_dict().get(time_grain) - col = database.db_engine_spec.get_timestamp_column(expression, column_name) - grain_expr = database.db_engine_spec.get_time_expr(col, pdf, time_grain, grain) - grain_expr_expected = grain.function.replace('{col}', f'"{column_name}"') - self.assertEqual(grain_expr, grain_expr_expected) - def test_single_statement(self): main_db = get_main_database(db.session) @@ -217,24 +176,6 @@ def test_get_timestamp_expression_epoch(self): self.assertEquals(compiled, 'DATE(from_unixtime(DATE_ADD(ds, 1)))') ds_col.expression = prev_ds_expr - def test_get_timestamp_expression_backward(self): - tbl = self.get_table_by_name('birth_names') - ds_col = tbl.get_column('ds') - - ds_col.expression = None - ds_col.python_date_format = None - sqla_literal = ds_col.get_timestamp_expression('day') - compiled = '{}'.format(sqla_literal.compile()) - if tbl.database.backend == 'mysql': - self.assertEquals(compiled, 'DATE(ds)') - - ds_col.expression = None - ds_col.python_date_format = None - sqla_literal = ds_col.get_timestamp_expression('Time Column') - compiled = '{}'.format(sqla_literal.compile()) - if tbl.database.backend == 'mysql': - self.assertEquals(compiled, 'ds') - def query_with_expr_helper(self, is_timeseries, inner_join=True): tbl = self.get_table_by_name('birth_names') ds_col = tbl.get_column('ds') diff --git a/tests/utils_tests.py b/tests/utils_tests.py index 40dedafaef597..a39631b832274 100644 --- a/tests/utils_tests.py +++ b/tests/utils_tests.py @@ -43,7 +43,9 @@ def mock_parse_human_datetime(s): - if s in ['now', 'today']: + if s == 'now': + return datetime(2016, 11, 7, 9, 30, 10) + elif s == 'today': return datetime(2016, 11, 7) elif s == 'yesterday': return datetime(2016, 11, 6) @@ -51,6 +53,8 @@ def mock_parse_human_datetime(s): return datetime(2016, 11, 8) elif s == 'Last year': return datetime(2015, 11, 7) + elif s == 'Last week': + return datetime(2015, 10, 31) elif s == 'Last 5 months': return datetime(2016, 6, 7) elif s == 'Next 5 months': @@ -600,7 +604,7 @@ def test_get_since_until(self): self.assertEqual(result, expected) result = get_since_until(' : now') - expected = None, datetime(2016, 11, 7) + expected = None, datetime(2016, 11, 7, 9, 30, 10) self.assertEqual(result, expected) result = get_since_until('yesterday : tomorrow') @@ -636,7 +640,19 @@ def test_get_since_until(self): self.assertEqual(result, expected) result = get_since_until(time_range='5 days : now') - expected = datetime(2016, 11, 2), datetime(2016, 11, 7) + expected = datetime(2016, 11, 2), datetime(2016, 11, 7, 9, 30, 10) + self.assertEqual(result, expected) + + result = get_since_until('Last week', relative_end='now') + expected = datetime(2016, 10, 31), datetime(2016, 11, 7, 9, 30, 10) + self.assertEqual(result, expected) + + result = get_since_until('Last week', relative_start='now') + expected = datetime(2016, 10, 31, 9, 30, 10), datetime(2016, 11, 7) + self.assertEqual(result, expected) + + result = get_since_until('Last week', relative_start='now', relative_end='now') + expected = datetime(2016, 10, 31, 9, 30, 10), datetime(2016, 11, 7, 9, 30, 10) self.assertEqual(result, expected) with self.assertRaises(ValueError):