Skip to content

Commit

Permalink
feat(flink): make schema arg optional for read_parquet()
Browse files Browse the repository at this point in the history
  • Loading branch information
mfatihaktas committed Feb 14, 2024
1 parent 580536c commit cb3ff24
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 31 deletions.
21 changes: 17 additions & 4 deletions ibis/backends/flink/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ def _read_file(
schema: sch.Schema | None = None,
table_name: str | None = None,
) -> ir.Table:
"""Register a file as a table in the current database.
"""Register a file/directory as a table in the current database.
Parameters
----------
Expand Down Expand Up @@ -740,8 +740,21 @@ def read_parquet(
-------
ir.Table
The just-registered table
"""
if schema is None:
import pyarrow as pa
import pyarrow_hotfix # noqa: F401

# Note: To support reading from a directory, we discussed using
# pyarrow_schema = pyarrow.dataset.dataset(path, format="parquet").schema
# [https://github.com/ibis-project/ibis/pull/8070#discussion_r1467046023]
# We decided to drop this as it might lead to silent errors when
# pyarrow.dataset infers a different schema than what flink would infer
# due to partitioning.
pyarrow_schema = pa.parquet.read_metadata(path).schema.to_arrow_schema()

schema = sch.Schema.from_pyarrow(pyarrow_schema)

return self._read_file(
file_type="parquet", path=path, schema=schema, table_name=table_name
)
Expand All @@ -752,7 +765,7 @@ def read_csv(
schema: sch.Schema | None = None,
table_name: str | None = None,
) -> ir.Table:
"""Register a csv file as a table in the current database.
"""Register a csv file/directory as a table in the current database.
Parameters
----------
Expand Down Expand Up @@ -780,7 +793,7 @@ def read_json(
schema: sch.Schema | None = None,
table_name: str | None = None,
) -> ir.Table:
"""Register a json file as a table in the current database.
"""Register a json file/directory as a table in the current database.
Parameters
----------
Expand Down
55 changes: 55 additions & 0 deletions ibis/backends/flink/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,45 @@
from ibis.backends.tests.data import array_types, json_types, struct_types, win


def _download_jar_for_package(
package_name: str,
jar_name: str,
jar_url: str,
):
import os
from importlib import metadata

import requests

# Find the path to package lib
try:
distribution = metadata.distribution(package_name)
lib_path = distribution.locate_file("")
except metadata.PackageNotFoundError:
lib_path = None

# Check if the JAR already exists
jar_path = os.path.join(lib_path, "pyflink/lib", f"{jar_name}.jar")
if os.path.exists(jar_path):
return jar_path

# Download the JAR
response = requests.get(jar_url, stream=True)
if response.status_code != 200:
raise SystemError(
f"Failed to download the JAR file \n"
f"\t jar_url= {jar_url} \n"
f"\t response.status_code= {response.status_code}"
)

# Save the JAR
with open(jar_path, "wb") as jar_file:
for chunk in response.iter_content(chunk_size=128):
jar_file.write(chunk)

return jar_path


class TestConf(BackendTest):
force_sort = True
stateful = False
Expand Down Expand Up @@ -46,6 +85,22 @@ def connect(*, tmpdir, worker_id, **kw: Any):
)

env = StreamExecutionEnvironment(j_stream_execution_environment)

# Downloading next the two JAR's needed for parquet support in Flink.
# Note: It is not ideal to do "test ops" in code here.
flink_sql_parquet_jar_path = _download_jar_for_package(
package_name="apache-flink",
jar_name="flink-sql-parquet-1.18.1",
jar_url="https://repo1.maven.org/maven2/org/apache/flink/flink-sql-parquet/1.18.1/flink-sql-parquet-1.18.1.jar",
)
flink_shaded_hadoop_jar_path = _download_jar_for_package(
package_name="apache-flink",
jar_name="flink-shaded-hadoop-2-uber-2.8.3-10.0",
jar_url="https://repo1.maven.org/maven2/org/apache/flink/flink-shaded-hadoop-2-uber/2.8.3-10.0/flink-shaded-hadoop-2-uber-2.8.3-10.0.jar",
)
env.add_jars(f"file://{flink_sql_parquet_jar_path}")
env.add_jars(f"file://{flink_shaded_hadoop_jar_path}")

stream_table_env = StreamTableEnvironment.create(env)
table_config = stream_table_env.get_config()
table_config.set("table.local-time-zone", "UTC")
Expand Down
28 changes: 15 additions & 13 deletions ibis/backends/flink/tests/test_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import os
import tempfile
from pathlib import Path

import pandas as pd
import pandas.testing as tm
Expand Down Expand Up @@ -463,22 +462,25 @@ def test_read_csv(con, awards_players_schema, csv_source_configs, table_name):


@pytest.mark.parametrize("table_name", ["new_table", None])
def test_read_parquet(con, data_dir, tmp_path, table_name, functional_alltypes_schema):
fname = Path("functional_alltypes.parquet")
fname = Path(data_dir) / "parquet" / fname.name
@pytest.mark.parametrize("schema", ["functional_alltypes_schema", None])
def test_read_parquet(con, data_dir, table_name, schema, functional_alltypes_schema):
if schema == "functional_alltypes_schema":
schema = functional_alltypes_schema

path = data_dir.joinpath("parquet", "functional_alltypes.parquet")
table = con.read_parquet(
path=tmp_path / fname.name,
schema=functional_alltypes_schema,
path=path,
schema=schema,
table_name=table_name,
)

try:
if table_name is None:
table_name = table.get_name()
assert table_name in con.list_tables()
assert table.schema() == functional_alltypes_schema
finally:
con.drop_table(table_name)
if table_name is None:
table_name = table.get_name()
assert table_name in con.list_tables()
if schema:
assert table.schema() == schema

con.drop_table(table_name)
assert table_name not in con.list_tables()


Expand Down
28 changes: 14 additions & 14 deletions ibis/backends/tests/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import ibis
from ibis.backends.conftest import TEST_TABLES, is_older_than
from ibis.backends.tests.errors import Py4JJavaError

if TYPE_CHECKING:
from collections.abc import Iterator
Expand Down Expand Up @@ -405,13 +406,17 @@ def test_register_garbage(con, monkeypatch):
],
)
@pytest.mark.notyet(
["flink", "impala", "mssql", "mysql", "postgres", "risingwave", "sqlite", "trino"]
["impala", "mssql", "mysql", "postgres", "risingwave", "sqlite", "trino"]
)
@pytest.mark.notimpl(
["flink"],
raises=Py4JJavaError,
reason="Flink TaskManager can not access the test files on the test instance",
)
def test_read_parquet(con, tmp_path, data_dir, fname, in_table_name):
pq = pytest.importorskip("pyarrow.parquet")

fname = Path(fname)
fname = Path(data_dir) / "parquet" / fname.name
fname = Path(data_dir) / "parquet" / fname
table = pq.read_table(fname)

pq.write_table(table, tmp_path / fname.name)
Expand All @@ -436,17 +441,12 @@ def ft_data(data_dir):


@pytest.mark.notyet(
[
"flink",
"impala",
"mssql",
"mysql",
"pandas",
"postgres",
"risingwave",
"sqlite",
"trino",
]
["impala", "mssql", "mysql", "pandas", "postgres", "risingwave", "sqlite", "trino"]
)
@pytest.mark.notimpl(
["flink"],
raises=ValueError,
reason="read_parquet() does not support glob",
)
def test_read_parquet_glob(con, tmp_path, ft_data):
pq = pytest.importorskip("pyarrow.parquet")
Expand Down

0 comments on commit cb3ff24

Please sign in to comment.