Skip to content

Commit

Permalink
feat(bigquery): adding offset and preserved_empty parameters to unnest
Browse files Browse the repository at this point in the history
  • Loading branch information
chelsea-lin committed Feb 8, 2024
1 parent 6a017f1 commit d480cc2
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 3 deletions.
108 changes: 108 additions & 0 deletions ibis/backends/bigquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
from __future__ import annotations

import re
import typing
from functools import singledispatchmethod

import sqlglot as sg
import sqlglot.expressions as sge
from sqlglot.dialects import BigQuery

This comment has been minimized.

Copy link
@tswast

tswast Feb 8, 2024

Nit: In Google-style we generally import modules, not classes/functions. This isn't a strict rule in Ibis but is also generally the case.

That said, I do see https://github.com/ibis-project/ibis/blob/d7dd8065dd5af2db1124b7aa753c61822e33c8c9/ibis/backends/datafusion/__init__.py#L15 so I don't think this would be totally out of place.

This comment has been minimized.

Copy link
@chelsea-lin

chelsea-lin Feb 12, 2024

Author Owner

Good to know. Thanks!

from sqlglot.dialects.bigquery import _alias_ordered_group
from sqlglot.helper import find_new_name

This comment has been minimized.

Copy link
@tswast

tswast Feb 8, 2024

I wasn't sure if this was a public API, but it does appear to be OK.

https://sqlglot.com/sqlglot/helper.html#find_new_name

This comment has been minimized.

Copy link
@chelsea-lin

chelsea-lin Feb 12, 2024

Author Owner

Yes! After onboard sqlglot fixes. We don't need these changes anymore.


import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
Expand All @@ -27,6 +31,99 @@
_NAME_REGEX = re.compile(r'[^!"$()*,./;?@[\\\]^`{}~\n]+')


def _explode_to_unnest(expression: sge.Expression) -> sge.Expression:
"""TODO

This comment has been minimized.

Copy link
@tswast

tswast Feb 8, 2024

It would be good to explain why we are different from the version in sqlglot upstream and what has changed.

This comment has been minimized.

Copy link
@chelsea-lin

chelsea-lin Feb 12, 2024

Author Owner

Yes! After onboard sqlglot fixes. We don't need these changes anymore.

"""
if isinstance(expression, sge.Select):

taken_select_names = set(expression.named_selects)

def new_name(names: typing.Set[str], name: str) -> str:
name = find_new_name(names, name)
names.add(name)
return name

# we use list here because expression.selects is mutated inside the loop
for select in list(expression.selects):
explode = select.find(sge.Explode)

if explode:
pos_alias = ""
explode_alias = ""

if isinstance(select, sge.Alias):
explode_alias = select.args["alias"]
alias = select
elif isinstance(select, sge.Aliases):
pos_alias = select.aliases[0]
explode_alias = select.aliases[1]
alias = select.replace(sge.alias_(select.this, "", copy=False))
else:
alias = select.replace(sge.alias_(select, ""))
explode = alias.find(sge.Explode)
assert explode

is_posexplode = isinstance(explode, sge.Posexplode)
explode_arg = explode.this

# This ensures that we won't use [POS]EXPLODE's argument as a new selection
if isinstance(explode_arg, sge.Column):
taken_select_names.add(explode_arg.output_name)

if not explode_alias:
explode_alias = new_name(taken_select_names, "col")

if is_posexplode:
pos_alias = new_name(taken_select_names, "pos")

if not pos_alias:
pos_alias = new_name(taken_select_names, "pos")

alias.set("alias", sge.to_identifier(explode_alias))

expressions = expression.expressions
index = expressions.index(alias)
expressions[index].replace(sge.column(explode_alias))

offset = None
if is_posexplode:
expressions = expression.expressions
expressions.insert(
index + 1,
sge.column(pos_alias),
)
expression.set("expressions", expressions)
offset = sge.to_identifier(pos_alias)

join_type = "LEFT" if isinstance(explode, sge.ExplodeOuter) or isinstance(explode, sge.PosexplodeOuter) else "CROSS"

This comment has been minimized.

Copy link
@tswast

tswast Feb 8, 2024

Interesting. So these existed already?

This comment has been minimized.

Copy link
@chelsea-lin

chelsea-lin Feb 12, 2024

Author Owner

yes, sqlglot already defines these extended Explode object to support offset=True and preserved_empty=True. Though preserved_empty did not work well in the explode_to_unnest function, but it can be fixed by: tobymao/sqlglot#2941


expression.join(
sge.alias_(
sge.Unnest(
expressions=[explode_arg.copy()],
offset=offset,
),
"",
table=[explode_alias],
),
join_type=join_type,
copy=False,
)

return expression


BigQuery.Generator.TRANSFORMS |= {
sge.Select: sg.transforms.preprocess([
# sg.transforms.explode_to_unnest(),
_explode_to_unnest,
sg.transforms.eliminate_distinct_on,
_alias_ordered_group,
sg.transforms.eliminate_semi_and_anti_joins,
]),
}


@replace(p.WindowFunction(p.MinRank | p.DenseRank, y @ p.WindowFrame(start=None)))
def exclude_unsupported_window_frame_from_rank(_, y):
return ops.Subtract(
Expand Down Expand Up @@ -729,6 +826,17 @@ def visit_CountDistinct(self, op, *, arg, where):
arg = self.if_(where, arg, NULL)
return self.f.count(sge.Distinct(expressions=[arg]))

@visit_node.register(ops.Unnest)
def visit_Unnest(self, op, *, arg, offset, preserve_empty):
if not offset and not preserve_empty:
return sge.Explode(this=arg)
elif not offset and preserve_empty:
return sge.ExplodeOuter(this=arg)
elif offset and not preserve_empty:
return sge.Posexplode(this=arg)
else:
return sge.PosexplodeOuter(this=arg)

@visit_node.register(ops.CountDistinctStar)
@visit_node.register(ops.DateDiff)
@visit_node.register(ops.ExtractAuthority)
Expand Down
2 changes: 2 additions & 0 deletions ibis/expr/operations/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ class ArrayFilter(Value):
@public
class Unnest(Value):
arg: Value[dt.Array]
offset: Optional[bool] = None
preserve_empty: Optional[bool] = None

shape = ds.columnar

Expand Down
10 changes: 7 additions & 3 deletions ibis/expr/types/arrays.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import inspect
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Callable, Optional

from public import public

Expand Down Expand Up @@ -266,7 +266,11 @@ def repeat(self, n: int | ir.IntegerValue) -> ArrayValue:

__mul__ = __rmul__ = repeat

def unnest(self) -> ir.Value:
def unnest(
self,
offset: Optional[bool] = None,
preserve_empty: Optional[bool] = None,

This comment has been minimized.

Copy link
@tswast

tswast Feb 8, 2024

I think this can be just bool with default value of False

preserve_empty: bool = False

Also, we'd need to document the new arguments.

This comment has been minimized.

Copy link
@chelsea-lin

chelsea-lin Feb 12, 2024

Author Owner

Got it. Will fix in the official changes.

) -> ir.Value:
"""Flatten an array into a column.
::: {.callout-note}
Expand Down Expand Up @@ -305,7 +309,7 @@ def unnest(self) -> ir.Value:
ir.Value
Unnested array
"""
expr = ops.Unnest(self).to_expr()
expr = ops.Unnest(self, offset, preserve_empty).to_expr()
try:
return expr.name(self.get_name())
except com.ExpressionError:
Expand Down

0 comments on commit d480cc2

Please sign in to comment.