Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-966003: Fix Arrow return value for zero-row queries #1801

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions src/snowflake/connector/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)

from typing_extensions import Self
from warnings import warn

from snowflake.connector.result_batch import create_batches_from_response
from snowflake.connector.result_set import ResultSet
Expand Down Expand Up @@ -1117,14 +1118,31 @@ def fetch_arrow_batches(self) -> Iterator[Table]:
)
return self._result_set._fetch_arrow_batches()

def fetch_arrow_all(self) -> Table | None:
@overload
def fetch_arrow_all(self, force_return_table: Literal[False] = False) -> None:
...

@overload
def fetch_arrow_all(self, force_return_table: Literal[True] = True) -> Table:
...

def fetch_arrow_all(self, force_return_table: bool = False) -> Table | None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shall also update the changelog

"""
Args:
force_return_table: Set to True so that when the query returns zero rows.

Behavior will change from returning None to returning an empty pyarrow table
with schema using the highest bit length for each column. Future behaviour will
be as if force_return_table = True, similar to the fetch_pandas_all method.
"""
self.check_can_use_arrow_resultset()

thomasaarholt marked this conversation as resolved.
Show resolved Hide resolved
if self._prefetch_hook is not None:
self._prefetch_hook()
if self._query_result_format != "arrow":
raise NotSupportedError
self._log_telemetry_job_data(TelemetryField.ARROW_FETCH_ALL, TelemetryData.TRUE)
return self._result_set._fetch_arrow_all()
return self._result_set._fetch_arrow_all(force_return_table=force_return_table)

def fetch_pandas_batches(self, **kwargs: Any) -> Iterator[DataFrame]:
"""Fetches a single Arrow Table."""
Expand Down
23 changes: 20 additions & 3 deletions src/snowflake/connector/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,16 @@
from concurrent.futures import Future
from concurrent.futures.thread import ThreadPoolExecutor
from logging import getLogger
from typing import TYPE_CHECKING, Any, Callable, Deque, Iterable, Iterator
from typing import (
TYPE_CHECKING,
Any,
Callable,
Deque,
Iterable,
Iterator,
Literal,
overload,
)

from .constants import IterUnit
from .errors import NotSupportedError
Expand Down Expand Up @@ -164,13 +173,21 @@ def _fetch_arrow_batches(
self._can_create_arrow_iter()
return self._create_iter(iter_unit=IterUnit.TABLE_UNIT, structure="arrow")

def _fetch_arrow_all(self) -> Table | None:
@overload
def _fetch_arrow_all(self, force_return_table: Literal[False] = False) -> None:
...

@overload
def _fetch_arrow_all(self, force_return_table: Literal[True] = True) -> Table:
...

def _fetch_arrow_all(self, force_return_table: bool = False) -> Table | None:
"""Fetches a single Arrow Table from all of the ``ResultBatch``."""
tables = list(self._fetch_arrow_batches())
if tables:
return pa.concat_tables(tables)
else:
return None
return self.batches[0].to_arrow() if force_return_table else None

def _fetch_pandas_batches(self, **kwargs) -> Iterator[DataFrame]:
"""Fetches Pandas dataframes in batches, where batch refers to Snowflake Chunk.
Expand Down
22 changes: 16 additions & 6 deletions test/integ/pandas/test_arrow_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,9 +697,7 @@ def validate_pandas(
assert abs(c_case - c_new) < epsilon, (
"{} row, {} column: original value is {}, "
"new value is {}, epsilon is {} \
values are not equal".format(
i, j, cases[i], c_new, epsilon
)
values are not equal".format(i, j, cases[i], c_new, epsilon)
)


Expand Down Expand Up @@ -831,9 +829,7 @@ def fetch_pandas(conn_cnx, sql, row_count, col_count, method="one"):
# verify the correctness
# only do it when fetch one dataframe
if method == "one":
assert (
df_old.shape == df_new.shape
), "the shape of old dataframe is {}, the shape of new dataframe is {}, \
assert df_old.shape == df_new.shape, "the shape of old dataframe is {}, the shape of new dataframe is {}, \
shapes are not equal".format(
df_old.shape, df_new.shape
)
Expand Down Expand Up @@ -1167,6 +1163,20 @@ def test_simple_arrow_fetch(conn_cnx):
assert lo == rowcount


def test_arrow_zero_rows(conn_cnx):
with conn_cnx() as cnx:
with cnx.cursor() as cur:
cur.execute(SQL_ENABLE_ARROW)
cur.execute("select 1::NUMBER(38,0) limit 0")
table = cur.fetch_arrow_all(force_return_table=True)
# Snowflake will return an integer dtype with maximum bit-length if
# no rows are returned
assert table.schema[0].type == pyarrow.int64()
cur.execute("select 1::NUMBER(38,0) limit 0")
# test default behavior
assert cur.fetch_arrow_all(force_return_table=False) is None


@pytest.mark.parametrize("fetch_fn_name", ["to_arrow", "to_pandas", "create_iter"])
@pytest.mark.parametrize("pass_connection", [True, False])
def test_sessions_used(conn_cnx, fetch_fn_name, pass_connection):
Expand Down
Loading