-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,10 +3,14 @@ | |
from __future__ import annotations | ||
|
||
import re | ||
import typing | ||
from functools import singledispatchmethod | ||
|
||
import sqlglot as sg | ||
import sqlglot.expressions as sge | ||
from sqlglot.dialects import BigQuery | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong. |
||
from sqlglot.dialects.bigquery import _alias_ordered_group | ||
from sqlglot.helper import find_new_name | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
chelsea-lin
Author
Owner
|
||
|
||
import ibis.common.exceptions as com | ||
import ibis.expr.datatypes as dt | ||
|
@@ -27,6 +31,99 @@ | |
_NAME_REGEX = re.compile(r'[^!"$()*,./;?@[\\\]^`{}~\n]+') | ||
|
||
|
||
def _explode_to_unnest(expression: sge.Expression) -> sge.Expression: | ||
"""TODO | ||
This comment has been minimized.
Sorry, something went wrong.
tswast
|
||
""" | ||
if isinstance(expression, sge.Select): | ||
|
||
taken_select_names = set(expression.named_selects) | ||
|
||
def new_name(names: typing.Set[str], name: str) -> str: | ||
name = find_new_name(names, name) | ||
names.add(name) | ||
return name | ||
|
||
# we use list here because expression.selects is mutated inside the loop | ||
for select in list(expression.selects): | ||
explode = select.find(sge.Explode) | ||
|
||
if explode: | ||
pos_alias = "" | ||
explode_alias = "" | ||
|
||
if isinstance(select, sge.Alias): | ||
explode_alias = select.args["alias"] | ||
alias = select | ||
elif isinstance(select, sge.Aliases): | ||
pos_alias = select.aliases[0] | ||
explode_alias = select.aliases[1] | ||
alias = select.replace(sge.alias_(select.this, "", copy=False)) | ||
else: | ||
alias = select.replace(sge.alias_(select, "")) | ||
explode = alias.find(sge.Explode) | ||
assert explode | ||
|
||
is_posexplode = isinstance(explode, sge.Posexplode) | ||
explode_arg = explode.this | ||
|
||
# This ensures that we won't use [POS]EXPLODE's argument as a new selection | ||
if isinstance(explode_arg, sge.Column): | ||
taken_select_names.add(explode_arg.output_name) | ||
|
||
if not explode_alias: | ||
explode_alias = new_name(taken_select_names, "col") | ||
|
||
if is_posexplode: | ||
pos_alias = new_name(taken_select_names, "pos") | ||
|
||
if not pos_alias: | ||
pos_alias = new_name(taken_select_names, "pos") | ||
|
||
alias.set("alias", sge.to_identifier(explode_alias)) | ||
|
||
expressions = expression.expressions | ||
index = expressions.index(alias) | ||
expressions[index].replace(sge.column(explode_alias)) | ||
|
||
offset = None | ||
if is_posexplode: | ||
expressions = expression.expressions | ||
expressions.insert( | ||
index + 1, | ||
sge.column(pos_alias), | ||
) | ||
expression.set("expressions", expressions) | ||
offset = sge.to_identifier(pos_alias) | ||
|
||
join_type = "LEFT" if isinstance(explode, sge.ExplodeOuter) or isinstance(explode, sge.PosexplodeOuter) else "CROSS" | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
chelsea-lin
Author
Owner
|
||
|
||
expression.join( | ||
sge.alias_( | ||
sge.Unnest( | ||
expressions=[explode_arg.copy()], | ||
offset=offset, | ||
), | ||
"", | ||
table=[explode_alias], | ||
), | ||
join_type=join_type, | ||
copy=False, | ||
) | ||
|
||
return expression | ||
|
||
|
||
BigQuery.Generator.TRANSFORMS |= { | ||
sge.Select: sg.transforms.preprocess([ | ||
# sg.transforms.explode_to_unnest(), | ||
_explode_to_unnest, | ||
sg.transforms.eliminate_distinct_on, | ||
_alias_ordered_group, | ||
sg.transforms.eliminate_semi_and_anti_joins, | ||
]), | ||
} | ||
|
||
|
||
@replace(p.WindowFunction(p.MinRank | p.DenseRank, y @ p.WindowFrame(start=None))) | ||
def exclude_unsupported_window_frame_from_rank(_, y): | ||
return ops.Subtract( | ||
|
@@ -729,6 +826,17 @@ def visit_CountDistinct(self, op, *, arg, where): | |
arg = self.if_(where, arg, NULL) | ||
return self.f.count(sge.Distinct(expressions=[arg])) | ||
|
||
@visit_node.register(ops.Unnest) | ||
def visit_Unnest(self, op, *, arg, offset, preserve_empty): | ||
if not offset and not preserve_empty: | ||
return sge.Explode(this=arg) | ||
elif not offset and preserve_empty: | ||
return sge.ExplodeOuter(this=arg) | ||
elif offset and not preserve_empty: | ||
return sge.Posexplode(this=arg) | ||
else: | ||
return sge.PosexplodeOuter(this=arg) | ||
|
||
@visit_node.register(ops.CountDistinctStar) | ||
@visit_node.register(ops.DateDiff) | ||
@visit_node.register(ops.ExtractAuthority) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
from __future__ import annotations | ||
|
||
import inspect | ||
from typing import TYPE_CHECKING, Callable | ||
from typing import TYPE_CHECKING, Callable, Optional | ||
|
||
from public import public | ||
|
||
|
@@ -266,7 +266,11 @@ def repeat(self, n: int | ir.IntegerValue) -> ArrayValue: | |
|
||
__mul__ = __rmul__ = repeat | ||
|
||
def unnest(self) -> ir.Value: | ||
def unnest( | ||
self, | ||
offset: Optional[bool] = None, | ||
preserve_empty: Optional[bool] = None, | ||
This comment has been minimized.
Sorry, something went wrong.
tswast
|
||
) -> ir.Value: | ||
"""Flatten an array into a column. | ||
::: {.callout-note} | ||
|
@@ -305,7 +309,7 @@ def unnest(self) -> ir.Value: | |
ir.Value | ||
Unnested array | ||
""" | ||
expr = ops.Unnest(self).to_expr() | ||
expr = ops.Unnest(self, offset, preserve_empty).to_expr() | ||
try: | ||
return expr.name(self.get_name()) | ||
except com.ExpressionError: | ||
|
Nit: In Google-style we generally import modules, not classes/functions. This isn't a strict rule in Ibis but is also generally the case.
That said, I do see https://github.com/ibis-project/ibis/blob/d7dd8065dd5af2db1124b7aa753c61822e33c8c9/ibis/backends/datafusion/__init__.py#L15 so I don't think this would be totally out of place.