Skip to content

Commit

Permalink
fix: Convert datetime object to a SQL expression string
Browse files Browse the repository at this point in the history
  • Loading branch information
john-bodley committed May 16, 2023
1 parent 6baa552 commit 3e36553
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 57 deletions.
45 changes: 0 additions & 45 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,21 +330,6 @@ def get_sqla_col(
def datasource(self) -> RelationshipProperty:
return self.table

def get_time_filter(
self,
start_dttm: Optional[DateTime] = None,
end_dttm: Optional[DateTime] = None,
label: Optional[str] = "__time",
template_processor: Optional[BaseTemplateProcessor] = None,
) -> ColumnElement:
col = self.get_sqla_col(label=label, template_processor=template_processor)
l = []
if start_dttm:
l.append(col >= self.table.text(self.dttm_sql_literal(start_dttm)))
if end_dttm:
l.append(col < self.table.text(self.dttm_sql_literal(end_dttm)))
return and_(*l)

def get_timestamp_expression(
self,
time_grain: Optional[str],
Expand Down Expand Up @@ -379,36 +364,6 @@ def get_timestamp_expression(
time_expr = self.db_engine_spec.get_timestamp_expr(col, pdf, time_grain)
return self.table.make_sqla_column_compatible(time_expr, label)

def dttm_sql_literal(self, dttm: DateTime) -> str:
"""Convert datetime object to a SQL expression string"""
sql = (
self.db_engine_spec.convert_dttm(self.type, dttm, db_extra=self.db_extra)
if self.type
else None
)

if sql:
return sql

tf = self.python_date_format

# Fallback to the default format (if defined).
if not tf:
tf = self.db_extra.get("python_date_format_by_column_name", {}).get(
self.column_name
)

if tf:
if tf in ["epoch_ms", "epoch_s"]:
seconds_since_epoch = int(dttm.timestamp())
if tf == "epoch_s":
return str(seconds_since_epoch)
return str(seconds_since_epoch * 1000)
return f"'{dttm.strftime(tf)}'"

# TODO(john-bodley): SIP-15 will explicitly require a type conversion.
return f"""'{dttm.strftime("%Y-%m-%d %H:%M:%S.%f")}'"""

@property
def data(self) -> Dict[str, Any]:
attrs = (
Expand Down
15 changes: 5 additions & 10 deletions superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1269,17 +1269,12 @@ def _get_top_groups(

return or_(*groups)

def dttm_sql_literal(
self,
col: "TableColumn",
dttm: sa.DateTime,
col_type: Optional[str],
) -> str:
def dttm_sql_literal(self, dttm: datetime, col: "TableColumn") -> str:
"""Convert datetime object to a SQL expression string"""

sql = (
self.db_engine_spec.convert_dttm(col_type, dttm, db_extra=None)
if col_type
self.db_engine_spec.convert_dttm(col.type, dttm, db_extra=self.db_extra)
if col.type
else None
)

Expand Down Expand Up @@ -1330,14 +1325,14 @@ def get_time_filter( # pylint: disable=too-many-arguments
l.append(
col
>= self.db_engine_spec.get_text_clause(
self.dttm_sql_literal(time_col, start_dttm, time_col.type)
self.dttm_sql_literal(start_dttm, time_col)
)
)
if end_dttm:
l.append(
col
< self.db_engine_spec.get_text_clause(
self.dttm_sql_literal(time_col, end_dttm, time_col.type)
self.dttm_sql_literal(end_dttm, time_col)
)
)
return and_(*l)
Expand Down
66 changes: 65 additions & 1 deletion tests/unit_tests/models/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@
# under the License.

# pylint: disable=import-outside-toplevel

import json
from datetime import datetime
from typing import List, Optional

import pytest
from pytest_mock import MockFixture
from sqlalchemy.engine.reflection import Inspector

from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.models.core import Database


def test_get_metrics(mocker: MockFixture) -> None:
"""
Expand Down Expand Up @@ -143,3 +148,62 @@ class OldDBEngineSpec(BaseEngineSpec):
).db_engine_spec
== OldDBEngineSpec
)


@pytest.mark.parametrize(
"dttm,col,database,result",
[
(
datetime(2023, 1, 1, 1, 23, 45, 600000),
TableColumn(python_date_format="epoch_s"),
Database(),
"1672536225",
),
(
datetime(2023, 1, 1, 1, 23, 45, 600000),
TableColumn(python_date_format="epoch_ms"),
Database(),
"1672536225000",
),
(
datetime(2023, 1, 1, 1, 23, 45, 600000),
TableColumn(python_date_format="%Y-%m-%d"),
Database(),
"'2023-01-01'",
),
(
datetime(2023, 1, 1, 1, 23, 45, 600000),
TableColumn(column_name="ds"),
Database(
extra=json.dumps(
{
"python_date_format_by_column_name": {
"ds": "%Y-%m-%d",
},
},
),
sqlalchemy_uri="foo://",
),
"'2023-01-01'",
),
(
datetime(2023, 1, 1, 1, 23, 45, 600000),
TableColumn(),
Database(sqlalchemy_uri="foo://"),
"'2023-01-01 01:23:45.600000'",
),
(
datetime(2023, 1, 1, 1, 23, 45, 600000),
TableColumn(type="TimeStamp"),
Database(sqlalchemy_uri="trino://"),
"TIMESTAMP '2023-01-01 01:23:45.600000'",
),
],
)
def test_dttm_sql_literal(
dttm: datetime,
col: TableColumn,
database: Database,
result: str,
) -> None:
assert SqlaTable(database=database).dttm_sql_literal(dttm, col) == result
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# Remember to start celery workers to run celery tests, e.g.
# celery --app=superset.tasks.celery_app:app worker -Ofair -c 2
[testenv]
basepython = python3.8
basepython = python3.10
ignore_basepython_conflict = true
commands =
superset db upgrade
Expand Down

0 comments on commit 3e36553

Please sign in to comment.