diff --git a/dali/python/nvidia/dali/types.py b/dali/python/nvidia/dali/types.py index 7c7ba2d49d2..84ab2547907 100644 --- a/dali/python/nvidia/dali/types.py +++ b/dali/python/nvidia/dali/types.py @@ -640,6 +640,9 @@ def Constant(value, dtype=None, shape=None, layout=None, device=None, **kwargs): device is not None or (_is_compatible_array_type(value) and not _is_true_scalar(value)) or isinstance(value, (list, tuple)) + # we force true scalar enums through a Constant node rather than using ScalarConstant + # as they do not support any arithmetic operations + or isinstance(value, (DALIDataType, DALIImageType, DALIInterpType)) or not _is_scalar_shape(shape) or kwargs or layout is not None diff --git a/dali/test/python/operator_2/test_enum_types.py b/dali/test/python/operator_2/test_enum_types.py index b1dcba76404..1e4f2a55505 100644 --- a/dali/test/python/operator_2/test_enum_types.py +++ b/dali/test/python/operator_2/test_enum_types.py @@ -15,9 +15,27 @@ from nvidia.dali import fn, pipeline_def, types import numpy as np +import tree +from nose_utils import assert_raises +from nose2.tools import params -def test_enum_constant_capture(): + +@params( + *[ + # Automatic promotion + lambda value, dtype: fn.copy(value), + # Explicit conversion to constant op + lambda value, dtype: types.Constant(value=value, dtype=dtype), + # Detection of type from value + lambda value, dtype: types.Constant(value=value), + # Explicit type when passed the underlying numeric value of the enum + lambda value, dtype: types.Constant( + value=tree.map_structure(lambda v: v.value, value), dtype=dtype + ), + ] +) +def test_enum_constant_capture(converter): batch_size = 2 scalar_v = types.DALIDataType.INT16 @@ -29,8 +47,8 @@ def test_enum_constant_capture(): @pipeline_def(batch_size=batch_size, device_id=0, num_threads=4) def enum_constant_pipe(): - scalar = fn.copy(scalar_v) - tensor = fn.copy(list_v) + scalar = converter(scalar_v, types.DALIDataType.DATA_TYPE) + tensor = converter(list_v, types.DALIDataType.INTERP_TYPE) scalar_as_int = fn.cast(scalar, dtype=types.DALIDataType.INT32) tensor_as_int = fn.cast(tensor, dtype=types.DALIDataType.INT32) @@ -38,15 +56,22 @@ def enum_constant_pipe(): pipe = enum_constant_pipe() pipe.build() - # Compare the cast values with Python values scalar, tensor, scalar_as_int, tensor_as_int = pipe.run() assert scalar.dtype == types.DALIDataType.DATA_TYPE assert scalar.shape() == [()] * batch_size, f"{scalar.shape}" assert tensor.dtype == types.DALIDataType.INTERP_TYPE - print(tensor.shape) assert tensor.shape() == [(3,)] * batch_size + # Compare the cast values with Python values for i in range(batch_size): assert np.array_equal(np.array(scalar_as_int[i]), np.array(scalar_v.value)) - assert np.array_equal( - np.array(tensor_as_int[i]), np.array([elem.value for elem in list_v], dtype=np.int32) - ) + assert np.array_equal(np.array(tensor_as_int[i]), np.array([elem.value for elem in list_v])) + with assert_raises( + TypeError, + glob="DALI enum types cannot be used with buffer protocol*" + "use `nvidia.dali.fn.cast` to convert", + ): + print(scalar) + + +def test_scalar_constant(): + print(types.ScalarConstant(types.DALIDataType.INT16)) diff --git a/dali/test/python/operator_2/test_random_choice.py b/dali/test/python/operator_2/test_random_choice.py index ebdce522eda..6dd6ccb2ce5 100644 --- a/dali/test/python/operator_2/test_random_choice.py +++ b/dali/test/python/operator_2/test_random_choice.py @@ -252,6 +252,12 @@ def choice_pipe(): "Data type float is not supported for 0D inputs. Supported types are: " "uint8, uint16, uint32, uint64, int8, int16, int32, int64", ), + ( + (types.DALIInterpType.INTERP_CUBIC,), + {}, + "Data type DALIInterpType is not supported for 0D inputs. Supported types are: " + "uint8, uint16, uint32, uint64, int8, int16, int32, int64", + ), ( (5,), {"p": np.array([0.25, 0.5, 0.25])}, @@ -270,3 +276,33 @@ def choice_pipe(): pipe = choice_pipe() pipe.build() pipe.run() + + +def test_enum_choice(): + batch_size = 8 + + interps_to_sample = [types.DALIInterpType.INTERP_LINEAR, types.DALIInterpType.INTERP_CUBIC] + + @pipeline_def(batch_size=batch_size, device_id=0, num_threads=4) + def choice_pipeline(): + interp = fn.random.choice(interps_to_sample, shape=[100]) + interp_as_int = fn.cast(interp, dtype=types.INT32) + imgs = fn.resize( + fn.random.uniform(range=[0, 255], dtype=types.UINT8, shape=(100, 100, 3)), + size=(25, 25), + interp_type=interp[0], + ) + return interp, interp_as_int, imgs + + pipe = choice_pipeline() + pipe.build() + (interp, interp_as_int, imgs) = pipe.run() + assert interp.dtype == types.DALIDataType.INTERP_TYPE + for i in range(batch_size): + check_sample( + np.array(interp_as_int[i]), + size=(100,), + a=np.array([v.value for v in interps_to_sample]), + p=None, + idx=i, + )