From cb3ff243925a3e6e61c26004ece69b6ab8e312fd Mon Sep 17 00:00:00 2001 From: mfatihaktas Date: Mon, 22 Jan 2024 23:54:00 -0500 Subject: [PATCH] feat(flink): make schema arg optional for read_parquet() --- ibis/backends/flink/__init__.py | 21 ++++++++-- ibis/backends/flink/tests/conftest.py | 55 +++++++++++++++++++++++++++ ibis/backends/flink/tests/test_ddl.py | 28 +++++++------- ibis/backends/tests/test_register.py | 28 +++++++------- 4 files changed, 101 insertions(+), 31 deletions(-) diff --git a/ibis/backends/flink/__init__.py b/ibis/backends/flink/__init__.py index 7e1af1bbeb1cf..623772a9c7342 100644 --- a/ibis/backends/flink/__init__.py +++ b/ibis/backends/flink/__init__.py @@ -675,7 +675,7 @@ def _read_file( schema: sch.Schema | None = None, table_name: str | None = None, ) -> ir.Table: - """Register a file as a table in the current database. + """Register a file/directory as a table in the current database. Parameters ---------- @@ -740,8 +740,21 @@ def read_parquet( ------- ir.Table The just-registered table - """ + if schema is None: + import pyarrow as pa + import pyarrow_hotfix # noqa: F401 + + # Note: To support reading from a directory, we discussed using + # pyarrow_schema = pyarrow.dataset.dataset(path, format="parquet").schema + # [https://github.com/ibis-project/ibis/pull/8070#discussion_r1467046023] + # We decided to drop this as it might lead to silent errors when + # pyarrow.dataset infers a different schema than what flink would infer + # due to partitioning. + pyarrow_schema = pa.parquet.read_metadata(path).schema.to_arrow_schema() + + schema = sch.Schema.from_pyarrow(pyarrow_schema) + return self._read_file( file_type="parquet", path=path, schema=schema, table_name=table_name ) @@ -752,7 +765,7 @@ def read_csv( schema: sch.Schema | None = None, table_name: str | None = None, ) -> ir.Table: - """Register a csv file as a table in the current database. + """Register a csv file/directory as a table in the current database. Parameters ---------- @@ -780,7 +793,7 @@ def read_json( schema: sch.Schema | None = None, table_name: str | None = None, ) -> ir.Table: - """Register a json file as a table in the current database. + """Register a json file/directory as a table in the current database. Parameters ---------- diff --git a/ibis/backends/flink/tests/conftest.py b/ibis/backends/flink/tests/conftest.py index 27f6f5f3cba08..6a4c4c2525b22 100644 --- a/ibis/backends/flink/tests/conftest.py +++ b/ibis/backends/flink/tests/conftest.py @@ -11,6 +11,45 @@ from ibis.backends.tests.data import array_types, json_types, struct_types, win +def _download_jar_for_package( + package_name: str, + jar_name: str, + jar_url: str, +): + import os + from importlib import metadata + + import requests + + # Find the path to package lib + try: + distribution = metadata.distribution(package_name) + lib_path = distribution.locate_file("") + except metadata.PackageNotFoundError: + lib_path = None + + # Check if the JAR already exists + jar_path = os.path.join(lib_path, "pyflink/lib", f"{jar_name}.jar") + if os.path.exists(jar_path): + return jar_path + + # Download the JAR + response = requests.get(jar_url, stream=True) + if response.status_code != 200: + raise SystemError( + f"Failed to download the JAR file \n" + f"\t jar_url= {jar_url} \n" + f"\t response.status_code= {response.status_code}" + ) + + # Save the JAR + with open(jar_path, "wb") as jar_file: + for chunk in response.iter_content(chunk_size=128): + jar_file.write(chunk) + + return jar_path + + class TestConf(BackendTest): force_sort = True stateful = False @@ -46,6 +85,22 @@ def connect(*, tmpdir, worker_id, **kw: Any): ) env = StreamExecutionEnvironment(j_stream_execution_environment) + + # Downloading next the two JAR's needed for parquet support in Flink. + # Note: It is not ideal to do "test ops" in code here. + flink_sql_parquet_jar_path = _download_jar_for_package( + package_name="apache-flink", + jar_name="flink-sql-parquet-1.18.1", + jar_url="https://repo1.maven.org/maven2/org/apache/flink/flink-sql-parquet/1.18.1/flink-sql-parquet-1.18.1.jar", + ) + flink_shaded_hadoop_jar_path = _download_jar_for_package( + package_name="apache-flink", + jar_name="flink-shaded-hadoop-2-uber-2.8.3-10.0", + jar_url="https://repo1.maven.org/maven2/org/apache/flink/flink-shaded-hadoop-2-uber/2.8.3-10.0/flink-shaded-hadoop-2-uber-2.8.3-10.0.jar", + ) + env.add_jars(f"file://{flink_sql_parquet_jar_path}") + env.add_jars(f"file://{flink_shaded_hadoop_jar_path}") + stream_table_env = StreamTableEnvironment.create(env) table_config = stream_table_env.get_config() table_config.set("table.local-time-zone", "UTC") diff --git a/ibis/backends/flink/tests/test_ddl.py b/ibis/backends/flink/tests/test_ddl.py index 44742bc7c8213..dbd772e115492 100644 --- a/ibis/backends/flink/tests/test_ddl.py +++ b/ibis/backends/flink/tests/test_ddl.py @@ -2,7 +2,6 @@ import os import tempfile -from pathlib import Path import pandas as pd import pandas.testing as tm @@ -463,22 +462,25 @@ def test_read_csv(con, awards_players_schema, csv_source_configs, table_name): @pytest.mark.parametrize("table_name", ["new_table", None]) -def test_read_parquet(con, data_dir, tmp_path, table_name, functional_alltypes_schema): - fname = Path("functional_alltypes.parquet") - fname = Path(data_dir) / "parquet" / fname.name +@pytest.mark.parametrize("schema", ["functional_alltypes_schema", None]) +def test_read_parquet(con, data_dir, table_name, schema, functional_alltypes_schema): + if schema == "functional_alltypes_schema": + schema = functional_alltypes_schema + + path = data_dir.joinpath("parquet", "functional_alltypes.parquet") table = con.read_parquet( - path=tmp_path / fname.name, - schema=functional_alltypes_schema, + path=path, + schema=schema, table_name=table_name, ) - try: - if table_name is None: - table_name = table.get_name() - assert table_name in con.list_tables() - assert table.schema() == functional_alltypes_schema - finally: - con.drop_table(table_name) + if table_name is None: + table_name = table.get_name() + assert table_name in con.list_tables() + if schema: + assert table.schema() == schema + + con.drop_table(table_name) assert table_name not in con.list_tables() diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_register.py index 8564a4d4a2820..cdf8aa44e0611 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_register.py @@ -12,6 +12,7 @@ import ibis from ibis.backends.conftest import TEST_TABLES, is_older_than +from ibis.backends.tests.errors import Py4JJavaError if TYPE_CHECKING: from collections.abc import Iterator @@ -405,13 +406,17 @@ def test_register_garbage(con, monkeypatch): ], ) @pytest.mark.notyet( - ["flink", "impala", "mssql", "mysql", "postgres", "risingwave", "sqlite", "trino"] + ["impala", "mssql", "mysql", "postgres", "risingwave", "sqlite", "trino"] +) +@pytest.mark.notimpl( + ["flink"], + raises=Py4JJavaError, + reason="Flink TaskManager can not access the test files on the test instance", ) def test_read_parquet(con, tmp_path, data_dir, fname, in_table_name): pq = pytest.importorskip("pyarrow.parquet") - fname = Path(fname) - fname = Path(data_dir) / "parquet" / fname.name + fname = Path(data_dir) / "parquet" / fname table = pq.read_table(fname) pq.write_table(table, tmp_path / fname.name) @@ -436,17 +441,12 @@ def ft_data(data_dir): @pytest.mark.notyet( - [ - "flink", - "impala", - "mssql", - "mysql", - "pandas", - "postgres", - "risingwave", - "sqlite", - "trino", - ] + ["impala", "mssql", "mysql", "pandas", "postgres", "risingwave", "sqlite", "trino"] +) +@pytest.mark.notimpl( + ["flink"], + raises=ValueError, + reason="read_parquet() does not support glob", ) def test_read_parquet_glob(con, tmp_path, ft_data): pq = pytest.importorskip("pyarrow.parquet")