Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
Signed-off-by: Joaquin Anton Guirao <[email protected]>
  • Loading branch information
jantonguirao committed Dec 5, 2024
1 parent 49ad5a1 commit a89b147
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 13 deletions.
30 changes: 17 additions & 13 deletions dali/python/nvidia/dali/plugin/pytorch/proxy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from nvidia.dali.backend import TensorGPU, TensorListCPU, TensorListGPU
from nvidia.dali import types, Pipeline
from nvidia.dali.external_source import ExternalSource
import ctypes
import threading
from queue import Empty
from nvidia.dali.plugin.pytorch.torch_utils import to_torch_tensor
Expand Down Expand Up @@ -60,8 +59,7 @@ def __init__(self, proxy, inputs):
f"Unexpected number of inputs. Expected: {self.input_names}, got: {inputs}"
)


class DALIProxy:
class _DALIProxy:
def __init__(self, input_names, send_q):
self.input_names = input_names
# Shared queue with the server
Expand All @@ -75,7 +73,8 @@ def __init__(self, input_names, send_q):
def worker_id(self):
"""Getter for 'worker_id'"""
if self._worker_id is None:
self._worker_id = torchdata.get_worker_info().id
worker_info = torchdata.get_worker_info()
self._worker_id = worker_info.id if worker_info else threading.get_ident()
return self._worker_id

def schedule_batch(self, inputs):
Expand All @@ -98,8 +97,8 @@ def __call__(self, *inputs):
)
return PipelineRunRef(self, inputs)


class DALIServer:

def __init__(self, pipeline, input_names=None):
"""
Initializes a new DALI proxy instance.
Expand All @@ -110,6 +109,7 @@ def __init__(self, pipeline, input_names=None):
"""
assert isinstance(pipeline, Pipeline), f"Expected an NVIDIA DALI pipeline, got: {pipeline}"
self.pipe = pipeline

self.pipe_input_names = _external_source_node_names(self.pipe)
if len(self.pipe_input_names) == 0:
raise RuntimeError("The provided pipeline doesn't have any inputs")
Expand All @@ -133,11 +133,10 @@ def __init__(self, pipeline, input_names=None):
self.thread = None
# Cache
self.cache_outputs = dict()
self.cache_inputs = dict()

@property
def proxy(self):
return DALIProxy(self.input_names, self.send_q)
return _DALIProxy(self.input_names, self.send_q)

def next_outputs(self):
# Get the information about the order of execution, so that we know which one is
Expand All @@ -163,15 +162,14 @@ def get_outputs(self, req_info):
if req_info in self.cache_outputs:
req_outputs = self.cache_outputs[req_info]
del self.cache_outputs[req_info]
del self.cache_inputs[req_info]

else:
info = None
# If not the data we are looking for, store it and keep processing until we find it
while req_info != info:
info, processed_outputs = self.next_outputs()
if info == req_info:
req_outputs = processed_outputs
del self.cache_inputs[req_info]
else:
self.cache_outputs[info] = processed_outputs
# Unpack single element tuples
Expand All @@ -184,19 +182,17 @@ def thread_fn(self):
Asynchronous DALI thread that gets iteration data from the queue and schedules it
for execution
"""
self.pipe.build() # just in case

while not self.thread_stop_event.is_set():
try:
torch.cuda.nvtx.range_push("send_q.get")
info, inputs = self.send_q.get(timeout=5)
torch.cuda.nvtx.range_pop()
self.cache_inputs[info] = inputs
except mp.TimeoutError:
continue
except Empty:
continue
torch.cuda.nvtx.range_push(f"order_q.put {info}")
self.order_q.put(info)
torch.cuda.nvtx.range_pop()

torch.cuda.nvtx.range_push(f"feed_input {info}")
for idx, input_name in enumerate(self.input_names):
Expand All @@ -207,6 +203,10 @@ def thread_fn(self):
self.pipe.schedule_run()
torch.cuda.nvtx.range_pop()

torch.cuda.nvtx.range_push(f"order_q.put {info}")
self.order_q.put(info)
torch.cuda.nvtx.range_pop()

def start_thread(self):
"""
Starts the DALI pipeline thread
Expand All @@ -230,9 +230,13 @@ def stop_thread(self):

def __enter__(self):
self.start_thread()
return self

def __exit__(self, exc_type, exc_value, tb):
self.stop_thread()
if exc_type is not None:
print(f"An exception occurred: {exc_value}")
return False # Return False to propagate exceptions


def _collate_pipeline_run_ref_fn(pipe_out, *, collate_fn_map=None):
Expand Down
159 changes: 159 additions & 0 deletions dali/test/python/test_dali_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from nvidia.dali import pipeline_def, fn, types
import numpy as np
import os
import threading
from nose2.tools import params
from nose_utils import attr

def read_file(path):
return np.fromfile(path, dtype=np.uint8)

def read_filepath(path):
return np.frombuffer(path.encode(), dtype=np.int8)

dali_extra = os.environ['DALI_EXTRA_PATH']
jpeg = os.path.join(dali_extra, 'db', 'single', 'jpeg')
jpeg_113 = os.path.join(jpeg, '113')
test_files = [os.path.join(jpeg_113, f) for f in ['snail-4291306_1280.jpg', 'snail-4345504_1280.jpg', 'snail-4368154_1280.jpg']]
test_input_filenames = [read_filepath(fname) for fname in test_files]

@pipeline_def(exec_dynamic=True)
def pipe_decoder(device):
filepaths = fn.external_source(name="images", no_copy=True, blocking=True)
images = fn.io.file.read(filepaths)
decoder_device = 'mixed' if device == 'gpu' else 'cpu'
images = fn.decoders.image(images,
device=decoder_device,
output_type=types.RGB)
images = fn.crop(images, crop=(224, 224))
return images

@attr("pytorch")
@params(("cpu",), ("gpu",))
def test_dali_proxy_demo_basic_communication(device, debug=False):
# This is a test that is meant to illustrate how the inter-process or inter-thread communication
# works when using DALI proxy. The code here is not really meant to be run like this by a user.
# A better example for user API is `test_dali_proxy_torch_data_loader`

import torch
from nvidia.dali.plugin.pytorch import proxy as dali_proxy

threads = []
batch_size = 4
num_threads = 3
device_id = 0
nworkers = 3
niter = 5
pipe = pipe_decoder(device, batch_size=batch_size, num_threads=num_threads, device_id=device_id)

# Runs the server (and clean up on exit)
with dali_proxy.DALIServer(pipe) as dali_server:

# Creating a bunch of worker threads that call the proxy callable on a sample by sample basis
# and call the collate function directly, which will trigger a pipeline run on the server
for _ in range(nworkers):
def thread_fn(proxy_pipe_call):
for _ in range(niter):
# The proxy call is run per sample
pipe_run_refs = [proxy_pipe_call(test_input_filenames[i % len(test_input_filenames)]) for i in range(batch_size)]
# this forms a batch and sends it to DALI
dali_proxy._collate_pipeline_run_ref_fn(pipe_run_refs)

thread = threading.Thread(target=thread_fn, args=(dali_server.proxy,))
threads.append(thread)
thread.start()

collected_data_info = {}
for thread in threads:
collected_data_info[thread.ident] = [None for _ in range(niter)]

# On the main thread, we can query the server for new outputs
for _ in range(nworkers * niter):
info, outputs = dali_server.next_outputs()
worker_id = info[0]
data_idx = info[1]
if debug:
print(f"worker_id={worker_id}, data_idx={data_idx}, data_shape={outputs[0].shape}")
assert worker_id in collected_data_info
collected_data_info[worker_id][data_idx] = outputs
assert(len(outputs) == 1)
np.testing.assert_equal([batch_size, 224, 224, 3], outputs[0].shape)

for thread in threads:
thread.join()

# Make sure we received all the data we expected
for thread in threads:
for data_idx in range(niter):
data, = collected_data_info[thread.ident][data_idx]
assert data is not None
expected_device = torch.device(type='cuda', index=device_id) if device == 'gpu' else torch.device('cpu')
np.testing.assert_equal(expected_device, data.device)


@pipeline_def
def rn50_train_pipe(dali_device="gpu"):
rng = fn.random.coin_flip(probability=0.5)

filepaths = fn.external_source(name="images", no_copy=True, blocking=True)
jpegs = fn.io.file.read(filepaths)
if dali_device == "gpu":
decoder_device = "mixed"
resize_device = "gpu"
else:
decoder_device = "cpu"
resize_device = "cpu"

images = fn.decoders.image_random_crop(jpegs, device=decoder_device, output_type=types.RGB,
random_aspect_ratio=[0.75, 4.0 / 3.0],
random_area=[0.08, 1.0])

images = fn.resize(images, device=resize_device, size=[224, 224],
interp_type=types.INTERP_LINEAR, antialias=False)

# Make sure that from this point we are processing on GPU regardless of dali_device parameter
images = images.gpu()

images = fn.flip(images, horizontal=rng)

output = fn.crop_mirror_normalize(images, dtype=types.FLOAT, output_layout='CHW',
crop=(224, 224),
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
return output


@attr("pytorch")
@params(("cpu",), ("gpu",))
def test_dali_proxy_torch_data_loader(device, debug=False):
# Shows how DALI proxy is used in practice with a PyTorch data loader

from nvidia.dali.plugin.pytorch import proxy as dali_proxy
import torchvision.datasets as datasets

batch_size = 4
num_threads = 3
device_id = 0
nworkers = 4
pipe = rn50_train_pipe(device, batch_size=batch_size, num_threads=num_threads, device_id=device_id)

# Run the server (it also cleans up on scope exit)
with dali_proxy.DALIServer(pipe) as dali_server:

dataset = datasets.ImageFolder(
jpeg,
transform=dali_server.proxy,
loader=read_filepath
)

loader = dali_proxy.DataLoader(
dali_server,
dataset,
batch_size=batch_size,
num_workers=nworkers,
drop_last=True,
)

for next_input, next_target in loader:
np.testing.assert_equal([batch_size, 3, 224, 224], next_input.shape)
np.testing.assert_equal([batch_size,], next_target.shape)

0 comments on commit a89b147

Please sign in to comment.