Skip to content

Commit

Permalink
refactor(flink): port to sqlglot
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Feb 9, 2024
1 parent 9b72fc0 commit 0b84087
Show file tree
Hide file tree
Showing 72 changed files with 1,792 additions and 1,393 deletions.
89 changes: 29 additions & 60 deletions .github/workflows/ibis-backends.yml
Original file line number Diff line number Diff line change
Expand Up @@ -177,22 +177,16 @@ jobs:
- oracle
services:
- oracle
# - name: flink
# title: Flink
# serial: true
# extras:
# - flink
# additional_deps:
# - apache-flink
# - pytest-split
# services:
# - flink
# - name: risingwave
# title: Risingwave
# services:
# - risingwave
# extras:
# - risingwave
- name: flink
title: Flink
serial: true
extras:
- flink
additional_deps:
- apache-flink
- pytest-split
services:
- flink
exclude:
- os: windows-latest
backend:
Expand Down Expand Up @@ -296,32 +290,15 @@ jobs:
- oracle
services:
- oracle
# - os: windows-latest
# backend:
# name: flink
# title: Flink
# serial: true
# extras:
# - flink
# services:
# - flink
# - python-version: "3.11"
# backend:
# name: flink
# title: Flink
# serial: true
# extras:
# - flink
# services:
# - flink
# - os: windows-latest
# backend:
# name: risingwave
# title: Risingwave
# services:
# - risingwave
# extras:
# - risingwave
- os: windows-latest
backend:
name: flink
title: Flink
serial: true
extras:
- flink
services:
- flink
- os: windows-latest
backend:
name: exasol
Expand Down Expand Up @@ -394,18 +371,18 @@ jobs:
# executes before common tests, they will fail with:
# org.apache.flink.table.api.ValidationException: Table `default_catalog`.`default_database`.`functional_alltypes` was not found.
# Therefore, we run backend-specific tests second to avoid this.
# - name: "run serial tests: ${{ matrix.backend.name }}"
# if: matrix.backend.serial && matrix.backend.name == 'flink'
# run: |
# just ci-check -m ${{ matrix.backend.name }} ibis/backends/tests
# just ci-check -m ${{ matrix.backend.name }} ibis/backends/flink/tests
# env:
# IBIS_EXAMPLES_DATA: ${{ runner.temp }}/examples-${{ matrix.backend.name }}-${{ matrix.os }}-${{ steps.install_python.outputs.python-version }}
# FLINK_REMOTE_CLUSTER_ADDR: localhost
# FLINK_REMOTE_CLUSTER_PORT: "8081"
- name: "run serial tests: ${{ matrix.backend.name }}"
if: matrix.backend.serial && matrix.backend.name == 'flink'
run: |
just ci-check -m ${{ matrix.backend.name }} ibis/backends/tests
just ci-check -m ${{ matrix.backend.name }} ibis/backends/flink/tests
env:
IBIS_EXAMPLES_DATA: ${{ runner.temp }}/examples-${{ matrix.backend.name }}-${{ matrix.os }}-${{ steps.install_python.outputs.python-version }}
FLINK_REMOTE_CLUSTER_ADDR: localhost
FLINK_REMOTE_CLUSTER_PORT: "8081"
#
- name: "run serial tests: ${{ matrix.backend.name }}"
if: matrix.backend.serial # && matrix.backend.name != 'flink'
if: matrix.backend.serial && matrix.backend.name != 'flink'
run: just ci-check -m ${{ matrix.backend.name }}
env:
IBIS_EXAMPLES_DATA: ${{ runner.temp }}/examples-${{ matrix.backend.name }}-${{ matrix.os }}-${{ steps.install_python.outputs.python-version }}
Expand Down Expand Up @@ -513,10 +490,6 @@ jobs:
- name: install poetry
run: python -m pip install --upgrade pip 'poetry==1.7.1'

- name: remove lonboard
# it requires a version of pandas that min versions are not compatible with
run: poetry remove lonboard

- name: install minimum versions
run: poetry add --lock --optional ${{ join(matrix.backend.deps, ' ') }}

Expand Down Expand Up @@ -596,10 +569,6 @@ jobs:
- name: install poetry
run: python -m pip install --upgrade pip 'poetry==1.7.1'

- name: remove lonboard
# it requires a version of pandas that pyspark is not compatible with
run: poetry remove lonboard

- name: install exact versions of pyspark, pandas and numpy
run: poetry add --lock ${{ join(matrix.deps, ' ') }}

Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/base/sqlglot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __getitem__(self, key: str) -> sge.Column:

def paren(expr):
"""Wrap a sqlglot expression in parentheses."""
return sge.Paren(this=expr)
return sge.Paren(this=sge.convert(expr))


def parenthesize(op, arg):
Expand Down
10 changes: 10 additions & 0 deletions ibis/backends/base/sqlglot/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,3 +1029,13 @@ def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType:
key_type = cls.from_ibis(dtype.key_type.copy(nullable=False))
value_type = cls.from_ibis(dtype.value_type)
return sge.DataType(this=typecode.MAP, expressions=[key_type, value_type])


class FlinkType(SqlglotType):
dialect = "flink"
default_decimal_precision = 38
default_decimal_scale = 18

@classmethod
def _from_ibis_Binary(cls, dtype: dt.Binary) -> sge.DataType:
return sge.DataType(this=sge.DataType.Type.VARBINARY)

Check warning on line 1041 in ibis/backends/base/sqlglot/datatypes.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/base/sqlglot/datatypes.py#L1041

Added line #L1041 was not covered by tests
100 changes: 44 additions & 56 deletions ibis/backends/flink/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
from __future__ import annotations

import itertools
from functools import lru_cache
from typing import TYPE_CHECKING, Any

import pyarrow as pa
import sqlglot as sg
import sqlglot.expressions as sge

import ibis.common.exceptions as exc
import ibis.expr.operations as ops
import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis.backends.base import BaseBackend, CanCreateDatabase, NoUrl
from ibis.backends.base.sql.ddl import fully_qualified_re, is_fully_qualified
from ibis.backends.flink.compiler.core import FlinkCompiler
from ibis.backends.base import CanCreateDatabase, NoUrl
from ibis.backends.base.sqlglot import SQLGlotBackend
from ibis.backends.flink.compiler import FlinkCompiler
from ibis.backends.flink.ddl import (
CreateDatabase,
CreateTableFromConnector,
CreateView,
DropDatabase,
DropTable,
DropView,
Expand All @@ -38,9 +37,9 @@
from ibis.api import Watermark


class Backend(BaseBackend, CanCreateDatabase, NoUrl):
class Backend(SQLGlotBackend, CanCreateDatabase, NoUrl):
name = "flink"
compiler = FlinkCompiler
compiler = FlinkCompiler()
supports_temporary_tables = True
supports_python_udfs = True

Expand Down Expand Up @@ -71,6 +70,17 @@ def do_connect(self, table_env: TableEnvironment) -> None:
def raw_sql(self, query: str) -> TableResult:
return self._table_env.execute_sql(query)

def _metadata(self, query: str):
from pyflink.table.types import create_arrow_schema

table = self._table_env.sql_query(query)
schema = table.get_schema()
pa_schema = create_arrow_schema(
schema.get_field_names(), schema.get_field_data_types()
)
# sort of wasteful, but less code to write
return sch.Schema.from_pyarrow(pa_schema).items()

def list_databases(self, like: str | None = None) -> list[str]:
databases = self._table_env.list_databases()
return self._filter_with_like(databases, like)
Expand Down Expand Up @@ -207,21 +217,6 @@ def list_views(

return self._filter_with_like(views, like)

def _fully_qualified_name(
self,
name: str,
database: str | None = None,
catalog: str | None = None,
) -> str:
if name and is_fully_qualified(name):
return name

return sg.table(
name,
db=database or self.current_database,
catalog=catalog or self.current_catalog,
).sql(dialect="hive")

def table(
self,
name: str,
Expand Down Expand Up @@ -250,15 +245,12 @@ def table(
f"`database` must be a string; got {type(database)}"
)
schema = self.get_schema(name, catalog=catalog, database=database)
qualified_name = self._fully_qualified_name(name, catalog, database)
_, quoted, unquoted = fully_qualified_re.search(qualified_name).groups()
unqualified_name = quoted or unquoted
node = ops.DatabaseTable(
unqualified_name,
schema,
self,
namespace=ops.Namespace(schema=database, database=catalog),
) # TODO(chloeh13q): look into namespacing with catalog + db
name,
schema=schema,
source=self,
namespace=ops.Namespace(schema=catalog, database=database),
)
return node.to_expr()

def get_schema(
Expand Down Expand Up @@ -288,7 +280,9 @@ def get_schema(

from ibis.backends.flink.datatypes import get_field_data_types

qualified_name = self._fully_qualified_name(table_name, catalog, database)
qualified_name = sg.table(table_name, db=catalog, catalog=database).sql(
self.name
)
table = self._table_env.from_path(qualified_name)
pyflink_schema = table.get_schema()

Expand All @@ -305,12 +299,9 @@ def version(self) -> str:
return pyflink.version.__version__

def compile(
self,
expr: ir.Expr,
params: Mapping[ir.Expr, Any] | None = None,
**kwargs: Any,
self, expr: ir.Expr, params: Mapping[ir.Expr, Any] | None = None, **_: Any
) -> Any:
"""Compile an expression."""
"""Compile an Ibis expression to Flink."""
return super().compile(expr, params=params) # Discard `limit` and other kwargs.

def _to_sql(self, expr: ir.Expr, **kwargs: Any) -> str:
Expand Down Expand Up @@ -604,7 +595,9 @@ def create_view(
)

if isinstance(obj, pd.DataFrame):
qualified_name = self._fully_qualified_name(name, database, catalog)
qualified_name = sg.table(
name, db=database, catalog=catalog, quoted=self.compiler.quoted
).sql(self.name)
if schema:
table = self._table_env.from_pandas(
obj, FlinkRowSchema.from_ibis(schema)
Expand All @@ -617,15 +610,18 @@ def create_view(

elif isinstance(obj, ir.Table):
query_expression = self.compile(obj)
statement = CreateView(
name=name,
query_expression=query_expression,
database=database,
can_exist=force,
temporary=temp,
stmt = sge.Create(
kind="VIEW",
this=sg.table(
name, db=database, catalog=catalog, quoted=self.compiler.quoted
),
expression=query_expression,
exists=force,
properties=sge.Properties(expressions=[sge.TemporaryProperty()])
if temp
else None,
)
sql = statement.compile()
self.raw_sql(sql)
self.raw_sql(stmt.sql(self.name))

else:
raise exc.IbisError(f"Unsupported `obj` type: {type(obj)}")
Expand Down Expand Up @@ -803,16 +799,6 @@ def read_json(
file_type="json", path=path, schema=schema, table_name=table_name
)

@classmethod
@lru_cache
def _get_operations(cls):
translator = cls.compiler.translator_class
return translator._registry.keys()

@classmethod
def has_operation(cls, operation: type[ops.Value]) -> bool:
return operation in cls._get_operations()

def insert(
self,
table_name: str,
Expand Down Expand Up @@ -946,7 +932,9 @@ def _from_ibis_table_to_pyflink_table(self, table: ir.Table) -> Table | None:
# `table` is not a registered table in Flink.
return None

qualified_name = self._fully_qualified_name(table_name)
qualified_name = sg.table(table_name, quoted=self.compiler.quoted).sql(
self.name
)
try:
return self._table_env.from_path(qualified_name)
except Py4JJavaError:
Expand Down
Loading

0 comments on commit 0b84087

Please sign in to comment.