Skip to content

Commit

Permalink
Improve support for DALI enum types
Browse files Browse the repository at this point in the history
Support for DALI enum types (DALIDataType, DALIImageType, DALIInterpType)
is added to Constant, Cast, Choice and Copy operators.

Thanks to Constant op support, followin syntax will work now:
```
fn.random.choice([types.DALIInterpType.INTERP_LINEAR, types.DALIInterpType.INTERP_NN])
```
allowing convenient selection of enum parameters.

Casting support is allowed only between non-fp types and enums.

Explicit error is raised when the enums are used with buffer protcol
(conversion to numpy, printing etc) - Python expectes pointer-to-object
representation there, while we return them as C-style enums with numeric
value under the hood.
As an alternative we can allow to just access the underlying data but I
have chosen a bit more restricive approach for now.

Signed-off-by: Krzysztof Lecki <[email protected]>
  • Loading branch information
klecki committed Apr 10, 2024
1 parent dedcfae commit 56c8bba
Show file tree
Hide file tree
Showing 11 changed files with 200 additions and 41 deletions.
57 changes: 32 additions & 25 deletions dali/kernels/common/cast_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "dali/kernels/common/utils.h"
#include "dali/kernels/kernel.h"
#include "dali/kernels/dynamic_scratchpad.h"
#include "dali/pipeline/data/types.h"

namespace dali {
namespace kernels {
Expand Down Expand Up @@ -102,32 +103,38 @@ void CastGPU<Out, In>::Run(KernelContext &ctx,

#define INSTANTIATE_IMPL(Out, In) template struct DLL_PUBLIC CastGPU<Out, In>;

#define INSTANTIATE_FOREACH_INTYPE(Out) \
INSTANTIATE_IMPL(Out, bool); \
INSTANTIATE_IMPL(Out, uint8_t); \
INSTANTIATE_IMPL(Out, uint16_t); \
INSTANTIATE_IMPL(Out, uint32_t); \
INSTANTIATE_IMPL(Out, uint64_t); \
INSTANTIATE_IMPL(Out, int8_t); \
INSTANTIATE_IMPL(Out, int16_t); \
INSTANTIATE_IMPL(Out, int32_t); \
INSTANTIATE_IMPL(Out, int64_t); \
INSTANTIATE_IMPL(Out, float); \
INSTANTIATE_IMPL(Out, double); \
INSTANTIATE_IMPL(Out, dali::float16);

INSTANTIATE_FOREACH_INTYPE(bool); \
INSTANTIATE_FOREACH_INTYPE(uint8_t); \
INSTANTIATE_FOREACH_INTYPE(uint16_t); \
INSTANTIATE_FOREACH_INTYPE(uint32_t); \
INSTANTIATE_FOREACH_INTYPE(uint64_t); \
INSTANTIATE_FOREACH_INTYPE(int8_t); \
INSTANTIATE_FOREACH_INTYPE(int16_t); \
INSTANTIATE_FOREACH_INTYPE(int32_t); \
INSTANTIATE_FOREACH_INTYPE(int64_t); \
INSTANTIATE_FOREACH_INTYPE(float); \
INSTANTIATE_FOREACH_INTYPE(double); \
#define INSTANTIATE_FOREACH_INTYPE(Out) \
INSTANTIATE_IMPL(Out, bool); \
INSTANTIATE_IMPL(Out, uint8_t); \
INSTANTIATE_IMPL(Out, uint16_t); \
INSTANTIATE_IMPL(Out, uint32_t); \
INSTANTIATE_IMPL(Out, uint64_t); \
INSTANTIATE_IMPL(Out, int8_t); \
INSTANTIATE_IMPL(Out, int16_t); \
INSTANTIATE_IMPL(Out, int32_t); \
INSTANTIATE_IMPL(Out, int64_t); \
INSTANTIATE_IMPL(Out, float); \
INSTANTIATE_IMPL(Out, double); \
INSTANTIATE_IMPL(Out, dali::float16); \
INSTANTIATE_IMPL(Out, DALIDataType); \
INSTANTIATE_IMPL(Out, DALIImageType); \
INSTANTIATE_IMPL(Out, DALIInterpType);

INSTANTIATE_FOREACH_INTYPE(bool);
INSTANTIATE_FOREACH_INTYPE(uint8_t);
INSTANTIATE_FOREACH_INTYPE(uint16_t);
INSTANTIATE_FOREACH_INTYPE(uint32_t);
INSTANTIATE_FOREACH_INTYPE(uint64_t);
INSTANTIATE_FOREACH_INTYPE(int8_t);
INSTANTIATE_FOREACH_INTYPE(int16_t);
INSTANTIATE_FOREACH_INTYPE(int32_t);
INSTANTIATE_FOREACH_INTYPE(int64_t);
INSTANTIATE_FOREACH_INTYPE(float);
INSTANTIATE_FOREACH_INTYPE(double);
INSTANTIATE_FOREACH_INTYPE(dali::float16);
INSTANTIATE_FOREACH_INTYPE(DALIDataType);
INSTANTIATE_FOREACH_INTYPE(DALIImageType);
INSTANTIATE_FOREACH_INTYPE(DALIInterpType);

} // namespace cast
} // namespace kernels
Expand Down
8 changes: 7 additions & 1 deletion dali/operators/generic/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@

#include "dali/core/convert.h"
#include "dali/core/tensor_shape.h"
#include "dali/pipeline/data/types.h"
#include "dali/pipeline/operator/checkpointing/stateless_operator.h"

namespace dali {

#define CAST_ALLOWED_TYPES \
(bool, uint8_t, uint16_t, uint32_t, uint64_t, int8_t, int16_t, int32_t, int64_t, float16, float, \
double)
double, DALIDataType, DALIImageType, DALIInterpType)

template <typename Backend>
class Cast : public StatelessOperator<Backend> {
Expand Down Expand Up @@ -54,6 +55,11 @@ class Cast : public StatelessOperator<Backend> {
bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override {
const auto &input = ws.Input<Backend>(0);
DALIDataType out_type = is_cast_like_ ? ws.GetInputDataType(1) : dtype_arg_;
DALI_ENFORCE(!(IsEnum(input.type()) && IsFloatingPoint(out_type) ||
IsEnum(out_type) && IsFloatingPoint(input.type())),
make_string("Cannot cast from ", input.type(), " to ", out_type,
". Enums can only participate in casts with integral types, "
"but not floating point types."));
output_desc.resize(1);
output_desc[0].shape = input.shape();
output_desc[0].type = out_type;
Expand Down
5 changes: 3 additions & 2 deletions dali/operators/generic/constant.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
#include "dali/core/tensor_view.h"
#include "dali/core/static_switch.h"

#define CONSTANT_OP_SUPPORTED_TYPES \
(bool, int8_t, uint8_t, int16_t, uint16_t, int32_t, uint32_t, int64_t, uint64_t, float, float16)
#define CONSTANT_OP_SUPPORTED_TYPES \
(bool, int8_t, uint8_t, int16_t, uint16_t, int32_t, uint32_t, int64_t, uint64_t, float, float16, \
DALIDataType, DALIImageType, DALIInterpType)

namespace dali {

Expand Down
2 changes: 1 addition & 1 deletion dali/operators/random/choice.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
uint8_t, uint16_t, uint32_t, uint64_t, int8_t, int16_t, int32_t, int64_t
#define DALI_CHOICE_1D_TYPES \
bool, uint8_t, uint16_t, uint32_t, uint64_t, int8_t, int16_t, int32_t, int64_t, float16, float, \
double
double, DALIDataType, DALIImageType, DALIInterpType

namespace dali {

Expand Down
3 changes: 3 additions & 0 deletions dali/operators/random/choice_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ a single value per sample is generated.
The type of the output matches the type of the input.
For scalar inputs, only integral types are supported, otherwise any type can be used.
The operator supports selection from an input containing elements of one of DALI enum types,
that is: :meth:`nvidia.dali.types.DALIDataType`, :meth:`nvidia.dali.types.DALIImageType`, or
:meth:`nvidia.dali.types.DALIInterpType`.
)code")
.NumInput(1, 2)
.InputDox(0, "a", "scalar or TensorList",
Expand Down
12 changes: 12 additions & 0 deletions dali/pipeline/data/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,18 @@ constexpr bool IsUnsigned(DALIDataType type) {
}
}


constexpr bool IsEnum(DALIDataType type) {
switch (type) {
case DALI_DATA_TYPE:
case DALI_IMAGE_TYPE:
case DALI_INTERP_TYPE:
return true;
default:
return false;
}
}

template <DALIDataType id>
struct id2type_helper;

Expand Down
3 changes: 3 additions & 0 deletions dali/python/nvidia/dali/_backend_enums.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class DALIDataType(Enum):
FLOAT64 = ...
BOOL = ...
STRING = ...
DATA_TYPE = ...
IMAGE_TYPE = ...
INTERP_TYPE = ...

class DALIImageType(Enum):
RGB = ...
Expand Down
32 changes: 31 additions & 1 deletion dali/python/nvidia/dali/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,16 +519,45 @@ def _type_from_value_or_list(v):
has_floats = False
has_ints = False
has_bools = False
has_enums = False
enum_type = None
for x in v:
if isinstance(x, float):
has_floats = True
elif isinstance(x, bool):
has_bools = True
elif isinstance(x, int):
has_ints = True
elif isinstance(x, (DALIDataType, DALIImageType, DALIInterpType)):
has_enums = True
enum_type = type(x)
break
else:
raise TypeError("Unexpected type: " + str(type(x)))

if has_enums:
for x in v:
if not isinstance(x, enum_type):
raise TypeError(
f"Expected all elements of the input to be the "
f"same enum type: `{enum_type.__name__}` but got `{type(x).__name__}` "
f"for one of the elements."
)

if has_enums:
if issubclass(enum_type, DALIDataType):
return DALIDataType.DATA_TYPE
elif issubclass(enum_type, DALIImageType):
return DALIDataType.IMAGE_TYPE
elif issubclass(enum_type, DALIInterpType):
return DALIDataType.INTERP_TYPE
else:
raise TypeError(
f"Unexpected enum type: `{enum_type.__name__}`, expected one of: "
"`nvidia.dali.types.DALIDataType`, `nvidia.dali.types.DALIImageType`, "
"or `nvidia.dali.types.DALIInterpType`."
)

if has_floats:
return DALIDataType.FLOAT
if has_ints:
Expand Down Expand Up @@ -582,7 +611,8 @@ def Constant(value, dtype=None, shape=None, layout=None, device=None, **kwargs):
Args
----
value: `bool`, `int`, `float`, a `list` or `tuple` thereof or a `numpy.ndarray`
value: `bool`, `int`, `float`, `DALIDataType` `DALIImageType`, `DALIInterpType`,
a `list` or `tuple` thereof or a `numpy.ndarray`
The constant value to wrap. If it is a scalar, it can be used as scalar
value in mathematical expressions. Otherwise, it will produce a constant
tensor node (optionally reshaped according to `shape` argument).
Expand Down
14 changes: 14 additions & 0 deletions dali/util/pybind.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <string>
#include "dali/pipeline/data/types.h"
#include "dali/pipeline/data/dltensor.h"
#include "dali/pipeline/operator/error_reporting.h"

namespace dali {

Expand Down Expand Up @@ -54,6 +55,19 @@ static std::string FormatStrFromType(DALIDataType type) {
return "=d";
case DALI_BOOL:
return "=?";
case DALI_DATA_TYPE:
case DALI_IMAGE_TYPE:
case DALI_INTERP_TYPE:
throw DaliTypeError(
"DALI enum types cannot be used with buffer protocol "
"when they are returned as Tensors or TensorLists from DALI pipeline."
"You can use `nvidia.dali.fn.cast` to convert those values to an integral type.");
// As an alternative, to allow the usage of tensors containing DALI enums (printing, use with
// buffer protocol, numpy conversion etc), we can return format specifier for the underlying
// type here. This would allow access to the actual numeric values, for example:
// case DALI_DATA_TYPE:
// return
// FormatStrFromType(TypeTable::GetTypeInfo<std::underlying_type_t<DALIDataType>>().id());
default:
break;
}
Expand Down
3 changes: 2 additions & 1 deletion docs/data_types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,10 @@ DALIDataType
:member-order: bysource
:exclude-members: name

.. autofunction:: to_numpy_type

DALIIterpType
^^^^^^^^^^^^^
.. autofunction:: to_numpy_type
.. autoenum:: DALIInterpType
:members:
:undoc-members:
Expand Down
Loading

0 comments on commit 56c8bba

Please sign in to comment.