Skip to content

Commit

Permalink
fix(dbt-derived-metrics): Support raw metrics and Jinja syntax (#277)
Browse files Browse the repository at this point in the history
* fix(dbt-derived-metrics): Support raw metrics and Jinja syntax
* Making pylint happy
* Fixing test coverage
* Now making pre-commit happy
  • Loading branch information
Vitor-Avila authored Apr 5, 2024
1 parent 9d904e5 commit 8207993
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/preset_cli/api/clients/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,7 @@ class MetricSchema(PostelSchema):
calculation_method = fields.String()
expression = fields.String()
dialect = fields.String()
skip_parsing = fields.Boolean(allow_none=True)


class MFMetricType(str, Enum):
Expand Down
38 changes: 35 additions & 3 deletions src/preset_cli/cli/superset/sync/dbt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@

import json
import logging
import re
from collections import defaultdict
from typing import Dict, List, Optional, Set

import sqlglot
from sqlglot import Expression, exp, parse_one
from sqlglot import Expression, ParseError, exp, parse_one
from sqlglot.expressions import (
Alias,
Case,
Expand Down Expand Up @@ -49,6 +50,7 @@
}


# pylint: disable=too-many-locals
def get_metric_expression(metric_name: str, metrics: Dict[str, MetricSchema]) -> str:
"""
Return a SQL expression for a given dbt metric using sqlglot.
Expand Down Expand Up @@ -87,7 +89,19 @@ def get_metric_expression(metric_name: str, metrics: Dict[str, MetricSchema]) ->
return f"COUNT(DISTINCT {sql})"

if type_ in {"expression", "derived"}:
expression = sqlglot.parse_one(sql, dialect=metric["dialect"])
try:
expression = sqlglot.parse_one(sql, dialect=metric["dialect"])
except ParseError:
for parent_metric in metric["depends_on"]:
parent_metric_name = parent_metric.split(".")[-1]
pattern = r"\b" + re.escape(parent_metric_name) + r"\b"
parent_metric_syntax = get_metric_expression(
parent_metric_name,
metrics,
)
sql = re.sub(pattern, parent_metric_syntax, sql)
return sql

tokens = expression.find_all(exp.Column)

for token in tokens:
Expand Down Expand Up @@ -192,7 +206,11 @@ def get_metric_definition(
kwargs = meta.pop("superset", {})

return {
"expression": get_metric_expression(metric_name, metric_map),
"expression": (
get_metric_expression(metric_name, metric_map)
if not metric.get("skip_parsing")
else metric.get("expression") or metric.get("sql")
),
"metric_name": metric_name,
"metric_type": (metric.get("type") or metric.get("calculation_method")),
"verbose_name": metric.get("label", metric_name),
Expand All @@ -212,6 +230,20 @@ def get_superset_metrics_per_model(
superset_metrics = defaultdict(list)
for metric in og_metrics:
metric_models = get_metric_models(metric["unique_id"], og_metrics)

# dbt supports creating derived metrics with raw syntax
if len(metric_models) == 0:
try:
metric_models.add(metric["meta"]["superset"].pop("model"))
metric["skip_parsing"] = True
except KeyError:
_logger.warning(
"Metric %s cannot be calculated because it's not associated with any model."
" Please specify the model under metric.meta.superset.model.",
metric["name"],
)
continue

if len(metric_models) != 1:
_logger.warning(
"Metric %s cannot be calculated because it depends on multiple models: %s",
Expand Down
133 changes: 132 additions & 1 deletion tests/cli/superset/sync/dbt/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Tests for metrics.
"""

# pylint: disable=line-too-long
# pylint: disable=line-too-long, too-many-lines

from typing import Dict

Expand Down Expand Up @@ -908,3 +908,134 @@ def test_get_superset_metrics_per_model() -> None:
},
],
}


def test_get_superset_metrics_per_model_og_derived(
caplog: pytest.CaptureFixture[str],
) -> None:
"""
Tests for the ``get_superset_metrics_per_model`` function
with derived OG metrics.
"""
og_metric_schema = MetricSchema()

og_metrics = [
og_metric_schema.load(
{
"name": "sales",
"unique_id": "sales",
"depends_on": ["orders"],
"calculation_method": "sum",
"expression": "1",
},
),
og_metric_schema.load(
{
"name": "derived_metric_missing_model_info",
"unique_id": "derived_metric_missing_model_info",
"depends_on": [],
"calculation_method": "derived",
"expression": "price_each * 1.2",
},
),
og_metric_schema.load(
{
"name": "derived_metric_model_from_meta",
"unique_id": "derived_metric_model_from_meta",
"depends_on": [],
"calculation_method": "derived",
"expression": "(SUM(price_each)) * 1.2",
"meta": {"superset": {"model": "customers"}},
},
),
og_metric_schema.load(
{
"name": "derived_metric_with_jinja",
"unique_id": "derived_metric_with_jinja",
"depends_on": [],
"calculation_method": "derived",
"expression": """
SUM(
{% for x in filter_values('x_values') %}
{{ + x_values }}
{% endfor %}
)
""",
"meta": {"superset": {"model": "customers"}},
},
),
og_metric_schema.load(
{
"name": "derived_metric_with_jinja_and_other_metric",
"unique_id": "derived_metric_with_jinja_and_other_metric",
"depends_on": ["sales"],
"dialect": "postgres",
"calculation_method": "derived",
"expression": """
SUM(
{% for x in filter_values('x_values') %}
{{ my_sales + sales }}
{% endfor %}
)
""",
},
),
]

result = get_superset_metrics_per_model(og_metrics, [])
output_content = caplog.text
assert (
"Metric derived_metric_missing_model_info cannot be calculated because it's not associated with any model"
in output_content
)

assert result == {
"customers": [
{
"expression": "(SUM(price_each)) * 1.2",
"metric_name": "derived_metric_model_from_meta",
"metric_type": "derived",
"verbose_name": "derived_metric_model_from_meta",
"description": "",
"extra": "{}",
},
{
"expression": """
SUM(
{% for x in filter_values('x_values') %}
{{ + x_values }}
{% endfor %}
)
""",
"metric_name": "derived_metric_with_jinja",
"metric_type": "derived",
"verbose_name": "derived_metric_with_jinja",
"description": "",
"extra": "{}",
},
],
"orders": [
{
"description": "",
"expression": "SUM(1)",
"extra": "{}",
"metric_name": "sales",
"metric_type": "sum",
"verbose_name": "sales",
},
{
"expression": """
SUM(
{% for x in filter_values('x_values') %}
{{ my_sales + SUM(1) }}
{% endfor %}
)
""",
"metric_name": "derived_metric_with_jinja_and_other_metric",
"metric_type": "derived",
"verbose_name": "derived_metric_with_jinja_and_other_metric",
"description": "",
"extra": "{}",
},
],
}

0 comments on commit 8207993

Please sign in to comment.