Skip to content

Commit

Permalink
Refs #28333 -- Added partial support for filtering against window fun…
Browse files Browse the repository at this point in the history
…ctions.

Adds support for joint predicates against window annotations through
subquery wrapping while maintaining errors for disjointed filter
attempts.

The "qualify" wording was used to refer to predicates against window
annotations as it's the name of a specialized Snowflake extension to
SQL that is to window functions what HAVING is to aggregates.

While not complete the implementation should cover most of the common
use cases for filtering against window functions without requiring
the complex subquery pushdown and predicate re-aliasing machinery to
deal with disjointed predicates against columns, aggregates, and window
functions.

A complete disjointed filtering implementation should likely be
deferred until proper QUALIFY support lands or the ORM gains a proper
subquery pushdown interface.
  • Loading branch information
charettes authored and felixxm committed Aug 15, 2022
1 parent f3f9d03 commit f387d02
Show file tree
Hide file tree
Showing 9 changed files with 486 additions and 64 deletions.
11 changes: 7 additions & 4 deletions django/db/backends/mysql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@ def as_sql(self):
# the SQLDeleteCompiler's default implementation when multiple tables
# are involved since MySQL/MariaDB will generate a more efficient query
# plan than when using a subquery.
where, having = self.query.where.split_having()
if self.single_alias or having:
# DELETE FROM cannot be used when filtering against aggregates
# since it doesn't allow for GROUP BY and HAVING clauses.
where, having, qualify = self.query.where.split_having_qualify(
must_group_by=self.query.group_by is not None
)
if self.single_alias or having or qualify:
# DELETE FROM cannot be used when filtering against aggregates or
# window functions as it doesn't allow for GROUP BY/HAVING clauses
# and the subquery wrapping (necessary to emulate QUALIFY).
return super().as_sql()
result = [
"DELETE %s FROM"
Expand Down
15 changes: 13 additions & 2 deletions django/db/models/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,7 @@ class ResolvedOuterRef(F):
"""

contains_aggregate = False
contains_over_clause = False

def as_sql(self, *args, **kwargs):
raise ValueError(
Expand Down Expand Up @@ -1210,6 +1211,12 @@ def as_sql(self, *args, **kwargs):
return "", ()
return super().as_sql(*args, **kwargs)

def get_group_by_cols(self):
group_by_cols = []
for order_by in self.get_source_expressions():
group_by_cols.extend(order_by.get_group_by_cols())
return group_by_cols


@deconstructible(path="django.db.models.ExpressionWrapper")
class ExpressionWrapper(SQLiteNumericMixin, Expression):
Expand Down Expand Up @@ -1631,7 +1638,6 @@ class Window(SQLiteNumericMixin, Expression):
# be introduced in the query as a result is not desired.
contains_aggregate = False
contains_over_clause = True
filterable = False

def __init__(
self,
Expand Down Expand Up @@ -1733,7 +1739,12 @@ def __repr__(self):
return "<%s: %s>" % (self.__class__.__name__, self)

def get_group_by_cols(self, alias=None):
return []
group_by_cols = []
if self.partition_by:
group_by_cols.extend(self.partition_by.get_group_by_cols())
if self.order_by is not None:
group_by_cols.extend(self.order_by.get_group_by_cols())
return group_by_cols


class WindowFrame(Expression):
Expand Down
1 change: 1 addition & 0 deletions django/db/models/fields/related_lookups.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

class MultiColSource:
contains_aggregate = False
contains_over_clause = False

def __init__(self, alias, targets, sources, field):
self.targets, self.sources, self.field, self.alias = (
Expand Down
76 changes: 75 additions & 1 deletion django/db/models/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value
from django.db.models.functions import Cast, Random
from django.db.models.lookups import Lookup
from django.db.models.query_utils import select_related_descend
from django.db.models.sql.constants import (
CURSOR,
Expand Down Expand Up @@ -73,7 +74,9 @@ def pre_sql_setup(self, with_col_aliases=False):
"""
self.setup_query(with_col_aliases=with_col_aliases)
order_by = self.get_order_by()
self.where, self.having = self.query.where.split_having()
self.where, self.having, self.qualify = self.query.where.split_having_qualify(
must_group_by=self.query.group_by is not None
)
extra_select = self.get_extra_select(order_by, self.select)
self.has_extra_select = bool(extra_select)
group_by = self.get_group_by(self.select + extra_select, order_by)
Expand Down Expand Up @@ -584,6 +587,74 @@ def get_combinator_sql(self, combinator, all):
params.extend(part)
return result, params

def get_qualify_sql(self):
where_parts = []
if self.where:
where_parts.append(self.where)
if self.having:
where_parts.append(self.having)
inner_query = self.query.clone()
inner_query.subquery = True
inner_query.where = inner_query.where.__class__(where_parts)
# Augment the inner query with any window function references that
# might have been masked via values() and alias(). If any masked
# aliases are added they'll be masked again to avoid fetching
# the data in the `if qual_aliases` branch below.
select = {
expr: alias for expr, _, alias in self.get_select(with_col_aliases=True)[0]
}
qual_aliases = set()
replacements = {}
expressions = list(self.qualify.leaves())
while expressions:
expr = expressions.pop()
if select_alias := (select.get(expr) or replacements.get(expr)):
replacements[expr] = select_alias
elif isinstance(expr, Lookup):
expressions.extend(expr.get_source_expressions())
else:
num_qual_alias = len(qual_aliases)
select_alias = f"qual{num_qual_alias}"
qual_aliases.add(select_alias)
inner_query.add_annotation(expr, select_alias)
replacements[expr] = select_alias
self.qualify = self.qualify.replace_expressions(
{expr: Ref(alias, expr) for expr, alias in replacements.items()}
)
inner_query_compiler = inner_query.get_compiler(
self.using, elide_empty=self.elide_empty
)
inner_sql, inner_params = inner_query_compiler.as_sql(
# The limits must be applied to the outer query to avoid pruning
# results too eagerly.
with_limits=False,
# Force unique aliasing of selected columns to avoid collisions
# and make rhs predicates referencing easier.
with_col_aliases=True,
)
qualify_sql, qualify_params = self.compile(self.qualify)
result = [
"SELECT * FROM (",
inner_sql,
")",
self.connection.ops.quote_name("qualify"),
"WHERE",
qualify_sql,
]
if qual_aliases:
# If some select aliases were unmasked for filtering purposes they
# must be masked back.
cols = [self.connection.ops.quote_name(alias) for alias in select.values()]
result = [
"SELECT",
", ".join(cols),
"FROM (",
*result,
")",
self.connection.ops.quote_name("qualify_mask"),
]
return result, list(inner_params) + qualify_params

def as_sql(self, with_limits=True, with_col_aliases=False):
"""
Create the SQL for this query. Return the SQL string and list of
Expand Down Expand Up @@ -614,6 +685,9 @@ def as_sql(self, with_limits=True, with_col_aliases=False):
result, params = self.get_combinator_sql(
combinator, self.query.combinator_all
)
elif self.qualify:
result, params = self.get_qualify_sql()
order_by = None
else:
distinct_fields, distinct_params = self.get_distinct()
# This must come after 'select', 'ordering', and 'distinct'
Expand Down
95 changes: 75 additions & 20 deletions django/db/models/sql/where.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,48 +35,81 @@ class WhereNode(tree.Node):
resolved = False
conditional = True

def split_having(self, negated=False):
def split_having_qualify(self, negated=False, must_group_by=False):
"""
Return two possibly None nodes: one for those parts of self that
should be included in the WHERE clause and one for those parts of
self that must be included in the HAVING clause.
Return three possibly None nodes: one for those parts of self that
should be included in the WHERE clause, one for those parts of self
that must be included in the HAVING clause, and one for those parts
that refer to window functions.
"""
if not self.contains_aggregate:
return self, None
if not self.contains_aggregate and not self.contains_over_clause:
return self, None, None
in_negated = negated ^ self.negated
# If the effective connector is OR or XOR and this node contains an
# aggregate, then we need to push the whole branch to HAVING clause.
may_need_split = (
# Whether or not children must be connected in the same filtering
# clause (WHERE > HAVING > QUALIFY) to maintain logical semantic.
must_remain_connected = (
(in_negated and self.connector == AND)
or (not in_negated and self.connector == OR)
or self.connector == XOR
)
if may_need_split and self.contains_aggregate:
return None, self
if (
must_remain_connected
and self.contains_aggregate
and not self.contains_over_clause
):
# It's must cheaper to short-circuit and stash everything in the
# HAVING clause than split children if possible.
return None, self, None
where_parts = []
having_parts = []
qualify_parts = []
for c in self.children:
if hasattr(c, "split_having"):
where_part, having_part = c.split_having(in_negated)
if hasattr(c, "split_having_qualify"):
where_part, having_part, qualify_part = c.split_having_qualify(
in_negated, must_group_by
)
if where_part is not None:
where_parts.append(where_part)
if having_part is not None:
having_parts.append(having_part)
if qualify_part is not None:
qualify_parts.append(qualify_part)
elif c.contains_over_clause:
qualify_parts.append(c)
elif c.contains_aggregate:
having_parts.append(c)
else:
where_parts.append(c)
if must_remain_connected and qualify_parts:
# Disjunctive heterogeneous predicates can be pushed down to
# qualify as long as no conditional aggregation is involved.
if not where_parts or (where_parts and not must_group_by):
return None, None, self
elif where_parts:
# In theory this should only be enforced when dealing with
# where_parts containing predicates against multi-valued
# relationships that could affect aggregation results but this
# is complex to infer properly.
raise NotImplementedError(
"Heterogeneous disjunctive predicates against window functions are "
"not implemented when performing conditional aggregation."
)
where_node = (
self.create(where_parts, self.connector, self.negated)
if where_parts
else None
)
having_node = (
self.create(having_parts, self.connector, self.negated)
if having_parts
else None
)
where_node = (
self.create(where_parts, self.connector, self.negated)
if where_parts
qualify_node = (
self.create(qualify_parts, self.connector, self.negated)
if qualify_parts
else None
)
return where_node, having_node
return where_node, having_node, qualify_node

def as_sql(self, compiler, connection):
"""
Expand Down Expand Up @@ -183,6 +216,14 @@ def relabeled_clone(self, change_map):
clone.relabel_aliases(change_map)
return clone

def replace_expressions(self, replacements):
if replacement := replacements.get(self):
return replacement
clone = self.create(connector=self.connector, negated=self.negated)
for child in self.children:
clone.children.append(child.replace_expressions(replacements))
return clone

@classmethod
def _contains_aggregate(cls, obj):
if isinstance(obj, tree.Node):
Expand Down Expand Up @@ -231,6 +272,10 @@ def output_field(self):

return BooleanField()

@property
def _output_field_or_none(self):
return self.output_field

def select_format(self, compiler, sql, params):
# Wrap filters with a CASE WHEN expression if a database backend
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
Expand All @@ -245,19 +290,28 @@ def get_db_converters(self, connection):
def get_lookup(self, lookup):
return self.output_field.get_lookup(lookup)

def leaves(self):
for child in self.children:
if isinstance(child, WhereNode):
yield from child.leaves()
else:
yield child


class NothingNode:
"""A node that matches nothing."""

contains_aggregate = False
contains_over_clause = False

def as_sql(self, compiler=None, connection=None):
raise EmptyResultSet


class ExtraWhere:
# The contents are a black box - assume no aggregates are used.
# The contents are a black box - assume no aggregates or windows are used.
contains_aggregate = False
contains_over_clause = False

def __init__(self, sqls, params):
self.sqls = sqls
Expand All @@ -269,9 +323,10 @@ def as_sql(self, compiler=None, connection=None):


class SubqueryConstraint:
# Even if aggregates would be used in a subquery, the outer query isn't
# interested about those.
# Even if aggregates or windows would be used in a subquery,
# the outer query isn't interested about those.
contains_aggregate = False
contains_over_clause = False

def __init__(self, alias, columns, targets, query_object):
self.alias = alias
Expand Down
31 changes: 25 additions & 6 deletions docs/ref/models/expressions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -741,12 +741,6 @@ instead they are part of the selected columns.

.. class:: Window(expression, partition_by=None, order_by=None, frame=None, output_field=None)

.. attribute:: filterable

Defaults to ``False``. The SQL standard disallows referencing window
functions in the ``WHERE`` clause and Django raises an exception when
constructing a ``QuerySet`` that would do that.

.. attribute:: template

Defaults to ``%(expression)s OVER (%(window)s)'``. If only the
Expand Down Expand Up @@ -819,6 +813,31 @@ to reduce repetition::
>>> ),
>>> )

Filtering against window functions is supported as long as lookups are not
disjunctive (not using ``OR`` or ``XOR`` as a connector) and against a queryset
performing aggregation.

For example, a query that relies on aggregation and has an ``OR``-ed filter
against a window function and a field is not supported. Applying combined
predicates post-aggregation could cause rows that would normally be excluded
from groups to be included::

>>> qs = Movie.objects.annotate(
>>> category_rank=Window(
>>> Rank(), partition_by='category', order_by='-rating'
>>> ),
>>> scenes_count=Count('actors'),
>>> ).filter(
>>> Q(category_rank__lte=3) | Q(title__contains='Batman')
>>> )
>>> list(qs)
NotImplementedError: Heterogeneous disjunctive predicates against window functions
are not implemented when performing conditional aggregation.

.. versionchanged:: 4.2

Support for filtering against window functions was added.

Among Django's built-in database backends, MySQL 8.0.2+, PostgreSQL, and Oracle
support window expressions. Support for different window expression features
varies among the different databases. For example, the options in
Expand Down
4 changes: 3 additions & 1 deletion docs/releases/4.2.txt
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,9 @@ Migrations
Models
~~~~~~

* ...
* ``QuerySet`` now extensively supports filtering against
:ref:`window-functions` with the exception of disjunctive filter lookups
against window functions when performing aggregation.

Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~
Expand Down
Loading

0 comments on commit f387d02

Please sign in to comment.