Skip to content

Commit

Permalink
Add dtype argument and check dtype consistency in ExternalSource (NVI…
Browse files Browse the repository at this point in the history
…DIA#3562)

Add dtype argument to the external source. 
Add error if data type fed to the external source changes from iteration to iteration.
Signed-off-by: Rafal <[email protected]>
  • Loading branch information
banasraf authored and cyyever committed Jan 23, 2022
1 parent 0de9ac5 commit 35d95ec
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 5 deletions.
9 changes: 8 additions & 1 deletion dali/pipeline/operator/builtin/external_source.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ The memory location must match the specified ``device`` parameter of the operato
For the CPU, the provided memory can be one contiguous buffer or a list of contiguous Tensors.
For the GPU, to avoid extra copy, the provided buffer must be contiguous. If you provide a list
of separate Tensors, there will be an additional copy made internally, consuming both memory
and bandwidth.)code", false);
and bandwidth.)code", false)
.AddOptionalArg("dtype", R"code(Input data type.
The operator will validate that the fetched data is of the provided type.
If the argument is omitted or ``DALIDataType.NO_TYPE`` is passed, the operator will infer
the type based on the provided data.
This argument will be required starting from DALI 2.0.)code", DALI_NO_TYPE);

} // namespace dali
17 changes: 17 additions & 0 deletions dali/pipeline/operator/builtin/external_source.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ class ExternalSource : public Operator<Backend>, virtual public BatchSizeProvide
blocking_(spec.GetArgument<bool>("blocking")),
no_copy_(spec.GetArgument<bool>("no_copy")),
device_id_(spec.GetArgument<int>("device_id")),
dtype_(spec.GetArgument<DALIDataType>("dtype")),
previous_dtype_(DALIDataType::DALI_NO_TYPE),
sync_worker_(device_id_, false) {
output_name_ = spec.Output(0);
sync_worker_.WaitForInit();
Expand Down Expand Up @@ -502,6 +504,19 @@ class ExternalSource : public Operator<Backend>, virtual public BatchSizeProvide
OperatorBase::max_batch_size_ >= static_cast<int>(batch.num_samples()),
make_string("Data list provided to ExternalSource needs to have batch_size <= ",
OperatorBase::max_batch_size_, ", found ", batch.num_samples(), " samples."));

DALI_ENFORCE(
dtype_ == DALI_NO_TYPE || dtype_ == batch.type(),
make_string("ExternalSource expected data of type ", TypeTable::GetTypeInfo(dtype_).name(),
" and got: ", batch.type_info().name()));

DALI_ENFORCE(previous_dtype_ == DALI_NO_TYPE || previous_dtype_ == batch.type(),
make_string("Type of the data fed to the external source has changed from the "
"previous iteration. Type in the previous iteration was ",
TypeTable::GetTypeInfo(previous_dtype_).name(),
" and the current type is ", batch.type_info().name(), "."));
previous_dtype_ = batch.type();

// Note: If we create a GPU source, we will need to figure
// out what stream we want to do this copy in. CPU we can
// pass anything as it is ignored.
Expand Down Expand Up @@ -537,6 +552,8 @@ class ExternalSource : public Operator<Backend>, virtual public BatchSizeProvide
bool blocking_ = true;
bool no_copy_ = false;
int device_id_;
DALIDataType dtype_;
DALIDataType previous_dtype_;

/*
* now it only indicates that there is data in the ExternalSource, in the future
Expand Down
37 changes: 33 additions & 4 deletions dali/python/nvidia/dali/external_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,15 @@ class ExternalSource():
output. If the list has fewer than ``num_outputs`` elements, only the first
outputs have the layout set, the rest of the outputs don't have a layout set.
`dtype` : `nvidia.dali.types.DALIDataType` or list/tuple thereof, optional
Input data type.
The operator will validate that the fetched data is of the provided type.
If the argument is omitted or :const:`DALIDataType.NO_TYPE` is passed, the operator will infer
the type from the provided data.
This argument will be required starting from DALI 2.0.
`cuda_stream` : optional, ``cudaStream_t`` or an object convertible to ``cudaStream_t``, such as ``cupy.cuda.Stream`` or ``torch.cuda.Stream``
The CUDA stream is used to copy data to the GPU or from a GPU source.
Expand Down Expand Up @@ -385,13 +394,14 @@ class ExternalSource():
"""

def __init__(
self, source=None, num_outputs=None, *, cycle=None, layout=None, name=None,
self, source=None, num_outputs=None, *, cycle=None, layout=None, dtype=None, name=None,
device="cpu", cuda_stream=None, use_copy_kernel=None, batch=None, parallel=None,
no_copy=None, prefetch_queue_depth=None, batch_info=None, **kwargs):
self._schema = _b.GetSchema("ExternalSource")
self._spec = _b.OpSpec("ExternalSource")
self._device = device
self._layout = layout
self._dtype = dtype
self._cuda_stream = cuda_stream
self._use_copy_kernel = use_copy_kernel

Expand All @@ -414,6 +424,8 @@ def __init__(
self._batch_info = batch_info

self._spec.AddArg("device", device)
if dtype is not None:
self._spec.AddArg("dtype", dtype)
for key, value in kwargs.items():
self._spec.AddArg(key, value)

Expand All @@ -434,7 +446,7 @@ def preserve(self):
return False

def __call__(
self, *, source=None, cycle=None, name=None, layout=None, cuda_stream=None,
self, *, source=None, cycle=None, name=None, layout=None, dtype=None, cuda_stream=None,
use_copy_kernel=None, batch=None, parallel=None, no_copy=None,
prefetch_queue_depth=None, batch_info=None, **kwargs):
""
Expand Down Expand Up @@ -525,6 +537,12 @@ def __call__(
else:
layout = self._layout

if self._dtype is not None:
if dtype is not None:
raise RuntimeError("``dtype`` already specified in constructor.")
else:
dtype = self._dtype

if self._cuda_stream is not None:
if cuda_stream is not None:
raise RuntimeError("``cuda_stream`` already specified in constructor.")
Expand Down Expand Up @@ -570,6 +588,16 @@ def __call__(
op_instance._layout = layout
else:
op_instance._layout = None

if dtype is not None:
if isinstance(dtype, (list, tuple)):
op_instance._dtype = dtype[i] if i < len(dtype) else nvidia.dali.types.DALIDataType.NO_TYPE
else:
op_instance._dtype = dtype
else:
op_instance._dtype = nvidia.dali.types.DALIDataType.NO_TYPE


op_instance._batch = batch

group.append(op_instance)
Expand All @@ -588,6 +616,7 @@ def __call__(
op_instance._group = _ExternalSourceGroup(
callback, source_desc, False, [op_instance], **group_common_kwargs)
op_instance._layout = layout
op_instance._dtype = dtype
op_instance._batch = batch
op_instance.generate_outputs()

Expand Down Expand Up @@ -615,7 +644,7 @@ def _has_external_source(pipeline):


def external_source(source = None, num_outputs = None, *, cycle = None, name = None, device = "cpu", layout = None,
cuda_stream = None, use_copy_kernel = None, batch = True, **kwargs):
dtype = None, cuda_stream = None, use_copy_kernel = None, batch = True, **kwargs):
"""Creates a data node which is populated with data from a Python source.
The data can be provided by the ``source`` function or iterable, or it can be provided by
``pipeline.feed_input(name, data, layout, cuda_stream)`` inside ``pipeline.iter_setup``.
Expand Down Expand Up @@ -645,7 +674,7 @@ def external_source(source = None, num_outputs = None, *, cycle = None, name = N
"``external_source`` nodes.")

op = ExternalSource(device = device, num_outputs = num_outputs, source = source,
cycle = cycle, layout = layout, cuda_stream = cuda_stream,
cycle = cycle, layout = layout, dtype = dtype, cuda_stream = cuda_stream,
use_copy_kernel = use_copy_kernel, batch = batch, **kwargs)
return op(name = name)

Expand Down
41 changes: 41 additions & 0 deletions dali/test/python/test_external_source_dali.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from nvidia.dali.pipeline import Pipeline
from test_utils import check_batch
from nose_utils import raises
from nvidia.dali.types import DALIDataType

def build_src_pipe(device, layout = None):
if layout is None:
Expand Down Expand Up @@ -189,3 +190,43 @@ def test_epoch_idx():
yield _test_epoch_idx, batch_size, epoch_size, batch_cb, batch_info, True
sample_cb = SampleCb(batch_size, epoch_size)
yield _test_epoch_idx, batch_size, epoch_size, sample_cb, None, False


def test_dtype_arg():
batch_size = 2
src_data = [
[np.ones((120, 120, 3), dtype=np.uint8)]*batch_size
]
src_pipe = Pipeline(batch_size, 1, 0)
src_ext = fn.external_source(source=src_data, device='cpu', dtype=DALIDataType.UINT8)
src_pipe.set_outputs(src_ext)
src_pipe.build()
src_pipe.run()


@raises(RuntimeError, glob="ExternalSource expected data of type uint8 and got: float")
def test_incorrect_dtype_arg():
batch_size = 2
src_data = [
[np.ones((120, 120, 3), dtype=np.float32)]*batch_size
]
src_pipe = Pipeline(batch_size, 1, 0)
src_ext = fn.external_source(source=src_data, device='cpu', dtype=DALIDataType.UINT8)
src_pipe.set_outputs(src_ext)
src_pipe.build()
src_pipe.run()

@raises(RuntimeError, glob="Type of the data fed to the external source has changed from the previous iteration. "
"Type in the previous iteration was float and the current type is uint8.")
def test_changing_dtype():
batch_size = 2
src_data = [
[np.ones((120, 120, 3), dtype=np.float32)]*batch_size,
[np.ones((120, 120, 3), dtype=np.uint8)]*batch_size
]
src_pipe = Pipeline(batch_size, 1, 0)
src_ext = fn.external_source(source=src_data, device='cpu')
src_pipe.set_outputs(src_ext)
src_pipe.build()
src_pipe.run()
src_pipe.run()

0 comments on commit 35d95ec

Please sign in to comment.