Skip to content

Commit

Permalink
Merge pull request #9377 from Tishj/pyarrow_timestamp_bugfix
Browse files Browse the repository at this point in the history
[PyArrow] Fix bug in timestamp pushdown
  • Loading branch information
Mytherin authored Oct 17, 2023
2 parents 2646836 + 42e22e3 commit 5ec85a7
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/function/table/arrow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ unique_ptr<ArrowArrayStreamWrapper> ProduceArrowScan(const ArrowScanFunctionData
auto &schema = *function.schema_root.arrow_schema.children[col_idx];
parameters.projected_columns.projection_map[idx] = schema.name;
parameters.projected_columns.columns.emplace_back(schema.name);
parameters.projected_columns.filter_to_col[idx] = col_idx;
}
}
parameters.filters = filters;
Expand Down
2 changes: 2 additions & 0 deletions src/include/duckdb/function/table/arrow.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ struct ArrowInterval {
struct ArrowProjectedColumns {
unordered_map<idx_t, string> projection_map;
vector<string> columns;
// Map from filter index to column index
unordered_map<idx_t, idx_t> filter_to_col;
};

struct ArrowStreamParameters {
Expand Down
10 changes: 6 additions & 4 deletions tools/pythonpkg/src/arrow/arrow_array_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,12 @@ py::object PythonTableArrowArrayStreamFactory::ProduceScanner(py::object &arrow_

auto filters = parameters.filters;
auto &column_list = parameters.projected_columns.columns;
auto &filter_to_col = parameters.projected_columns.filter_to_col;
bool has_filter = filters && !filters->filters.empty();
py::list projection_list = py::cast(column_list);
if (has_filter) {
auto filter =
TransformFilter(*filters, parameters.projected_columns.projection_map, client_properties, arrow_table);
auto filter = TransformFilter(*filters, parameters.projected_columns.projection_map, filter_to_col,
client_properties, arrow_table);
if (column_list.empty()) {
return arrow_scanner(arrow_obj_handle, py::arg("filter") = filter);
} else {
Expand Down Expand Up @@ -176,7 +177,7 @@ string ConvertTimestampUnit(ArrowDateTimeType unit) {
case ArrowDateTimeType::SECONDS:
return "s";
default:
throw NotImplementedException("DatetimeType not recognized in ConvertTimestampUnit");
throw NotImplementedException("DatetimeType not recognized in ConvertTimestampUnit: %d", (int)unit);
}
}

Expand Down Expand Up @@ -365,12 +366,13 @@ py::object TransformFilterRecursive(TableFilter *filter, const string &column_na

py::object PythonTableArrowArrayStreamFactory::TransformFilter(TableFilterSet &filter_collection,
std::unordered_map<idx_t, string> &columns,
unordered_map<idx_t, idx_t> filter_to_col,
const ClientProperties &config,
const ArrowTableType &arrow_table) {
auto filters_map = &filter_collection.filters;
auto it = filters_map->begin();
D_ASSERT(columns.find(it->first) != columns.end());
auto &arrow_type = *arrow_table.GetColumns().at(it->first);
auto &arrow_type = *arrow_table.GetColumns().at(filter_to_col.at(it->first));
py::object expression =
TransformFilterRecursive(it->second.get(), columns[it->first], config.time_zone, arrow_type);
while (it != filters_map->end()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class PythonTableArrowArrayStreamFactory {
private:
//! We transform a TableFilterSet to an Arrow Expression Object
static py::object TransformFilter(TableFilterSet &filters, std::unordered_map<idx_t, string> &columns,
unordered_map<idx_t, idx_t> filter_to_col,
const ClientProperties &client_properties, const ArrowTableType &arrow_table);

static py::object ProduceScanner(py::object &arrow_scanner, py::handle &arrow_obj_handle,
Expand Down
26 changes: 26 additions & 0 deletions tools/pythonpkg/tests/fast/arrow/test_filter_pushdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,32 @@ def test_filter_pushdown_integers(self, duckdb_cursor, data_type, value, create_
actual = duckdb_cursor.execute("select * from arrow_table where i = ?", (value,)).fetchall()
assert expected == actual

def test_9371(self, duckdb_cursor, tmp_path):
import datetime
import pathlib

# connect to an in-memory database
duckdb_cursor.execute("SET TimeZone='UTC';")
base_path = tmp_path / "parquet_folder"
base_path.mkdir(exist_ok=True)
file_path = base_path / "test.parquet"

duckdb_cursor.execute("SET TimeZone='UTC';")

# Example data
dt = datetime.datetime(2023, 8, 29, 1, tzinfo=datetime.timezone.utc)

my_arrow_table = pa.Table.from_pydict({'ts': [dt, dt, dt], 'value': [1, 2, 3]})
df = my_arrow_table.to_pandas()
df = df.set_index("ts") # SET INDEX! (It all works correctly when the index is not set)
df.to_parquet(str(file_path))

my_arrow_dataset = ds.dataset(str(file_path))
res = duckdb_cursor.execute("SELECT * FROM my_arrow_dataset WHERE ts = ?", parameters=[dt]).arrow()
output = duckdb_cursor.sql("select * from res").fetchall()
expected = [(1, dt), (2, dt), (3, dt)]
assert output == expected

@pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table])
def test_filter_pushdown_date(self, duckdb_cursor, create_table):
duckdb_cursor.execute(
Expand Down

0 comments on commit 5ec85a7

Please sign in to comment.