diff --git a/superset/jinja_context.py b/superset/jinja_context.py index c159a667ee4ed..13a639df7bd78 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -17,9 +17,11 @@ """Defines the templating context for SQL Lab""" import json import re +from datetime import datetime from functools import lru_cache, partial from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypedDict, Union +import dateutil from flask import current_app, g, has_request_context, request from flask_babel import gettext as _ from jinja2 import DebugUndefined @@ -486,6 +488,19 @@ def process_template(self, sql: str, **kwargs: Any) -> str: class JinjaTemplateProcessor(BaseTemplateProcessor): + def _parse_datetime(self, dttm: str) -> Optional[datetime]: + """ + Try to parse a datetime and default to None in the worst case. + + Since this may have been rendered by different engines, the datetime may + vary slightly in format. We try to make it consistent, and if all else + fails, just return None. + """ + try: + return dateutil.parser.parse(dttm) + except dateutil.parser.ParserError: + return None + def set_context(self, **kwargs: Any) -> None: super().set_context(**kwargs) extra_cache = ExtraCache( @@ -494,6 +509,23 @@ def set_context(self, **kwargs: Any) -> None: removed_filters=self._removed_filters, dialect=self._database.get_dialect(), ) + + from_dttm = ( + self._parse_datetime(dttm) + if (dttm := self._context.get("from_dttm")) + else None + ) + to_dttm = ( + self._parse_datetime(dttm) + if (dttm := self._context.get("to_dttm")) + else None + ) + + dataset_macro_with_context = partial( + dataset_macro, + from_dttm=from_dttm, + to_dttm=to_dttm, + ) self._context.update( { "url_param": partial(safe_proxy, extra_cache.url_param), @@ -502,7 +534,7 @@ def set_context(self, **kwargs: Any) -> None: "cache_key_wrapper": partial(safe_proxy, extra_cache.cache_key_wrapper), "filter_values": partial(safe_proxy, extra_cache.filter_values), "get_filters": partial(safe_proxy, extra_cache.get_filters), - "dataset": partial(safe_proxy, dataset_macro), + "dataset": partial(safe_proxy, dataset_macro_with_context), } ) @@ -638,12 +670,18 @@ def dataset_macro( dataset_id: int, include_metrics: bool = False, columns: Optional[list[str]] = None, + from_dttm: Optional[datetime] = None, + to_dttm: Optional[datetime] = None, ) -> str: """ Given a dataset ID, return the SQL that represents it. The generated SQL includes all columns (including computed) by default. Optionally the user can also request metrics to be included, and columns to group by. + + The from_dttm and to_dttm parameters are filled in from filter values in explore + views, and we take them to make those properties available to jinja templates in + the underlying dataset. """ # pylint: disable=import-outside-toplevel from superset.daos.dataset import DatasetDAO @@ -659,6 +697,8 @@ def dataset_macro( "filter": [], "metrics": metrics if include_metrics else None, "columns": columns, + "from_dttm": from_dttm, + "to_dttm": to_dttm, } sqla_query = dataset.get_query_str_extended(query_obj, mutate=False) sql = sqla_query.sql