Skip to content

Commit

Permalink
refactor(flink): port to sqlglot
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Feb 11, 2024
1 parent 3c72303 commit f6dc380
Show file tree
Hide file tree
Showing 78 changed files with 1,436 additions and 1,805 deletions.
108 changes: 50 additions & 58 deletions .github/workflows/ibis-backends.yml
Original file line number Diff line number Diff line change
Expand Up @@ -177,22 +177,28 @@ jobs:
- oracle
services:
- oracle
# - name: flink
# title: Flink
# serial: true
# extras:
# - flink
# additional_deps:
# - apache-flink
# - pytest-split
# services:
# - flink
# - name: risingwave
# title: Risingwave
# services:
# - risingwave
# extras:
# - risingwave
- name: flink
title: Flink
serial: true
extras:
- flink
additional_deps:
- apache-flink
services:
- flink
include:
- os: ubuntu-latest
python-version: "3.10"
backend:
name: flink
title: Flink
serial: true
extras:
- flink
additional_deps:
- apache-flink
services:
- flink
exclude:
- os: windows-latest
backend:
Expand Down Expand Up @@ -296,32 +302,29 @@ jobs:
- oracle
services:
- oracle
# - os: windows-latest
# backend:
# name: flink
# title: Flink
# serial: true
# extras:
# - flink
# services:
# - flink
# - python-version: "3.11"
# backend:
# name: flink
# title: Flink
# serial: true
# extras:
# - flink
# services:
# - flink
# - os: windows-latest
# backend:
# name: risingwave
# title: Risingwave
# services:
# - risingwave
# extras:
# - risingwave
- os: windows-latest
backend:
name: flink
title: Flink
serial: true
extras:
- flink
additional_deps:
- apache-flink
services:
- flink
- os: ubuntu-latest
python-version: "3.11"
backend:
name: flink
title: Flink
serial: true
extras:
- flink
additional_deps:
- apache-flink
services:
- flink
- os: windows-latest
backend:
name: exasol
Expand Down Expand Up @@ -390,29 +393,18 @@ jobs:
IBIS_TEST_IMPALA_PORT: 21050
IBIS_EXAMPLES_DATA: ${{ runner.temp }}/examples-${{ matrix.backend.name }}-${{ matrix.os }}-${{ steps.install_python.outputs.python-version }}

# FIXME(deepyaman): If some backend-specific test, in test_ddl.py,
# executes before common tests, they will fail with:
# org.apache.flink.table.api.ValidationException: Table `default_catalog`.`default_database`.`functional_alltypes` was not found.
# Therefore, we run backend-specific tests second to avoid this.
# - name: "run serial tests: ${{ matrix.backend.name }}"
# if: matrix.backend.serial && matrix.backend.name == 'flink'
# run: |
# just ci-check -m ${{ matrix.backend.name }} ibis/backends/tests
# just ci-check -m ${{ matrix.backend.name }} ibis/backends/flink/tests
# env:
# IBIS_EXAMPLES_DATA: ${{ runner.temp }}/examples-${{ matrix.backend.name }}-${{ matrix.os }}-${{ steps.install_python.outputs.python-version }}
# FLINK_REMOTE_CLUSTER_ADDR: localhost
# FLINK_REMOTE_CLUSTER_PORT: "8081"
#
- name: "run serial tests: ${{ matrix.backend.name }}"
if: matrix.backend.serial # && matrix.backend.name != 'flink'
if: matrix.backend.serial
run: just ci-check -m ${{ matrix.backend.name }}
env:
FLINK_REMOTE_CLUSTER_ADDR: localhost
FLINK_REMOTE_CLUSTER_PORT: "8081"
IBIS_EXAMPLES_DATA: ${{ runner.temp }}/examples-${{ matrix.backend.name }}-${{ matrix.os }}-${{ steps.install_python.outputs.python-version }}

- name: check that no untracked files were produced
shell: bash
run: git checkout poetry.lock pyproject.toml && ! git status --porcelain | tee /dev/stderr | grep .
run: |
! git status --porcelain | tee /dev/stderr | grep .
- name: upload code coverage
if: success()
Expand Down
9 changes: 7 additions & 2 deletions ibis/backends/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,9 +1277,14 @@ def _transpile_sql(self, query: str, *, dialect: str | None = None) -> str:


@functools.cache
def _get_backend_names() -> frozenset[str]:
def _get_backend_names(*, exclude: tuple[str] = ()) -> frozenset[str]:
"""Return the set of known backend names.
Parameters
----------
exclude
Exclude these backend names from the result
Notes
-----
This function returns a frozenset to prevent cache pollution.
Expand All @@ -1293,7 +1298,7 @@ def _get_backend_names() -> frozenset[str]:
entrypoints = importlib.metadata.entry_points()["ibis.backends"]
else:
entrypoints = importlib.metadata.entry_points(group="ibis.backends")
return frozenset(ep.name for ep in entrypoints)
return frozenset(ep.name for ep in entrypoints).difference(exclude)


def connect(resource: Path | str, **kwargs: Any) -> BaseBackend:
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/base/sqlglot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __getitem__(self, key: str) -> sge.Column:

def paren(expr):
"""Wrap a sqlglot expression in parentheses."""
return sge.Paren(this=expr)
return sge.Paren(this=sge.convert(expr))


def parenthesize(op, arg):
Expand Down
10 changes: 10 additions & 0 deletions ibis/backends/base/sqlglot/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,3 +1029,13 @@ def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType:
key_type = cls.from_ibis(dtype.key_type.copy(nullable=False))
value_type = cls.from_ibis(dtype.value_type)
return sge.DataType(this=typecode.MAP, expressions=[key_type, value_type])


class FlinkType(SqlglotType):
dialect = "flink"
default_decimal_precision = 38
default_decimal_scale = 18

@classmethod
def _from_ibis_Binary(cls, dtype: dt.Binary) -> sge.DataType:
return sge.DataType(this=sge.DataType.Type.VARBINARY)
105 changes: 45 additions & 60 deletions ibis/backends/flink/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
from __future__ import annotations

import itertools
from functools import lru_cache
from typing import TYPE_CHECKING, Any

import pyarrow as pa
import sqlglot as sg
import sqlglot.expressions as sge

import ibis.common.exceptions as exc
import ibis.expr.operations as ops
import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis.backends.base import BaseBackend, CanCreateDatabase, NoUrl
from ibis.backends.base.sql.ddl import fully_qualified_re, is_fully_qualified
from ibis.backends.flink.compiler.core import FlinkCompiler
from ibis.backends.base import CanCreateDatabase, NoUrl
from ibis.backends.base.sqlglot import SQLGlotBackend
from ibis.backends.flink.compiler import FlinkCompiler
from ibis.backends.flink.ddl import (
CreateDatabase,
CreateTableFromConnector,
CreateView,
DropDatabase,
DropTable,
DropView,
Expand All @@ -38,9 +37,9 @@
from ibis.api import Watermark


class Backend(BaseBackend, CanCreateDatabase, NoUrl):
class Backend(SQLGlotBackend, CanCreateDatabase, NoUrl):
name = "flink"
compiler = FlinkCompiler
compiler = FlinkCompiler()
supports_temporary_tables = True
supports_python_udfs = True

Expand Down Expand Up @@ -71,6 +70,17 @@ def do_connect(self, table_env: TableEnvironment) -> None:
def raw_sql(self, query: str) -> TableResult:
return self._table_env.execute_sql(query)

def _metadata(self, query: str):
from pyflink.table.types import create_arrow_schema

table = self._table_env.sql_query(query)
schema = table.get_schema()
pa_schema = create_arrow_schema(
schema.get_field_names(), schema.get_field_data_types()
)
# sort of wasteful, but less code to write
return sch.Schema.from_pyarrow(pa_schema).items()

def list_databases(self, like: str | None = None) -> list[str]:
databases = self._table_env.list_databases()
return self._filter_with_like(databases, like)
Expand Down Expand Up @@ -207,21 +217,6 @@ def list_views(

return self._filter_with_like(views, like)

def _fully_qualified_name(
self,
name: str,
database: str | None = None,
catalog: str | None = None,
) -> str:
if name and is_fully_qualified(name):
return name

return sg.table(
name,
db=database or self.current_database,
catalog=catalog or self.current_catalog,
).sql(dialect="hive")

def table(
self,
name: str,
Expand Down Expand Up @@ -250,15 +245,12 @@ def table(
f"`database` must be a string; got {type(database)}"
)
schema = self.get_schema(name, catalog=catalog, database=database)
qualified_name = self._fully_qualified_name(name, catalog, database)
_, quoted, unquoted = fully_qualified_re.search(qualified_name).groups()
unqualified_name = quoted or unquoted
node = ops.DatabaseTable(
unqualified_name,
schema,
self,
namespace=ops.Namespace(schema=database, database=catalog),
) # TODO(chloeh13q): look into namespacing with catalog + db
name,
schema=schema,
source=self,
namespace=ops.Namespace(schema=catalog, database=database),
)
return node.to_expr()

def get_schema(
Expand Down Expand Up @@ -288,7 +280,9 @@ def get_schema(

from ibis.backends.flink.datatypes import get_field_data_types

qualified_name = self._fully_qualified_name(table_name, catalog, database)
qualified_name = sg.table(table_name, db=catalog, catalog=database).sql(
self.name
)
table = self._table_env.from_path(qualified_name)
pyflink_schema = table.get_schema()

Expand All @@ -305,12 +299,9 @@ def version(self) -> str:
return pyflink.version.__version__

def compile(
self,
expr: ir.Expr,
params: Mapping[ir.Expr, Any] | None = None,
**kwargs: Any,
self, expr: ir.Expr, params: Mapping[ir.Expr, Any] | None = None, **_: Any
) -> Any:
"""Compile an expression."""
"""Compile an Ibis expression to Flink."""
return super().compile(expr, params=params) # Discard `limit` and other kwargs.

def _to_sql(self, expr: ir.Expr, **kwargs: Any) -> str:
Expand Down Expand Up @@ -604,7 +595,9 @@ def create_view(
)

if isinstance(obj, pd.DataFrame):
qualified_name = self._fully_qualified_name(name, database, catalog)
qualified_name = sg.table(
name, db=database, catalog=catalog, quoted=self.compiler.quoted
).sql(self.name)
if schema:
table = self._table_env.from_pandas(
obj, FlinkRowSchema.from_ibis(schema)
Expand All @@ -617,15 +610,18 @@ def create_view(

elif isinstance(obj, ir.Table):
query_expression = self.compile(obj)
statement = CreateView(
name=name,
query_expression=query_expression,
database=database,
can_exist=force,
temporary=temp,
stmt = sge.Create(
kind="VIEW",
this=sg.table(
name, db=database, catalog=catalog, quoted=self.compiler.quoted
),
expression=query_expression,
exists=force,
properties=sge.Properties(expressions=[sge.TemporaryProperty()])
if temp
else None,
)
sql = statement.compile()
self.raw_sql(sql)
self.raw_sql(stmt.sql(self.name))

else:
raise exc.IbisError(f"Unsupported `obj` type: {type(obj)}")
Expand Down Expand Up @@ -803,16 +799,6 @@ def read_json(
file_type="json", path=path, schema=schema, table_name=table_name
)

@classmethod
@lru_cache
def _get_operations(cls):
translator = cls.compiler.translator_class
return translator._registry.keys()

@classmethod
def has_operation(cls, operation: type[ops.Value]) -> bool:
return operation in cls._get_operations()

def insert(
self,
table_name: str,
Expand Down Expand Up @@ -852,12 +838,9 @@ def insert(
import pyarrow_hotfix # noqa: F401

if isinstance(obj, ir.Table):
expr = obj
ast = self.compiler.to_ast(expr)
select = ast.queries[0]
statement = InsertSelect(
table_name,
select,
self.compile(obj),
database=database,
catalog=catalog,
overwrite=overwrite,
Expand Down Expand Up @@ -946,7 +929,9 @@ def _from_ibis_table_to_pyflink_table(self, table: ir.Table) -> Table | None:
# `table` is not a registered table in Flink.
return None

qualified_name = self._fully_qualified_name(table_name)
qualified_name = sg.table(table_name, quoted=self.compiler.quoted).sql(
self.name
)
try:
return self._table_env.from_path(qualified_name)
except Py4JJavaError:
Expand Down
Loading

0 comments on commit f6dc380

Please sign in to comment.