Skip to content

Commit

Permalink
fix: Handling of column types for Presto, Trino, et al. (apache#28653)
Browse files Browse the repository at this point in the history
  • Loading branch information
john-bodley authored and michael-s-molina committed May 30, 2024
1 parent f5224f9 commit 0f27d61
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 34 deletions.
6 changes: 5 additions & 1 deletion superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def get_children(column: ResultSetColumnType) -> list[ResultSetColumnType]:
pattern = re.compile(r"(?P<type>\w+)\((?P<children>.*)\)")
if not column["type"]:
raise ValueError
match = pattern.match(column["type"])
match = pattern.match(cast(str, column["type"]))
if not match:
raise Exception( # pylint: disable=broad-exception-raised
f"Unable to parse column type {column['type']}"
Expand Down Expand Up @@ -509,6 +509,10 @@ def where_latest_partition( # pylint: disable=too-many-arguments
for col_name, value in zip(col_names, values):
col_type = column_type_by_name.get(col_name)

if isinstance(col_type, str):
col_type_class = getattr(types, col_type, None)
col_type = col_type_class() if col_type_class else None

if isinstance(col_type, types.DATE):
col_type = Date()
elif isinstance(col_type, types.TIMESTAMP):
Expand Down
5 changes: 4 additions & 1 deletion superset/superset_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
from datetime import datetime
from typing import Any, Literal, Optional, TYPE_CHECKING, TypedDict, Union

from sqlalchemy.sql.type_api import TypeEngine
from typing_extensions import NotRequired
from werkzeug.wrappers import Response

if TYPE_CHECKING:
from superset.utils.core import GenericDataType

SQLType = Union[TypeEngine, type[TypeEngine]]


class LegacyMetric(TypedDict):
label: Optional[str]
Expand Down Expand Up @@ -73,7 +76,7 @@ class ResultSetColumnType(TypedDict):

name: str # legacy naming convention keeping this for backwards compatibility
column_name: str
type: Optional[str]
type: Optional[Union[SQLType, str]]
is_dttm: Optional[bool]
type_generic: NotRequired[Optional["GenericDataType"]]

Expand Down
59 changes: 30 additions & 29 deletions tests/unit_tests/db_engine_specs/test_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from sqlalchemy import sql, text, types
from sqlalchemy.engine.url import make_url

from superset.superset_typing import ResultSetColumnType
from superset.sql_parse import Table
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
assert_column_spec,
Expand Down Expand Up @@ -115,39 +115,40 @@ def test_get_schema_from_engine_params() -> None:
@pytest.mark.parametrize(
["column_type", "column_value", "expected_value"],
[
(types.DATE(), "2023-05-01", "DATE '2023-05-01'"),
(types.TIMESTAMP(), "2023-05-01", "TIMESTAMP '2023-05-01'"),
(types.VARCHAR(), "2023-05-01", "'2023-05-01'"),
(types.INT(), 1234, "1234"),
("DATE", "2023-05-01", "DATE '2023-05-01'"),
("TIMESTAMP", "2023-05-01", "TIMESTAMP '2023-05-01'"),
("VARCHAR", "2023-05-01", "'2023-05-01'"),
("INT", 1234, "1234"),
],
)
def test_where_latest_partition(
mock_latest_partition, column_type, column_value: Any, expected_value: str
mock_latest_partition,
column_type: str,
column_value: Any,
expected_value: str,
) -> None:
"""
Test the ``where_latest_partition`` method
"""
from superset.db_engine_specs.presto import PrestoEngineSpec as spec
from superset.db_engine_specs.presto import PrestoEngineSpec

mock_latest_partition.return_value = (["partition_key"], [column_value])

query = sql.select(text("* FROM table"))
columns: list[ResultSetColumnType] = [
{
"column_name": "partition_key",
"name": "partition_key",
"type": column_type,
"is_dttm": False,
}
]

expected = f"""SELECT * FROM table \nWHERE "partition_key" = {expected_value}"""
result = spec.where_latest_partition(
"table", mock.MagicMock(), mock.MagicMock(), query, columns
)
assert result is not None
actual = result.compile(
dialect=PrestoDialect(), compile_kwargs={"literal_binds": True}
assert (
str(
PrestoEngineSpec.where_latest_partition( # type: ignore
database=mock.MagicMock(),
table=Table("table"),
query=sql.select(text("* FROM table")),
columns=[
{
"column_name": "partition_key",
"name": "partition_key",
"type": column_type,
"is_dttm": False,
}
],
).compile(
dialect=PrestoDialect(),
compile_kwargs={"literal_binds": True},
)
)
== f"""SELECT * FROM table \nWHERE "partition_key" = {expected_value}"""
)

assert str(actual) == expected
51 changes: 48 additions & 3 deletions tests/unit_tests/db_engine_specs/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@
import json
from datetime import datetime
from typing import Any, Optional
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch

import pandas as pd
import pytest
from pytest_mock import MockerFixture
from requests.exceptions import ConnectionError as RequestsConnectionError
from sqlalchemy import types
from sqlalchemy import sql, text, types
from trino.exceptions import TrinoExternalError, TrinoInternalError, TrinoUserError
from trino.sqlalchemy import datatype
from trino.sqlalchemy.dialect import TrinoDialect

import superset.config
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT
Expand All @@ -37,7 +38,8 @@
SupersetDBAPIOperationalError,
SupersetDBAPIProgrammingError,
)
from superset.superset_typing import ResultSetColumnType, SQLAColumnType
from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType, SQLAColumnType, SQLType
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
assert_column_spec,
Expand Down Expand Up @@ -548,3 +550,46 @@ def test_get_dbapi_exception_mapping():
assert mapping.get(TrinoExternalError) == SupersetDBAPIOperationalError
assert mapping.get(RequestsConnectionError) == SupersetDBAPIConnectionError
assert mapping.get(Exception) is None


@patch("superset.db_engine_specs.trino.TrinoEngineSpec.latest_partition")
@pytest.mark.parametrize(
["column_type", "column_value", "expected_value"],
[
(types.DATE(), "2023-05-01", "DATE '2023-05-01'"),
(types.TIMESTAMP(), "2023-05-01", "TIMESTAMP '2023-05-01'"),
(types.VARCHAR(), "2023-05-01", "'2023-05-01'"),
(types.INT(), 1234, "1234"),
],
)
def test_where_latest_partition(
mock_latest_partition,
column_type: SQLType,
column_value: Any,
expected_value: str,
) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec

mock_latest_partition.return_value = (["partition_key"], [column_value])

assert (
str(
TrinoEngineSpec.where_latest_partition( # type: ignore
database=MagicMock(),
table=Table("table"),
query=sql.select(text("* FROM table")),
columns=[
{
"column_name": "partition_key",
"name": "partition_key",
"type": column_type,
"is_dttm": False,
}
],
).compile(
dialect=TrinoDialect(),
compile_kwargs={"literal_binds": True},
)
)
== f"""SELECT * FROM table \nWHERE partition_key = {expected_value}"""
)

0 comments on commit 0f27d61

Please sign in to comment.