Skip to content

Commit

Permalink
refactor(impala): port to sqlglot
Browse files Browse the repository at this point in the history
BREAKING CHANGE: Impala UDFs no longer require explicit registration. Remove any calls to `Function.register`. If you were passing `database` to `Function.register`, pass that to `scalar_function` or `aggregate_function` as appropriate.
  • Loading branch information
cpcloud committed Jan 17, 2024
1 parent 3727798 commit b2cfd65
Show file tree
Hide file tree
Showing 316 changed files with 2,760 additions and 1,369 deletions.
44 changes: 22 additions & 22 deletions .github/workflows/ibis-backends.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,16 @@ jobs:
- postgres
sys-deps:
- libgeos-dev
# - name: impala
# title: Impala
# extras:
# - impala
# services:
# - impala
# - kudu
# sys-deps:
# - cmake
# - ninja-build
- name: impala
title: Impala
extras:
- impala
services:
- impala
- kudu
sys-deps:
- cmake
- ninja-build
# - name: mssql
# title: MS SQL Server
# extras:
Expand Down Expand Up @@ -223,18 +223,18 @@ jobs:
- postgres
sys-deps:
- libgeos-dev
# - os: windows-latest
# backend:
# name: impala
# title: Impala
# extras:
# - impala
# services:
# - impala
# - kudu
# sys-deps:
# - cmake
# - ninja-build
- os: windows-latest
backend:
name: impala
title: Impala
extras:
- impala
services:
- impala
- kudu
sys-deps:
- cmake
- ninja-build
# - os: windows-latest
# backend:
# name: mssql
Expand Down
1 change: 0 additions & 1 deletion ibis/backends/base/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import pandas as pd
import pyarrow as pa

raise RuntimeError("Temporarily make the SQL backends dysfunctional")

__all__ = ["BaseSQLBackend"]

Expand Down
6 changes: 3 additions & 3 deletions ibis/backends/base/sql/ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def _pieces(self):
yield self._storage()
yield self._location()
yield "AS"
yield self.select.compile()
yield self.select

def _partitioned_by(self):
if self.partition is not None:
Expand All @@ -212,7 +212,7 @@ def __init__(self, table_name, select, database=None, can_exist=False):
@property
def _pieces(self):
yield "AS"
yield self.select.compile()
yield self.select

@property
def _prefix(self):
Expand Down Expand Up @@ -352,7 +352,7 @@ def compile(self):
else:
partition = ""

select_query = self.select.compile()
select_query = self.select
scoped_name = self._get_scoped_name(self.table_name, self.database)
return f"{cmd} {scoped_name}{partition}\n{select_query}"

Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/base/sql/registry/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,10 +355,8 @@ def _floor(t, op):
ops.IfElse: fixed_arity("if", 3),
ops.Between: between,
ops.InValues: binary_infix.in_values,
ops.InSubquery: binary_infix.in_column,
ops.SimpleCase: case.simple_case,
ops.SearchedCase: case.searched_case,
ops.Field: table_column,
ops.DateAdd: timestamp.timestamp_op("date_add"),
ops.DateSub: timestamp.timestamp_op("date_sub"),
ops.DateDiff: timestamp.timestamp_op("datediff"),
Expand Down
18 changes: 14 additions & 4 deletions ibis/backends/base/sqlglot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,14 @@ def visit_Window(self, op, *, how, func, start, end, group_by, order_by):
end_value = end.get("value", "UNBOUNDED")
end_side = end.get("side", "FOLLOWING")

if getattr(start_value, "this", None) == "0":
start_value = "CURRENT ROW"
start_side = None

if getattr(end_value, "this", None) == "0":
end_value = "CURRENT ROW"
end_side = None

spec = sge.WindowSpec(
kind=how.upper(),
start=start_value,
Expand Down Expand Up @@ -1004,10 +1012,12 @@ def visit_JoinLink(self, op, *, how, table, predicates):
"cross": "cross",
"outer": "outer",
}
assert predicates
return sge.Join(
this=table, side=sides[how], kind=kinds[how], on=sg.and_(*predicates)
)
assert (
predicates or how == "cross"
), "expected non-empty predicates when not a cross join"

on = sg.and_(*predicates) if predicates else None
return sge.Join(this=table, side=sides[how], kind=kinds[how], on=on)

@staticmethod
def _gen_valid_name(name: str) -> str:
Expand Down
8 changes: 8 additions & 0 deletions ibis/backends/base/sqlglot/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,15 @@ class SQLiteType(SqlglotType):
dialect = "sqlite"


class ImpalaType(SqlglotType):
dialect = "impala"

default_decimal_precision = 9
default_decimal_scale = 0


class PySparkType(SqlglotType):
dialect = "spark"

default_decimal_precision = 38
default_decimal_scale = 18
6 changes: 2 additions & 4 deletions ibis/backends/base/sqlglot/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,8 @@ def rewrite_last_to_last_value(_, x, y):


@replace(p.WindowFunction(frame=y @ p.WindowFrame(order_by=())))
def rewrite_empty_order_by_window(_, y):
import ibis

return _.copy(frame=y.copy(order_by=(ibis.NA,)))
def rewrite_empty_order_by_window(_, y, **__):
return _.copy(frame=y.copy(order_by=(ops.NULL,)))


@replace(p.WindowFunction(p.RowNumber | p.NTile, y))
Expand Down
3 changes: 3 additions & 0 deletions ibis/backends/datafusion/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from itertools import starmap

import sqlglot as sg
import sqlglot.expressions as sge
from sqlglot import exp, transforms
from sqlglot.dialects import Postgres
from sqlglot.dialects.dialect import rename_func
Expand Down Expand Up @@ -102,6 +103,8 @@ def visit_NonNullLiteral(self, op, *, value, dtype):
return self.f.date_trunc("day", value.isoformat())
elif dtype.is_binary():
return sg.exp.HexString(this=value.hex())
elif dtype.is_uuid():
return sge.convert(str(value))
else:
return None

Expand Down
Loading

0 comments on commit b2cfd65

Please sign in to comment.