Skip to content

Commit

Permalink
refactor(ir): split the relational operations
Browse files Browse the repository at this point in the history
Rationale and history
---------------------
In the last couple of years we have been constantly refactoring the
internals to make it easier to work with. Although we have made great
progress, the current codebase is still hard to maintain and extend.
One example of that complexity is the try to remove the `Projector`
class in #7430. I had to realize that we are unable to improve the
internals in smaller incremental steps, we need to make a big leap
forward to make the codebase maintainable in the long run.

One of the hotspots of problems is the `analysis.py` module which tries
to bridge the gap between the user-facing API and the internal
representation. Part of its complexity is caused by loose integrity
checks in the internal representation, allowing various ways to
represent the same operation. This makes it hard to inspect, reason
about and optimize the relational operations. In addition to that, it
makes much harder to implement the backends since more branching is
required to cover all the variations.

We have always been aware of these problems, and actually we had several
attempts to solve them the same way this PR does. However, we never
managed to actually split the relational operations, we always hit
roadblocks to maintain compatibility with the current test suite.
Actually we were unable to even understand those issues because of the
complexity of the codebase and number of indirections between the API,
analysis functions and the internal representation.

But(!) finally we managed to prototype a new IR in #7580 along with
implementations for the majority of the backends, including `various SQL
backends` and `pandas`. After successfully validating the viability of
the new IR, we split the PR into smaller pieces which can be
individually reviewed. This PR is the first step of that process, it
introduces the new IR and the new API. The next steps will be to
implement the remaining backends on top of the new IR.

Changes in this commit
----------------------
- Split the `ops.Selection` and `ops.Aggregration` nodes into proper
  relational algebra operations.
- Almost entirely remove `analysis.py` with the technical debt
  accumulated over the years.
- More flexible window frame binding: if an unbound analytical function
  is used with a window containing references to a relation then
  `.over()` is now able to bind the window frame to the relation.
- Introduce a new API-level technique to dereference columns to the
  target relation(s).
- Revamp the subquery handling to be more robust and to support more
  use cases with strict validation, now we have `ScalarSubquery`,
  `ExistsSubquery`, and `InSubquery` nodes which can only be used in
  the appropriate context.
- Use way stricter integrity checks for all the relational operations,
  most of the time enforcing that all the value inputs of the node must
  originate from the parent relation the node depends on.
- Introduce a new `JoinChain` operations to represent multiple joins in
  a single operation followed by a projection attached to the same
  relation. This enabled to solve several outstanding issues with the
  join handling (including the notorious chain join issue).
- Use straightforward rewrite rules collected in `rewrites.py` to
  reinterpret user input so that the new operations can be constructed,
  even with the strict integrity checks.
- Provide a set of simplification rules to reorder and squash the
  relational operations into a more compact form.
- Use mappings to represent projections, eliminating the need of
  internally storing `ops.Alias` nodes. In addition to that table nodes
  in projections are not allowed anymore, the columns are expanded to
  the same mapping making the semantics clear.
- Uniform handling of the various kinds of inputs for all the API
  methods using a generic `bind()` function.

Advantages of the new IR
------------------------
- The operations are much simpler with clear semantics.
- The operations are easier to reason about and to optimize.
- The backends can easily lower the internal representation to a
  backend-specific form before compilation/execution, so the lowered
  form can be easily inspected, debugged, and optimized.
- The API is much closer to the users' mental model, thanks to the
  dereferencing technique.
- The backend implementation can be greatly simplified due to the
  simpler internal representation and strict integrity checks. As an
  example the pandas backend can be slimmed down by 4k lines of code
  while being more robust and easier to maintain.

Disadvantages of the new IR
---------------------------
- The backends must be rewritten to support the new internal
  representation.
  • Loading branch information
kszucs committed Feb 5, 2024
1 parent 0d2d414 commit 13a4f74
Show file tree
Hide file tree
Showing 75 changed files with 3,533 additions and 2,393 deletions.
442 changes: 4 additions & 438 deletions ibis/expr/analysis.py

Large diffs are not rendered by default.

13 changes: 10 additions & 3 deletions ibis/expr/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import builtins
import datetime
import functools
import itertools
import numbers
import operator
from collections import Counter
Expand Down Expand Up @@ -303,6 +304,9 @@ def schema(
return sch.Schema.from_tuples(zip(names, types))


_table_names = (f"unbound_table_{i:d}" for i in itertools.count())


def table(
schema: SupportsSchema | None = None,
name: str | None = None,
Expand Down Expand Up @@ -333,9 +337,12 @@ def table(
a int64
b string
"""
if isinstance(schema, type) and name is None:
name = schema.__name__
return ops.UnboundTable(schema=schema, name=name).to_expr()
if name is None:
if isinstance(schema, type):
name = schema.__name__
else:
name = next(_table_names)
return ops.UnboundTable(name=name, schema=schema).to_expr()


@lazy_singledispatch
Expand Down
43 changes: 37 additions & 6 deletions ibis/expr/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@
import ibis.expr.rules as rlz
import ibis.expr.types as ir
from ibis import util
from ibis.common.annotations import annotated
from ibis.common.annotations import annotated, attribute
from ibis.common.deferred import Deferred, Resolver, deferrable
from ibis.common.exceptions import IbisInputError
from ibis.common.grounds import Concrete
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.operations.relations import Relation # noqa: TCH001
from ibis.expr.types.relations import bind_expr

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down Expand Up @@ -146,6 +144,25 @@ class WindowBuilder(Builder):
orderings: VarTuple[Union[str, Resolver, ops.Value]] = ()
max_lookback: Optional[ops.Value[dt.Interval]] = None

@attribute
def _table(self):
inputs = (
self.start,
self.end,
*self.groupings,
*self.orderings,
self.max_lookback,
)
valuerels = (v.relations for v in inputs if isinstance(v, ops.Value))
relations = frozenset().union(*valuerels)
if len(relations) == 0:
return None
elif len(relations) == 1:
(table,) = relations
return table
else:
raise IbisInputError("Window frame can only depend on a single relation")

def _maybe_cast_boundary(self, boundary, dtype):
if boundary.dtype == dtype:
return boundary
Expand Down Expand Up @@ -214,9 +231,23 @@ def lookback(self, value) -> Self:
return self.copy(max_lookback=value)

@annotated
def bind(self, table: Relation):
groupings = bind_expr(table.to_expr(), self.groupings)
orderings = bind_expr(table.to_expr(), self.orderings)
def bind(self, table: Optional[ops.Relation]):
table = table or self._table
if table is None:
raise IbisInputError("Unable to bind window frame to a table")

table = table.to_expr()

def bind_value(value):
if isinstance(value, str):
return table._get_column(value)
elif isinstance(value, Resolver):
return value.resolve({"_": table})
else:
return value

groupings = map(bind_value, self.groupings)
orderings = map(bind_value, self.orderings)
if self.how == "rows":
return ops.RowsWindowFrame(
table=table,
Expand Down
141 changes: 90 additions & 51 deletions ibis/expr/decompile.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis.common.graph import Graph
from ibis.expr.rewrites import simplify
from ibis.util import experimental

_method_overrides = {
Expand All @@ -31,11 +32,8 @@
ops.ExtractYear: "year",
ops.Intersection: "intersect",
ops.IsNull: "isnull",
ops.LeftAntiJoin: "anti_join",
ops.LeftSemiJoin: "semi_join",
ops.Lowercase: "lower",
ops.RegexSearch: "re_search",
ops.SelfReference: "view",
ops.StartsWith: "startswith",
ops.StringContains: "contains",
ops.StringSQLILike: "ilike",
Expand Down Expand Up @@ -87,7 +85,6 @@ def translate(op, *args, **kwargs):


@translate.register(ops.Value)
@translate.register(ops.TableNode)
def value(op, *args, **kwargs):
method = _get_method_name(op)
kwargs = [(k, v) for k, v in kwargs.items() if v is not None]
Expand Down Expand Up @@ -125,44 +122,80 @@ def _try_unwrap(stmt):
if len(stmt) == 1:
return stmt[0]
else:
return f"[{', '.join(stmt)}]"
stmt = map(str, stmt)
values = ", ".join(stmt)
return f"[{values}]"


def _wrap_alias(values, rendered):
result = []
for k, v in values.items():
text = rendered[k]
if v.name != k:
text = f"{text}.name({k!r})"
result.append(text)
return result


def _inline(args):
return ", ".join(map(str, args))


@translate.register(ops.Project)
def project(op, parent, values):
out = f"{parent}"
if not values:
return out

@translate.register(ops.Selection)
def selection(op, table, selections, predicates, sort_keys):
out = f"{table}"
if selections:
out = f"{out}.select({_try_unwrap(selections)})"
values = _wrap_alias(op.values, values)
return f"{out}.select({_inline(values)})"


@translate.register(ops.Filter)
def filter_(op, parent, predicates):
out = f"{parent}"
if predicates:
out = f"{out}.filter({_try_unwrap(predicates)})"
if sort_keys:
out = f"{out}.order_by({_try_unwrap(sort_keys)})"
out = f"{out}.filter({_inline(predicates)})"
return out


@translate.register(ops.Aggregation)
def aggregation(op, table, by, metrics, predicates, having, sort_keys):
out = f"{table}"
if predicates:
out = f"{out}.filter({_try_unwrap(predicates)})"
if by:
out = f"{out}.group_by({_try_unwrap(by)})"
if having:
out = f"{out}.having({_try_unwrap(having)})"
if metrics:
out = f"{out}.aggregate({_try_unwrap(metrics)})"
if sort_keys:
out = f"{out}.order_by({_try_unwrap(sort_keys)})"
@translate.register(ops.Sort)
def sort(op, parent, keys):
out = f"{parent}"
if keys:
out = f"{out}.order_by({_inline(keys)})"
return out


@translate.register(ops.Join)
def join(op, left, right, predicates):
method = _get_method_name(op)
return f"{left}.{method}({right}, {_try_unwrap(predicates)})"
@translate.register(ops.Aggregate)
def aggregation(op, parent, groups, metrics):
groups = _wrap_alias(op.groups, groups)
metrics = _wrap_alias(op.metrics, metrics)
if groups and metrics:
return f"{parent}.aggregate([{_inline(metrics)}], by=[{_inline(groups)}])"
elif metrics:
return f"{parent}.aggregate([{_inline(metrics)}])"
else:
raise ValueError("No metrics to aggregate")


@translate.register(ops.SelfReference)
def self_reference(op, parent, identifier):
return parent


@translate.register(ops.JoinLink)
def join_link(op, table, predicates, how):
return f".{how}_join({table}, {_try_unwrap(predicates)})"


@translate.register(ops.JoinChain)
def join(op, first, rest, values):
calls = "".join(rest)
return f"{first}{calls}"


@translate.register(ops.SetOp)
@translate.register(ops.Set)
def union(op, left, right, distinct):
method = _get_method_name(op)
if distinct:
Expand All @@ -172,16 +205,16 @@ def union(op, left, right, distinct):


@translate.register(ops.Limit)
def limit(op, table, n, offset):
def limit(op, parent, n, offset):
if offset:
return f"{table}.limit({n}, {offset})"
return f"{parent}.limit({n}, {offset})"
else:
return f"{table}.limit({n})"
return f"{parent}.limit({n})"


@translate.register(ops.TableColumn)
def table_column(op, table, name):
return f"{table}.{name}"
@translate.register(ops.Field)
def table_column(op, rel, name):
return f"{rel}.{name}"


@translate.register(ops.SortKey)
Expand Down Expand Up @@ -292,14 +325,22 @@ def isin(op, value, options):


class CodeContext:
always_assign = (ops.ScalarParameter, ops.UnboundTable, ops.Aggregation)
always_ignore = (ops.TableColumn, dt.Primitive, dt.Variadic, dt.Temporal)
always_assign = (ops.ScalarParameter, ops.UnboundTable, ops.Aggregate)
always_ignore = (
ops.SelfReference,
ops.Field,
dt.Primitive,
dt.Variadic,
dt.Temporal,
)
shorthands = {
ops.Aggregation: "agg",
ops.Aggregate: "agg",
ops.Literal: "lit",
ops.ScalarParameter: "param",
ops.Selection: "proj",
ops.TableNode: "t",
ops.Project: "p",
ops.Relation: "r",
ops.Filter: "f",
ops.Sort: "s",
}

def __init__(self, assign_result_to="result"):
Expand All @@ -308,7 +349,7 @@ def __init__(self, assign_result_to="result"):

def variable_for(self, node):
klass = type(node)
if isinstance(node, ops.TableNode) and isinstance(node, ops.Named):
if isinstance(node, ops.Relation) and hasattr(node, "name"):
name = node.name
elif klass in self.shorthands:
name = self.shorthands[klass]
Expand Down Expand Up @@ -345,7 +386,7 @@ def render(self, node, code, n_dependents):

@experimental
def decompile(
node: ops.Node | ir.Expr,
expr: ir.Expr,
render_import: bool = True,
assign_result_to: str = "result",
format: bool = False,
Expand All @@ -354,7 +395,7 @@ def decompile(
Parameters
----------
node
expr
node or expression to decompile
render_import
Whether to add `import ibis` to the result.
Expand All @@ -368,13 +409,11 @@ def decompile(
str
Equivalent Python source code for `node`.
"""
if isinstance(node, ir.Expr):
node = node.op()
elif not isinstance(node, ops.Node):
raise TypeError(
f"Expected ibis expression or operation, got {type(node).__name__}"
)
if not isinstance(expr, ir.Expr):
raise TypeError(f"Expected ibis expression, got {type(expr).__name__}")

node = expr.op()
node = simplify(node)
out = io.StringIO()
ctx = CodeContext(assign_result_to=assign_result_to)
dependents = Graph(node).invert()
Expand Down
Loading

0 comments on commit 13a4f74

Please sign in to comment.