Skip to content

Commit

Permalink
feat(ir): add default implementation of pretty formatting nodes (#8880)
Browse files Browse the repository at this point in the history
We have custom nodes for example to lower expressions to certain
backends. Pretty printing these are especially useful but currently we
raise for unknown node types.
  • Loading branch information
kszucs authored Apr 4, 2024
1 parent 38e7e14 commit a696c70
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 49 deletions.
14 changes: 9 additions & 5 deletions ibis/expr/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def __repr__(self):


@public
def pretty(expr: ir.Expr, scope: Optional[dict[str, ir.Expr]] = None):
def pretty(expr: ops.Node | ir.Expr, scope: Optional[dict[str, ir.Expr]] = None):
"""Pretty print an expression.
Parameters
Expand All @@ -186,10 +186,13 @@ def pretty(expr: ir.Expr, scope: Optional[dict[str, ir.Expr]] = None):
str
A pretty printed representation of the expression.
"""
if not isinstance(expr, ir.Expr):
raise TypeError(f"Expected an expression, got {type(expr)}")
if isinstance(expr, ir.Expr):
node = expr.op()
elif isinstance(expr, ops.Node):
node = expr
else:
raise TypeError(f"Expected an expression or a node, got {type(expr)}")

node = expr.op()
refs = {}
refcnt = itertools.count()
variables = {v.op(): k for k, v in (scope or {}).items()}
Expand Down Expand Up @@ -224,7 +227,8 @@ def mapper(op, _, **kwargs):

@functools.singledispatch
def fmt(op, **kwargs):
raise NotImplementedError(f"no pretty printer for {type(op)}")
top = f"{op.__class__.__name__}\n"
return top + render_fields(kwargs, 1)


@fmt.register(ops.Relation)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
r0 := UnboundTable: t
a int64

ValueList
values:
1
2.0
'three'
r0.a
96 changes: 52 additions & 44 deletions ibis/expr/tests/test_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,20 @@

import ibis
import ibis.expr.datatypes as dt
import ibis.expr.format
import ibis.expr.operations as ops
import ibis.legacy.udf.vectorized as udf
from ibis import util

# easier to switch implementation if needed
fmt = repr
from ibis.expr.format import fmt, pretty


@pytest.mark.parametrize("cls", [ops.PhysicalTable, ops.Relation])
def test_tables_have_format_value_rules(cls):
assert cls in ibis.expr.format.fmt.registry
assert cls in fmt.registry


def test_format_table_column(alltypes, snapshot):
# GH #507
result = fmt(alltypes.f)
result = repr(alltypes.f)
assert "float64" in result
snapshot.assert_match(result, "repr.txt")

Expand All @@ -32,14 +29,14 @@ def test_format_projection(alltypes, snapshot):
# This should produce a ref to the projection
proj = alltypes[["c", "a", "f"]]
expr = proj["a"]
result = fmt(expr)
result = repr(expr)
snapshot.assert_match(result, "repr.txt")


def test_format_table_with_empty_schema(snapshot):
# GH #6837
schema = ibis.table({}, name="t")
result = fmt(schema)
result = repr(schema)
snapshot.assert_match(result, "repr.txt")


Expand All @@ -55,7 +52,7 @@ def test_table_type_output(snapshot):
)

expr = foo.dept_id == foo.view().dept_id
result = fmt(expr)
result = repr(expr)
assert "UnboundTable: foo" in result
snapshot.assert_match(result, "repr.txt")

Expand All @@ -68,7 +65,7 @@ def test_aggregate_arg_names(alltypes, snapshot):
metrics = [t.c.sum().name("c"), t.d.mean().name("d")]

expr = t.group_by(by_exprs).aggregate(metrics)
result = fmt(expr)
result = repr(expr)
assert "metrics" in result
assert "groups" in result

Expand Down Expand Up @@ -103,7 +100,7 @@ def test_format_multiple_join_with_projection(snapshot):
view = j2[[filtered, table2["value1"], table3["value2"]]]

# it works!
result = fmt(view)
result = repr(view)
snapshot.assert_match(result, "repr.txt")


Expand All @@ -117,7 +114,7 @@ def test_memoize_filtered_table(snapshot):
t = airlines[airlines.dest.isin(dests)]
delay_filter = t.dest.topk(10, by=t.arrdelay.mean())

result = fmt(delay_filter)
result = repr(delay_filter)
snapshot.assert_match(result, "repr.txt")


Expand All @@ -126,8 +123,8 @@ def test_named_value_expr_show_name(alltypes, snapshot):
expr2 = expr.name("baz")

# it works!
result = fmt(expr)
result2 = fmt(expr2)
result = repr(expr)
result2 = repr(expr2)

assert "baz" not in result
assert "baz" in result2
Expand Down Expand Up @@ -157,14 +154,14 @@ def test_memoize_filtered_tables_in_join(snapshot):
cond = left.region == right.region
joined = left.join(right, cond)[left, right.total.name("right_total")]

result = fmt(joined)
result = repr(joined)
snapshot.assert_match(result, "repr.txt")


def test_argument_repr_shows_name(snapshot):
t = ibis.table([("fakecolname1", "int64")], name="fakename2")
expr = t.fakecolname1.nullif(2)
result = fmt(expr)
result = repr(expr)

assert "fakecolname1" in result
assert "fakename2" in result
Expand All @@ -182,7 +179,7 @@ def test_scalar_parameter_formatting():
def test_same_column_multiple_aliases(snapshot):
table = ibis.table([("col", "int64")], name="t")
expr = table[table.col.name("fakealias1"), table.col.name("fakealias2")]
result = fmt(expr)
result = repr(expr)

assert "UnboundTable: t" in result
assert "col int64" in result
Expand All @@ -193,7 +190,7 @@ def test_same_column_multiple_aliases(snapshot):

def test_scalar_parameter_repr():
value = ibis.param(dt.timestamp).name("value")
assert fmt(value) == "value: $(timestamp)"
assert repr(value) == "value: $(timestamp)"


def test_repr_exact(snapshot):
Expand All @@ -205,7 +202,7 @@ def test_repr_exact(snapshot):
name="t",
).mutate(col4=lambda t: t.col2.length())

result = fmt(table)
result = repr(table)
snapshot.assert_match(result, "repr.txt")


Expand All @@ -218,7 +215,7 @@ def test_complex_repr(snapshot):
.aggregate(y=lambda t: t.a.sum())
.limit(10)
)
result = fmt(t)
result = repr(t)

snapshot.assert_match(result, "repr.txt")

Expand All @@ -245,20 +242,20 @@ def test_schema_truncation(monkeypatch, snapshot):

monkeypatch.setattr(ibis.options.repr, "table_columns", 0)
with pytest.raises(ValueError):
fmt(t)
repr(t)

monkeypatch.setattr(ibis.options.repr, "table_columns", 1)
result = fmt(t)
result = repr(t)
assert util.VERTICAL_ELLIPSIS not in result
snapshot.assert_match(result, "repr1.txt")

monkeypatch.setattr(ibis.options.repr, "table_columns", 8)
result = fmt(t)
result = repr(t)
assert util.VERTICAL_ELLIPSIS in result
snapshot.assert_match(result, "repr8.txt")

monkeypatch.setattr(ibis.options.repr, "table_columns", 1000)
result = fmt(t)
result = repr(t)
assert util.VERTICAL_ELLIPSIS not in result
snapshot.assert_match(result, "repr_all.txt")

Expand All @@ -271,15 +268,15 @@ def test_table_count_expr(snapshot):
join_cnt = t1.join(t2, t1.a == t2.a).count()
union_cnt = ibis.union(t1, t2).count()

snapshot.assert_match(fmt(cnt), "cnt_repr.txt")
snapshot.assert_match(fmt(join_cnt), "join_repr.txt")
snapshot.assert_match(fmt(union_cnt), "union_repr.txt")
snapshot.assert_match(repr(cnt), "cnt_repr.txt")
snapshot.assert_match(repr(join_cnt), "join_repr.txt")
snapshot.assert_match(repr(union_cnt), "union_repr.txt")


def test_window_no_group_by(snapshot):
t = ibis.table(dict(a="int64", b="string"), name="t")
expr = t.a.mean().over(ibis.window(preceding=0))
result = fmt(expr)
result = repr(expr)

assert "group_by=[]" not in result
snapshot.assert_match(result, "repr.txt")
Expand All @@ -289,7 +286,7 @@ def test_window_group_by(snapshot):
t = ibis.table(dict(a="int64", b="string"), name="t")
expr = t.a.mean().over(ibis.window(group_by=t.b))

result = fmt(expr)
result = repr(expr)
assert "start=0" not in result
assert "group_by=[r0.b]" in result
snapshot.assert_match(result, "repr.txt")
Expand All @@ -299,13 +296,13 @@ def test_fillna(snapshot):
t = ibis.table(dict(a="int64", b="string"), name="t")

expr = t.fillna({"a": 3})
snapshot.assert_match(fmt(expr), "fillna_dict_repr.txt")
snapshot.assert_match(repr(expr), "fillna_dict_repr.txt")

expr = t[["a"]].fillna(3)
snapshot.assert_match(fmt(expr), "fillna_int_repr.txt")
snapshot.assert_match(repr(expr), "fillna_int_repr.txt")

expr = t[["b"]].fillna("foo")
snapshot.assert_match(fmt(expr), "fillna_str_repr.txt")
snapshot.assert_match(repr(expr), "fillna_str_repr.txt")


def test_asof_join(snapshot):
Expand All @@ -315,7 +312,7 @@ def test_asof_join(snapshot):
right, left.value == right.value2
)

result = fmt(joined)
result = repr(joined)
snapshot.assert_match(result, "repr.txt")


Expand All @@ -330,7 +327,7 @@ def test_two_inner_joins(snapshot):
right, left.value == right.value2
)

result = fmt(joined)
result = repr(joined)
snapshot.assert_match(result, "repr.txt")


Expand All @@ -347,7 +344,7 @@ def multi_output_udf(v):
return v.sum(), v.mean()

expr = table.aggregate(multi_output_udf(table["col"]).destructure())
result = fmt(expr)
result = repr(expr)

assert "sum: StructField(ReductionVectorizedUDF" in result
assert "mean: StructField(ReductionVectorizedUDF" in result
Expand All @@ -360,41 +357,41 @@ def multi_output_udf(v):
)
def test_format_literal(literal, typ, output):
expr = ibis.literal(literal, type=typ)
assert fmt(expr) == output
assert repr(expr) == output


def test_format_dummy_table(snapshot):
t = ops.DummyTable({"foo": ibis.array([1]).cast("array<int8>")}).to_expr()

result = fmt(t)
result = repr(t)
snapshot.assert_match(result, "repr.txt")


def test_format_in_memory_table(snapshot):
t = ibis.memtable([(1, 2), (3, 4), (5, 6)], columns=["x", "y"])
expr = t.x.sum() + t.y.sum()

result = fmt(expr)
result = repr(expr)
assert "InMemoryTable" in result
snapshot.assert_match(result, "repr.txt")


def test_format_unbound_table_namespace(snapshot):
t = ibis.table(name="bork", schema=(("a", "int"), ("b", "int")))

result = fmt(t)
result = repr(t)
snapshot.assert_match(result, "repr.txt")

t = ibis.table(name="bork", schema=(("a", "int"), ("b", "int")), database="bork")

result = fmt(t)
result = repr(t)
snapshot.assert_match(result, "reprdb.txt")

t = ibis.table(
name="bork", schema=(("a", "int"), ("b", "int")), catalog="ork", database="bork"
)

result = fmt(t)
result = repr(t)
snapshot.assert_match(result, "reprcatdb.txt")


Expand All @@ -413,7 +410,7 @@ def values(self):

table = MyRelation(alltypes, kind="foo").to_expr()
expr = table[table, table.a.name("a2")]
result = fmt(expr)
result = repr(expr)

snapshot.assert_match(result, "repr.txt")

Expand All @@ -431,7 +428,7 @@ def shape(self):
return self.arg.shape

expr = Inc(alltypes.a).to_expr().name("incremented")
result = fmt(expr)
result = repr(expr)
last_line = result.splitlines()[-1]

assert "Inc" in result
Expand All @@ -449,11 +446,22 @@ def test_format_show_variables(monkeypatch, alltypes, snapshot):
sub = projected.a - projected.b
expr = add * sub

result = fmt(expr)
result = repr(expr)

assert "projected.a" in result
assert "projected.b" in result
assert "filtered" in result
assert "ordered" in result

snapshot.assert_match(result, "repr.txt")


def test_default_format_implementation(snapshot):
class ValueList(ops.Node):
values: tuple[ops.Value, ...]

t = ibis.table([("a", "int64")], name="t")
vl = ValueList((1, 2.0, "three", t.a))
result = pretty(vl)

snapshot.assert_match(result, "repr.txt")

0 comments on commit a696c70

Please sign in to comment.