Skip to content

Commit

Permalink
chore: move t.optimize() to simplify() function under rewrites
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Dec 19, 2023
1 parent 9a4844d commit 402b90f
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 81 deletions.
4 changes: 3 additions & 1 deletion ibis/expr/decompile.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis.common.graph import Graph
from ibis.expr.rewrites import simplify
from ibis.util import experimental

_method_overrides = {
Expand Down Expand Up @@ -411,7 +412,8 @@ def decompile(
if not isinstance(expr, ir.Expr):
raise TypeError(f"Expected ibis expression, got {type(expr).__name__}")

node = expr.optimize().op()
node = expr.op()
node = simplify(node)
out = io.StringIO()
ctx = CodeContext(assign_result_to=assign_result_to)
dependents = Graph(node).invert()
Expand Down
10 changes: 10 additions & 0 deletions ibis/expr/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,13 @@ def reorder_filter_project(_, y):
projs = {k: v.replace(rule) for k, v in y.values.items()}

return ops.Project(inner, projs)


def simplify(node):
# TODO(kszucs): add a utility to the graph module to do rewrites in multiple
# passes after each other
node = node.replace(reorder_filter_project)
node = node.replace(reorder_filter_project)
node = node.replace(subsequent_projects | subsequent_filters)
node = node.replace(complete_reprojection)
return node
55 changes: 0 additions & 55 deletions ibis/expr/tests/test_newrels.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,6 @@ def test_select_full_reprojection():
},
)

t1_opt = t1.optimize()
assert t1_opt.op() == t.op()


def test_subsequent_selections_with_field_names():
t1 = t.select("bool_col", "int_col", "float_col")
Expand Down Expand Up @@ -224,14 +221,6 @@ def test_subsequent_selections_with_field_names():
},
)

t2_opt = t2.optimize()
assert t2_opt.op() == Project(
parent=t, values={"bool_col": t.bool_col, "int_col": t.int_col}
)

t3_opt = t3.optimize()
assert t3_opt.op() == Project(parent=t, values={"bool_col": t.bool_col})


def test_subsequent_selections_field_dereferencing():
t1 = t.select(t.bool_col, t.int_col, t.float_col)
Expand All @@ -245,7 +234,6 @@ def test_subsequent_selections_field_dereferencing():
)

t2 = t1.select(t1.bool_col, t1.int_col)
t2_opt = t2.optimize()
assert t1.select(t1.bool_col, t.int_col).equals(t2)
assert t1.select(t.bool_col, t.int_col).equals(t2)
assert t2.op() == Project(
Expand All @@ -255,16 +243,8 @@ def test_subsequent_selections_field_dereferencing():
"int_col": t1.int_col,
},
)
assert t2_opt.op() == Project(
parent=t,
values={
"bool_col": t.bool_col,
"int_col": t.int_col,
},
)

t3 = t2.select(t2.bool_col)
t3_opt = t3.optimize()
assert t2.select(t1.bool_col).equals(t3)
assert t2.select(t.bool_col).equals(t3)
assert t3.op() == Project(
Expand All @@ -273,7 +253,6 @@ def test_subsequent_selections_field_dereferencing():
"bool_col": t2.bool_col,
},
)
assert t3_opt.op() == Project(parent=t, values={"bool_col": t.bool_col})

u1 = t.select(t.bool_col, t.int_col, t.float_col)
assert u1.op() == Project(
Expand Down Expand Up @@ -324,7 +303,6 @@ def test_subsequent_selections_value_dereferencing():
)

t2 = t1.select(t1.bool_col, t1.int_col, t1.float_col)
t2_opt = t2.optimize()
assert t2.op() == Project(
parent=t1,
values={
Expand All @@ -333,23 +311,13 @@ def test_subsequent_selections_value_dereferencing():
"float_col": t1.float_col,
},
)
assert t2_opt.op() == Project(
parent=t,
values={
"bool_col": ~t.bool_col,
"int_col": t.int_col + 1,
"float_col": t.float_col * 3,
},
)

t3 = t2.select(
t2.bool_col,
t2.int_col,
float_col=t2.float_col * 2,
another_col=t1.float_col - 1,
)

t3_opt = t3.optimize()
assert t3.op() == Project(
parent=t2,
values={
Expand All @@ -359,15 +327,6 @@ def test_subsequent_selections_value_dereferencing():
"another_col": t2.float_col - 1,
},
)
assert t3_opt.op() == Project(
parent=t,
values={
"bool_col": ~t.bool_col,
"int_col": t.int_col + 1,
"float_col": (t.float_col * 3) * 2,
"another_col": (t.float_col * 3) - 1,
},
)


def test_where():
Expand Down Expand Up @@ -466,9 +425,6 @@ def test_subsequent_filter():
expected = Filter(f1, predicates=[f1.int_col > 0])
assert f2.op() == expected

f2_opt = f2.optimize()
assert f2_opt.op() == Filter(t, predicates=[t.bool_col, t.int_col > 0])


def test_project_before_and_after_filter():
t1 = t.select(
Expand Down Expand Up @@ -501,17 +457,6 @@ def test_project_before_and_after_filter():
},
)

t2_opt = t2.optimize(enable_reordering=False)
assert t2_opt.op() == Filter(parent=t1, predicates=[t1.bool_col])

t3_opt = t3.optimize(enable_reordering=False)
assert t3_opt.op() == Filter(parent=t1, predicates=[t1.bool_col, t1.int_col > 0])

t4_opt = t4.optimize(enable_reordering=False)
assert t4_opt.op() == Project(
parent=t3_opt, values={"bool_col": t3_opt.bool_col, "int_col": t3_opt.int_col}
)


# TODO(kszucs): add test for failing integrity checks
def test_join():
Expand Down
104 changes: 104 additions & 0 deletions ibis/expr/tests/test_rewrites.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from __future__ import annotations

import ibis
import ibis.expr.operations as ops
from ibis.expr.rewrites import simplify

t = ibis.table(
name="t",
schema={
"bool_col": "boolean",
"int_col": "int64",
"float_col": "float64",
"string_col": "string",
},
)


def test_simplify_full_reprojection():
t1 = t.select(t)
t1_opt = simplify(t1.op())
assert t1_opt == t.op()


def test_simplify_subsequent_field_selections():
t1 = t.select(t.bool_col, t.int_col, t.float_col)
assert t1.op() == ops.Project(
parent=t,
values={
"bool_col": t.bool_col,
"int_col": t.int_col,
"float_col": t.float_col,
},
)

t2 = t1.select(t1.bool_col, t1.int_col)
t2_opt = simplify(t2.op())
assert t2_opt == ops.Project(
parent=t,
values={
"bool_col": t.bool_col,
"int_col": t.int_col,
},
)

t3 = t2.select(t2.bool_col)
t3_opt = simplify(t3.op())
assert t3_opt == ops.Project(parent=t, values={"bool_col": t.bool_col})


def test_simplify_subsequent_value_selections():
t1 = t.select(
bool_col=~t.bool_col, int_col=t.int_col + 1, float_col=t.float_col * 3
)
t2 = t1.select(t1.bool_col, t1.int_col, t1.float_col)
t2_opt = simplify(t2.op())
assert t2_opt == ops.Project(
parent=t,
values={
"bool_col": ~t.bool_col,
"int_col": t.int_col + 1,
"float_col": t.float_col * 3,
},
)

t3 = t2.select(
t2.bool_col,
t2.int_col,
float_col=t2.float_col * 2,
another_col=t1.float_col - 1,
)
t3_opt = simplify(t3.op())
assert t3_opt == ops.Project(
parent=t,
values={
"bool_col": ~t.bool_col,
"int_col": t.int_col + 1,
"float_col": (t.float_col * 3) * 2,
"another_col": (t.float_col * 3) - 1,
},
)


def test_simplify_subsequent_filters():
f1 = t.filter(t.bool_col)
f2 = f1.filter(t.int_col > 0)
f2_opt = simplify(f2.op())
assert f2_opt == ops.Filter(t, predicates=[t.bool_col, t.int_col > 0])


def test_simplify_project_filter_project():
t1 = t.select(
bool_col=~t.bool_col, int_col=t.int_col + 1, float_col=t.float_col * 3
)
t2 = t1.filter(t1.bool_col)
t3 = t2.filter(t2.int_col > 0)
t4 = t3.select(t3.bool_col, t3.int_col)

filt = ops.Filter(parent=t, predicates=[~t.bool_col, t.int_col + 1 > 0]).to_expr()
proj = ops.Project(
parent=filt, values={"bool_col": ~filt.bool_col, "int_col": filt.int_col + 1}
).to_expr()

t4_opt = simplify(t4.op())
assert t4_opt == proj.op()
19 changes: 0 additions & 19 deletions ibis/expr/types/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,25 +352,6 @@ def compile(
self, limit=limit, timecontext=timecontext, params=params
)

def optimize(self, enable_reordering: bool = True) -> ir.Expr:
from ibis.expr.rewrites import (
complete_reprojection,
subsequent_filters,
subsequent_projects,
reorder_filter_project,
)

node = self.op()
if enable_reordering:
node = node.replace(reorder_filter_project)
node = node.replace(reorder_filter_project)

node = node.replace(subsequent_projects | subsequent_filters)
node = node.replace(complete_reprojection)

# return with a new expression wrapping the optimized node
return node.to_expr()

@experimental
def to_pyarrow_batches(
self,
Expand Down
9 changes: 6 additions & 3 deletions ibis/tests/expr/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import ibis
import ibis.common.exceptions as com
import ibis.expr.operations as ops
from ibis.expr.rewrites import simplify

# Place to collect esoteric expression analysis bugs and tests

Expand Down Expand Up @@ -207,7 +208,8 @@ def test_mutation_fusion_no_overwrite():
result = result.mutate(col2=t["col"] + 2)
result = result.mutate(col3=t["col"] + 3)

assert result.optimize().op() == ops.Project(
simplified = simplify(result.op())
assert simplified == ops.Project(
parent=t,
values={
"col": t["col"],
Expand All @@ -234,7 +236,8 @@ def test_mutation_fusion_overwrite():
# unable to dereference the column since result doesn't contain it anymore
result.mutate(col4=t["col"] + 4)

assert result.optimize().op() == ops.Project(
simplified = simplify(result.op())
assert simplified == ops.Project(
parent=t,
values={
"col": t["col"] - 1,
Expand Down Expand Up @@ -264,7 +267,7 @@ def test_select_filter_mutate_fusion():
filt = ops.Filter(parent=t, predicates=[t.col.isnan()]).to_expr()
proj = ops.Project(parent=filt, values={"col": filt.col.cast("int32")}).to_expr()

t3_opt = t3.optimize()
t3_opt = simplify(t3.op()).to_expr()
assert t3_opt.equals(proj)


Expand Down
8 changes: 5 additions & 3 deletions ibis/tests/expr/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ibis.common.deferred import Deferred
from ibis.common.exceptions import ExpressionError, IntegrityError, RelationError
from ibis.expr import api
from ibis.expr.rewrites import simplify
from ibis.expr.types import Column, Table
from ibis.tests.util import assert_equal, assert_pickle_roundtrip

Expand Down Expand Up @@ -358,13 +359,13 @@ def test_add_predicate_coalesce(table):
pred1 = table["a"] > 5
pred2 = table["b"] > 0

result = table[pred1][pred2].optimize()
result = simplify(table[pred1][pred2].op()).to_expr()
expected = table.filter([pred1, pred2])
assert_equal(result, expected)

# 59, if we are not careful, we can obtain broken refs
subset = table[pred1]
result = subset.filter([subset["b"] > 0]).optimize()
result = simplify(subset.filter([subset["b"] > 0]).op()).to_expr()
assert_equal(result, expected)


Expand Down Expand Up @@ -1671,7 +1672,8 @@ def test_mutate_chain():
assert isinstance(values["b"], ops.Field)
assert values["b"].rel == two.op()

assert three.optimize().op() == ops.Project(
three_opt = simplify(three.op())
assert three_opt == ops.Project(
parent=one,
values={
"a": one.a.fillna("Short Term"),
Expand Down

0 comments on commit 402b90f

Please sign in to comment.