From 053b5eecb582a93bc92e6793768fc66d2c5d5046 Mon Sep 17 00:00:00 2001 From: Dane Pitkin <48041712+danepitkin@users.noreply.github.com> Date: Thu, 11 May 2023 14:26:01 -0400 Subject: [PATCH] GH-35040: [Python] Pyarrow scalar cast should use compute kernel (#35395) ### Rationale for this change Scalar cast should use the computer kernel just like Arrays, instead of its own custom implementation. ### Are these changes tested? Added test cases for GH-35370, GH-34901, and GH-35040 ### Are there any user-facing changes? The Scalar.cast() API is enhanced and backwards compatible. * Closes: #35040 Authored-by: Dane Pitkin Signed-off-by: Alenka Frim --- python/pyarrow/scalar.pxi | 26 +++++++++++++------------ python/pyarrow/tests/test_scalars.py | 29 ++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 12 deletions(-) diff --git a/python/pyarrow/scalar.pxi b/python/pyarrow/scalar.pxi index ce122a91fab11..f438c8847bb02 100644 --- a/python/pyarrow/scalar.pxi +++ b/python/pyarrow/scalar.pxi @@ -67,27 +67,29 @@ cdef class Scalar(_Weakrefable): """ return self.wrapped.get().is_valid - def cast(self, object target_type): + def cast(self, object target_type=None, safe=None, options=None, memory_pool=None): """ - Attempt a safe cast to target data type. + Cast scalar value to another data type. + + See :func:`pyarrow.compute.cast` for usage. Parameters ---------- - target_type : DataType or string coercible to DataType - The type to cast the scalar to. + target_type : DataType, default None + Type to cast scalar to. + safe : boolean, default True + Whether to check for conversion errors such as overflow. + options : CastOptions, default None + Additional checks pass by CastOptions + memory_pool : MemoryPool, optional + memory pool to use for allocations during function execution. Returns ------- scalar : A Scalar of the given target data type. """ - cdef: - DataType type = ensure_type(target_type) - shared_ptr[CScalar] result - - with nogil: - result = GetResultValue(self.wrapped.get().CastTo(type.sp_type)) - - return Scalar.wrap(result) + return _pc().cast(self, target_type, safe=safe, + options=options, memory_pool=memory_pool) def validate(self, *, full=False): """ diff --git a/python/pyarrow/tests/test_scalars.py b/python/pyarrow/tests/test_scalars.py index 7b00acd07242d..ca2d29e5dac25 100644 --- a/python/pyarrow/tests/test_scalars.py +++ b/python/pyarrow/tests/test_scalars.py @@ -295,6 +295,35 @@ def test_cast(): pa.scalar('foo').cast('int32') +def test_cast_timestamp_to_string(): + # GH-35370 + pytest.importorskip("pytz") + import pytz + dt = datetime.datetime(2000, 1, 1, 0, 0, 0, tzinfo=pytz.utc) + ts = pa.scalar(dt, type=pa.timestamp("ns", tz="UTC")) + assert ts.cast(pa.string()) == pa.scalar('2000-01-01 00:00:00.000000000Z') + + +def test_cast_float_to_int(): + # GH-35040 + float_scalar = pa.scalar(1.5, type=pa.float64()) + unsafe_cast = float_scalar.cast(pa.int64(), safe=False) + expected_unsafe_cast = pa.scalar(1, type=pa.int64()) + assert unsafe_cast == expected_unsafe_cast + with pytest.raises(pa.ArrowInvalid): + float_scalar.cast(pa.int64()) # verify default is safe cast + + +def test_cast_int_to_float(): + # GH-34901 + int_scalar = pa.scalar(18014398509481983, type=pa.int64()) + unsafe_cast = int_scalar.cast(pa.float64(), safe=False) + expected_unsafe_cast = pa.scalar(18014398509481983.0, type=pa.float64()) + assert unsafe_cast == expected_unsafe_cast + with pytest.raises(pa.ArrowInvalid): + int_scalar.cast(pa.float64()) # verify default is safe cast + + @pytest.mark.pandas def test_timestamp(): import pandas as pd