Skip to content

Commit

Permalink
refactor(datafusion): create registry of time udfs to create them onl…
Browse files Browse the repository at this point in the history
…y once
  • Loading branch information
mesejo authored and cpcloud committed Oct 11, 2023
1 parent 70df318 commit 9ed0a89
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 65 deletions.
111 changes: 46 additions & 65 deletions ibis/backends/datafusion/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis import NA
from ibis.backends.datafusion import registry
from ibis.expr.operations.udf import InputType
from ibis.formats.pyarrow import PyArrowType

Expand Down Expand Up @@ -937,34 +938,30 @@ def extract_hour(op, **kw):

@translate.register(ops.ExtractMillisecond)
def extract_millisecond(op, **kw):
def ms(array: pa.Array) -> pa.Array:
return pc.cast(pc.millisecond(array), pa.int32())

extract_milliseconds_udf = df.udf(
ms,
input_types=[PyArrowType.from_ibis(op.arg.dtype)],
return_type=PyArrowType.from_ibis(op.dtype),
volatility="immutable",
name="extract_milliseconds_udf",
)
arg = translate(op.arg, **kw)
return extract_milliseconds_udf(arg)

if op.arg.dtype.is_date():
return registry.UDFS["extract_millisecond_time"](arg)
elif op.arg.dtype.is_timestamp():
return registry.UDFS["extract_millisecond_timestamp"](arg)
else:
raise com.OperationNotDefinedError(
f"The function is not defined for {type(op.arg)}"
)


@translate.register(ops.ExtractSecond)
def extract_second(op, **kw):
def s(array: pa.Array) -> pa.Array:
return pc.cast(pc.second(array), pa.int32())

extract_seconds_udf = df.udf(
s,
input_types=[PyArrowType.from_ibis(op.arg.dtype)],
return_type=PyArrowType.from_ibis(op.dtype),
volatility="immutable",
name="extract_seconds_udf",
)
arg = translate(op.arg, **kw)
return extract_seconds_udf(arg)

if op.arg.dtype.is_date():
return registry.UDFS["extract_second_time"](arg)
elif op.arg.dtype.is_timestamp():
return registry.UDFS["extract_second_timestamp"](arg)
else:
raise com.OperationNotDefinedError(
f"The function is not defined for {type(op.arg)}"
)


@translate.register(ops.ExtractDayOfYear)
Expand All @@ -981,27 +978,16 @@ def extract_day_of_the_week_index(op, **kw):

@translate.register(ops.DayOfWeekName)
def extract_down(op, **kw):
def down(array: pa.Array) -> pa.Array:
return pc.choose(
pc.day_of_week(array),
"Monday",
"Tuesday",
"Wednesday",
"Thursday",
"Friday",
"Saturday",
"Sunday",
)

extract_down_udf = df.udf(
down,
input_types=[PyArrowType.from_ibis(op.arg.dtype)],
return_type=PyArrowType.from_ibis(op.dtype),
volatility="immutable",
name="extract_down_udf",
)
arg = translate(op.arg, **kw)
return extract_down_udf(arg)

if op.arg.dtype.is_date():
return registry.UDFS["extract_down_date"](arg)
elif op.arg.dtype.is_timestamp():
return registry.UDFS["extract_down_timestamp"](arg)
else:
raise com.OperationNotDefinedError(
f"The function is not defined for {type(op.arg)}"
)


@translate.register(ops.Date)
Expand All @@ -1018,32 +1004,27 @@ def extract_week_of_year(op, **kw):

@translate.register(ops.ExtractMicrosecond)
def extract_microsecond(op, **kw):
def us(array: pa.Array) -> pa.Array:
arr = pc.multiply(pc.millisecond(array), 1000)
return pc.cast(pc.add(pc.microsecond(array), arr), pa.int32())

extract_microseconds_udf = df.udf(
us,
input_types=[PyArrowType.from_ibis(op.arg.dtype)],
return_type=PyArrowType.from_ibis(op.dtype),
volatility="immutable",
name="extract_microseconds_udf",
)
arg = translate(op.arg, **kw)
return extract_microseconds_udf(arg)

if op.arg.dtype.is_time():
return registry.UDFS["extract_microsecond_time"](arg)
elif op.arg.dtype.is_timestamp():
return registry.UDFS["extract_microsecond_timestamp"](arg)
else:
raise com.OperationNotDefinedError(
f"The function is not defined for {type(op.arg)}"
)


@translate.register(ops.ExtractEpochSeconds)
def extract_epoch_seconds(op, **kw):
def epoch_seconds(array: pa.Array) -> pa.Array:
return pc.cast(pc.divide(pc.cast(array, pa.int64()), 1000_000), pa.int32())

extract_epoch_seconds_udf = df.udf(
epoch_seconds,
input_types=[PyArrowType.from_ibis(op.arg.dtype)],
return_type=PyArrowType.from_ibis(op.dtype),
volatility="immutable",
name="extract_epoch_seconds_udf",
)
arg = translate(op.arg, **kw)
return extract_epoch_seconds_udf(arg)

if op.arg.dtype.is_time():
return registry.UDFS["extract_epoch_seconds_time"](arg)
elif op.arg.dtype.is_timestamp():
return registry.UDFS["extract_epoch_seconds_timestamp"](arg)
else:
raise com.OperationNotDefinedError(
f"The function is not defined for {type(op.arg)}"
)
114 changes: 114 additions & 0 deletions ibis/backends/datafusion/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from __future__ import annotations

import datafusion as df
import pyarrow as pa
import pyarrow.compute as pc

import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.expr.decompile import _to_snake_case
from ibis.formats.pyarrow import PyArrowType


def create_udf(op, udf, input_types, volatility="immutable", name=None):
return df.udf(
udf,
input_types=list(map(PyArrowType.from_ibis, input_types)),
return_type=PyArrowType.from_ibis(op.dtype),
volatility=volatility,
name=_to_snake_case(op.__name__) if name is None else name,
)


def extract_microsecond(array: pa.Array) -> pa.Array:
arr = pc.multiply(pc.millisecond(array), 1000)
return pc.cast(pc.add(pc.microsecond(array), arr), pa.int32())


def epoch_seconds(array: pa.Array) -> pa.Array:
return pc.cast(pc.divide(pc.cast(array, pa.int64()), 1000_000), pa.int32())


def extract_down(array: pa.Array) -> pa.Array:
return pc.choose(
pc.day_of_week(array),
"Monday",
"Tuesday",
"Wednesday",
"Thursday",
"Friday",
"Saturday",
"Sunday",
)


def extract_second(array: pa.Array) -> pa.Array:
return pc.cast(pc.second(array), pa.int32())


def extract_millisecond(array: pa.Array) -> pa.Array:
return pc.cast(pc.millisecond(array), pa.int32())


UDFS = {
"extract_microseconds_time": create_udf(
ops.ExtractMicrosecond,
extract_microsecond,
input_types=[dt.time],
name="extract_microseconds_time",
),
"extract_microsecond_timestamp": create_udf(
ops.ExtractMicrosecond,
extract_microsecond,
input_types=[dt.timestamp],
name="extract_microseconds_timestamp",
),
"extract_epoch_seconds_time": create_udf(
ops.ExtractEpochSeconds,
epoch_seconds,
input_types=[dt.time],
name="extract_epoch_seconds_time",
),
"extract_epoch_seconds_timestamp": create_udf(
ops.ExtractEpochSeconds,
epoch_seconds,
input_types=[dt.timestamp],
name="extract_epoch_seconds_timestamp",
),
"extract_down_date": create_udf(
ops.DayOfWeekName,
extract_down,
input_types=[dt.date],
name="extract_down_date",
),
"extract_down_timestamp": create_udf(
ops.DayOfWeekName,
extract_down,
input_types=[dt.timestamp],
name="extract_down_timestamp",
),
"extract_second_time": create_udf(
ops.ExtractSecond,
extract_second,
input_types=[dt.time],
name="extract_second_time",
),
"extract_second_timestamp": create_udf(
ops.ExtractSecond,
extract_second,
input_types=[dt.timestamp],
name="extract_second_timestamp",
),
"extract_millisecond_time": create_udf(
ops.ExtractMillisecond,
extract_millisecond,
input_types=[dt.time],
name="extract_millisecond_time",
),
"extract_millisecond_timestamp": create_udf(
ops.ExtractMillisecond,
extract_millisecond,
input_types=[dt.timestamp],
name="extract_millisecond_timestamp",
),
}

0 comments on commit 9ed0a89

Please sign in to comment.