Skip to content

Commit

Permalink
feat: Add DataType inference from Python types (#3555)
Browse files Browse the repository at this point in the history
Closes #3549

---------

Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
  • Loading branch information
jaychia and Jay Chia authored Dec 12, 2024
1 parent 6ae4e77 commit da6f499
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 10 deletions.
40 changes: 39 additions & 1 deletion daft/datatype.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union

from daft.context import get_context
from daft.daft import ImageMode, PyDataType, PyTimeUnit
Expand Down Expand Up @@ -83,6 +83,40 @@ def __init__(self) -> None:
"use a creator method like DataType.int32() or use DataType.from_arrow_type(pa_type)"
)

@classmethod
def _infer_type(cls, user_provided_type: DataTypeLike) -> DataType:
from typing import get_args, get_origin

if isinstance(user_provided_type, DataType):
return user_provided_type
elif isinstance(user_provided_type, dict):
return DataType.struct({k: DataType._infer_type(user_provided_type[k]) for k in user_provided_type})
elif get_origin(user_provided_type) is not None:
origin_type = get_origin(user_provided_type)
if origin_type is list:
child_type = get_args(user_provided_type)[0]
return DataType.list(DataType._infer_type(child_type))
elif origin_type is dict:
(key_type, val_type) = get_args(user_provided_type)
return DataType.map(DataType._infer_type(key_type), DataType._infer_type(val_type))
else:
raise ValueError(f"Unrecognized Python origin type, cannot convert to Daft type: {origin_type}")
elif isinstance(user_provided_type, type):
if user_provided_type is str:
return DataType.string()
elif user_provided_type is int:
return DataType.int64()
elif user_provided_type is float:
return DataType.float64()
elif user_provided_type is bytes:
return DataType.binary()
elif user_provided_type is object:
return DataType.python()
else:
raise ValueError(f"Unrecognized Python type, cannot convert to Daft type: {user_provided_type}")
else:
raise ValueError(f"Unable to infer Daft DataType for provided value: {user_provided_type}")

@staticmethod
def _from_pydatatype(pydt: PyDataType) -> DataType:
dt = DataType.__new__(DataType)
Expand Down Expand Up @@ -538,6 +572,10 @@ def __hash__(self) -> int:
return self._dtype.__hash__()


# Type alias for a union of types that can be inferred into a DataType
DataTypeLike = Union[DataType, type]


_EXT_TYPE_REGISTERED = False
_STATIC_DAFT_EXTENSION = None

Expand Down
16 changes: 10 additions & 6 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from daft.daft import udf as _udf
from daft.daft import url_download as _url_download
from daft.daft import utf8_count_matches as _utf8_count_matches
from daft.datatype import DataType, TimeUnit
from daft.datatype import DataType, DataTypeLike, TimeUnit
from daft.dependencies import pa
from daft.expressions.testing import expr_structurally_equal
from daft.logical.schema import Field, Schema
Expand Down Expand Up @@ -542,7 +542,7 @@ def alias(self, name: builtins.str) -> Expression:
expr = self._expr.alias(name)
return Expression._from_pyexpr(expr)

def cast(self, dtype: DataType) -> Expression:
def cast(self, dtype: DataTypeLike) -> Expression:
"""Casts an expression to the given datatype if possible.
The following combinations of datatype casting is valid:
Expand Down Expand Up @@ -622,8 +622,10 @@ def cast(self, dtype: DataType) -> Expression:
Returns:
Expression: Expression with the specified new datatype
"""
assert isinstance(dtype, DataType)
expr = self._expr.cast(dtype._dtype)
assert isinstance(dtype, (DataType, type))
inferred_dtype = DataType._infer_type(dtype)

expr = self._expr.cast(inferred_dtype._dtype)
return Expression._from_pyexpr(expr)

def ceil(self) -> Expression:
Expand Down Expand Up @@ -999,7 +1001,7 @@ def if_else(self, if_true: Expression, if_false: Expression) -> Expression:
if_false = Expression._to_expression(if_false)
return Expression._from_pyexpr(self._expr.if_else(if_true._expr, if_false._expr))

def apply(self, func: Callable, return_dtype: DataType) -> Expression:
def apply(self, func: Callable, return_dtype: DataTypeLike) -> Expression:
"""Apply a function on each value in a given expression.
.. NOTE::
Expand Down Expand Up @@ -1039,6 +1041,8 @@ def apply(self, func: Callable, return_dtype: DataType) -> Expression:
"""
from daft.udf import UDF

inferred_return_dtype = DataType._infer_type(return_dtype)

def batch_func(self_series):
return [func(x) for x in self_series.to_pylist()]

Expand All @@ -1050,7 +1054,7 @@ def batch_func(self_series):
return UDF(
inner=batch_func,
name=name,
return_dtype=return_dtype,
return_dtype=inferred_return_dtype,
)(self)

def is_null(self) -> Expression:
Expand Down
7 changes: 4 additions & 3 deletions daft/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, Callable, Dict, Optional, Tuple, Union

from daft.daft import PyDataType, ResourceRequest
from daft.datatype import DataType
from daft.datatype import DataType, DataTypeLike
from daft.dependencies import np, pa
from daft.expressions import Expression
from daft.series import PySeries, Series
Expand Down Expand Up @@ -394,7 +394,7 @@ def __hash__(self) -> int:

def udf(
*,
return_dtype: DataType,
return_dtype: DataTypeLike,
num_cpus: float | None = None,
num_gpus: float | None = None,
memory_bytes: int | None = None,
Expand Down Expand Up @@ -511,6 +511,7 @@ def udf(
Returns:
Callable[[UserDefinedPyFuncLike], UDF]: UDF decorator - converts a user-provided Python function as a UDF that can be called on Expressions
"""
inferred_return_dtype = DataType._infer_type(return_dtype)

def _udf(f: UserDefinedPyFuncLike) -> UDF:
# Grab a name for the UDF. It **should** be unique.
Expand All @@ -534,7 +535,7 @@ def _udf(f: UserDefinedPyFuncLike) -> UDF:
return UDF(
inner=f,
name=name,
return_dtype=return_dtype,
return_dtype=inferred_return_dtype,
resource_request=resource_request,
batch_size=batch_size,
)
Expand Down
58 changes: 58 additions & 0 deletions tests/test_datatypes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import copy
import sys
from typing import Dict, List

import pytest

Expand Down Expand Up @@ -30,3 +32,59 @@
def test_datatype_pickling(dtype) -> None:
copy_dtype = copy.deepcopy(dtype)
assert copy_dtype == dtype


@pytest.mark.parametrize(
["source", "expected"],
[
(str, DataType.string()),
(int, DataType.int64()),
(float, DataType.float64()),
(bytes, DataType.binary()),
(object, DataType.python()),
(
{"foo": str, "bar": int},
DataType.struct({"foo": DataType.string(), "bar": DataType.int64()}),
),
],
)
def test_datatype_parsing(source, expected):
assert DataType._infer_type(source) == expected


# These tests are only valid for more modern versions of Python, but can't be skipped in the conventional
# way either because we cannot even run the subscripting during import-time
if sys.version_info >= (3, 9):

@pytest.mark.parametrize(
["source", "expected"],
[
# These tests must be run in later version of Python that allow for subscripting of types
(list[str], DataType.list(DataType.string())),
(dict[str, int], DataType.map(DataType.string(), DataType.int64())),
(
{"foo": list[str], "bar": int},
DataType.struct({"foo": DataType.list(DataType.string()), "bar": DataType.int64()}),
),
(list[list[str]], DataType.list(DataType.list(DataType.string()))),
],
)
def test_subscripted_datatype_parsing(source, expected):
assert DataType._infer_type(source) == expected


@pytest.mark.parametrize(
["source", "expected"],
[
# These tests must be run in later version of Python that allow for subscripting of types
(List[str], DataType.list(DataType.string())),
(Dict[str, int], DataType.map(DataType.string(), DataType.int64())),
(
{"foo": List[str], "bar": int},
DataType.struct({"foo": DataType.list(DataType.string()), "bar": DataType.int64()}),
),
(List[List[str]], DataType.list(DataType.list(DataType.string()))),
],
)
def test_legacy_subscripted_datatype_parsing(source, expected):
assert DataType._infer_type(source) == expected

0 comments on commit da6f499

Please sign in to comment.