Skip to content

Commit

Permalink
feat(flink): make schema arg optional for read_parquet()
Browse files Browse the repository at this point in the history
  • Loading branch information
mfatihaktas committed Jan 23, 2024
1 parent 0f6d45d commit b3a239d
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 13 deletions.
17 changes: 14 additions & 3 deletions ibis/backends/flink/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,9 +677,20 @@ def _read_file(
If `schema` is None.
"""
if schema is None:
raise ValueError(
f"`schema` must be explicitly provided when calling `read_{file_type}`"
)
if file_type == "parquet":
import pyarrow as pa

try:
pyarrow_schema = pa.parquet.read_schema(path)
except FileNotFoundError:
raise ValueError(f"No file found at {path}")

schema = sch.Schema.from_pyarrow(pyarrow_schema)

else:
raise ValueError(
f"`schema` must be explicitly provided when calling `read_{file_type}`"
)

table_name = table_name or gen_name(f"read_{file_type}")
tbl_properties = {
Expand Down
14 changes: 7 additions & 7 deletions ibis/backends/flink/tests/test_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import os
import tempfile
from pathlib import Path

import pandas as pd
import pyarrow as pa
Expand Down Expand Up @@ -511,19 +510,20 @@ def test_read_csv(con, awards_players_schema, csv_source_configs, table_name):


@pytest.mark.parametrize("table_name", ["new_table", None])
def test_read_parquet(con, data_dir, tmp_path, table_name):
fname = Path("functional_alltypes.parquet")
fname = Path(data_dir) / "parquet" / fname.name
@pytest.mark.parametrize("schema", [_functional_alltypes_schema, None])
def test_read_parquet(con, data_dir, table_name, schema):
path = data_dir.joinpath("parquet", "functional_alltypes.parquet")
table = con.read_parquet(
path=tmp_path / fname.name,
schema=_functional_alltypes_schema,
path=path,
schema=schema,
table_name=table_name,
)

if table_name is None:
table_name = table.get_name()
assert table_name in con.list_tables()
assert table.schema() == _functional_alltypes_schema
if schema:
assert table.schema() == schema

con.drop_table(table_name)
assert table_name not in con.list_tables()
Expand Down
7 changes: 4 additions & 3 deletions ibis/backends/tests/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import ibis
from ibis.backends.conftest import TEST_TABLES
from ibis.backends.tests.errors import Py4JJavaError

if TYPE_CHECKING:
from collections.abc import Iterator
Expand Down Expand Up @@ -398,8 +399,8 @@ def test_register_garbage(con, monkeypatch):
@pytest.mark.notyet(["impala", "mssql", "mysql", "postgres", "sqlite", "trino"])
@pytest.mark.notimpl(
["flink"],
raises=ValueError,
reason="read_parquet() missing required argument: 'schema'",
raises=Py4JJavaError,
reason="Parquet format jar for Flink is not in the path",
)
def test_read_parquet(con, tmp_path, data_dir, fname, in_table_name):
pq = pytest.importorskip("pyarrow.parquet")
Expand Down Expand Up @@ -435,7 +436,7 @@ def ft_data(data_dir):
@pytest.mark.notimpl(
["flink"],
raises=ValueError,
reason="read_parquet() missing required argument: 'schema'",
reason="read_parquet() does not support glob",
)
def test_read_parquet_glob(con, tmp_path, ft_data):
pq = pytest.importorskip("pyarrow.parquet")
Expand Down

0 comments on commit b3a239d

Please sign in to comment.