Skip to content

Commit

Permalink
Make removal of unbound expressions from GROUP BY optional based on d…
Browse files Browse the repository at this point in the history
…ialect (#290)
  • Loading branch information
KonstantAnxiety authored and github-actions[bot] committed Feb 7, 2024
1 parent 259d2f9 commit e13ff23
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 8 deletions.
5 changes: 5 additions & 0 deletions lib/dl_connector_ydb/dl_connector_ydb/formula/connector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ydb.sqlalchemy import YqlDialect as SAYqlDialect

from dl_formula.connectors.base.connector import FormulaConnector
from dl_query_processing.compilation.query_mutator import RemoveConstFromGroupByFormulaAtomicQueryMutator

from dl_connector_ydb.formula.constants import YqlDialect as YqlDialectNS
from dl_connector_ydb.formula.definitions.all import DEFINITIONS
Expand All @@ -11,3 +12,7 @@ class YQLFormulaConnector(FormulaConnector):
dialects = YqlDialectNS.YQL
op_definitions = DEFINITIONS
sa_dialect = SAYqlDialect()

@classmethod
def registration_hook(cls) -> None:
RemoveConstFromGroupByFormulaAtomicQueryMutator.register_dialect(cls.dialects)
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import abc
from typing import (
Callable,
ClassVar,
List,
Sequence,
Set,
Expand Down Expand Up @@ -193,6 +194,36 @@ def mutate_formula_list(
return new_formula_list


@attr.s
class RemoveConstFromGroupByFormulaAtomicQueryMutator(AtomicQueryFormulaListMutatorBase):
_applicable_dialects: ClassVar[set[DialectCombo]] = set()

_dialects: DialectCombo = attr.ib()

def match_query(self, compiled_query: CompiledQuery) -> bool:
return True # Apply to all

def mutate_formula_list(
self,
formula_list: List[_COMPILED_FLA_TV],
query_part: QueryPart,
) -> List[_COMPILED_FLA_TV]:
if query_part != QueryPart.group_by or self._dialects not in self._applicable_dialects:
return formula_list

new_group_by_formula_list = [
group_by_item
for group_by_item in formula_list
if not is_bound_only_to(group_by_item.formula_obj, NodeSet())
]
return new_group_by_formula_list

@classmethod
def register_dialect(cls, dialects: DialectCombo) -> None:
for dialect in dialects.to_list(with_self=True):
cls._applicable_dialects.add(dialect)


def contains_inconsistent_aggregations(
node: formula_nodes.FormulaItem,
dimensions: list[formula_nodes.FormulaItem],
Expand Down Expand Up @@ -271,14 +302,8 @@ def mutate_query(self, compiled_query: CompiledQuery) -> CompiledQuery:
)
compiled_query = mutator.mutate_query(compiled_query)

# Don't group by unbound expressions (ones that don't refer to source fields)
compiled_query = compiled_query.clone(
group_by=[
group_by_item
for group_by_item in compiled_query.group_by
if not is_bound_only_to(group_by_item.formula_obj, NodeSet())
]
)
group_by_xonst_mutator = RemoveConstFromGroupByFormulaAtomicQueryMutator(self._dialect)
compiled_query = group_by_xonst_mutator.mutate_query(compiled_query)

filter_mutator = IgnoreFormulaAtomicQueryMutator(ignore_formula_checks=[formula_is_true])
compiled_query = filter_mutator.mutate_query(compiled_query)
Expand Down

0 comments on commit e13ff23

Please sign in to comment.