diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 0b150686c63b0..bc826f15a0d44 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -1015,7 +1015,10 @@ def run_query( # noqa / druid orderby=None, extras=None, # noqa columns=None, phase=2, client=None, form_data=None, - order_desc=True): + order_desc=True, + prequeries=None, + is_prequery=False, + ): """Runs a query against Druid and returns a dataframe. """ # TODO refactor into using a TBD Query object diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index e0952288ffdcb..15b8cbbe7918c 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -334,6 +334,8 @@ def get_query_str(self, query_obj): ) logging.info(sql) sql = sqlparse.format(sql, reindent=True) + if query_obj['is_prequery']: + query_obj['prequeries'].append(sql) return sql def get_sqla_table(self): @@ -369,7 +371,10 @@ def get_sqla_query( # sqla extras=None, columns=None, form_data=None, - order_desc=True): + order_desc=True, + prequeries=None, + is_prequery=False, + ): """Querying any sqla table from this common interface""" template_kwargs = { 'from_dttm': from_dttm, @@ -528,37 +533,73 @@ def get_sqla_query( # sqla if is_timeseries and \ timeseries_limit and groupby and not time_groupby_inline: - # some sql dialects require for order by expressions - # to also be in the select clause -- others, e.g. vertica, - # require a unique inner alias - inner_main_metric_expr = main_metric_expr.label('mme_inner__') - inner_select_exprs += [inner_main_metric_expr] - subq = select(inner_select_exprs) - subq = subq.select_from(tbl) - inner_time_filter = dttm_col.get_time_filter( - inner_from_dttm or from_dttm, - inner_to_dttm or to_dttm, - ) - subq = subq.where(and_(*(where_clause_and + [inner_time_filter]))) - subq = subq.group_by(*inner_groupby_exprs) - - ob = inner_main_metric_expr - if timeseries_limit_metric: - timeseries_limit_metric = metrics_dict.get(timeseries_limit_metric) - ob = timeseries_limit_metric.sqla_col - direction = desc if order_desc else asc - subq = subq.order_by(direction(ob)) - subq = subq.limit(timeseries_limit) - - on_clause = [] - for i, gb in enumerate(groupby): - on_clause.append( - groupby_exprs[i] == column(gb + '__')) - - tbl = tbl.join(subq.alias(), and_(*on_clause)) + if self.database.db_engine_spec.inner_joins: + # some sql dialects require for order by expressions + # to also be in the select clause -- others, e.g. vertica, + # require a unique inner alias + inner_main_metric_expr = main_metric_expr.label('mme_inner__') + inner_select_exprs += [inner_main_metric_expr] + subq = select(inner_select_exprs) + subq = subq.select_from(tbl) + inner_time_filter = dttm_col.get_time_filter( + inner_from_dttm or from_dttm, + inner_to_dttm or to_dttm, + ) + subq = subq.where(and_(*(where_clause_and + [inner_time_filter]))) + subq = subq.group_by(*inner_groupby_exprs) + + ob = inner_main_metric_expr + if timeseries_limit_metric: + timeseries_limit_metric = metrics_dict.get(timeseries_limit_metric) + ob = timeseries_limit_metric.sqla_col + direction = desc if order_desc else asc + subq = subq.order_by(direction(ob)) + subq = subq.limit(timeseries_limit) + + on_clause = [] + for i, gb in enumerate(groupby): + on_clause.append( + groupby_exprs[i] == column(gb + '__')) + + tbl = tbl.join(subq.alias(), and_(*on_clause)) + else: + # run subquery to get top groups + subquery_obj = { + 'prequeries': prequeries, + 'is_prequery': True, + 'is_timeseries': False, + 'row_limit': timeseries_limit, + 'groupby': groupby, + 'metrics': metrics, + 'granularity': granularity, + 'from_dttm': inner_from_dttm or from_dttm, + 'to_dttm': inner_to_dttm or to_dttm, + 'filter': filter, + 'orderby': orderby, + 'extras': extras, + 'columns': columns, + 'form_data': form_data, + 'order_desc': True, + } + result = self.query(subquery_obj) + dimensions = [c for c in result.df.columns if c not in metrics] + top_groups = self._get_top_groups(result.df, dimensions) + qry = qry.where(top_groups) return qry.select_from(tbl) + def _get_top_groups(self, df, dimensions): + cols = {col.column_name: col for col in self.columns} + groups = [] + for unused, row in df.iterrows(): + group = [] + for dimension in dimensions: + col_obj = cols.get(dimension) + group.append(col_obj.sqla_col == row[dimension]) + groups.append(and_(*group)) + + return or_(*groups) + def query(self, query_obj): qry_start_dttm = datetime.now() sql = self.get_query_str(query_obj) @@ -573,6 +614,12 @@ def query(self, query_obj): error_message = ( self.database.db_engine_spec.extract_error_message(e)) + # if this is a main query with prequeries, combine them together + if not query_obj['is_prequery']: + query_obj['prequeries'].append(sql) + sql = ';\n\n'.join(query_obj['prequeries']) + sql += ';' + return QueryResult( status=status, df=df, diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 158d9d6cfc9cc..5284a1b580b51 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -62,6 +62,7 @@ class BaseEngineSpec(object): time_groupby_inline = False limit_method = LimitMethod.FETCH_MANY time_secondary_columns = False + inner_joins = True @classmethod def fetch_data(cls, cursor, limit): @@ -1221,6 +1222,7 @@ class DruidEngineSpec(BaseEngineSpec): """Engine spec for Druid.io""" engine = 'druid' limit_method = LimitMethod.FETCH_MANY + inner_joins = False engines = { diff --git a/superset/views/core.py b/superset/views/core.py index 00254b4ca27ba..30d647d112fd4 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -989,6 +989,12 @@ def get_query_string_response(self, viz_obj): query = viz_obj.datasource.get_query_str(query_obj) except Exception as e: return json_error_response(e) + + if query_obj['prequeries']: + query_obj['prequeries'].append(query) + query = ';\n\n'.join(query_obj['prequeries']) + query += ';' + return Response( json.dumps({ 'query': query, diff --git a/superset/viz.py b/superset/viz.py index 6551577de15c4..2b7df0e900fa1 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -200,6 +200,8 @@ def query_obj(self): 'timeseries_limit_metric': timeseries_limit_metric, 'form_data': form_data, 'order_desc': order_desc, + 'prequeries': [], + 'is_prequery': False, } return d