Skip to content

Commit

Permalink
Merge branch 'main' into feat/flink/map-support
Browse files Browse the repository at this point in the history
  • Loading branch information
mfatihaktas authored Feb 23, 2024
2 parents 2d9f240 + de174a2 commit ac20e5d
Show file tree
Hide file tree
Showing 27 changed files with 164 additions and 80 deletions.
28 changes: 19 additions & 9 deletions docs/how-to/extending/sql.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ t = ibis.examples.penguins.fetch(backend=con, table_name="penguins") # <2>
1. Connect to an in-memory DuckDB database
2. Read in the `penguins` example with our DuckDB database, and name it `penguins`

## `Table.sql`
## [`Table.sql`](../../reference/expression-tables.qmd#ibis.expr.types.relations.Table.sql)

At the highest level there's the `Table.sql` method. This method allows you to
run arbitrary `SELECT` statements against a table expression:
Expand Down Expand Up @@ -119,14 +119,18 @@ another.

## `Backend.sql`

::: {.callout-tip}
## `Backend.sql` supports the `dialect` argument.
:::

There's also the `Backend.sql` method, which can handle arbitrary `SELECT`
statements as well and returns an Ibis table expression.

The main difference with `Table.sql` is that `Backend.sql` **can only refer to
tables that already exist in the database**, because the API is defined on
`Backend` instances.

After the `Backend.sql` call, however, you're able to mix and match similar
After calling `Backend.sql`, however, you're able to mix and match similar
to `Table.sql`:

```{python}
Expand All @@ -147,10 +151,6 @@ to `Table.sql`:
)
```

::: {.callout-tip}
## `Backend.sql` also supports the `dialect` argument.
:::

## `Backend.raw_sql`

At the lowest level there's `Backend.raw_sql` which is for those situations
Expand All @@ -161,16 +161,26 @@ modeled as a table expression.
with the SQL statement's execution.

::: {.callout-caution}
## You **must** close the cursor returned from `raw_sql` to avoid leaking resources
## You may need to close the cursor returned from `raw_sql` to avoid leaking resources

Failure to do so can result in a variety of errors and hard-to-debug behaviors.

For DDL statements, you may not need to close the cursor since DDL statements
do not produce results.

Failure to do results in variety of errors and hard-to-debug behaviors.
Depending on the backend you may have to experiment to see when closing the
cursor is necessary.

In most cases a cursor returned from a `SELECT` statement requires a call to
`close()`.

The easiest way to do this is to use a context manager:

```{python}
from contextlib import closing
with closing(con.raw_sql("CREATE TEMP TABLE my_table AS SELECT * FROM RANGE(10)")) as c:
with closing(con.raw_sql("SELECT * FROM RANGE(10)")) as c:
... # do something with c if necessary
```
:::
7 changes: 7 additions & 0 deletions ibis/backends/flink/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,13 @@ def _minimize_spec(start, end, spec):
and end.following
):
return None
elif (
isinstance(getattr(end, "value", None), ops.Cast)
and end.value.arg.value == 0
and end.following
):
spec.args["end"] = "CURRENT ROW"
spec.args["end_side"] = None
return spec

def visit_TumbleWindowingTVF(self, op, *, table, time_col, window_size, offset):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ SELECT
FROM (
SELECT
`t0`.*
FROM TABLE(TUMBLE(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '15' MINUTE)) AS `t0`
FROM TABLE(TUMBLE(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '15' MINUTE(2))) AS `t0`
) AS `t1`
GROUP BY
`t1`.`window_start`,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ FROM (
FROM (
SELECT
`t0`.*
FROM TABLE(TUMBLE(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '600' SECOND)) AS `t0`
FROM TABLE(TUMBLE(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '600' SECOND(3))) AS `t0`
) AS `t1`
) AS `t2`
) AS `t3`
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SELECT
`t0`.*
FROM TABLE(
CUMULATE(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '10' SECOND, INTERVAL '1' MINUTE)
CUMULATE(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '10' SECOND(2), INTERVAL '1' MINUTE(2))
) AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
SELECT
`t0`.*
FROM TABLE(HOP(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '1' MINUTE, INTERVAL '15' MINUTE)) AS `t0`
FROM TABLE(
HOP(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '1' MINUTE(2), INTERVAL '15' MINUTE(2))
) AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
`t0`.*
FROM TABLE(TUMBLE(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '15' MINUTE)) AS `t0`
FROM TABLE(TUMBLE(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '15' MINUTE(2))) AS `t0`

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
SUM(`t0`.`f`) OVER (ORDER BY `t0`.`f` ASC NULLS LAST RANGE BETWEEN INTERVAL '500' MINUTE preceding AND CAST(0 AS INTERVAL MINUTE) following) AS `Sum(f)`
SUM(`t0`.`f`) OVER (ORDER BY `t0`.`f` ASC NULLS LAST RANGE BETWEEN INTERVAL '500' MINUTE(3) preceding AND CURRENT ROW) AS `Sum(f)`
FROM `table` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
SUM(`t0`.`f`) OVER (ORDER BY `t0`.`f` ASC NULLS LAST ROWS BETWEEN 1000 preceding AND CAST(0 AS SMALLINT) following) AS `Sum(f)`
SUM(`t0`.`f`) OVER (ORDER BY `t0`.`f` ASC NULLS LAST ROWS BETWEEN 1000 preceding AND CURRENT ROW) AS `Sum(f)`
FROM `table` AS `t0`
26 changes: 15 additions & 11 deletions ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,21 @@ def _from_url(self, url: str, **kwargs):
"""

url = urlparse(url)
database, schema = url.path[1:].split("/", 1)
query_params = parse_qs(url.query)
(warehouse,) = query_params.pop("warehouse", (None,))
connect_args = {
"user": url.username,
"password": url.password or "",
"account": url.hostname,
"warehouse": warehouse,
"database": database or "",
"schema": schema or "",
}
if url.path:
database, schema = url.path[1:].split("/", 1)
query_params = parse_qs(url.query)
(warehouse,) = query_params.pop("warehouse", (None,))
connect_args = {
"user": url.username,
"password": url.password or "",
"account": url.hostname,
"warehouse": warehouse,
"database": database or "",
"schema": schema or "",
}
else:
connect_args = {}
query_params = {}

for name, value in query_params.items():
if len(value) > 1:
Expand Down
3 changes: 3 additions & 0 deletions ibis/backends/snowflake/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,6 @@ def test_compile_does_not_make_requests(con, mocker):
def test_no_argument_connection():
con = ibis.snowflake.connect()
assert con.list_tables() is not None

con = ibis.connect("snowflake://")
assert con.list_tables() is not None
4 changes: 2 additions & 2 deletions ibis/backends/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import ibis.expr.datatypes as dt
from ibis.backends.sql.compiler import SQLGlotCompiler
from ibis.common.typing import SupportsSchema
from ibis.expr.schema import SchemaLike


class SQLBackend(BaseBackend):
Expand Down Expand Up @@ -137,7 +137,7 @@ def _log(self, sql: str) -> None:
def sql(
self,
query: str,
schema: SupportsSchema | None = None,
schema: SchemaLike | None = None,
dialect: str | None = None,
) -> ir.Table:
query = self._transpile_sql(query, dialect=dialect)
Expand Down
41 changes: 41 additions & 0 deletions ibis/backends/sql/dialects.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextlib
import math

import sqlglot.expressions as sge
from sqlglot import transforms
Expand Down Expand Up @@ -68,6 +69,45 @@ class Generator(Postgres.Generator):
}


def _calculate_precision(interval_value: int) -> int:
"""Calculate interval precision.
FlinkSQL interval data types use leading precision and fractional-
seconds precision. Because the leading precision defaults to 2, we need to
specify a different precision when the value exceeds 2 digits.
(see
https://learn.microsoft.com/en-us/sql/odbc/reference/appendixes/interval-literals)
"""
# log10(interval_value) + 1 is equivalent to len(str(interval_value)), but is significantly
# faster and more memory-efficient
if interval_value == 0:
return 0
if interval_value < 0:
raise ValueError(
f"Expecting value to be a non-negative integer, got {interval_value}"
)
return int(math.log10(interval_value)) + 1


def _interval_with_precision(self, e):
"""Format interval with precision."""
arg = e.args["this"].this
formatted_arg = arg
with contextlib.suppress(AttributeError):
formatted_arg = arg.sql(self.dialect)

unit = e.args["unit"]
# when formatting interval scalars, need to quote arg and add precision
if isinstance(arg, str):
formatted_arg = f"'{formatted_arg}'"
prec = _calculate_precision(int(arg))
prec = max(prec, 2)
unit += f"({prec})"

return f"INTERVAL {formatted_arg} {unit}"


class Flink(Hive):
class Generator(Hive.Generator):
TYPE_MAPPING = Hive.Generator.TYPE_MAPPING.copy() | {
Expand All @@ -91,6 +131,7 @@ class Generator(Hive.Generator):
sge.DayOfYear: rename_func("dayofyear"),
sge.DayOfWeek: rename_func("dayofweek"),
sge.DayOfMonth: rename_func("dayofmonth"),
sge.Interval: _interval_with_precision,
}

class Tokenizer(Hive.Tokenizer):
Expand Down
36 changes: 33 additions & 3 deletions ibis/backends/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def test_grouped_bounded_following_window(backend, alltypes, df, preceding, foll


@pytest.mark.parametrize(
"window_fn",
"window_fn, window_size",
[
param(
lambda t: ibis.window(
Expand All @@ -462,6 +462,7 @@ def test_grouped_bounded_following_window(backend, alltypes, df, preceding, foll
group_by=[t.string_col],
order_by=[t.id],
),
3,
id="preceding-2-following-0",
),
param(
Expand All @@ -471,26 +472,55 @@ def test_grouped_bounded_following_window(backend, alltypes, df, preceding, foll
group_by=[t.string_col],
order_by=[t.id],
),
3,
id="preceding-2-following-0-tuple",
),
param(
lambda t: ibis.trailing_window(
preceding=2, group_by=[t.string_col], order_by=[t.id]
),
3,
id="trailing-2",
),
param(
lambda t: ibis.window(
# snowflake doesn't allow windows larger than 1000
preceding=999,
following=0,
group_by=[t.string_col],
order_by=[t.id],
),
1000,
id="large-preceding-999-following-0",
),
param(
lambda t: ibis.window(
preceding=1000, following=0, group_by=[t.string_col], order_by=[t.id]
),
1001,
marks=[
pytest.mark.notyet(
["snowflake"],
raises=SnowflakeProgrammingError,
reason="Windows larger than 1000 are not allowed",
)
],
id="large-preceding-1000-following-0",
),
],
)
@pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError)
def test_grouped_bounded_preceding_window(backend, alltypes, df, window_fn):
def test_grouped_bounded_preceding_window(
backend, alltypes, df, window_fn, window_size
):
window = window_fn(alltypes)
expr = alltypes.mutate(val=alltypes.double_col.sum().over(window))

result = expr.execute().set_index("id").sort_index()
gdf = df.sort_values("id").groupby("string_col")
expected = (
df.assign(
val=gdf.double_col.rolling(3, min_periods=1)
val=gdf.double_col.rolling(window_size, min_periods=1)
.sum()
.sort_index(level=1)
.reset_index(drop=True)
Expand Down
14 changes: 3 additions & 11 deletions ibis/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,7 @@
import sys
from abc import abstractmethod
from itertools import zip_longest
from types import ModuleType # noqa: F401
from typing import (
TYPE_CHECKING,
Any,
Generic, # noqa: F401
Optional,
TypeVar,
Union,
get_args,
get_origin,
)
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, get_args, get_origin
from typing import get_type_hints as _get_type_hints

from ibis.common.bases import Abstract
Expand All @@ -23,6 +13,7 @@
if TYPE_CHECKING:
from typing_extensions import Self


if sys.version_info >= (3, 10):
from types import UnionType
from typing import TypeAlias
Expand Down Expand Up @@ -144,6 +135,7 @@ def get_bound_typevars(obj: Any) -> dict[TypeVar, tuple[str, type]]:
Examples
--------
>>> from typing import Generic
>>> class MyStruct(Generic[T, U]):
... a: T
... b: U
Expand Down
Loading

0 comments on commit ac20e5d

Please sign in to comment.