Skip to content

Commit

Permalink
fix: Remove BASE_AXIS from pre-query (#29084)
Browse files Browse the repository at this point in the history
  • Loading branch information
john-bodley authored and eschutho committed Jul 24, 2024
1 parent cece3bf commit 9163886
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 17 deletions.
Binary file added null_byte.csv
Binary file not shown.
10 changes: 5 additions & 5 deletions superset/common/query_context_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
get_column_names_from_columns,
get_column_names_from_metrics,
get_metric_names,
get_xaxis_label,
get_x_axis_label,
normalize_dttm_col,
TIME_COMPARISON,
)
Expand Down Expand Up @@ -399,7 +399,7 @@ def processing_time_offsets( # pylint: disable=too-many-locals,too-many-stateme
for offset in query_object.time_offsets:
try:
# pylint: disable=line-too-long
# Since the xaxis is also a column name for the time filter, xaxis_label will be set as granularity
# Since the x-axis is also a column name for the time filter, x_axis_label will be set as granularity
# these query object are equivalent:
# 1) { granularity: 'dttm_col', time_range: '2020 : 2021', time_offsets: ['1 year ago']}
# 2) { columns: [
Expand All @@ -414,9 +414,9 @@ def processing_time_offsets( # pylint: disable=too-many-locals,too-many-stateme
)
query_object_clone.to_dttm = get_past_or_future(offset, outer_to_dttm)

xaxis_label = get_xaxis_label(query_object.columns)
x_axis_label = get_x_axis_label(query_object.columns)
query_object_clone.granularity = (
query_object_clone.granularity or xaxis_label
query_object_clone.granularity or x_axis_label
)
except ValueError as ex:
raise QueryObjectValidationError(str(ex)) from ex
Expand Down Expand Up @@ -450,7 +450,7 @@ def processing_time_offsets( # pylint: disable=too-many-locals,too-many-stateme
query_object_clone.filter = [
flt
for flt in query_object_clone.filter
if flt.get("col") != xaxis_label
if flt.get("col") != x_axis_label
]

# `offset` is added to the hash function
Expand Down
6 changes: 3 additions & 3 deletions superset/common/query_object_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
DatasourceDict,
DatasourceType,
FilterOperator,
get_xaxis_label,
get_x_axis_label,
QueryObjectFilterClause,
)

Expand Down Expand Up @@ -122,9 +122,9 @@ def _process_time_range(
# Use the temporal filter as the time range.
# if the temporal filters uses x-axis as the temporal filter
# then use it or use the first temporal filter
xaxis_label = get_xaxis_label(columns or [])
x_axis_label = get_x_axis_label(columns)
match_flt = [
flt for flt in temporal_flt if flt.get("col") == xaxis_label
flt for flt in temporal_flt if flt.get("col") == x_axis_label
]
if match_flt:
time_range = cast(str, match_flt[0].get("val"))
Expand Down
2 changes: 1 addition & 1 deletion superset/common/utils/time_range_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_since_until_from_query_object(
"""
this function will return since and until by tuple if
1) the time_range is in the query object.
2) the xaxis column is in the columns field
2) the x-axis column is in the columns field
and its corresponding `temporal_range` filter is in the adhoc filters.
:param query_object: a valid query object
:return: since and until by tuple
Expand Down
3 changes: 2 additions & 1 deletion superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
from superset.utils.core import (
GenericDataType,
get_column_name,
get_non_base_axis_columns,
get_user_id,
is_adhoc_column,
MediumText,
Expand Down Expand Up @@ -2083,7 +2084,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
"filter": filter,
"orderby": orderby,
"extras": extras,
"columns": columns,
"columns": get_non_base_axis_columns(columns),
"order_desc": True,
}

Expand Down
21 changes: 14 additions & 7 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,16 +1056,23 @@ def is_adhoc_column(column: Column) -> TypeGuard[AdhocColumn]:
)


def is_base_axis(column: Column) -> bool:
return is_adhoc_column(column) and column.get("columnType") == "BASE_AXIS"


def get_base_axis_columns(columns: list[Column] | None) -> list[Column]:
return [column for column in columns or [] if is_base_axis(column)]


def get_non_base_axis_columns(columns: list[Column] | None) -> list[Column]:
return [column for column in columns or [] if not is_base_axis(column)]


def get_base_axis_labels(columns: list[Column] | None) -> tuple[str, ...]:
axis_cols = [
col
for col in columns or []
if is_adhoc_column(col) and col.get("columnType") == "BASE_AXIS"
]
return tuple(get_column_name(col) for col in axis_cols)
return tuple(get_column_name(column) for column in get_base_axis_columns(columns))


def get_xaxis_label(columns: list[Column] | None) -> str | None:
def get_x_axis_label(columns: list[Column] | None) -> str | None:
labels = get_base_axis_labels(columns)
return labels[0] if labels else None

Expand Down

0 comments on commit 9163886

Please sign in to comment.