Skip to content

Commit

Permalink
Fix slow infer type (mlflow#13912)
Browse files Browse the repository at this point in the history
Signed-off-by: serena-ruan <[email protected]>
  • Loading branch information
serena-ruan authored Nov 29, 2024
1 parent fabfcf3 commit 4800875
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 108 deletions.
76 changes: 29 additions & 47 deletions mlflow/types/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import builtins
import datetime as dt
import importlib.util
import json
import string
from abc import ABC, abstractmethod
Expand All @@ -28,6 +27,13 @@
"or str for the '{arg_name}' argument, but got {passed_type}"
)

try:
import pyspark # noqa: F401

HAS_PYSPARK = True
except ImportError:
HAS_PYSPARK = False


class DataType(Enum):
"""
Expand Down Expand Up @@ -96,48 +102,21 @@ def to_python(self):
return self._python_type

@classmethod
def is_boolean(cls, value):
return type(value) in DataType.boolean.get_all_types()

@classmethod
def is_integer(cls, value):
return type(value) in DataType.integer.get_all_types()

@classmethod
def is_long(cls, value):
return type(value) in DataType.long.get_all_types()

@classmethod
def is_float(cls, value):
return type(value) in DataType.float.get_all_types()

@classmethod
def is_double(cls, value):
return type(value) in DataType.double.get_all_types()

@classmethod
def is_string(cls, value):
return type(value) in DataType.string.get_all_types()

@classmethod
def is_binary(cls, value):
return type(value) in DataType.binary.get_all_types()
def check_type(cls, data_type, value):
types = [data_type.to_numpy(), data_type.to_pandas(), data_type.to_python()]
if data_type.name == "datetime":
types.extend([np.datetime64, dt.datetime])
if data_type.name == "binary":
types.append(bytearray)
if type(value) in types:
return True
if HAS_PYSPARK:
return isinstance(value, type(data_type.to_spark()))
return False

@classmethod
def is_datetime(cls, value):
return type(value) in DataType.datetime.get_all_types()

def get_all_types(self):
types = [self.to_numpy(), self.to_pandas(), self.to_python()]
if importlib.util.find_spec("pyspark") is not None:
types.append(self.to_spark())
if self.name == "datetime":
types.extend([np.datetime64, dt.datetime])
if self.name == "binary":
# This is to support identifying bytearrays as binary data
# for pandas DataFrame schema inference
types.extend([bytearray])
return types
def all_types(cls):
return list(DataType.__members__.values())

@classmethod
def get_spark_types(cls):
Expand Down Expand Up @@ -1225,19 +1204,22 @@ def enforce_param_datatype(cls, name, value, dtype: DataType):
)

# Always convert to python native type for params
if getattr(DataType, f"is_{dtype.name}")(value):
return DataType[dtype.name].to_python()(value)
if DataType.check_type(dtype, value):
return dtype.to_python()(value)

if (
(
DataType.is_integer(value)
DataType.check_type(DataType.integer, value)
and dtype in (DataType.long, DataType.float, DataType.double)
)
or (DataType.is_long(value) and dtype in (DataType.float, DataType.double))
or (DataType.is_float(value) and dtype == DataType.double)
or (
DataType.check_type(DataType.long, value)
and dtype in (DataType.float, DataType.double)
)
or (DataType.check_type(DataType.float, value) and dtype == DataType.double)
):
try:
return DataType[dtype.name].to_python()(value)
return dtype.to_python()(value)
except ValueError as e:
raise MlflowException.invalid_parameter_value(
f"Failed to convert value {value} from type {type(value).__name__} "
Expand Down
50 changes: 29 additions & 21 deletions mlflow/types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
from mlflow.types import DataType
from mlflow.types.schema import (
HAS_PYSPARK,
AnyType,
Array,
ColSpec,
Expand Down Expand Up @@ -194,27 +195,34 @@ def _infer_array_datatype(data: Union[list, np.ndarray]) -> Optional[Array]:
return result


# datetime is not included here
SCALAR_TO_DATATYPE_MAPPING = {
bool: DataType.boolean,
np.bool_: DataType.boolean,
int: DataType.long,
np.int64: DataType.long,
np.int32: DataType.integer,
float: DataType.double,
np.float64: DataType.double,
np.float32: DataType.float,
str: DataType.string,
np.str_: DataType.string,
object: DataType.string,
bytes: DataType.binary,
np.bytes_: DataType.binary,
bytearray: DataType.binary,
}


def _infer_scalar_datatype(data) -> DataType:
if DataType.is_boolean(data):
return DataType.boolean
# Order of is_long & is_integer matters
# as both of their python_types are int
if DataType.is_long(data):
return DataType.long
if DataType.is_integer(data):
return DataType.integer
# Order of is_double & is_float matters
# as both of their python_types are float
if DataType.is_double(data):
return DataType.double
if DataType.is_float(data):
return DataType.float
if DataType.is_string(data):
return DataType.string
if DataType.is_binary(data):
return DataType.binary
if DataType.is_datetime(data):
if data_type := SCALAR_TO_DATATYPE_MAPPING.get(type(data)):
return data_type
if DataType.check_type(DataType.datetime, data):
return DataType.datetime
if HAS_PYSPARK:
for data_type in DataType.all_types():
if isinstance(data, type(data_type.to_spark())):
return data_type
raise MlflowException.invalid_parameter_value(
f"Data {data} is not one of the supported DataType"
)
Expand Down Expand Up @@ -798,11 +806,11 @@ def _infer_type_and_shape(value):
raise MlflowException.invalid_parameter_value(
f"Expected parameters to be 1D array or scalar, got {ndim}D array",
)
if all(DataType.is_datetime(v) for v in value):
if all(DataType.check_type(DataType.datetime, v) for v in value):
return DataType.datetime, (-1,)
value_type = _infer_numpy_dtype(np.array(value).dtype)
return value_type, (-1,)
elif DataType.is_datetime(value):
elif DataType.check_type(DataType.datetime, value):
return DataType.datetime, None
elif np.isscalar(value):
try:
Expand Down
30 changes: 15 additions & 15 deletions tests/pyfunc/test_pyfunc_schema_enforcement.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,19 @@ def param_schema_basic():
class PythonModelWithBasicParams(mlflow.pyfunc.PythonModel):
def predict(self, context, model_input, params=None):
assert isinstance(params, dict)
assert DataType.is_string(params["str_param"])
assert DataType.is_integer(params["int_param"])
assert DataType.is_boolean(params["bool_param"])
assert DataType.is_double(params["double_param"])
assert DataType.is_float(params["float_param"])
assert DataType.is_long(params["long_param"])
assert DataType.is_datetime(params["datetime_param"])
assert isinstance(params["str_param"], str)
assert isinstance(params["int_param"], int)
assert isinstance(params["bool_param"], bool)
assert isinstance(params["double_param"], float)
assert isinstance(params["float_param"], float)
assert isinstance(params["long_param"], int)
assert isinstance(params["datetime_param"], datetime.datetime)
assert isinstance(params["str_list"], list)
assert all(DataType.is_string(x) for x in params["str_list"])
assert all(isinstance(x, str) for x in params["str_list"])
assert isinstance(params["bool_list"], list)
assert all(DataType.is_boolean(x) for x in params["bool_list"])
assert all(isinstance(x, bool) for x in params["bool_list"])
assert isinstance(params["double_array"], list)
assert all(DataType.is_double(x) for x in params["double_array"])
assert all(isinstance(x, float) for x in params["double_array"])
return params


Expand All @@ -114,11 +114,11 @@ def sample_params_with_arrays():
class PythonModelWithArrayParams(mlflow.pyfunc.PythonModel):
def predict(self, context, model_input, params=None):
assert isinstance(params, dict)
assert all(DataType.is_integer(x) for x in params["int_array"])
assert all(DataType.is_double(x) for x in params["double_array"])
assert all(DataType.is_float(x) for x in params["float_array"])
assert all(DataType.is_long(x) for x in params["long_array"])
assert all(DataType.is_datetime(x) for x in params["datetime_array"])
assert all(isinstance(x, int) for x in params["int_array"])
assert all(isinstance(x, float) for x in params["double_array"])
assert all(isinstance(x, float) for x in params["float_array"])
assert all(isinstance(x, int) for x in params["long_array"])
assert all(isinstance(x, datetime.datetime) for x in params["datetime_array"])
return params


Expand Down
56 changes: 31 additions & 25 deletions tests/types/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,33 +761,39 @@ def test_enforce_tensor_spec_variable_signature():


def test_datatype_type_check():
assert DataType.is_string("string")
assert DataType.check_type(DataType.string, "string")

assert DataType.is_integer(1)
assert DataType.is_integer(np.int32(1))
assert not DataType.is_integer(np.int64(1))
integer_type = DataType.integer
assert DataType.check_type(integer_type, 1)
assert DataType.check_type(integer_type, np.int32(1))
assert not DataType.check_type(integer_type, np.int64(1))
# Note that isinstance(True, int) returns True
assert not DataType.is_integer(True)

assert DataType.is_long(1)
assert DataType.is_long(np.int64(1))
assert not DataType.is_long(np.int32(1))

assert DataType.is_boolean(True)
assert DataType.is_boolean(np.bool_(True))
assert not DataType.is_boolean(1)

assert DataType.is_double(1.0)
assert DataType.is_double(np.float64(1.0))
assert not DataType.is_double(np.float32(1.0))

assert DataType.is_float(1.0)
assert DataType.is_float(np.float32(1.0))
assert not DataType.is_float(np.float64(1.0))

assert DataType.is_datetime(datetime.date(2023, 6, 26))
assert DataType.is_datetime(np.datetime64("2023-06-26 00:00:00"))
assert not DataType.is_datetime("2023-06-26 00:00:00")
assert not DataType.check_type(integer_type, True)

long_type = DataType.long
assert DataType.check_type(long_type, 1)
assert DataType.check_type(long_type, np.int64(1))
assert not DataType.check_type(long_type, np.int32(1))

bool_type = DataType.boolean
assert DataType.check_type(bool_type, True)
assert DataType.check_type(bool_type, np.bool_(True))
assert not DataType.check_type(bool_type, 1)

double_type = DataType.double
assert DataType.check_type(double_type, 1.0)
assert DataType.check_type(double_type, np.float64(1.0))
assert not DataType.check_type(double_type, np.float32(1.0))

float_type = DataType.float
assert DataType.check_type(float_type, 1.0)
assert DataType.check_type(float_type, np.float32(1.0))
assert not DataType.check_type(float_type, np.float64(1.0))

datetime_type = DataType.datetime
assert DataType.check_type(datetime_type, datetime.date(2023, 6, 26))
assert DataType.check_type(datetime_type, np.datetime64("2023-06-26 00:00:00"))
assert not DataType.check_type(datetime_type, "2023-06-26 00:00:00")


def test_param_schema_find_duplicates():
Expand Down

0 comments on commit 4800875

Please sign in to comment.