diff --git a/dali/pipeline/operator/builtin/external_source.cc b/dali/pipeline/operator/builtin/external_source.cc index 20cc6f4cbe4..eeea0b77e88 100644 --- a/dali/pipeline/operator/builtin/external_source.cc +++ b/dali/pipeline/operator/builtin/external_source.cc @@ -92,6 +92,13 @@ The memory location must match the specified ``device`` parameter of the operato For the CPU, the provided memory can be one contiguous buffer or a list of contiguous Tensors. For the GPU, to avoid extra copy, the provided buffer must be contiguous. If you provide a list of separate Tensors, there will be an additional copy made internally, consuming both memory -and bandwidth.)code", false); +and bandwidth.)code", false) + .AddOptionalArg("dtype", R"code(Input data type. + +The operator will validate that the fetched data is of the provided type. +If the argument is omitted or ``DALIDataType.NO_TYPE`` is passed, the operator will infer +the type based on the provided data. + +This argument will be required starting from DALI 2.0.)code", DALI_NO_TYPE); } // namespace dali diff --git a/dali/pipeline/operator/builtin/external_source.h b/dali/pipeline/operator/builtin/external_source.h index dc36fdff7f6..8a599da8379 100644 --- a/dali/pipeline/operator/builtin/external_source.h +++ b/dali/pipeline/operator/builtin/external_source.h @@ -218,6 +218,8 @@ class ExternalSource : public Operator, virtual public BatchSizeProvide blocking_(spec.GetArgument("blocking")), no_copy_(spec.GetArgument("no_copy")), device_id_(spec.GetArgument("device_id")), + dtype_(spec.GetArgument("dtype")), + previous_dtype_(DALIDataType::DALI_NO_TYPE), sync_worker_(device_id_, false) { output_name_ = spec.Output(0); sync_worker_.WaitForInit(); @@ -502,6 +504,19 @@ class ExternalSource : public Operator, virtual public BatchSizeProvide OperatorBase::max_batch_size_ >= static_cast(batch.num_samples()), make_string("Data list provided to ExternalSource needs to have batch_size <= ", OperatorBase::max_batch_size_, ", found ", batch.num_samples(), " samples.")); + + DALI_ENFORCE( + dtype_ == DALI_NO_TYPE || dtype_ == batch.type(), + make_string("ExternalSource expected data of type ", TypeTable::GetTypeInfo(dtype_).name(), + " and got: ", batch.type_info().name())); + + DALI_ENFORCE(previous_dtype_ == DALI_NO_TYPE || previous_dtype_ == batch.type(), + make_string("Type of the data fed to the external source has changed from the " + "previous iteration. Type in the previous iteration was ", + TypeTable::GetTypeInfo(previous_dtype_).name(), + " and the current type is ", batch.type_info().name(), ".")); + previous_dtype_ = batch.type(); + // Note: If we create a GPU source, we will need to figure // out what stream we want to do this copy in. CPU we can // pass anything as it is ignored. @@ -537,6 +552,8 @@ class ExternalSource : public Operator, virtual public BatchSizeProvide bool blocking_ = true; bool no_copy_ = false; int device_id_; + DALIDataType dtype_; + DALIDataType previous_dtype_; /* * now it only indicates that there is data in the ExternalSource, in the future diff --git a/dali/python/nvidia/dali/external_source.py b/dali/python/nvidia/dali/external_source.py index 200a7eb169a..e3c0302946b 100644 --- a/dali/python/nvidia/dali/external_source.py +++ b/dali/python/nvidia/dali/external_source.py @@ -291,6 +291,15 @@ class ExternalSource(): output. If the list has fewer than ``num_outputs`` elements, only the first outputs have the layout set, the rest of the outputs don't have a layout set. +`dtype` : `nvidia.dali.types.DALIDataType` or list/tuple thereof, optional + Input data type. + + The operator will validate that the fetched data is of the provided type. + If the argument is omitted or :const:`DALIDataType.NO_TYPE` is passed, the operator will infer + the type from the provided data. + + This argument will be required starting from DALI 2.0. + `cuda_stream` : optional, ``cudaStream_t`` or an object convertible to ``cudaStream_t``, such as ``cupy.cuda.Stream`` or ``torch.cuda.Stream`` The CUDA stream is used to copy data to the GPU or from a GPU source. @@ -385,13 +394,14 @@ class ExternalSource(): """ def __init__( - self, source=None, num_outputs=None, *, cycle=None, layout=None, name=None, + self, source=None, num_outputs=None, *, cycle=None, layout=None, dtype=None, name=None, device="cpu", cuda_stream=None, use_copy_kernel=None, batch=None, parallel=None, no_copy=None, prefetch_queue_depth=None, batch_info=None, **kwargs): self._schema = _b.GetSchema("ExternalSource") self._spec = _b.OpSpec("ExternalSource") self._device = device self._layout = layout + self._dtype = dtype self._cuda_stream = cuda_stream self._use_copy_kernel = use_copy_kernel @@ -414,6 +424,8 @@ def __init__( self._batch_info = batch_info self._spec.AddArg("device", device) + if dtype is not None: + self._spec.AddArg("dtype", dtype) for key, value in kwargs.items(): self._spec.AddArg(key, value) @@ -434,7 +446,7 @@ def preserve(self): return False def __call__( - self, *, source=None, cycle=None, name=None, layout=None, cuda_stream=None, + self, *, source=None, cycle=None, name=None, layout=None, dtype=None, cuda_stream=None, use_copy_kernel=None, batch=None, parallel=None, no_copy=None, prefetch_queue_depth=None, batch_info=None, **kwargs): "" @@ -525,6 +537,12 @@ def __call__( else: layout = self._layout + if self._dtype is not None: + if dtype is not None: + raise RuntimeError("``dtype`` already specified in constructor.") + else: + dtype = self._dtype + if self._cuda_stream is not None: if cuda_stream is not None: raise RuntimeError("``cuda_stream`` already specified in constructor.") @@ -570,6 +588,16 @@ def __call__( op_instance._layout = layout else: op_instance._layout = None + + if dtype is not None: + if isinstance(dtype, (list, tuple)): + op_instance._dtype = dtype[i] if i < len(dtype) else nvidia.dali.types.DALIDataType.NO_TYPE + else: + op_instance._dtype = dtype + else: + op_instance._dtype = nvidia.dali.types.DALIDataType.NO_TYPE + + op_instance._batch = batch group.append(op_instance) @@ -588,6 +616,7 @@ def __call__( op_instance._group = _ExternalSourceGroup( callback, source_desc, False, [op_instance], **group_common_kwargs) op_instance._layout = layout + op_instance._dtype = dtype op_instance._batch = batch op_instance.generate_outputs() @@ -615,7 +644,7 @@ def _has_external_source(pipeline): def external_source(source = None, num_outputs = None, *, cycle = None, name = None, device = "cpu", layout = None, - cuda_stream = None, use_copy_kernel = None, batch = True, **kwargs): + dtype = None, cuda_stream = None, use_copy_kernel = None, batch = True, **kwargs): """Creates a data node which is populated with data from a Python source. The data can be provided by the ``source`` function or iterable, or it can be provided by ``pipeline.feed_input(name, data, layout, cuda_stream)`` inside ``pipeline.iter_setup``. @@ -645,7 +674,7 @@ def external_source(source = None, num_outputs = None, *, cycle = None, name = N "``external_source`` nodes.") op = ExternalSource(device = device, num_outputs = num_outputs, source = source, - cycle = cycle, layout = layout, cuda_stream = cuda_stream, + cycle = cycle, layout = layout, dtype = dtype, cuda_stream = cuda_stream, use_copy_kernel = use_copy_kernel, batch = batch, **kwargs) return op(name = name) diff --git a/dali/test/python/test_external_source_dali.py b/dali/test/python/test_external_source_dali.py index 662c4088080..8a11d3ecc5b 100644 --- a/dali/test/python/test_external_source_dali.py +++ b/dali/test/python/test_external_source_dali.py @@ -18,6 +18,7 @@ from nvidia.dali.pipeline import Pipeline from test_utils import check_batch from nose_utils import raises +from nvidia.dali.types import DALIDataType def build_src_pipe(device, layout = None): if layout is None: @@ -189,3 +190,43 @@ def test_epoch_idx(): yield _test_epoch_idx, batch_size, epoch_size, batch_cb, batch_info, True sample_cb = SampleCb(batch_size, epoch_size) yield _test_epoch_idx, batch_size, epoch_size, sample_cb, None, False + + +def test_dtype_arg(): + batch_size = 2 + src_data = [ + [np.ones((120, 120, 3), dtype=np.uint8)]*batch_size + ] + src_pipe = Pipeline(batch_size, 1, 0) + src_ext = fn.external_source(source=src_data, device='cpu', dtype=DALIDataType.UINT8) + src_pipe.set_outputs(src_ext) + src_pipe.build() + src_pipe.run() + + +@raises(RuntimeError, glob="ExternalSource expected data of type uint8 and got: float") +def test_incorrect_dtype_arg(): + batch_size = 2 + src_data = [ + [np.ones((120, 120, 3), dtype=np.float32)]*batch_size + ] + src_pipe = Pipeline(batch_size, 1, 0) + src_ext = fn.external_source(source=src_data, device='cpu', dtype=DALIDataType.UINT8) + src_pipe.set_outputs(src_ext) + src_pipe.build() + src_pipe.run() + +@raises(RuntimeError, glob="Type of the data fed to the external source has changed from the previous iteration. " + "Type in the previous iteration was float and the current type is uint8.") +def test_changing_dtype(): + batch_size = 2 + src_data = [ + [np.ones((120, 120, 3), dtype=np.float32)]*batch_size, + [np.ones((120, 120, 3), dtype=np.uint8)]*batch_size + ] + src_pipe = Pipeline(batch_size, 1, 0) + src_ext = fn.external_source(source=src_data, device='cpu') + src_pipe.set_outputs(src_ext) + src_pipe.build() + src_pipe.run() + src_pipe.run()