diff --git a/ibis/backends/datafusion/compiler.py b/ibis/backends/datafusion/compiler.py index 41095f04c99d..44eb7323943b 100644 --- a/ibis/backends/datafusion/compiler.py +++ b/ibis/backends/datafusion/compiler.py @@ -14,6 +14,7 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis import NA +from ibis.backends.datafusion import registry from ibis.expr.operations.udf import InputType from ibis.formats.pyarrow import PyArrowType @@ -937,34 +938,30 @@ def extract_hour(op, **kw): @translate.register(ops.ExtractMillisecond) def extract_millisecond(op, **kw): - def ms(array: pa.Array) -> pa.Array: - return pc.cast(pc.millisecond(array), pa.int32()) - - extract_milliseconds_udf = df.udf( - ms, - input_types=[PyArrowType.from_ibis(op.arg.dtype)], - return_type=PyArrowType.from_ibis(op.dtype), - volatility="immutable", - name="extract_milliseconds_udf", - ) arg = translate(op.arg, **kw) - return extract_milliseconds_udf(arg) + + if op.arg.dtype.is_date(): + return registry.UDFS["extract_millisecond_time"](arg) + elif op.arg.dtype.is_timestamp(): + return registry.UDFS["extract_millisecond_timestamp"](arg) + else: + raise com.OperationNotDefinedError( + f"The function is not defined for {type(op.arg)}" + ) @translate.register(ops.ExtractSecond) def extract_second(op, **kw): - def s(array: pa.Array) -> pa.Array: - return pc.cast(pc.second(array), pa.int32()) - - extract_seconds_udf = df.udf( - s, - input_types=[PyArrowType.from_ibis(op.arg.dtype)], - return_type=PyArrowType.from_ibis(op.dtype), - volatility="immutable", - name="extract_seconds_udf", - ) arg = translate(op.arg, **kw) - return extract_seconds_udf(arg) + + if op.arg.dtype.is_date(): + return registry.UDFS["extract_second_time"](arg) + elif op.arg.dtype.is_timestamp(): + return registry.UDFS["extract_second_timestamp"](arg) + else: + raise com.OperationNotDefinedError( + f"The function is not defined for {type(op.arg)}" + ) @translate.register(ops.ExtractDayOfYear) @@ -981,27 +978,16 @@ def extract_day_of_the_week_index(op, **kw): @translate.register(ops.DayOfWeekName) def extract_down(op, **kw): - def down(array: pa.Array) -> pa.Array: - return pc.choose( - pc.day_of_week(array), - "Monday", - "Tuesday", - "Wednesday", - "Thursday", - "Friday", - "Saturday", - "Sunday", - ) - - extract_down_udf = df.udf( - down, - input_types=[PyArrowType.from_ibis(op.arg.dtype)], - return_type=PyArrowType.from_ibis(op.dtype), - volatility="immutable", - name="extract_down_udf", - ) arg = translate(op.arg, **kw) - return extract_down_udf(arg) + + if op.arg.dtype.is_date(): + return registry.UDFS["extract_down_date"](arg) + elif op.arg.dtype.is_timestamp(): + return registry.UDFS["extract_down_timestamp"](arg) + else: + raise com.OperationNotDefinedError( + f"The function is not defined for {type(op.arg)}" + ) @translate.register(ops.Date) @@ -1018,32 +1004,27 @@ def extract_week_of_year(op, **kw): @translate.register(ops.ExtractMicrosecond) def extract_microsecond(op, **kw): - def us(array: pa.Array) -> pa.Array: - arr = pc.multiply(pc.millisecond(array), 1000) - return pc.cast(pc.add(pc.microsecond(array), arr), pa.int32()) - - extract_microseconds_udf = df.udf( - us, - input_types=[PyArrowType.from_ibis(op.arg.dtype)], - return_type=PyArrowType.from_ibis(op.dtype), - volatility="immutable", - name="extract_microseconds_udf", - ) arg = translate(op.arg, **kw) - return extract_microseconds_udf(arg) + + if op.arg.dtype.is_time(): + return registry.UDFS["extract_microsecond_time"](arg) + elif op.arg.dtype.is_timestamp(): + return registry.UDFS["extract_microsecond_timestamp"](arg) + else: + raise com.OperationNotDefinedError( + f"The function is not defined for {type(op.arg)}" + ) @translate.register(ops.ExtractEpochSeconds) def extract_epoch_seconds(op, **kw): - def epoch_seconds(array: pa.Array) -> pa.Array: - return pc.cast(pc.divide(pc.cast(array, pa.int64()), 1000_000), pa.int32()) - - extract_epoch_seconds_udf = df.udf( - epoch_seconds, - input_types=[PyArrowType.from_ibis(op.arg.dtype)], - return_type=PyArrowType.from_ibis(op.dtype), - volatility="immutable", - name="extract_epoch_seconds_udf", - ) arg = translate(op.arg, **kw) - return extract_epoch_seconds_udf(arg) + + if op.arg.dtype.is_time(): + return registry.UDFS["extract_epoch_seconds_time"](arg) + elif op.arg.dtype.is_timestamp(): + return registry.UDFS["extract_epoch_seconds_timestamp"](arg) + else: + raise com.OperationNotDefinedError( + f"The function is not defined for {type(op.arg)}" + ) diff --git a/ibis/backends/datafusion/registry.py b/ibis/backends/datafusion/registry.py new file mode 100644 index 000000000000..7b152d827b5a --- /dev/null +++ b/ibis/backends/datafusion/registry.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import datafusion as df +import pyarrow as pa +import pyarrow.compute as pc + +import ibis.expr.datatypes as dt +import ibis.expr.operations as ops +from ibis.expr.decompile import _to_snake_case +from ibis.formats.pyarrow import PyArrowType + + +def create_udf(op, udf, input_types, volatility="immutable", name=None): + return df.udf( + udf, + input_types=list(map(PyArrowType.from_ibis, input_types)), + return_type=PyArrowType.from_ibis(op.dtype), + volatility=volatility, + name=_to_snake_case(op.__name__) if name is None else name, + ) + + +def extract_microsecond(array: pa.Array) -> pa.Array: + arr = pc.multiply(pc.millisecond(array), 1000) + return pc.cast(pc.add(pc.microsecond(array), arr), pa.int32()) + + +def epoch_seconds(array: pa.Array) -> pa.Array: + return pc.cast(pc.divide(pc.cast(array, pa.int64()), 1000_000), pa.int32()) + + +def extract_down(array: pa.Array) -> pa.Array: + return pc.choose( + pc.day_of_week(array), + "Monday", + "Tuesday", + "Wednesday", + "Thursday", + "Friday", + "Saturday", + "Sunday", + ) + + +def extract_second(array: pa.Array) -> pa.Array: + return pc.cast(pc.second(array), pa.int32()) + + +def extract_millisecond(array: pa.Array) -> pa.Array: + return pc.cast(pc.millisecond(array), pa.int32()) + + +UDFS = { + "extract_microseconds_time": create_udf( + ops.ExtractMicrosecond, + extract_microsecond, + input_types=[dt.time], + name="extract_microseconds_time", + ), + "extract_microsecond_timestamp": create_udf( + ops.ExtractMicrosecond, + extract_microsecond, + input_types=[dt.timestamp], + name="extract_microseconds_timestamp", + ), + "extract_epoch_seconds_time": create_udf( + ops.ExtractEpochSeconds, + epoch_seconds, + input_types=[dt.time], + name="extract_epoch_seconds_time", + ), + "extract_epoch_seconds_timestamp": create_udf( + ops.ExtractEpochSeconds, + epoch_seconds, + input_types=[dt.timestamp], + name="extract_epoch_seconds_timestamp", + ), + "extract_down_date": create_udf( + ops.DayOfWeekName, + extract_down, + input_types=[dt.date], + name="extract_down_date", + ), + "extract_down_timestamp": create_udf( + ops.DayOfWeekName, + extract_down, + input_types=[dt.timestamp], + name="extract_down_timestamp", + ), + "extract_second_time": create_udf( + ops.ExtractSecond, + extract_second, + input_types=[dt.time], + name="extract_second_time", + ), + "extract_second_timestamp": create_udf( + ops.ExtractSecond, + extract_second, + input_types=[dt.timestamp], + name="extract_second_timestamp", + ), + "extract_millisecond_time": create_udf( + ops.ExtractMillisecond, + extract_millisecond, + input_types=[dt.time], + name="extract_millisecond_time", + ), + "extract_millisecond_timestamp": create_udf( + ops.ExtractMillisecond, + extract_millisecond, + input_types=[dt.timestamp], + name="extract_millisecond_timestamp", + ), +}