diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index cd8f2fac56bf..11fcbaa01e57 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -5,7 +5,16 @@ import operator import re from keyword import iskeyword -from typing import TYPE_CHECKING, Callable, Iterable, Literal, Mapping, Sequence, Any +from typing import ( + TYPE_CHECKING, + Callable, + Iterable, + Literal, + Mapping, + Sequence, + Any, + Iterator, +) import toolz from public import public @@ -53,8 +62,8 @@ def _regular_join_method( ], ): def f( # noqa: D417 - self: Table, - right: Table, + self: ir.Table, + right: ir.Table, predicates: str | Sequence[ str | tuple[str | ir.Column, str | ir.Column] | ir.BooleanValue @@ -62,7 +71,7 @@ def f( # noqa: D417 *, lname: str = "", rname: str = "{name}_right", - ) -> Table: + ) -> ir.Table: """Perform a join between two tables. Parameters @@ -420,7 +429,8 @@ def __interactive_rich_console__(self, console, options): table = to_rich_table(self, width) return console.render(table, options=options) - def column(self, name): + def column(self, name: str | int) -> ir.Column: + """Get a column from the table.""" if isinstance(name, int): name = self.schema().name_at_position(name) return ops.Field(self, name).to_expr() @@ -4400,6 +4410,7 @@ def join( lname: str = "", rname: str = "{name}_right", ): + """Join with another table.""" from ibis.expr.analysis import flatten_predicates chain = self.op() @@ -4411,7 +4422,7 @@ def join( # TODO(kszucs): clean this up preds = dict(enumerate(preds)) - preds = dereference_values(self.tables(), preds) + preds = dereference_values(self._tables, preds) preds = list(preds.values()) # TODO(kszucs): factor this out into a separate function, e.g. defereference_values_from() @@ -4433,9 +4444,10 @@ def join( chain = chain.copy(rest=chain.rest + (link,), fields=fields) # return with a new JoinExpr wrapping the new join chain - return JoinExpr(chain) + return self.__class__(chain) def select(self, *args, **kwargs): + """Select expressions.""" # do the fields projection here # TODO(kszucs): need to do smarter binding here since references may # point to any of the relations in the join chain @@ -4447,35 +4459,37 @@ def select(self, *args, **kwargs): # with a field referencing one of the relations in the join chain fields = {ops.Field(self, k): v for k, v in self.op().fields.items()} values = {k: v.replace(fields, filter=ops.Value) for k, v in values.items()} - values = dereference_values(self.tables(), values) + values = dereference_values(self._tables, values) # TODO(kszucs): add reduction conversion here detect_foreign_values(values)? - return self.finish(values) + return self._finish(values) # TODO(kszucs): figure out a solution to automatically wrap all the # TableExpr methods including the docstrings and the signature def filter(self, *predicates): - return self.finish().filter(*predicates) + """Filter with `predicates`.""" + return self._finish().filter(*predicates) def order_by(self, *keys): - return self.finish().order_by(*keys) + """Order the join by the given keys.""" + return self._finish().order_by(*keys) - def tables(self): + @property + def _tables(self) -> Iterator[ops.TableNode]: node = self.op() - parents = [node.first] + yield node.first for join in node.rest: - parents.append(join.table) - return parents + yield join.table - def finish(self, fields=None): + def _finish(self, fields: Mapping[str, ops.Field] | None = None) -> ir.Table: node = self.op() if fields is None: # TODO(kszucs): clean this up with a nicer error message # raise if there are missing fields from either of the tables # raise on collisions collisions = [] - fields = set(self.op().fields.values()) - for rel in self.tables(): + fields = frozenset(self.op().fields.values()) + for rel in self._tables: for k in rel.schema: f = ops.Field(rel, k) if f not in fields: