Skip to content

Commit

Permalink
Use only Constant nodes for enums, add more tests
Browse files Browse the repository at this point in the history
Signed-off-by: Krzysztof Lecki <[email protected]>
  • Loading branch information
klecki committed Apr 12, 2024
1 parent 3f40adc commit e3e9be0
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 8 deletions.
3 changes: 3 additions & 0 deletions dali/python/nvidia/dali/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,9 @@ def Constant(value, dtype=None, shape=None, layout=None, device=None, **kwargs):
device is not None
or (_is_compatible_array_type(value) and not _is_true_scalar(value))
or isinstance(value, (list, tuple))
# we force true scalar enums through a Constant node rather than using ScalarConstant
# as they do not support any arithmetic operations
or isinstance(value, (DALIDataType, DALIImageType, DALIInterpType))
or not _is_scalar_shape(shape)
or kwargs
or layout is not None
Expand Down
41 changes: 33 additions & 8 deletions dali/test/python/operator_2/test_enum_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,27 @@
from nvidia.dali import fn, pipeline_def, types

import numpy as np
import tree

from nose_utils import assert_raises
from nose2.tools import params

def test_enum_constant_capture():

@params(
*[
# Automatic promotion
lambda value, dtype: fn.copy(value),
# Explicit conversion to constant op
lambda value, dtype: types.Constant(value=value, dtype=dtype),
# Detection of type from value
lambda value, dtype: types.Constant(value=value),
# Explicit type when passed the underlying numeric value of the enum
lambda value, dtype: types.Constant(
value=tree.map_structure(lambda v: v.value, value), dtype=dtype
),
]
)
def test_enum_constant_capture(converter):
batch_size = 2

scalar_v = types.DALIDataType.INT16
Expand All @@ -29,24 +47,31 @@ def test_enum_constant_capture():

@pipeline_def(batch_size=batch_size, device_id=0, num_threads=4)
def enum_constant_pipe():
scalar = fn.copy(scalar_v)
tensor = fn.copy(list_v)
scalar = converter(scalar_v, types.DALIDataType.DATA_TYPE)
tensor = converter(list_v, types.DALIDataType.INTERP_TYPE)

scalar_as_int = fn.cast(scalar, dtype=types.DALIDataType.INT32)
tensor_as_int = fn.cast(tensor, dtype=types.DALIDataType.INT32)
return scalar, tensor, scalar_as_int, tensor_as_int

pipe = enum_constant_pipe()
pipe.build()
# Compare the cast values with Python values
scalar, tensor, scalar_as_int, tensor_as_int = pipe.run()
assert scalar.dtype == types.DALIDataType.DATA_TYPE
assert scalar.shape() == [()] * batch_size, f"{scalar.shape}"
assert tensor.dtype == types.DALIDataType.INTERP_TYPE
print(tensor.shape)
assert tensor.shape() == [(3,)] * batch_size
# Compare the cast values with Python values
for i in range(batch_size):
assert np.array_equal(np.array(scalar_as_int[i]), np.array(scalar_v.value))
assert np.array_equal(
np.array(tensor_as_int[i]), np.array([elem.value for elem in list_v], dtype=np.int32)
)
assert np.array_equal(np.array(tensor_as_int[i]), np.array([elem.value for elem in list_v]))
with assert_raises(
TypeError,
glob="DALI enum types cannot be used with buffer protocol*"
"use `nvidia.dali.fn.cast` to convert",
):
print(scalar)


def test_scalar_constant():
print(types.ScalarConstant(types.DALIDataType.INT16))
36 changes: 36 additions & 0 deletions dali/test/python/operator_2/test_random_choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,12 @@ def choice_pipe():
"Data type float is not supported for 0D inputs. Supported types are: "
"uint8, uint16, uint32, uint64, int8, int16, int32, int64",
),
(
(types.DALIInterpType.INTERP_CUBIC,),
{},
"Data type DALIInterpType is not supported for 0D inputs. Supported types are: "
"uint8, uint16, uint32, uint64, int8, int16, int32, int64",
),
(
(5,),
{"p": np.array([0.25, 0.5, 0.25])},
Expand All @@ -270,3 +276,33 @@ def choice_pipe():
pipe = choice_pipe()
pipe.build()
pipe.run()


def test_enum_choice():
batch_size = 8

interps_to_sample = [types.DALIInterpType.INTERP_LINEAR, types.DALIInterpType.INTERP_CUBIC]

@pipeline_def(batch_size=batch_size, device_id=0, num_threads=4)
def choice_pipeline():
interp = fn.random.choice(interps_to_sample, shape=[100])
interp_as_int = fn.cast(interp, dtype=types.INT32)
imgs = fn.resize(
fn.random.uniform(range=[0, 255], dtype=types.UINT8, shape=(100, 100, 3)),
size=(25, 25),
interp_type=interp[0],
)
return interp, interp_as_int, imgs

pipe = choice_pipeline()
pipe.build()
(interp, interp_as_int, imgs) = pipe.run()
assert interp.dtype == types.DALIDataType.INTERP_TYPE
for i in range(batch_size):
check_sample(
np.array(interp_as_int[i]),
size=(100,),
a=np.array([v.value for v in interps_to_sample]),
p=None,
idx=i,
)

0 comments on commit e3e9be0

Please sign in to comment.