Skip to content

Commit

Permalink
fix: add get_column function for Query obj (#21691)
Browse files Browse the repository at this point in the history
  • Loading branch information
hughhhh authored Oct 5, 2022
1 parent 7b66e0b commit 51c54b3
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
3 changes: 2 additions & 1 deletion superset/common/query_context_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,8 @@ def _get_timestamp_format(
# Query datasource didn't support `get_column`
and hasattr(datasource, "get_column")
and (col := datasource.get_column(label))
and col.is_dttm
# todo(hugh) standardize column object in Query datasource
and (col.get("is_dttm") if isinstance(col, dict) else col.is_dttm)
)
dttm_cols = [
DateColumn(
Expand Down
15 changes: 11 additions & 4 deletions superset/models/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
)
from superset.sql_parse import CtasMethod, ParsedQuery, Table
from superset.sqllab.limiting_factor import LimitingFactor
from superset.superset_typing import ResultSetColumnType
from superset.utils.core import GenericDataType, QueryStatus, user_label

if TYPE_CHECKING:
Expand Down Expand Up @@ -183,7 +182,7 @@ def sql_tables(self) -> List[Table]:
return list(ParsedQuery(self.sql).tables)

@property
def columns(self) -> List[ResultSetColumnType]:
def columns(self) -> List[Dict[str, Any]]:
bool_types = ("BOOL",)
num_types = (
"DOUBLE",
Expand Down Expand Up @@ -217,7 +216,7 @@ def columns(self) -> List[ResultSetColumnType]:
computed_column["column_name"] = col.get("name")
computed_column["groupby"] = True
columns.append(computed_column)
return columns # type: ignore
return columns

@property
def data(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -288,7 +287,7 @@ def offset(self) -> int:
def main_dttm_col(self) -> Optional[str]:
for col in self.columns:
if col.get("is_dttm"):
return col.get("column_name") # type: ignore
return col.get("column_name")
return None

@property
Expand Down Expand Up @@ -332,6 +331,14 @@ def tracking_url(self) -> Optional[str]:
def tracking_url(self, value: str) -> None:
self.tracking_url_raw = value

def get_column(self, column_name: Optional[str]) -> Optional[Dict[str, Any]]:
if not column_name:
return None
for col in self.columns:
if col.get("column_name") == column_name:
return col
return None


class SavedQuery(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
"""ORM model for SQL query"""
Expand Down

0 comments on commit 51c54b3

Please sign in to comment.