Skip to content

Commit

Permalink
apacheGH-35040: [Python] Pyarrow scalar cast should use compute kernel (
Browse files Browse the repository at this point in the history
apache#35395)

### Rationale for this change

Scalar cast should use the computer kernel just like Arrays, instead of its own custom implementation.

### Are these changes tested?

Added test cases for apacheGH-35370, apacheGH-34901, and apacheGH-35040

### Are there any user-facing changes?

The Scalar.cast() API is enhanced and backwards compatible. 
* Closes: apache#35040

Authored-by: Dane Pitkin <[email protected]>
Signed-off-by: Alenka Frim <[email protected]>
  • Loading branch information
danepitkin authored and ArgusLi committed May 15, 2023
1 parent e1316ea commit 71306b7
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 12 deletions.
26 changes: 14 additions & 12 deletions python/pyarrow/scalar.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -67,27 +67,29 @@ cdef class Scalar(_Weakrefable):
"""
return self.wrapped.get().is_valid

def cast(self, object target_type):
def cast(self, object target_type=None, safe=None, options=None, memory_pool=None):
"""
Attempt a safe cast to target data type.
Cast scalar value to another data type.
See :func:`pyarrow.compute.cast` for usage.
Parameters
----------
target_type : DataType or string coercible to DataType
The type to cast the scalar to.
target_type : DataType, default None
Type to cast scalar to.
safe : boolean, default True
Whether to check for conversion errors such as overflow.
options : CastOptions, default None
Additional checks pass by CastOptions
memory_pool : MemoryPool, optional
memory pool to use for allocations during function execution.
Returns
-------
scalar : A Scalar of the given target data type.
"""
cdef:
DataType type = ensure_type(target_type)
shared_ptr[CScalar] result

with nogil:
result = GetResultValue(self.wrapped.get().CastTo(type.sp_type))

return Scalar.wrap(result)
return _pc().cast(self, target_type, safe=safe,
options=options, memory_pool=memory_pool)

def validate(self, *, full=False):
"""
Expand Down
29 changes: 29 additions & 0 deletions python/pyarrow/tests/test_scalars.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,35 @@ def test_cast():
pa.scalar('foo').cast('int32')


def test_cast_timestamp_to_string():
# GH-35370
pytest.importorskip("pytz")
import pytz
dt = datetime.datetime(2000, 1, 1, 0, 0, 0, tzinfo=pytz.utc)
ts = pa.scalar(dt, type=pa.timestamp("ns", tz="UTC"))
assert ts.cast(pa.string()) == pa.scalar('2000-01-01 00:00:00.000000000Z')


def test_cast_float_to_int():
# GH-35040
float_scalar = pa.scalar(1.5, type=pa.float64())
unsafe_cast = float_scalar.cast(pa.int64(), safe=False)
expected_unsafe_cast = pa.scalar(1, type=pa.int64())
assert unsafe_cast == expected_unsafe_cast
with pytest.raises(pa.ArrowInvalid):
float_scalar.cast(pa.int64()) # verify default is safe cast


def test_cast_int_to_float():
# GH-34901
int_scalar = pa.scalar(18014398509481983, type=pa.int64())
unsafe_cast = int_scalar.cast(pa.float64(), safe=False)
expected_unsafe_cast = pa.scalar(18014398509481983.0, type=pa.float64())
assert unsafe_cast == expected_unsafe_cast
with pytest.raises(pa.ArrowInvalid):
int_scalar.cast(pa.float64()) # verify default is safe cast


@pytest.mark.pandas
def test_timestamp():
import pandas as pd
Expand Down

0 comments on commit 71306b7

Please sign in to comment.