Skip to content

Commit

Permalink
Code review fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Joaquin Anton Guirao <[email protected]>
  • Loading branch information
jantonguirao committed Dec 9, 2024
1 parent a89b147 commit 8f55c0b
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 155 deletions.
48 changes: 1 addition & 47 deletions dali/python/nvidia/dali/plugin/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,59 +34,13 @@
from . import fn # noqa: F401
from . import proxy # noqa: F401

from nvidia.dali.plugin.pytorch.torch_utils import to_torch_type
from nvidia.dali.plugin.pytorch.torch_utils import to_torch_type, feed_ndarray
from nvidia.dali.plugin.pytorch._torch_function import TorchPythonFunction as TorchPythonFunction

_internal._adjust_operator_module(TorchPythonFunction, sys.modules[__name__], [])

ops._wrap_op(TorchPythonFunction, "fn", __name__)


def feed_ndarray(
dali_tensor: Union[TensorCPU, TensorGPU, TensorListCPU, TensorListGPU],
arr: torch.Tensor,
cuda_stream: Union[torch.cuda.Stream, Any, None] = None,
) -> torch.Tensor:
"""
Copy contents of DALI tensor to PyTorch's Tensor.
Parameters
----------
dali_tensor : nvidia.dali.backend.TensorCPU or nvidia.dali.backend.TensorGPU
Tensor from which to copy
arr : torch.Tensor
Destination of the copy
cuda_stream : torch.cuda.Stream, cudaStream_t or any value that can be cast to cudaStream_t.
CUDA stream to be used for the copy
(if not provided, an internal user stream will be selected)
In most cases, using pytorch's current stream is expected (for example,
if we are copying to a tensor allocated with torch.zeros(...))
"""
dali_type = to_torch_type[dali_tensor.dtype]

assert dali_type == arr.dtype, (
"The element type of DALI Tensor/TensorList"
" doesn't match the element type of the target PyTorch Tensor: "
"{} vs {}".format(dali_type, arr.dtype)
)
assert dali_tensor.shape() == list(
arr.size()
), "Shapes do not match: DALI tensor has size {0}, but PyTorch Tensor has size {1}".format(
dali_tensor.shape(), list(arr.size())
)

non_blocking = cuda_stream is not None
cuda_stream = types._raw_cuda_stream_ptr(cuda_stream)

# turn raw int to a c void pointer
c_type_pointer = ctypes.c_void_p(arr.data_ptr())
if isinstance(dali_tensor, (TensorGPU, TensorListGPU)):
dali_tensor.copy_to_external(c_type_pointer, cuda_stream, non_blocking=non_blocking)
else:
dali_tensor.copy_to_external(c_type_pointer)
return arr


class DALIGenericIterator(_DaliBaseIterator):
"""
General DALI iterator for PyTorch. It can return any number of
Expand Down
88 changes: 45 additions & 43 deletions dali/python/nvidia/dali/plugin/pytorch/proxy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
import torch
import torch.multiprocessing as mp
from torch.utils import data as torchdata
from torch.utils.data._utils.collate import collate
from nvidia.dali.backend import TensorGPU, TensorListCPU, TensorListGPU
from torch.utils.data._utils.collate import default_collate_fn_map
from nvidia.dali.backend import TensorGPU
from nvidia.dali import types, Pipeline
from nvidia.dali.external_source import ExternalSource
import threading
from queue import Empty
from nvidia.dali.plugin.pytorch.torch_utils import to_torch_tensor

from nvidia.dali.plugin.pytorch.torch_utils import to_torch_type, feed_ndarray
import tree
import warnings

def _external_source_node_names(pipeline):
if not pipeline._py_graph_built:
Expand Down Expand Up @@ -150,13 +151,31 @@ def next_outputs(self):
outputs = self.pipe.outputs()
torch.cuda.nvtx.range_pop()

# Return information about the iteration, together with the data
processed_outputs = tuple(
[to_torch_tensor(output, device_id=self.pipe.device_id) for output in outputs]
)
is_exec_dynamic = self.pipe.exec_dynamic
if not is_exec_dynamic:
processed_outputs = []
for output in outputs:
tensor = output.as_tensor()
torch_dtype = to_torch_type[tensor.dtype]
if isinstance(tensor, TensorGPU):
torch_device = torch.device("cuda", self.pipe.device_id)
else:
torch_device = torch.device("cpu")
processed_output = torch.empty(
tensor.shape(),
dtype=torch_dtype,
device=torch_device,
)
cuda_stream = torch.cuda.current_stream(device=torch_device) if isinstance(tensor, TensorGPU) else None
feed_ndarray(tensor, processed_output, cuda_stream=cuda_stream)
processed_outputs.append(processed_output)
processed_outputs = tuple(processed_outputs)
else:
processed_outputs = tuple([torch.from_dlpack(output.as_tensor()) for output in outputs])
return (info, processed_outputs)

def get_outputs(self, req_info):
def get_outputs(self, pipe_out_ref: DALIPipelineOutputRef):
req_info = pipe_out_ref.info
req_outputs = None
# If the data was already read, just return it (and clear the cache entry)
if req_info in self.cache_outputs:
Expand Down Expand Up @@ -235,7 +254,7 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, tb):
self.stop_thread()
if exc_type is not None:
print(f"An exception occurred: {exc_value}")
warnings.warn(f"An exception occurred: {exc_value}", category=UserWarning)
return False # Return False to propagate exceptions


Expand All @@ -251,17 +270,12 @@ def _collate_pipeline_run_ref_fn(pipe_out, *, collate_fn_map=None):
assert proxy == elem.proxy
for idx, input_ref in enumerate(elem.inputs):
inputs[idx].append(input_ref)
return proxy.schedule_batch(inputs)
ret = proxy.schedule_batch(inputs)
return ret


def _custom_collate(batch):
"""
Subscribe a special collate function for PipelineRunRef, that handles the scheduling
of the iteration on the fly
"""
collate_fn_map = torchdata._utils.collate.default_collate_fn_map
collate_fn_map.update({PipelineRunRef: _collate_pipeline_run_ref_fn})
return collate(batch, collate_fn_map=collate_fn_map)
# In-place modify `default_collate_fn_map` to handle PipelineRunRef
default_collate_fn_map.update({PipelineRunRef: _collate_pipeline_run_ref_fn})


class DataLoader(torchdata.dataloader.DataLoader):
Expand All @@ -270,47 +284,35 @@ class DataLoader(torchdata.dataloader.DataLoader):
processing asynchronously with regards to the training.
"""

class DALIMultiProcessingDataLoaderIter(torchdata.dataloader._MultiProcessingDataLoaderIter):
class _Iter(torchdata.dataloader._MultiProcessingDataLoaderIter):
"""
Data loader iterator used by the DALI proxy data loader
"""

def __init__(self, loader):
super().__init__(loader)
self.loader = loader

def _next_data(self):
data = super()._next_data()
if not hasattr(data, "__iter__"):
print(
warnings.warn(
"Warning: Non iterable returned from dataloader. Please "
" review the code, since it usually indicates a bug in the pipeline."
" review the code, since it usually indicates a bug in the pipeline.",
category=UserWarning
)
data = [data]
for data_idx, data_elem in enumerate(data):
# If loader returns a dictionary the iterator iterates over its keys.
# We need to access a value. Probably need to address more casess.
if isinstance(data, dict):
if isinstance(data[data_elem], DALIPipelineOutputRef):
data[data_elem] = self.loader.dali_server.get_outputs(data[data_elem].info)
elif isinstance(data_elem, DALIPipelineOutputRef):
data[data_idx] = self.loader.dali_server.get_outputs(data_elem.info)
if self.loader.dali_server.thread is None:
raise RuntimeError("DALI server is not running")
data = tree.map_structure(
lambda entry:
self.loader.dali_server.get_outputs(entry) if isinstance(entry, DALIPipelineOutputRef) else entry,
data
)
return data

def __init__(self, dali_server, *args, **kwargs):
if "collate_fn" in kwargs and kwargs["collate_fn"] is not None:
print(
"Warning: Make sure to handle PipelineRunRef when providing"
" a custom collate_fn (see collate_pipeline_run_ref_fn)"
)
else:
kwargs["collate_fn"] = _custom_collate
super().__init__(*args, **kwargs)
self.dali_server = dali_server

def _get_iterator(self) -> "_BaseDataLoaderIter":
if self.num_workers == 0:
return torchdata.dataloader._SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return DataLoader.DALIMultiProcessingDataLoaderIter(self)
return DataLoader._Iter(self)
58 changes: 30 additions & 28 deletions dali/python/nvidia/dali/plugin/pytorch/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
import torch
import ctypes
from nvidia.dali import types
from nvidia.dali.tensors import TensorListGPU, TensorListCPU, TensorGPU
from nvidia.dali.tensors import TensorListGPU, TensorListCPU, TensorGPU, TensorCPU
from typing import Union, Optional

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'Optional' is not used.
from typing import Any, Dict, List

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'Dict' is not used.
Import of 'List' is not used.

to_torch_type = {
types.DALIDataType.FLOAT: torch.float32,
Expand All @@ -30,46 +32,46 @@
}


def to_torch_tensor(tensor_or_tl, device_id=0):
def feed_ndarray(
dali_tensor: Union[TensorCPU, TensorGPU, TensorListCPU, TensorListGPU],
arr: torch.Tensor,
cuda_stream: Union[torch.cuda.Stream, Any, None] = None,
) -> torch.Tensor:
"""
Copy contents of DALI tensor to PyTorch's Tensor.
Parameters
----------
`tensor_or_tl` : TensorGPU or TensorListGPU
`arr` : torch.Tensor
dali_tensor : nvidia.dali.backend.TensorCPU or nvidia.dali.backend.TensorGPU
Tensor from which to copy
arr : torch.Tensor
Destination of the copy
`cuda_stream` : torch.cuda.Stream, cudaStream_t or any value that can be cast to cudaStream_t.
cuda_stream : torch.cuda.Stream, cudaStream_t or any value that can be cast to cudaStream_t.
CUDA stream to be used for the copy
(if not provided, an internal user stream will be selected)
In most cases, using pytorch's current stream is expected (for example,
if we are copying to a tensor allocated with torch.zeros(...))
"""
if isinstance(tensor_or_tl, (TensorListGPU, TensorListCPU)):
dali_tensor = tensor_or_tl.as_tensor()
else:
dali_tensor = tensor_or_tl

if isinstance(dali_tensor, (TensorGPU)):
torch_device = torch.device("cuda", device_id)
else:
torch_device = torch.device("cpu")
dali_type = to_torch_type[dali_tensor.dtype]

out_torch = torch.empty(
dali_tensor.shape(),
dtype=to_torch_type[dali_tensor.dtype],
device=torch_device,
assert dali_type == arr.dtype, (
"The element type of DALI Tensor/TensorList"
" doesn't match the element type of the target PyTorch Tensor: "
"{} vs {}".format(dali_type, arr.dtype)
)
assert dali_tensor.shape() == list(
arr.size()
), "Shapes do not match: DALI tensor has size {0}, but PyTorch Tensor has size {1}".format(
dali_tensor.shape(), list(arr.size())
)

non_blocking = cuda_stream is not None
cuda_stream = types._raw_cuda_stream_ptr(cuda_stream)

# turn raw int to a c void pointer
c_type_pointer = ctypes.c_void_p(out_torch.data_ptr())
if isinstance(dali_tensor, (TensorGPU)):
non_blocking = True
cuda_stream = torch.cuda.current_stream(device=torch_device)
cuda_stream = types._raw_cuda_stream(cuda_stream)
stream = None if cuda_stream is None else ctypes.c_void_p(cuda_stream)
tensor_or_tl.copy_to_external(c_type_pointer, stream, non_blocking)
c_type_pointer = ctypes.c_void_p(arr.data_ptr())
if isinstance(dali_tensor, (TensorGPU, TensorListGPU)):
dali_tensor.copy_to_external(c_type_pointer, cuda_stream, non_blocking=non_blocking)
else:
tensor_or_tl.copy_to_external(c_type_pointer)

return out_torch
dali_tensor.copy_to_external(c_type_pointer)
return arr
Loading

0 comments on commit 8f55c0b

Please sign in to comment.