diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 69f10c75f596b..cc85a92b7fae4 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -50,6 +50,13 @@ def __init__(self, name, field_names, function): self.name = name +class CustomPostAggregator(Postaggregator): + """A way to allow users to specify completely custom PostAggregators""" + def __init__(self, name, post_aggregator): + self.name = name + self.post_aggregator = post_aggregator + + class DruidCluster(Model, AuditMixinNullable): """ORM object referencing the Druid clusters""" @@ -690,6 +697,75 @@ def granularity(period_name, timezone=None, origin=None): period_name).total_seconds() * 1000 return granularity + @staticmethod + def _metrics_and_post_aggs(metrics, metrics_dict): + all_metrics = [] + post_aggs = {} + + def recursive_get_fields(_conf): + _type = _conf.get('type') + _field = _conf.get('field') + _fields = _conf.get('fields') + + field_names = [] + if _type in ['fieldAccess', 'hyperUniqueCardinality', + 'quantile', 'quantiles']: + field_names.append(_conf.get('fieldName', '')) + + if _field: + field_names += recursive_get_fields(_field) + + if _fields: + for _f in _fields: + field_names += recursive_get_fields(_f) + + return list(set(field_names)) + + for metric_name in metrics: + metric = metrics_dict[metric_name] + if metric.metric_type != 'postagg': + all_metrics.append(metric_name) + else: + mconf = metric.json_obj + all_metrics += recursive_get_fields(mconf) + all_metrics += mconf.get('fieldNames', []) + if mconf.get('type') == 'javascript': + post_aggs[metric_name] = JavascriptPostAggregator( + name=mconf.get('name', ''), + field_names=mconf.get('fieldNames', []), + function=mconf.get('function', '')) + elif mconf.get('type') == 'quantile': + post_aggs[metric_name] = Quantile( + mconf.get('name', ''), + mconf.get('probability', ''), + ) + elif mconf.get('type') == 'quantiles': + post_aggs[metric_name] = Quantiles( + mconf.get('name', ''), + mconf.get('probabilities', ''), + ) + elif mconf.get('type') == 'fieldAccess': + post_aggs[metric_name] = Field(mconf.get('name')) + elif mconf.get('type') == 'constant': + post_aggs[metric_name] = Const( + mconf.get('value'), + output_name=mconf.get('name', '') + ) + elif mconf.get('type') == 'hyperUniqueCardinality': + post_aggs[metric_name] = HyperUniqueCardinality( + mconf.get('name') + ) + elif mconf.get('type') == 'arithmetic': + post_aggs[metric_name] = Postaggregator( + mconf.get('fn', "/"), + mconf.get('fields', []), + mconf.get('name', '')) + else: + post_aggs[metric_name] = CustomPostAggregator( + mconf.get('name', ''), + mconf) + return all_metrics, post_aggs + def values_for_column(self, column_name, limit=10000): @@ -749,61 +825,10 @@ def run_query( # noqa / druid query_str = "" metrics_dict = {m.metric_name: m for m in self.metrics} - all_metrics = [] - post_aggs = {} columns_dict = {c.column_name: c for c in self.columns} - def recursive_get_fields(_conf): - _fields = _conf.get('fields', []) - field_names = [] - for _f in _fields: - _type = _f.get('type') - if _type in ['fieldAccess', 'hyperUniqueCardinality']: - field_names.append(_f.get('fieldName')) - elif _type == 'arithmetic': - field_names += recursive_get_fields(_f) - return list(set(field_names)) - - for metric_name in metrics: - metric = metrics_dict[metric_name] - if metric.metric_type != 'postagg': - all_metrics.append(metric_name) - else: - mconf = metric.json_obj - all_metrics += recursive_get_fields(mconf) - all_metrics += mconf.get('fieldNames', []) - if mconf.get('type') == 'javascript': - post_aggs[metric_name] = JavascriptPostAggregator( - name=mconf.get('name', ''), - field_names=mconf.get('fieldNames', []), - function=mconf.get('function', '')) - elif mconf.get('type') == 'quantile': - post_aggs[metric_name] = Quantile( - mconf.get('name', ''), - mconf.get('probability', ''), - ) - elif mconf.get('type') == 'quantiles': - post_aggs[metric_name] = Quantiles( - mconf.get('name', ''), - mconf.get('probabilities', ''), - ) - elif mconf.get('type') == 'fieldAccess': - post_aggs[metric_name] = Field(mconf.get('name')) - elif mconf.get('type') == 'constant': - post_aggs[metric_name] = Const( - mconf.get('value'), - output_name=mconf.get('name', '') - ) - elif mconf.get('type') == 'hyperUniqueCardinality': - post_aggs[metric_name] = HyperUniqueCardinality( - mconf.get('name') - ) - else: - post_aggs[metric_name] = Postaggregator( - mconf.get('fn', "/"), - mconf.get('fields', []), - mconf.get('name', '')) + all_metrics, post_aggs = self._metrics_and_post_aggs(metrics, metrics_dict) aggregations = OrderedDict() for m in self.metrics: diff --git a/tests/druid_tests.py b/tests/druid_tests.py index d7b93dee0638e..637afe984ce02 100644 --- a/tests/druid_tests.py +++ b/tests/druid_tests.py @@ -11,8 +11,8 @@ from mock import Mock, patch from superset import db, sm, security -from superset.connectors.druid.models import DruidCluster, DruidDatasource -from superset.connectors.druid.models import PyDruid +from superset.connectors.druid.models import DruidMetric, DruidCluster, DruidDatasource +from superset.connectors.druid.models import PyDruid, Quantile, Postaggregator from .base_tests import SupersetTestCase @@ -38,7 +38,7 @@ "metric1": { "type": "longSum", "name": "metric1", - "fieldName": "metric1"} + "fieldName": "metric1"}, }, "size": 300000, "numRows": 5000000 @@ -318,6 +318,77 @@ def test_sync_druid_perm(self, PyDruid): permission=permission, view_menu=view_menu).first() assert pv is not None + def test_metrics_and_post_aggs(self): + """ + Test generation of metrics and post-aggregations from an initial list + of superset metrics (which may include the results of either). This + primarily tests that specifying a post-aggregator metric will also + require the raw aggregation of the associated druid metric column. + """ + metrics_dict = { + 'unused_count': DruidMetric( + metric_name='unused_count', + verbose_name='COUNT(*)', + metric_type='count', + json=json.dumps({'type': 'count', 'name': 'unused_count'})), + 'some_sum': DruidMetric( + metric_name='some_sum', + verbose_name='SUM(*)', + metric_type='sum', + json=json.dumps({'type': 'sum', 'name': 'sum'})), + 'a_histogram': DruidMetric( + metric_name='a_histogram', + verbose_name='APPROXIMATE_HISTOGRAM(*)', + metric_type='approxHistogramFold', + json=json.dumps({'type': 'approxHistogramFold', 'name': 'a_histogram'})), + 'aCustomMetric': DruidMetric( + metric_name='aCustomMetric', + verbose_name='MY_AWESOME_METRIC(*)', + metric_type='aCustomType', + json=json.dumps({'type': 'customMetric', 'name': 'aCustomMetric'})), + 'quantile_p95': DruidMetric( + metric_name='quantile_p95', + verbose_name='P95(*)', + metric_type='postagg', + json=json.dumps({ + 'type': 'quantile', + 'probability': 0.95, + 'name': 'p95', + 'fieldName': 'a_histogram'})), + 'aCustomPostAgg': DruidMetric( + metric_name='aCustomPostAgg', + verbose_name='CUSTOM_POST_AGG(*)', + metric_type='postagg', + json=json.dumps({ + 'type': 'customPostAgg', + 'name': 'aCustomPostAgg', + 'field': { + 'type': 'fieldAccess', + 'fieldName': 'aCustomMetric'}})), + } + + metrics = ['some_sum'] + all_metrics, post_aggs = DruidDatasource._metrics_and_post_aggs( + metrics, metrics_dict) + + assert all_metrics == ['some_sum'] + assert post_aggs == {} + + metrics = ['quantile_p95'] + all_metrics, post_aggs = DruidDatasource._metrics_and_post_aggs( + metrics, metrics_dict) + + result_postaggs = set(['quantile_p95']) + assert all_metrics == ['a_histogram'] + assert set(post_aggs.keys()) == result_postaggs + + metrics = ['aCustomPostAgg'] + all_metrics, post_aggs = DruidDatasource._metrics_and_post_aggs( + metrics, metrics_dict) + + result_postaggs = set(['aCustomPostAgg']) + assert all_metrics == ['aCustomMetric'] + assert set(post_aggs.keys()) == result_postaggs if __name__ == '__main__':