diff --git a/ibis/backends/flink/__init__.py b/ibis/backends/flink/__init__.py index 7aa374be3b837..e38ca6c822ad0 100644 --- a/ibis/backends/flink/__init__.py +++ b/ibis/backends/flink/__init__.py @@ -1,6 +1,7 @@ from __future__ import annotations from functools import lru_cache +from pathlib import Path from typing import TYPE_CHECKING, Any import sqlglot as sg @@ -22,6 +23,7 @@ InsertSelect, RenameTable, ) +from ibis.util import gen_name, normalize_filename if TYPE_CHECKING: from collections.abc import Mapping @@ -119,9 +121,10 @@ def drop_database( def list_tables( self, like: str | None = None, - temp: bool = False, + *, database: str | None = None, catalog: str | None = None, + temp: bool = False, ) -> list[str]: """Return the list of table/view names. @@ -198,7 +201,7 @@ def _fully_qualified_name( database: str | None = None, catalog: str | None = None, ) -> str: - if is_fully_qualified(name): + if name and is_fully_qualified(name): return name return sg.table( @@ -635,6 +638,199 @@ def drop_view( sql = statement.compile() self._exec_sql(sql) + def register( + self, + source: str + | Path + | pa.Table + | pd.DataFrame, # TODO (mehmet): pa.RecordBatch | pa.Dataset + table_name: str | None = None, + **kwargs: Any, + ) -> ir.Table: + """Register a data set with `table_name` located at `source`. + + Parameters + ---------- + source + The data source(s). May be a path to a file or directory of + parquet/csv files, a pandas dataframe, or a pyarrow table. + table_name + The name of the table + kwargs + Datafusion-specific keyword arguments + + Examples + -------- + Register a csv: + >>> import ibis + >>> conn = ibis.datafusion.connect(config) + >>> conn.register("path/to/data.csv", "my_table") + >>> conn.table("my_table") + + Register a PyArrow table: + >>> import pyarrow as pa + >>> tab = pa.table({"x": [1, 2, 3]}) + >>> conn.register(tab, "my_table") + >>> conn.table("my_table") + + Register a PyArrow dataset: + >>> import pyarrow.dataset as ds + >>> dataset = ds.dataset("path/to/table") + >>> conn.register(dataset, "my_table") + >>> conn.table("my_table") + """ + import pandas as pd + import pyarrow as pa + + if isinstance(source, (str, Path)): + obj = self._get_dataframe_from_path(path=source) + + elif isinstance(source, list) and all( + isinstance(s, (str, Path)) for s in source + ): + obj = pd.concat( + [self._get_dataframe_from_path(path) for path in source], + ignore_index=True, + sort=False, + ) + + elif isinstance(source, (pa.Table, pd.DataFrame)): + obj = source + + else: + raise ValueError(f"Unsupported source type: {type(source)}") + + # TODO (mehmet): Added the following block to make the following tests pass + # - test_register_csv() + # - test_register_iterator_parquet() + # Implicit setting of `table_name` (when it is None) looks strange to me. + if table_name is None: + source_list = source + if not isinstance(source, list): + source_list = [source] + + for source in source_list: + if not isinstance(source, (str, Path)): + continue + + source = str(source) + if source.startswith(("parquet://", "parq://")) or source.endswith( + ("parq", "parquet") + ): + table_name = table_name or gen_name("read_parquet") + elif source.startswith("csv://") or source.endswith(("csv", "csv.gz")): + table_name = table_name or gen_name("read_csv") + elif source.endswith("json"): + table_name = table_name or gen_name("read_json") + + table_name = table_name or gen_name("default") + return self.create_table(table_name, obj, temp=True, overwrite=True) + + def _get_dataframe_from_path(self, path: str | Path) -> pd.DataFrame: + import glob + + import pandas as pd + + dataframe_list = [] + path_str = str(path) + path_normalized = normalize_filename(path) + for file_path in glob.glob(path_normalized): + if path_str.startswith(("parquet://", "parq://")) or path_str.endswith( + ("parq", "parquet") + ): + dataframe = pd.read_parquet(file_path) + elif path_str.startswith("csv://") or path_str.endswith(("csv", "csv.gz")): + dataframe = pd.read_csv(file_path) + elif path_str.endswith("json"): + dataframe = pd.read_json(file_path, lines=True) + else: + raise ValueError(f"Unsupported file_path: {file_path}") + + dataframe_list.append(dataframe) + + return pd.concat(dataframe_list, ignore_index=True, sort=False) + + def read_file( + self, + file_type: str, + path: str | Path, + table_name: str | None = None, + ) -> ir.Table: + """Register a file as a table in the current database. + + Parameters + ---------- + file_type + File type, e.g., parquet, csv, json. + path + The data source. + table_name + An optional name to use for the created table. This defaults to + a sequentially generated name. + + Returns + ------- + ir.Table + The just-registered table + """ + obj = self._get_dataframe_from_path(path) + table_name = table_name or gen_name(f"read_{file_type}") + return self.create_table(table_name, obj, temp=True, overwrite=True) + + def read_parquet(self, path: str | Path, table_name: str | None = None) -> ir.Table: + """Register a parquet file as a table in the current database. + + Parameters + ---------- + path + The data source. + table_name + An optional name to use for the created table. This defaults to + a sequentially generated name. + + Returns + ------- + ir.Table + The just-registered table + """ + return self.read_file(file_type="parquet", path=path, table_name=table_name) + + def read_csv(self, path: str | Path, table_name: str | None = None) -> ir.Table: + """Register a csv file as a table in the current database. + + Parameters + ---------- + path + The data source. + table_name + An optional name to use for the created table. This defaults to + a sequentially generated name. + + Returns + ------- + ir.Table + The just-registered table + """ + return self.read_file(file_type="csv", path=path, table_name=table_name) + + def read_json(self, path: str | Path, table_name: str | None = None) -> ir.Table: + """Register a json file as a table in the current database. + + Parameters + ---------- + path + The data source. + table_name + An optional name to use for the created table. This defaults to + a sequentially generated name. + + Returns + ------- + ir.Table + The just-registered table + """ + return self.read_file(file_type="json", path=path, table_name=table_name) + @classmethod @lru_cache def _get_operations(cls): diff --git a/ibis/backends/flink/registry.py b/ibis/backends/flink/registry.py index 97fca3ff65c04..a668739f15495 100644 --- a/ibis/backends/flink/registry.py +++ b/ibis/backends/flink/registry.py @@ -324,6 +324,7 @@ def _date_sub(translator: ExprTranslator, op: ops.temporal.DateSub) -> str: def _extract_epoch_seconds(translator: ExprTranslator, op: ops.Node) -> str: arg = translator.translate(op.arg) + # return f"UNIX_TIMESTAMP(CAST({arg} AS STRING), 'yyyy-MM-dd HH:mm:ss.SSS')" return f"UNIX_TIMESTAMP(CAST({arg} AS STRING))" diff --git a/ibis/backends/tests/test_param.py b/ibis/backends/tests/test_param.py index 61b6c7ea17e77..89c32783838b4 100644 --- a/ibis/backends/tests/test_param.py +++ b/ibis/backends/tests/test_param.py @@ -86,7 +86,6 @@ def test_scalar_param_array(con): [ "datafusion", "impala", - "flink", "postgres", "pyspark", "druid", @@ -251,7 +250,6 @@ def test_scalar_param_date(backend, alltypes, value): "exasol", ] ) -@pytest.mark.notimpl(["flink"], "WIP") def test_scalar_param_nested(con): param = ibis.param("struct>>>") value = OrderedDict([("x", [OrderedDict([("y", [1.0, 2.0, 3.0])])])]) diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_register.py index 75a121f578d85..0caca60863132 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_register.py @@ -85,7 +85,6 @@ def gzip_csv(data_dir, tmp_path): "bigquery", "clickhouse", "dask", - "flink", "impala", "mssql", "mysql", @@ -111,7 +110,6 @@ def test_register_csv(con, data_dir, fname, in_table_name, out_table_name): "bigquery", "clickhouse", "dask", - "flink", "impala", "mssql", "mysql", @@ -134,7 +132,6 @@ def test_register_csv_gz(con, data_dir, gzip_csv): "bigquery", "clickhouse", "dask", - "flink", "impala", "mssql", "mysql", @@ -190,7 +187,6 @@ def read_table(path: Path) -> Iterator[tuple[str, pa.Table]]: "bigquery", "clickhouse", "dask", - "flink", "impala", "mssql", "mysql", @@ -226,7 +222,6 @@ def test_register_parquet( "clickhouse", "dask", "datafusion", - "flink", "impala", "mssql", "mysql", @@ -265,7 +260,6 @@ def test_register_iterator_parquet( "bigquery", "clickhouse", "dask", - "flink", "impala", "mssql", "mysql", @@ -295,7 +289,6 @@ def test_register_pandas(con): "bigquery", "clickhouse", "dask", - "flink", "impala", "mssql", "mysql", @@ -320,7 +313,6 @@ def test_register_pyarrow_tables(con): "bigquery", "clickhouse", "dask", - "flink", "impala", "mssql", "mysql", @@ -395,9 +387,7 @@ def test_register_garbage(con, monkeypatch): ("functional_alltypes.parquet", "funk_all"), ], ) -@pytest.mark.notyet( - ["flink", "impala", "mssql", "mysql", "postgres", "sqlite", "trino"] -) +@pytest.mark.notyet(["impala", "mssql", "mysql", "postgres", "sqlite", "trino"]) def test_read_parquet(con, tmp_path, data_dir, fname, in_table_name): pq = pytest.importorskip("pyarrow.parquet") @@ -427,7 +417,7 @@ def ft_data(data_dir): @pytest.mark.notyet( - ["flink", "impala", "mssql", "mysql", "pandas", "postgres", "sqlite", "trino"] + ["impala", "mssql", "mysql", "pandas", "postgres", "sqlite", "trino"] ) def test_read_parquet_glob(con, tmp_path, ft_data): pq = pytest.importorskip("pyarrow.parquet") @@ -446,7 +436,7 @@ def test_read_parquet_glob(con, tmp_path, ft_data): @pytest.mark.notyet( - ["flink", "impala", "mssql", "mysql", "pandas", "postgres", "sqlite", "trino"] + ["impala", "mssql", "mysql", "pandas", "postgres", "sqlite", "trino"] ) def test_read_csv_glob(con, tmp_path, ft_data): pc = pytest.importorskip("pyarrow.csv") @@ -469,7 +459,6 @@ def test_read_csv_glob(con, tmp_path, ft_data): "clickhouse", "dask", "datafusion", - "flink", "impala", "mssql", "mysql", @@ -522,9 +511,7 @@ def num_diamonds(data_dir): "in_table_name", [param(None, id="default"), param("fancy_stones", id="file_name")], ) -@pytest.mark.notyet( - ["flink", "impala", "mssql", "mysql", "postgres", "sqlite", "trino"] -) +@pytest.mark.notyet(["impala", "mssql", "mysql", "postgres", "sqlite", "trino"]) def test_read_csv(con, data_dir, in_table_name, num_diamonds): fname = "diamonds.csv" with pushd(data_dir / "csv"): diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index 57371718907e2..32f52a2ecdf95 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -349,6 +349,13 @@ def test_timestamp_extract_milliseconds(backend, alltypes, df): reason="test was adjusted to work with pandas 2.1 output; pyspark doesn't support pandas 2", ) @pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) +# @pytest.mark.broken( +# ["flink"], +# raises=AssertionError, +# reason=( +# "Series values are different (100.0 %)" +# ), +# ) def test_timestamp_extract_epoch_seconds(backend, alltypes, df): expr = alltypes.timestamp_col.epoch_seconds().name("tmp") result = expr.execute()