From 601cd136bec9d65f661ac508a9f6b77effbb7bd3 Mon Sep 17 00:00:00 2001 From: mfatihaktas Date: Fri, 15 Dec 2023 13:38:48 -0500 Subject: [PATCH] test(flink): deep dive on the tests marked for Flink in test_json.py --- ibis/backends/flink/__init__.py | 112 ++++++++++++++++++++++++++- ibis/backends/tests/test_param.py | 2 - ibis/backends/tests/test_register.py | 13 +--- 3 files changed, 114 insertions(+), 13 deletions(-) diff --git a/ibis/backends/flink/__init__.py b/ibis/backends/flink/__init__.py index 7aa374be3b837..be21a8e56fa4d 100644 --- a/ibis/backends/flink/__init__.py +++ b/ibis/backends/flink/__init__.py @@ -22,9 +22,11 @@ InsertSelect, RenameTable, ) +from ibis.util import gen_name, normalize_filename if TYPE_CHECKING: from collections.abc import Mapping + from pathlib import Path import pandas as pd import pyarrow as pa @@ -119,9 +121,10 @@ def drop_database( def list_tables( self, like: str | None = None, - temp: bool = False, + *, database: str | None = None, catalog: str | None = None, + temp: bool = False, ) -> list[str]: """Return the list of table/view names. @@ -198,7 +201,7 @@ def _fully_qualified_name( database: str | None = None, catalog: str | None = None, ) -> str: - if is_fully_qualified(name): + if name and is_fully_qualified(name): return name return sg.table( @@ -635,6 +638,111 @@ def drop_view( sql = statement.compile() self._exec_sql(sql) + def _get_dataframe_from_path(self, path: str | Path) -> pd.DataFrame: + import glob + + import pandas as pd + + dataframe_list = [] + path_str = str(path) + path_normalized = normalize_filename(path) + for file_path in glob.glob(path_normalized): + if path_str.startswith(("parquet://", "parq://")) or path_str.endswith( + ("parq", "parquet") + ): + dataframe = pd.read_parquet(file_path) + elif path_str.startswith("csv://") or path_str.endswith(("csv", "csv.gz")): + dataframe = pd.read_csv(file_path) + elif path_str.endswith("json"): + dataframe = pd.read_json(file_path, lines=True) + else: + raise ValueError(f"Unsupported file_path: {file_path}") + + dataframe_list.append(dataframe) + + return pd.concat(dataframe_list, ignore_index=True, sort=False) + + def read_file( + self, + file_type: str, + path: str | Path, + table_name: str | None = None, + ) -> ir.Table: + """Register a file as a table in the current database. + + Parameters + ---------- + file_type + File type, e.g., parquet, csv, json. + path + The data source. + table_name + An optional name to use for the created table. This defaults to + a sequentially generated name. + + Returns + ------- + ir.Table + The just-registered table + """ + obj = self._get_dataframe_from_path(path) + table_name = table_name or gen_name(f"read_{file_type}") + return self.create_table(table_name, obj, temp=True, overwrite=True) + + def read_parquet(self, path: str | Path, table_name: str | None = None) -> ir.Table: + """Register a parquet file as a table in the current database. + + Parameters + ---------- + path + The data source. + table_name + An optional name to use for the created table. This defaults to + a sequentially generated name. + + Returns + ------- + ir.Table + The just-registered table + """ + return self.read_file(file_type="parquet", path=path, table_name=table_name) + + def read_csv(self, path: str | Path, table_name: str | None = None) -> ir.Table: + """Register a csv file as a table in the current database. + + Parameters + ---------- + path + The data source. + table_name + An optional name to use for the created table. This defaults to + a sequentially generated name. + + Returns + ------- + ir.Table + The just-registered table + """ + return self.read_file(file_type="csv", path=path, table_name=table_name) + + def read_json(self, path: str | Path, table_name: str | None = None) -> ir.Table: + """Register a json file as a table in the current database. + + Parameters + ---------- + path + The data source. + table_name + An optional name to use for the created table. This defaults to + a sequentially generated name. + + Returns + ------- + ir.Table + The just-registered table + """ + return self.read_file(file_type="json", path=path, table_name=table_name) + @classmethod @lru_cache def _get_operations(cls): diff --git a/ibis/backends/tests/test_param.py b/ibis/backends/tests/test_param.py index dc7a898be44e1..452536197320c 100644 --- a/ibis/backends/tests/test_param.py +++ b/ibis/backends/tests/test_param.py @@ -77,7 +77,6 @@ def test_scalar_param_array(con): [ "datafusion", "impala", - "flink", "postgres", "pyspark", "druid", @@ -244,7 +243,6 @@ def test_scalar_param_date(backend, alltypes, value): "exasol", ] ) -@pytest.mark.notimpl(["flink"], "WIP") def test_scalar_param_nested(con): param = ibis.param("struct>>>") value = OrderedDict([("x", [OrderedDict([("y", [1.0, 2.0, 3.0])])])]) diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_register.py index 75a121f578d85..b73dc3541fa8e 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_register.py @@ -395,9 +395,7 @@ def test_register_garbage(con, monkeypatch): ("functional_alltypes.parquet", "funk_all"), ], ) -@pytest.mark.notyet( - ["flink", "impala", "mssql", "mysql", "postgres", "sqlite", "trino"] -) +@pytest.mark.notyet(["impala", "mssql", "mysql", "postgres", "sqlite", "trino"]) def test_read_parquet(con, tmp_path, data_dir, fname, in_table_name): pq = pytest.importorskip("pyarrow.parquet") @@ -427,7 +425,7 @@ def ft_data(data_dir): @pytest.mark.notyet( - ["flink", "impala", "mssql", "mysql", "pandas", "postgres", "sqlite", "trino"] + ["impala", "mssql", "mysql", "pandas", "postgres", "sqlite", "trino"] ) def test_read_parquet_glob(con, tmp_path, ft_data): pq = pytest.importorskip("pyarrow.parquet") @@ -446,7 +444,7 @@ def test_read_parquet_glob(con, tmp_path, ft_data): @pytest.mark.notyet( - ["flink", "impala", "mssql", "mysql", "pandas", "postgres", "sqlite", "trino"] + ["impala", "mssql", "mysql", "pandas", "postgres", "sqlite", "trino"] ) def test_read_csv_glob(con, tmp_path, ft_data): pc = pytest.importorskip("pyarrow.csv") @@ -469,7 +467,6 @@ def test_read_csv_glob(con, tmp_path, ft_data): "clickhouse", "dask", "datafusion", - "flink", "impala", "mssql", "mysql", @@ -522,9 +519,7 @@ def num_diamonds(data_dir): "in_table_name", [param(None, id="default"), param("fancy_stones", id="file_name")], ) -@pytest.mark.notyet( - ["flink", "impala", "mssql", "mysql", "postgres", "sqlite", "trino"] -) +@pytest.mark.notyet(["impala", "mssql", "mysql", "postgres", "sqlite", "trino"]) def test_read_csv(con, data_dir, in_table_name, num_diamonds): fname = "diamonds.csv" with pushd(data_dir / "csv"):