diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index 255fcd344e..c7ef450844 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -462,16 +462,8 @@ def to_arrow_scalar(value: Any, arrow_type: pyarrow.DataType) -> Any: def from_arrow_scalar(arrow_value: pyarrow.Scalar) -> Any: - """Converts arrow scalar into Python type. Currently adds "UTC" to naive date times and converts all others to UTC""" - row_value = arrow_value.as_py() - # dates are not represented as datetimes but I see connector-x represents - # datetimes as dates and keeping the exact time inside. probably a bug - # but can be corrected this way - if isinstance(row_value, date) and not isinstance(row_value, datetime): - row_value = pendulum.from_timestamp(arrow_value.cast(pyarrow.int64()).as_py() / 1000) - elif isinstance(row_value, datetime): - row_value = pendulum.instance(row_value).in_tz("UTC") - return row_value + """Converts arrow scalar into Python type.""" + return arrow_value.as_py() TNewColumns = Sequence[Tuple[int, pyarrow.Field, Callable[[pyarrow.Table], Iterable[Any]]]] diff --git a/dlt/extract/incremental/__init__.py b/dlt/extract/incremental/__init__.py index 40734095cf..6d77758d59 100644 --- a/dlt/extract/incremental/__init__.py +++ b/dlt/extract/incremental/__init__.py @@ -30,7 +30,6 @@ coerce_value, py_type_to_sc_type, ) -from dlt.common.utils import without_none from dlt.extract.exceptions import IncrementalUnboundError from dlt.extract.incremental.exceptions import ( diff --git a/dlt/extract/incremental/transform.py b/dlt/extract/incremental/transform.py index 1d213e26c2..a59675ce03 100644 --- a/dlt/extract/incremental/transform.py +++ b/dlt/extract/incremental/transform.py @@ -113,6 +113,19 @@ def __call__( row: TDataItem, ) -> Tuple[bool, bool, bool]: ... + @staticmethod + def _adapt_if_datetime(row_value: Any, last_value: Any) -> Any: + # For datetime cursor, ensure the value is a timezone aware datetime. + # The object saved in state will always be a tz aware pendulum datetime so this ensures values are comparable + if ( + isinstance(row_value, datetime) + and row_value.tzinfo is None + and isinstance(last_value, datetime) + and last_value.tzinfo is not None + ): + row_value = pendulum.instance(row_value).in_tz("UTC") + return row_value + @property def deduplication_disabled(self) -> bool: """Skip deduplication when length of the key is 0 or if lag is applied.""" @@ -185,19 +198,9 @@ def __call__( return None, False, False else: return row, False, False - last_value = self.last_value last_value_func = self.last_value_func - - # For datetime cursor, ensure the value is a timezone aware datetime. - # The object saved in state will always be a tz aware pendulum datetime so this ensures values are comparable - if ( - isinstance(row_value, datetime) - and row_value.tzinfo is None - and isinstance(last_value, datetime) - and last_value.tzinfo is not None - ): - row_value = pendulum.instance(row_value).in_tz("UTC") + row_value = self._adapt_if_datetime(row_value, last_value) # Check whether end_value has been reached # Filter end value ranges exclusively, so in case of "max" function we remove values >= end_value @@ -354,13 +357,8 @@ def __call__( # TODO: Json path support. For now assume the cursor_path is a column name cursor_path = self.cursor_path - - # The new max/min value try: - # NOTE: datetimes are always pendulum in UTC - row_value = from_arrow_scalar(self.compute(tbl[cursor_path])) cursor_data_type = tbl.schema.field(cursor_path).type - row_value_scalar = to_arrow_scalar(row_value, cursor_data_type) except KeyError as e: raise IncrementalCursorPathMissing( self.resource_name, @@ -371,6 +369,12 @@ def __call__( " must be a column name.", ) from e + # The new max/min value + row_value_scalar = self.compute( + tbl[cursor_path] + ) # to_arrow_scalar(row_value, cursor_data_type) + row_value = self._adapt_if_datetime(from_arrow_scalar(row_value_scalar), self.last_value) + if tbl.schema.field(cursor_path).nullable: tbl_without_null, tbl_with_null = self._process_null_at_cursor_path(tbl) tbl = tbl_without_null diff --git a/tests/libs/pyarrow/test_pyarrow.py b/tests/libs/pyarrow/test_pyarrow.py index 07e8d3428d..bec4db2634 100644 --- a/tests/libs/pyarrow/test_pyarrow.py +++ b/tests/libs/pyarrow/test_pyarrow.py @@ -1,4 +1,4 @@ -from datetime import timezone, datetime, timedelta # noqa: I251 +from datetime import timezone, datetime, date, timedelta # noqa: I251 from copy import deepcopy from typing import List, Any @@ -109,25 +109,24 @@ def test_to_arrow_scalar() -> None: assert dt_converted == datetime(2021, 1, 1, 13, 2, 32, tzinfo=timezone.utc) -def test_from_arrow_scalar() -> None: +def test_arrow_type_coercion() -> None: + # coerce UTC python dt into naive arrow dt naive_dt = get_py_arrow_timestamp(6, tz=None) - sc_dt = to_arrow_scalar(datetime(2021, 1, 1, 5, 2, 32), naive_dt) - - # this value is like UTC - py_dt = from_arrow_scalar(sc_dt) - assert isinstance(py_dt, pendulum.DateTime) - # and we convert to explicit UTC - assert py_dt == datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone.utc) - - # converts to UTC - berlin_dt = get_py_arrow_timestamp(6, tz="Europe/Berlin") - sc_dt = to_arrow_scalar( - datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone(timedelta(hours=-8))), berlin_dt - ) + sc_dt = to_arrow_scalar(datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone.utc), naive_dt) + # does not convert to pendulum py_dt = from_arrow_scalar(sc_dt) - assert isinstance(py_dt, pendulum.DateTime) - assert py_dt.tzname() == "UTC" - assert py_dt == datetime(2021, 1, 1, 13, 2, 32, tzinfo=timezone.utc) + assert not isinstance(py_dt, pendulum.DateTime) + assert isinstance(py_dt, datetime) + assert py_dt.tzname() is None + + # coerce datetime into date + py_date = pa.date32() + sc_date = to_arrow_scalar(datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone.utc), py_date) + assert from_arrow_scalar(sc_date) == date(2021, 1, 1) + + py_date = pa.date64() + sc_date = to_arrow_scalar(datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone.utc), py_date) + assert from_arrow_scalar(sc_date) == date(2021, 1, 1) def _row_at_index(table: pa.Table, index: int) -> List[Any]: diff --git a/tests/libs/test_parquet_writer.py b/tests/libs/test_parquet_writer.py index b6a25c5db5..406b72c8c4 100644 --- a/tests/libs/test_parquet_writer.py +++ b/tests/libs/test_parquet_writer.py @@ -12,7 +12,6 @@ from dlt.common.schema.utils import new_column from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.common.time import ensure_pendulum_datetime -from dlt.common.libs.pyarrow import from_arrow_scalar from tests.common.data_writers.utils import get_writer from tests.cases import (