Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: Utilize ORDER BY LIMIT over ROW_NUMBER where possible #1077

Merged
merged 11 commits into from
Oct 18, 2024
11 changes: 9 additions & 2 deletions bigframes/core/compile/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import google.cloud.bigquery as bigquery

import bigframes.core.compile.compiler as compiler
import bigframes.core.rewrite as rewrites

if TYPE_CHECKING:
import bigframes.core.nodes
Expand All @@ -42,6 +43,11 @@ def compile_unordered(
col_id_overrides: Mapping[str, str] = {},
) -> str:
"""Compile node into sql where rows are unsorted, and no ordering information is preserved."""
new_node, limit = rewrites.pullup_limit_from_slice(node)
if limit is not None:
return self._compiler.compile_ordered_ir(new_node).to_sql(
col_id_overrides=col_id_overrides, ordered=True, limit=limit
)
return self._compiler.compile_unordered_ir(node).to_sql(
col_id_overrides=col_id_overrides
)
Expand All @@ -53,8 +59,9 @@ def compile_ordered(
col_id_overrides: Mapping[str, str] = {},
) -> str:
"""Compile node into sql where rows are sorted with ORDER BY."""
return self._compiler.compile_ordered_ir(node).to_sql(
col_id_overrides=col_id_overrides, ordered=True
new_node, limit = rewrites.pullup_limit_from_slice(node)
return self._compiler.compile_ordered_ir(new_node).to_sql(
col_id_overrides=col_id_overrides, ordered=True, limit=limit
)

def compile_raw(
Expand Down
8 changes: 6 additions & 2 deletions bigframes/core/compile/compiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,8 +943,9 @@ def to_sql(
self,
col_id_overrides: typing.Mapping[str, str] = {},
ordered: bool = False,
limit: Optional[int] = None,
) -> str:
if ordered:
if ordered or limit:
# Need to bake ordering expressions into the selected column in order for our ordering clause builder to work.
baked_ir = self._bake_ordering()
sql = ibis_bigquery.Backend().compile(
Expand All @@ -969,7 +970,10 @@ def to_sql(
order_by_clause = bigframes.core.sql.ordering_clause(
baked_ir._ordering.all_ordering_columns
)
sql += f"{order_by_clause}\n"
sql += f"\n{order_by_clause}"
if limit is not None:
assert isinstance(limit, int)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we raise a TypeError here instead of an assertion error?

I don't know how well our public APIs are validating the type of parameters. Not everyone is going to use a type checker.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, in general, probably best to assume that wrong types can sometimes reach even the lower levels of the code base. In particular want to be safe about what we are putting into the sql strings. Switched to TypeError here.

sql += f"\nLIMIT {limit}"
else:
sql = ibis_bigquery.Backend().compile(
self._to_ibis_expr(
Expand Down
19 changes: 16 additions & 3 deletions bigframes/core/compile/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import bigframes.core.identifiers as ids
import bigframes.core.nodes as nodes
import bigframes.core.ordering as bf_ordering
import bigframes.core.rewrite as rewrites

if typing.TYPE_CHECKING:
import bigframes.core
Expand All @@ -48,20 +49,32 @@ class Compiler:
# In unstrict mode, ordering from ReadTable or after joins may be ambiguous to improve query performance.
strict: bool = True
scalar_op_compiler = compile_scalar.ScalarOpCompiler()
enable_pruning: bool = False

def _preprocess(self, node: nodes.BigFrameNode):
if self.enable_pruning:
used_fields = frozenset(field.id for field in node.fields)
node = node.prune(used_fields)
node = functools.cache(rewrites.replace_slice_ops)(node)
return node

def compile_ordered_ir(self, node: nodes.BigFrameNode) -> compiled.OrderedIR:
ir = typing.cast(compiled.OrderedIR, self.compile_node(node, True))
ir = typing.cast(
compiled.OrderedIR, self.compile_node(self._preprocess(node), True)
)
if self.strict:
assert ir.has_total_order
return ir

def compile_unordered_ir(self, node: nodes.BigFrameNode) -> compiled.UnorderedIR:
return typing.cast(compiled.UnorderedIR, self.compile_node(node, False))
return typing.cast(
compiled.UnorderedIR, self.compile_node(self._preprocess(node), False)
)

def compile_peak_sql(
self, node: nodes.BigFrameNode, n_rows: int
) -> typing.Optional[str]:
return self.compile_unordered_ir(node).peek_sql(n_rows)
return self.compile_unordered_ir(self._preprocess(node)).peek_sql(n_rows)

# TODO: Remove cache when schema no longer requires compilation to derive schema (and therefor only compiles for execution)
@functools.lru_cache(maxsize=5000)
Expand Down
38 changes: 38 additions & 0 deletions bigframes/core/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,44 @@ def common_selection_root(
return None


def pullup_limit_from_slice(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we able to have a test to verify this process?

root: nodes.BigFrameNode,
) -> Tuple[nodes.BigFrameNode, Optional[int]]:
"""
This is a BQ-sql specific optimization that can be helpful as ORDER BY LIMIT is more efficient than ROW_NUMBER() + WHERE.
"""
if isinstance(root, nodes.SliceNode):
new_child, limit = pullup_limit_from_slice(root.child)
# head case
if (
(not root.start)
and ((root.stop is not None) and root.stop > 0)
and (root.step == 1)
):
limit = root.stop
new_root, prior_limit = pullup_limit_from_slice(root.child)
if prior_limit is not None and prior_limit < limit:
limit = prior_limit
return new_root, limit
# tail case
if (
(root.start in [None, -1])
and ((root.stop is not None) and root.stop < 0)
and (root.step == -1)
):
limit = -root.stop
new_root, prior_limit = pullup_limit_from_slice(root.child)
if prior_limit is not None and prior_limit < limit:
limit = prior_limit
return nodes.ReversedNode(new_root), limit
elif isinstance(root, nodes.UnaryNode) and root.row_preserving:
new_child, limit = pullup_limit_from_slice(root.child)
if limit is not None:
return root.transform_children(lambda _: new_child), limit
# Many ops don't support pulling up slice, like filter, agg, join, etc.
return root, None


def replace_slice_ops(root: nodes.BigFrameNode) -> nodes.BigFrameNode:
# TODO: we want to pull up some slices into limit op if near root.
if isinstance(root, nodes.SliceNode):
Expand Down
2 changes: 0 additions & 2 deletions bigframes/session/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
import bigframes.core.identifiers
import bigframes.core.nodes as nodes
import bigframes.core.ordering as order
import bigframes.core.rewrite as rewrites
import bigframes.core.schema
import bigframes.core.tree_properties as tree_properties
import bigframes.features
Expand Down Expand Up @@ -437,7 +436,6 @@ def _get_optimized_plan(self, node: nodes.BigFrameNode) -> nodes.BigFrameNode:
if ENABLE_PRUNING:
used_fields = frozenset(field.id for field in optimized_plan.fields)
optimized_plan = optimized_plan.prune(used_fields)
optimized_plan = rewrites.replace_slice_ops(optimized_plan)
return optimized_plan

def _is_trivially_executable(self, array_value: bigframes.core.ArrayValue):
Expand Down