Skip to content

Commit

Permalink
refactor(polars): update the polars backend to use the new relational…
Browse files Browse the repository at this point in the history
… abstractions
  • Loading branch information
kszucs committed Jan 4, 2024
1 parent eb31002 commit d76bc34
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 115 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/ibis-backends.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ jobs:
title: Datafusion
extras:
- datafusion
# - name: polars
# title: Polars
# extras:
# - polars
# - deltalake
- name: polars
title: Polars
extras:
- polars
- deltalake
# - name: mysql
# title: MySQL
# services:
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/pandas/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def split_join_predicates(left, right, predicates, only_equality=True):

@replace(ops.JoinChain)
def rewrite_join(_, **kwargs):
# TODO(kszucs): JoinTable.index can be used as a prefix
prefixes = {}
prefixes[_.first] = prefix = str(len(prefixes))
left = PandasRename.from_prefix(_.first, prefix)
Expand Down
26 changes: 14 additions & 12 deletions ibis/backends/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis.backends.base import BaseBackend, Database
from ibis.backends.pandas.rewrites import (
bind_unbound_table,
replace_parameter,
rewrite_join,
)
from ibis.backends.polars.compiler import translate
from ibis.backends.polars.datatypes import dtype_to_polars, schema_from_polars
from ibis.common.patterns import Replace
from ibis.util import gen_name, normalize_filename

if TYPE_CHECKING:
Expand Down Expand Up @@ -379,20 +383,18 @@ def has_operation(cls, operation: type[ops.Value]) -> bool:
def compile(
self, expr: ir.Expr, params: Mapping[ir.Expr, object] | None = None, **_: Any
):
node = expr.op()
ctx = self._context

if params:
if params is None:
params = dict()
else:
params = {param.op(): value for param, value in params.items()}
rule = Replace(
ops.ScalarParameter,
lambda _: ops.Literal(value=params[_], dtype=_.dtype),
)
node = node.replace(rule)
expr = node.to_expr()

node = expr.as_table().op()
return translate(node, ctx=ctx)
node = node.replace(
rewrite_join | replace_parameter | bind_unbound_table,
context={"params": params, "backend": self},
)

return translate(node, ctx=self._context)

def _get_schema_using_query(self, query: str) -> sch.Schema:
return schema_from_polars(
Expand Down
195 changes: 126 additions & 69 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.pandas.rewrites import PandasAsofJoin, PandasJoin, PandasRename
from ibis.backends.polars.datatypes import dtype_to_polars, schema_from_polars
from ibis.expr.operations.udf import InputType
from ibis.util import gen_name


def _expr_method(expr, op, methods):
Expand Down Expand Up @@ -59,7 +61,7 @@ def table(op, **_):

@translate.register(ops.DummyTable)
def dummy_table(op, **kw):
selections = [translate(arg, **kw) for arg in op.values]
selections = [translate(arg, **kw) for name, arg in op.values.items()]
return pl.DataFrame().lazy().select(selections)


Expand Down Expand Up @@ -181,7 +183,7 @@ def _cast(op, strict=True, **kw):
return arg.cast(typ, strict=strict)


@translate.register(ops.TableColumn)
@translate.register(ops.Field)
def column(op, **_):
return pl.col(op.name)

Expand All @@ -196,45 +198,41 @@ def sort_key(op, **kw):
return arg.sort(reverse=descending) # pragma: no cover


@translate.register(ops.Selection)
def selection(op, **kw):
lf = translate(op.table, **kw)

if op.predicates:
predicates = map(partial(translate, **kw), op.predicates)
predicate = reduce(operator.and_, predicates)
lf = lf.filter(predicate)
@translate.register(ops.Project)
def project(op, **kw):
lf = translate(op.parent, **kw)

selections = []
unnests = []
for arg in op.selections:
if isinstance(arg, ops.TableNode):
for name in arg.schema.names:
column = ops.TableColumn(table=arg, name=name)
selections.append(translate(column, **kw))
elif (
isinstance(arg, ops.Alias) and isinstance(unnest := arg.arg, ops.Unnest)
) or isinstance(unnest := arg, ops.Unnest):
name = arg.name
for name, arg in op.values.items():
if isinstance(arg, ops.Unnest):
unnests.append(name)
selections.append(translate(unnest.arg, **kw).alias(name))
translated = translate(arg.arg, **kw)
elif isinstance(arg, ops.Value):
selections.append(translate(arg, **kw))
translated = translate(arg, **kw)
else:
raise com.TranslationError(
"Polars backend is unable to compile selection with "
f"operation type of {type(arg)}"
)
selections.append(translated.alias(name))

if selections:
lf = lf.select(selections)

if unnests:
lf = lf.explode(*unnests)

if op.sort_keys:
by = [key.name for key in op.sort_keys]
descending = [key.descending for key in op.sort_keys]
return lf


@translate.register(ops.Sort)
def sort(op, **kw):
lf = translate(op.parent, **kw)

if op.keys:
by = [key.name for key in op.keys]
descending = [key.descending for key in op.keys]
try:
lf = lf.sort(by, descending=descending)
except TypeError: # pragma: no cover
Expand All @@ -243,6 +241,18 @@ def selection(op, **kw):
return lf


@translate.register(ops.Filter)
def filter_(op, **kw):
lf = translate(op.parent, **kw)

if op.predicates:
predicates = map(partial(translate, **kw), op.predicates)
predicate = reduce(operator.and_, predicates)
lf = lf.filter(predicate)

return lf


@translate.register(ops.Limit)
def limit(op, **kw):
if (n := op.n) is not None and not isinstance(n, int):
Expand All @@ -251,75 +261,99 @@ def limit(op, **kw):
if not isinstance(offset := op.offset, int):
raise NotImplementedError("Dynamic offset not supported")

lf = translate(op.table, **kw)
lf = translate(op.parent, **kw)
return lf.slice(offset, n)


@translate.register(ops.Aggregation)
@translate.register(ops.Aggregate)
def aggregation(op, **kw):
lf = translate(op.table, **kw)
lf = translate(op.parent, **kw)

if op.predicates:
lf = lf.filter(
reduce(
operator.and_,
map(partial(translate, **kw), op.predicates),
if op.groups:
# project first to handle computed group by columns
lf = (
lf.with_columns(
[translate(arg, **kw).alias(name) for name, arg in op.groups.items()]
)
.group_by(list(op.groups.keys()))
.agg
)

# project first to handle computed group by columns
lf = lf.with_columns([translate(arg, **kw) for arg in op.by])

if op.by:
lf = lf.group_by([pl.col(by.name) for by in op.by]).agg
else:
lf = lf.select

if op.metrics:
metrics = [translate(arg, **kw).alias(arg.name) for arg in op.metrics]
metrics = [translate(arg, **kw).alias(name) for name, arg in op.metrics.items()]
lf = lf(metrics)

return lf


_join_types = {
ops.InnerJoin: "inner",
ops.LeftJoin: "left",
ops.RightJoin: "right",
ops.OuterJoin: "outer",
ops.LeftAntiJoin: "anti",
ops.LeftSemiJoin: "semi",
}
@translate.register(PandasRename)
def rename(op, **kw):
parent = translate(op.parent, **kw)
return parent.rename(op.mapping)


@translate.register(ops.Join)
@translate.register(PandasJoin)
def join(op, **kw):
how = op.how
left = translate(op.left, **kw)
right = translate(op.right, **kw)

if isinstance(op, ops.RightJoin):
# workaround required for https://github.com/pola-rs/polars/issues/13130
prefix = gen_name("on")
left_on = {f"{prefix}_{i}": translate(v, **kw) for i, v in enumerate(op.left_on)}
right_on = {f"{prefix}_{i}": translate(v, **kw) for i, v in enumerate(op.right_on)}
left = left.with_columns(**left_on)
right = right.with_columns(**right_on)
on = list(left_on.keys())

if how == "right":
how = "left"
left, right = right, left
else:
how = _join_types[type(op)]

left_on, right_on = [], []
for pred in op.predicates:
if isinstance(pred, ops.Equals):
left_on.append(translate(pred.left, **kw))
right_on.append(translate(pred.right, **kw))
else:
raise com.TranslationError(
"Polars backend is unable to compile join predicate "
f"with operation type of {type(pred)}"
)
joined = left.join(right, on=on, how=how)
joined = joined.drop(columns=on)

return joined

return left.join(right, left_on=left_on, right_on=right_on, how=how)

@translate.register(PandasAsofJoin)
def asof_join(op, **kw):
left = translate(op.left, **kw)
right = translate(op.right, **kw)

# workaround required for https://github.com/pola-rs/polars/issues/13130
on, by = gen_name("on"), gen_name("by")
left_on = {f"{on}_{i}": translate(v, **kw) for i, v in enumerate(op.left_on)}
right_on = {f"{on}_{i}": translate(v, **kw) for i, v in enumerate(op.right_on)}
left_by = {f"{by}_{i}": translate(v, **kw) for i, v in enumerate(op.left_by)}
right_by = {f"{by}_{i}": translate(v, **kw) for i, v in enumerate(op.right_by)}

left = left.with_columns(**left_on, **left_by)
right = right.with_columns(**right_on, **right_by)

on = list(left_on.keys())
by = list(left_by.keys())

if op.operator in {ops.Less, ops.LessEqual}:
direction = "forward"
elif op.operator in {ops.Greater, ops.GreaterEqual}:
direction = "backward"
elif op.operator == ops.Equals:
direction = "nearest"
else:
raise NotImplementedError(f"Operator {operator} not supported for asof join")

assert len(on) == 1
joined = left.join_asof(right, on=on[0], by=by, strategy=direction)
joined = joined.drop(columns=on + by)
return joined


@translate.register(ops.DropNa)
def dropna(op, **kw):
lf = translate(op.table, **kw)
lf = translate(op.parent, **kw)

if op.subset is None:
subset = None
Expand All @@ -337,10 +371,28 @@ def dropna(op, **kw):

@translate.register(ops.FillNa)
def fillna(op, **kw):
table = translate(op.table, **kw)
table = translate(op.parent, **kw)

columns = []
for name, dtype in op.table.schema.items():

repls = op.replacements

if isinstance(repls, Mapping):

def get_replacement(name):
repl = repls.get(name)
if repl is not None:
return repl.value
else:
return None

else:
value = repls.value

def get_replacement(_):
return value

for name, dtype in op.parent.schema.items():
column = pl.col(name)
if isinstance(op.replacements, Mapping):
value = op.replacements.get(name)
Expand Down Expand Up @@ -422,11 +474,11 @@ def greatest(op, **kw):
return pl.max_horizontal(arg)


@translate.register(ops.InColumn)
@translate.register(ops.InSubquery)
def in_column(op, **kw):
value = translate(op.value, **kw)
options = translate(op.options, **kw)
return value.is_in(options)
needle = translate(op.needle, **kw)
return needle.is_in(value)


@translate.register(ops.InValues)
Expand Down Expand Up @@ -734,7 +786,7 @@ def correlation(op, **kw):

@translate.register(ops.Distinct)
def distinct(op, **kw):
table = translate(op.table, **kw)
table = translate(op.parent, **kw)
return table.unique()


Expand Down Expand Up @@ -1163,6 +1215,11 @@ def execute_self_reference(op, **kw):
return translate(op.table, **kw)


@translate.register(ops.JoinTable)
def execute_join_table(op, **kw):
return translate(op.parent, **kw)


@translate.register(ops.CountDistinctStar)
def execute_count_distinct_star(op, **kw):
arg = pl.struct(*op.arg.schema.names)
Expand Down
Loading

0 comments on commit d76bc34

Please sign in to comment.