Skip to content

Commit

Permalink
feat(trino): port to sqlglot
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Feb 12, 2024
1 parent 29b5b53 commit 9c5a907
Show file tree
Hide file tree
Showing 56 changed files with 3,850 additions and 1,985 deletions.
37 changes: 14 additions & 23 deletions .github/workflows/ibis-backends.yml
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,12 @@ jobs:
# - freetds-dev
# - unixodbc-dev
# - tdsodbc
# - name: trino
# title: Trino
# extras:
# - trino
# - postgres
# services:
# - trino
- name: trino
title: Trino
extras:
- trino
services:
- trino
# - name: druid
# title: Druid
# extras:
Expand Down Expand Up @@ -257,15 +256,14 @@ jobs:
# - freetds-dev
# - unixodbc-dev
# - tdsodbc
# - os: windows-latest
# backend:
# name: trino
# title: Trino
# services:
# - trino
# extras:
# - trino
# - postgres
- os: windows-latest
backend:
name: trino
title: Trino
services:
- trino
extras:
- trino
# - os: windows-latest
# backend:
# name: druid
Expand Down Expand Up @@ -694,13 +692,6 @@ jobs:
# title: SQLite
# extras:
# - sqlite
# - name: trino
# title: Trino
# services:
# - trino
# extras:
# - trino
# - postgres
# - name: oracle
# title: Oracle
# serial: true
Expand Down
64 changes: 64 additions & 0 deletions ibis/backends/base/sqlglot/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,70 @@ class TrinoType(SqlglotType):
}
)

@classmethod
def _from_ibis_Interval(cls, dtype: dt.Interval) -> sge.DataType:
assert dtype.unit is not None, "interval unit cannot be None"
if (short := dtype.unit.short) in ("Y", "Q", "M"):
return sge.DataType(
this=typecode.INTERVAL,
expressions=[
sge.IntervalSpan(
this=sge.Var(this="YEAR"), expression=sge.Var(this="MONTH")
)
],
)
elif short in ("D", "h", "m", "s", "ms", "us", "ns"):
return sge.DataType(
this=typecode.INTERVAL,
expressions=[
sge.IntervalSpan(
this=sge.Var(this="DAY"), expression=sge.Var(this="SECOND")
)
],
)
else:
raise NotImplementedError(
f"Trino does not support {dtype.unit.name} intervals"
)

@classmethod
def _from_sqlglot_UBIGINT(cls):
return dt.Decimal(precision=19, scale=0, nullable=cls.default_nullable)

@classmethod
def _from_ibis_UInt64(cls, dtype):
return sge.DataType(
this=typecode.DECIMAL,
expressions=[
sge.DataTypeParam(this=sge.convert(19)),
sge.DataTypeParam(this=sge.convert(0)),
],
)

@classmethod
def _from_sqlglot_UINT(cls):
return dt.Int64(nullable=cls.default_nullable)

@classmethod
def _from_ibis_UInt32(cls, dtype):
return sge.DataType(this=typecode.BIGINT)

@classmethod
def _from_sqlglot_USMALLINT(cls):
return dt.Int32(nullable=cls.default_nullable)

@classmethod
def _from_ibis_UInt16(cls, dtype):
return sge.DataType(this=typecode.INT)

@classmethod
def _from_sqlglot_UTINYINT(cls):
return dt.Int16(nullable=cls.default_nullable)

@classmethod
def _from_ibis_UInt8(cls, dtype):
return sge.DataType(this=typecode.SMALLINT)


class DruidType(SqlglotType):
# druid doesn't have a sophisticated type system and hive is close enough
Expand Down
47 changes: 47 additions & 0 deletions ibis/backends/base/sqlglot/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,21 @@
import toolz
from public import public

import ibis.common.exceptions as com
import ibis.expr.datashape as ds
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.common.annotations import attribute
from ibis.common.collections import FrozenDict # noqa: TCH001
from ibis.common.deferred import var
from ibis.common.patterns import Object, replace
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.rewrites import p
from ibis.expr.schema import Schema

x = var("x")
y = var("y")


@public
class Select(ops.Relation):
Expand Down Expand Up @@ -140,3 +145,45 @@ def sqlize(node):
)
step2 = step1.replace(merge_select_select)
return step2


@replace(p.WindowFunction(p.First(x, y)))
def rewrite_first_to_first_value(_, x, y):
"""Rewrite Ibis's first to first_value when used in a window function."""
if y is not None:
raise com.UnsupportedOperationError(
"`first` with `where` is unsupported in a window function"
)
return _.copy(func=ops.FirstValue(x))


@replace(p.WindowFunction(p.Last(x, y)))
def rewrite_last_to_last_value(_, x, y):
"""Rewrite Ibis's last to last_value when used in a window function."""
if y is not None:
raise com.UnsupportedOperationError(
"`last` with `where` is unsupported in a window function"
)
return _.copy(func=ops.LastValue(x))


@replace(p.WindowFunction(frame=y @ p.WindowFrame(order_by=())))
def rewrite_empty_order_by_window(_, y):
import ibis

return _.copy(frame=y.copy(order_by=(ibis.NA,)))


@replace(p.WindowFunction(p.RowNumber | p.NTile, y))
def exclude_unsupported_window_frame_from_row_number(_, y):
return ops.Subtract(_.copy(frame=y.copy(start=None, end=None)), 1)


@replace(
p.WindowFunction(
p.Lag | p.Lead | p.PercentRank | p.CumeDist | p.Any | p.All,
y @ p.WindowFrame(start=None),
)
)
def exclude_unsupported_window_frame_from_ops(_, y):
return _.copy(frame=y.copy(start=None, end=0, order_by=y.order_by or (ops.NULL,)))
1 change: 0 additions & 1 deletion ibis/backends/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,6 @@ def ddl_con(ddl_backend):
"postgres",
"risingwave",
"sqlite",
"trino",
)
),
scope="session",
Expand Down
5 changes: 5 additions & 0 deletions ibis/backends/tests/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,8 @@
from snowflake.connector.errors import ProgrammingError as SnowflakeProgrammingError
except ImportError:
SnowflakeProgrammingError = None

try:
from trino.exceptions import TrinoUserError
except ImportError:
TrinoUserError = None
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
SELECT
"t0"."id",
"t0"."bool_col"
FROM "functional_alltypes" AS "t0"
LIMIT 11
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
SELECT
"t0"."id",
"t0"."bool_col"
FROM "functional_alltypes" AS "t0"
LIMIT 11
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SELECT
SUM("t0"."bigint_col") AS "Sum(bigint_col)"
FROM "functional_alltypes" AS "t0"
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
SELECT
*
FROM (
SELECT
"t0"."id",
"t0"."bool_col"
FROM "functional_alltypes" AS "t0"
LIMIT 10
) AS "t2"
LIMIT 11
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SELECT
CASE t0.continent
CASE "t0"."continent"
WHEN 'NA'
THEN 'North America'
WHEN 'SA'
Expand All @@ -15,8 +15,8 @@ SELECT
WHEN 'AN'
THEN 'Antarctica'
ELSE 'Unknown continent'
END AS cont,
SUM(t0.population) AS total_pop
FROM countries AS t0
END AS "cont",
SUM("t0"."population") AS "total_pop"
FROM "countries" AS "t0"
GROUP BY
1
18 changes: 7 additions & 11 deletions ibis/backends/tests/snapshots/test_sql/test_isin_bug/trino/out.sql
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
SELECT
t0.x IN (
"t0"."x" IN (
SELECT
t1.x
FROM (
SELECT
t0.x AS x
FROM t AS t0
WHERE
t0.x > 2
) AS t1
) AS "InColumn(x, x)"
FROM t AS t0
"t0"."x"
FROM "t" AS "t0"
WHERE
"t0"."x" > 2
) AS "InSubquery(x)"
FROM "t" AS "t0"
Loading

0 comments on commit 9c5a907

Please sign in to comment.