Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorRT Initial commit #22131

Merged
merged 40 commits into from
Sep 16, 2022
Merged
Changes from 1 commit
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
1109810
Initial commit
azhurkevich Jul 1, 2022
dcccde4
Format
yeandy Jul 8, 2022
b29b9f6
Fix read of onnx file
yeandy Jul 8, 2022
f9517d6
TensorRT Object Detection example
azhurkevich Jul 8, 2022
b6b1798
Fix copyright; Add back pycuda
yeandy Jul 13, 2022
23c14f1
Addressing noted and resolved comments
azhurkevich Jul 13, 2022
4d507b7
Addressing noted and resolved comments
azhurkevich Jul 13, 2022
bb099d5
Addressing noted and resolved comments
azhurkevich Jul 13, 2022
e509716
Keeping mem alloc out of inference call + CUDA Python replaced PyCUDA…
azhurkevich Jul 14, 2022
d2e5694
Addressing comments
azhurkevich Jul 15, 2022
d61688b
Fixing header
azhurkevich Jul 18, 2022
8055026
Format and lint
yeandy Jul 18, 2022
4b6b133
Format and lint tensorrt example
yeandy Jul 18, 2022
b7157d4
Sort imports
yeandy Jul 18, 2022
918bf95
Format docstrings
yeandy Jul 19, 2022
ff83c38
Addressing some new comments
azhurkevich Jul 26, 2022
382b885
Addressing some new comments
azhurkevich Jul 26, 2022
a79e91f
Refactor tensorrt imports
yeandy Aug 1, 2022
6b94cc4
Address PR comment
yeandy Aug 2, 2022
caea4c2
Merge pull request #1 from yeandy/tensorrt_runinference
azhurkevich Aug 2, 2022
1b27035
Merge branch 'master' into tensorrt_runinference
yeandy Aug 2, 2022
5d52984
Adding stream sync
azhurkevich Aug 16, 2022
93a1d98
Adding stream sync
azhurkevich Aug 16, 2022
ff1433e
Getting rid of optional argument validation
azhurkevich Aug 18, 2022
bea148f
Setting experimental status
azhurkevich Aug 24, 2022
ec0f276
Raising exception intead of exiting
azhurkevich Aug 30, 2022
93eb45f
Update sdks/python/apache_beam/examples/inference/README.md
azhurkevich Aug 30, 2022
70fd62f
Update sdks/python/apache_beam/examples/inference/README.md
azhurkevich Aug 30, 2022
8ac1d10
gradle task for tensor RT example E2E test
AnandInguva Aug 10, 2022
4fabc51
Update input path and docker image
AnandInguva Sep 6, 2022
700e528
Add no_use_multiple_sdk_containers
AnandInguva Sep 6, 2022
811ee31
Add disk_size_gb for the example to work on Dataflow with Custom Image
AnandInguva Sep 7, 2022
72f89a7
Add inference test to Python 3.8. Also add TODO for py37, py39
AnandInguva Sep 7, 2022
59b8784
Apply suggestions from code review
AnandInguva Sep 7, 2022
a9b0417
Merge pull request #2 from AnandInguva/tensort-test
AnandInguva Sep 7, 2022
89eeadd
Fix lint issues
yeandy Sep 8, 2022
fbacd43
Guard TensorRT context with a reentrant lock.
tvalentyn Sep 14, 2022
a994b80
Add dockerfile
yeandy Sep 16, 2022
bee8af2
Apache license
damccorm Sep 16, 2022
2c417db
Apache license
damccorm Sep 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 81 additions & 55 deletions sdks/python/apache_beam/ml/inference/tensorrt_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@

import logging
import numpy as np
import pycuda.autoinit # pylint: disable=unused-import
import pycuda.driver as cuda
from cuda import cuda
import sys
import tensorrt as trt
from typing import Any, Dict, Iterable, Optional, Sequence, Tuple
Expand Down Expand Up @@ -80,9 +79,70 @@ def _validate_inference_args(inference_args):
'engines do not need extra arguments in their execute_v2() call.')


class TensorRTEngine:
def __init__(self, engine: trt.ICudaEngine):
"""Implementation of the TensorRTEngine class which handles
allocations associated with TensorRT engine.

Example Usage::

TensorRTEngine(engine)

Args:
engine: trt.ICudaEngine object that contains TensorRT engine
"""
self.engine = engine
self.context = engine.create_execution_context()
tvalentyn marked this conversation as resolved.
Show resolved Hide resolved
self.inputs = []
self.outputs = []
self.gpu_allocations = []
self.cpu_allocations = []

# Setup I/O bindings
for i in range(self.engine.num_bindings):
is_input = False
if self.engine.binding_is_input(i):
is_input = True
name = self.engine.get_binding_name(i)
dtype = self.engine.get_binding_dtype(i)
shape = self.engine.get_binding_shape(i)
if is_input:
batch_size = shape[0]
azhurkevich marked this conversation as resolved.
Show resolved Hide resolved
size = np.dtype(trt.nptype(dtype)).itemsize
for s in shape:
size *= s
azhurkevich marked this conversation as resolved.
Show resolved Hide resolved
err, allocation = cuda.cuMemAlloc(size)
azhurkevich marked this conversation as resolved.
Show resolved Hide resolved
binding = {
'index': i,
'name': name,
'dtype': np.dtype(trt.nptype(dtype)),
'shape': list(shape),
'allocation': allocation,
'size': size
}
self.gpu_allocations.append(allocation)
if self.engine.binding_is_input(i):
self.inputs.append(binding)
else:
self.outputs.append(binding)

assert self.context
assert batch_size > 0
azhurkevich marked this conversation as resolved.
Show resolved Hide resolved
assert len(self.inputs) > 0
assert len(self.outputs) > 0
assert len(self.gpu_allocations) > 0

for output in self.outputs:
self.cpu_allocations.append(np.zeros(output['shape'], output['dtype']))

def get_engine_attrs(self):
"""Returns TensorRT engine attributes."""
return self.engine, self.context, self.inputs, self.outputs, self.gpu_allocations, self.cpu_allocations


class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray,
tvalentyn marked this conversation as resolved.
Show resolved Hide resolved
PredictionResult,
trt.ICudaEngine]):
TensorRTEngine]):
def __init__(self, min_batch_size: int, max_batch_size: int, **kwargs):
"""Implementation of the ModelHandler interface for TensorRT.

Expand Down Expand Up @@ -119,22 +179,24 @@ def batch_elements_kwargs(self):
'max_batch_size': self.max_batch_size
}

def load_model(self) -> trt.ICudaEngine:
def load_model(self) -> TensorRTEngine:
"""Loads and initializes a TensorRT engine for processing."""
return _load_engine(self.engine_path)
engine = _load_engine(self.engine_path)
return TensorRTEngine(engine)

def load_onnx(self) -> Tuple[trt.INetworkDefinition, trt.Builder]:
"""Loads and parses an onnx model for processing."""
return _load_onnx(self.onnx_path)

def build_engine(self, network: trt.INetworkDefinition, builder: trt.Builder) -> trt.ICudaEngine:
def build_engine(self, network: trt.INetworkDefinition, builder: trt.Builder) -> TensorRTEngine:
"""Build an engine according to parsed/created network."""
return _build_engine(network, builder)
engine = _build_engine(network, builder)
return TensorRTEngine(engine)

def run_inference(
Copy link
Contributor

@tvalentyn tvalentyn Jul 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this implementation thread-safe?
Beam SDKs can start multiple concurrent threads within sdk_worker process, which will be processing the data. Are there any concerns with that from hardware or software standpoint?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pranavm-nvidia any comments?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks thread-safe. It would be good to share the engine between threads so you don't end up duplicating the model weights (maybe that can be future work). Each thread will still need a separate execution context.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, leave it for the future. Big thanks @pranavm-nvidia, any tips where I can read more about how to properly implement this?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@azhurkevich I'm not sure how multi-threading is set up in Beam, but the basic idea would be:

import threading

def run_inference(context, inputs, outputs, stream):
    ... # Copy HtoD
    context.execute_async_v2(...)
    ... # Copy DtoH
    stream.synchronize()

engine = ...

# Engine can be shared across threads, but need 1 context per thread
context0 = engine.create_execution_context()
context1 = engine.create_execution_context()

... # Allocate I/O buffers for each thread

t0 = threading.Thread(target=run_inference, args=(context0, inputs0, outputs0, stream0))
t1 = threading.Thread(target=run_inference, args=(context1, inputs1, outputs1, stream1))
t0.start()
t1.start()
t0.join()
t1.join()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like threading.Lock should work here? Unless there is a separate API provided by Beam which would be preferred?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tvalentyn any thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could use threading.RLock, so that the same thread can also aquire it if it somehow becomes necessary, I haven't look very close into where to place this lock.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tvalentyn Do we have this mechanism that you are mentioning in PyTorch example? It would be easier for me to understand what you would like to see if there is a reference.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no. I took a closer look, see suggestions. I haven't tested it.

self,
batch: np.ndarray,
engine: trt.ICudaEngine,
engine: TensorRTEngine,
inference_args: Optional[Dict[str, Any]] = None
) -> Iterable[PredictionResult]:
"""
Expand All @@ -152,55 +214,19 @@ def run_inference(
An Iterable of type PredictionResult.
"""
_validate_inference_args(inference_args)
stream = cuda.Stream()

context = engine.create_execution_context()
assert context
# Setup I/O bindings
inputs = []
outputs = []
allocations = []
for i in range(engine.num_bindings):
is_input = False
if engine.binding_is_input(i):
is_input = True
name = engine.get_binding_name(i)
dtype = engine.get_binding_dtype(i)
shape = engine.get_binding_shape(i)
if is_input:
batch_size = shape[0]
size = np.dtype(trt.nptype(dtype)).itemsize
for s in shape:
size *= s
allocation = cuda.mem_alloc(size)
binding = {
'index': i,
'name': name,
'dtype': np.dtype(trt.nptype(dtype)),
'shape': list(shape),
'allocation': allocation,
}
allocations.append(allocation)
if engine.binding_is_input(i):
inputs.append(binding)
else:
outputs.append(binding)

assert batch_size > 0
assert len(inputs) > 0
assert len(outputs) > 0
assert len(allocations) > 0
# Prepare the output data
predictions = []
for output in outputs:
predictions.append(np.zeros(output['shape'], output['dtype']))
#Create CUDA Stream
err, stream = cuda.cuStreamCreate(0)
azhurkevich marked this conversation as resolved.
Show resolved Hide resolved
engine, context, inputs, outputs, gpu_allocations, cpu_allocations = engine.get_engine_attrs()
# Process I/O and execute the network
tvalentyn marked this conversation as resolved.
Show resolved Hide resolved
cuda.memcpy_htod_async(inputs[0]['allocation'], np.ascontiguousarray(batch), stream)
context.execute_async_v2(allocations, stream.handle)
for output in range(len(predictions)):
cuda.memcpy_dtoh_async(predictions[output], outputs[output]['allocation'], stream)
err, = cuda.cuMemcpyHtoDAsync(inputs[0]['allocation'], np.ascontiguousarray(batch), inputs[0]['size'], stream)
context.execute_async_v2(gpu_allocations, stream)
tvalentyn marked this conversation as resolved.
Show resolved Hide resolved
for output in range(len(cpu_allocations)):
err, = cuda.cuMemcpyDtoHAsync(cpu_allocations[output], outputs[output]['allocation'], outputs[output]['size'], stream)
# Destroy CUDA Stream
err, = cuda.cuStreamDestroy(stream)

return [
PredictionResult(x, [prediction[idx] for prediction in predictions])
PredictionResult(x, [prediction[idx] for prediction in cpu_allocations])
for idx,
x in enumerate(batch)
]
Expand Down