Skip to content

Commit

Permalink
test(ir): made core expression tests passing again
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Nov 26, 2023
1 parent a572c99 commit 7025d1c
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 10 deletions.
1 change: 1 addition & 0 deletions ibis/expr/operations/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def __init__(self, first, rest, fields):
_check_integrity(fields.values(), allowed_parents)
super().__init__(first=first, rest=rest, fields=fields)

# TODO(kszucs): the fields should be changed to be nullable
@attribute
def schema(self):
return Schema({k: v.dtype for k, v in self.fields.items()})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
name="employee",
schema={"first_name": "string", "last_name": "string", "id": "int64"},
)

result = employee.select(employee.first_name).filter(
f = employee.filter(
employee.first_name.isin(
(
ibis.literal("Graham"),
Expand All @@ -17,3 +16,5 @@
)
)
)

result = f.select(f.first_name)
6 changes: 3 additions & 3 deletions ibis/expr/tests/test_newrels.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def test_join():
assert isinstance(joined.op(), JoinChain)
assert isinstance(joined.op().to_expr(), ir.JoinExpr)

result = joined.finish()
result = joined._finish()
assert isinstance(joined, ir.TableExpr)
assert isinstance(joined.op(), JoinChain)
assert isinstance(joined.op().to_expr(), ir.JoinExpr)
Expand Down Expand Up @@ -531,7 +531,7 @@ def test_chained_join():
c = ibis.table(name="c", schema={"e": "int64", "f": "string"})

joined = a.join(b, [a.a == b.c]).join(c, [a.a == c.e])
result = joined.finish()
result = joined._finish()
assert result.op() == JoinChain(
first=a,
rest=[
Expand Down Expand Up @@ -576,7 +576,7 @@ def test_chained_join_referencing_intermediate_table():
abc = ab.join(c, [ab.a == c.e])
assert isinstance(abc, ir.JoinExpr)

result = abc.finish()
result = abc._finish()
assert result.op() == JoinChain(
first=a,
rest=[JoinLink("inner", b, [a.a == b.c]), JoinLink("inner", c, [a.a == c.e])],
Expand Down
7 changes: 3 additions & 4 deletions ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4422,7 +4422,7 @@ def join(

# TODO(kszucs): clean this up
preds = dict(enumerate(preds))
preds = dereference_values(self._tables, preds)
preds = dereference_values(self._tables(), preds)
preds = list(preds.values())

# TODO(kszucs): factor this out into a separate function, e.g. defereference_values_from()
Expand Down Expand Up @@ -4459,7 +4459,7 @@ def select(self, *args, **kwargs):
# with a field referencing one of the relations in the join chain
fields = {ops.Field(self, k): v for k, v in self.op().fields.items()}
values = {k: v.replace(fields, filter=ops.Value) for k, v in values.items()}
values = dereference_values(self._tables, values)
values = dereference_values(self._tables(), values)
# TODO(kszucs): add reduction conversion here detect_foreign_values(values)?

return self._finish(values)
Expand All @@ -4474,7 +4474,6 @@ def order_by(self, *keys):
"""Order the join by the given keys."""
return self._finish().order_by(*keys)

@property
def _tables(self) -> Iterator[ops.TableNode]:
node = self.op()
yield node.first
Expand All @@ -4489,7 +4488,7 @@ def _finish(self, fields: Mapping[str, ops.Field] | None = None) -> ir.Table:
# raise on collisions
collisions = []
fields = frozenset(self.op().fields.values())
for rel in self._tables:
for rel in self._tables():
for k in rel.schema:
f = ops.Field(rel, k)
if f not in fields:
Expand Down
2 changes: 1 addition & 1 deletion ibis/tests/expr/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1695,7 +1695,7 @@ def test_join_lname_rname_still_collide():
t3 = ibis.table({"id": "int64", "col1": "int64", "col2": "int64"})

with pytest.raises(com.IntegrityError):
t1.left_join(t2, "id").left_join(t3, "id").finish()
t1.left_join(t2, "id").left_join(t3, "id")._finish()

# assert "`['col1_right', 'col2_right', 'id_right']`" in str(rec.value)
# assert "`lname='', rname='{name}_right'`" in str(rec.value)
Expand Down

0 comments on commit 7025d1c

Please sign in to comment.