diff --git a/src/function/table/arrow.cpp b/src/function/table/arrow.cpp index 9601e41d28c8..306e57d30dcc 100644 --- a/src/function/table/arrow.cpp +++ b/src/function/table/arrow.cpp @@ -266,6 +266,7 @@ unique_ptr ProduceArrowScan(const ArrowScanFunctionData auto &schema = *function.schema_root.arrow_schema.children[col_idx]; parameters.projected_columns.projection_map[idx] = schema.name; parameters.projected_columns.columns.emplace_back(schema.name); + parameters.projected_columns.filter_to_col[idx] = col_idx; } } parameters.filters = filters; diff --git a/src/include/duckdb/function/table/arrow.hpp b/src/include/duckdb/function/table/arrow.hpp index fae51810aae2..df6e49953835 100644 --- a/src/include/duckdb/function/table/arrow.hpp +++ b/src/include/duckdb/function/table/arrow.hpp @@ -33,6 +33,8 @@ struct ArrowInterval { struct ArrowProjectedColumns { unordered_map projection_map; vector columns; + // Map from filter index to column index + unordered_map filter_to_col; }; struct ArrowStreamParameters { diff --git a/tools/pythonpkg/src/arrow/arrow_array_stream.cpp b/tools/pythonpkg/src/arrow/arrow_array_stream.cpp index baf5ee1be856..f7093df8a126 100644 --- a/tools/pythonpkg/src/arrow/arrow_array_stream.cpp +++ b/tools/pythonpkg/src/arrow/arrow_array_stream.cpp @@ -66,11 +66,12 @@ py::object PythonTableArrowArrayStreamFactory::ProduceScanner(py::object &arrow_ auto filters = parameters.filters; auto &column_list = parameters.projected_columns.columns; + auto &filter_to_col = parameters.projected_columns.filter_to_col; bool has_filter = filters && !filters->filters.empty(); py::list projection_list = py::cast(column_list); if (has_filter) { - auto filter = - TransformFilter(*filters, parameters.projected_columns.projection_map, client_properties, arrow_table); + auto filter = TransformFilter(*filters, parameters.projected_columns.projection_map, filter_to_col, + client_properties, arrow_table); if (column_list.empty()) { return arrow_scanner(arrow_obj_handle, py::arg("filter") = filter); } else { @@ -176,7 +177,7 @@ string ConvertTimestampUnit(ArrowDateTimeType unit) { case ArrowDateTimeType::SECONDS: return "s"; default: - throw NotImplementedException("DatetimeType not recognized in ConvertTimestampUnit"); + throw NotImplementedException("DatetimeType not recognized in ConvertTimestampUnit: %d", (int)unit); } } @@ -365,12 +366,13 @@ py::object TransformFilterRecursive(TableFilter *filter, const string &column_na py::object PythonTableArrowArrayStreamFactory::TransformFilter(TableFilterSet &filter_collection, std::unordered_map &columns, + unordered_map filter_to_col, const ClientProperties &config, const ArrowTableType &arrow_table) { auto filters_map = &filter_collection.filters; auto it = filters_map->begin(); D_ASSERT(columns.find(it->first) != columns.end()); - auto &arrow_type = *arrow_table.GetColumns().at(it->first); + auto &arrow_type = *arrow_table.GetColumns().at(filter_to_col.at(it->first)); py::object expression = TransformFilterRecursive(it->second.get(), columns[it->first], config.time_zone, arrow_type); while (it != filters_map->end()) { diff --git a/tools/pythonpkg/src/include/duckdb_python/arrow/arrow_array_stream.hpp b/tools/pythonpkg/src/include/duckdb_python/arrow/arrow_array_stream.hpp index 4ad36b6e606a..e6769f63b1b3 100644 --- a/tools/pythonpkg/src/include/duckdb_python/arrow/arrow_array_stream.hpp +++ b/tools/pythonpkg/src/include/duckdb_python/arrow/arrow_array_stream.hpp @@ -76,6 +76,7 @@ class PythonTableArrowArrayStreamFactory { private: //! We transform a TableFilterSet to an Arrow Expression Object static py::object TransformFilter(TableFilterSet &filters, std::unordered_map &columns, + unordered_map filter_to_col, const ClientProperties &client_properties, const ArrowTableType &arrow_table); static py::object ProduceScanner(py::object &arrow_scanner, py::handle &arrow_obj_handle, diff --git a/tools/pythonpkg/tests/fast/arrow/test_filter_pushdown.py b/tools/pythonpkg/tests/fast/arrow/test_filter_pushdown.py index 21aaab9e56c2..74630f4a74cf 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_filter_pushdown.py +++ b/tools/pythonpkg/tests/fast/arrow/test_filter_pushdown.py @@ -504,6 +504,32 @@ def test_filter_pushdown_integers(self, duckdb_cursor, data_type, value, create_ actual = duckdb_cursor.execute("select * from arrow_table where i = ?", (value,)).fetchall() assert expected == actual + def test_9371(self, duckdb_cursor, tmp_path): + import datetime + import pathlib + + # connect to an in-memory database + duckdb_cursor.execute("SET TimeZone='UTC';") + base_path = tmp_path / "parquet_folder" + base_path.mkdir(exist_ok=True) + file_path = base_path / "test.parquet" + + duckdb_cursor.execute("SET TimeZone='UTC';") + + # Example data + dt = datetime.datetime(2023, 8, 29, 1, tzinfo=datetime.timezone.utc) + + my_arrow_table = pa.Table.from_pydict({'ts': [dt, dt, dt], 'value': [1, 2, 3]}) + df = my_arrow_table.to_pandas() + df = df.set_index("ts") # SET INDEX! (It all works correctly when the index is not set) + df.to_parquet(str(file_path)) + + my_arrow_dataset = ds.dataset(str(file_path)) + res = duckdb_cursor.execute("SELECT * FROM my_arrow_dataset WHERE ts = ?", parameters=[dt]).arrow() + output = duckdb_cursor.sql("select * from res").fetchall() + expected = [(1, dt), (2, dt), (3, dt)] + assert output == expected + @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_date(self, duckdb_cursor, create_table): duckdb_cursor.execute(