-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes in TRT to support TF-SSD-MN #149
Conversation
tests/python/relay/test_tensorrt.py
Outdated
mod.set_input(**graph_params) | ||
mod.run(**input_dict) | ||
results = [mod.get_output(i) for i in range(mod.get_num_outputs())] | ||
trt_vm_exec = relay.create_executor("vm", mod=mod, ctx=tvm.gpu(0), target="cuda") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to verify both graph runtime and VM here
tests/python/relay/test_tensorrt.py
Outdated
check_trt_used(mod) | ||
|
||
with relay.build_config(opt_level=3): | ||
exec = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), target="llvm") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's also test graph runtime
tests/python/relay/test_tensorrt.py
Outdated
dtype = 'float32' | ||
input_shape = (1, 3, 224, 224) | ||
i_data = np.random.uniform(-1, 1, input_shape).astype(dtype) | ||
for model in models: | ||
latency[model], res = test_model(model, i_data, input_shape, dtype, use_trt=True) | ||
_, ref_res = test_model(model, i_data, input_shape, dtype, use_trt=False, num_iteration=1) | ||
tvm.testing.assert_allclose(res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-3) | ||
|
||
for model in models: | ||
print(model, latency[model]) | ||
|
||
def test_tensorrt_serialize(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make sure not to remove any previous tests, we should only be adding additional tests for VM. We still need the test for graph runtime serialization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I am adding the graph runtime codes as well
tests/python/relay/test_tensorrt.py
Outdated
def run_vm(code, lib): | ||
vm_exec = tvm.runtime.vm.Executable.load_exec(code, lib) | ||
vm = VirtualMachine(vm_exec, tvm.cpu(0)) | ||
result = vm.invoke("main", *i_data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i_data is used here before it is declared
@rohanmukh The CI is fixed now, could you please rebase and push? |
python/tvm/relay/tensorrt.py
Outdated
@@ -152,11 +201,13 @@ def register_tensorrt_annotations(trt_version, use_implicit_batch=True): | |||
# _register_external_op_helper("split") | |||
# _register_external_op_helper("slice_like") | |||
|
|||
@tvm.ir.register_op_attr("add", "target.tensorrt") | |||
# @tvm.ir.register_op_attr("add", "target.tensorrt") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you remove the commented out lines
python/tvm/relay/tensorrt.py
Outdated
@@ -165,9 +216,17 @@ def add_whitelist_fn(attrs, args): # pylint: disable=unused-variable | |||
): | |||
print("add: bug in TRT with adding batched constants.") | |||
return False | |||
|
|||
# Skip this add op in TRT to avoid accuracy mismatch | |||
if all([a == b for a, b in zip(args[0].checked_type.shape, [1, 546, 1, 1])]) and all( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
list(map(int, args[x].checked_type.shape)) == [1, 546, 1, 1]
is slightly cleaner
python/tvm/relay/tensorrt.py
Outdated
@@ -412,6 +476,20 @@ def pad_whitelist_fn(attrs, args): # pylint: disable=unused-variable | |||
return False | |||
return True | |||
|
|||
_register_external_dynamic_check_func("nn.dense", dense_whitelist_fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have all of the _register_external_dynamic_check_func
calls in the same place at the end?
@@ -123,6 +162,16 @@ def _func_wrapper(attrs, args): | |||
return _func_wrapper | |||
|
|||
|
|||
def _register_external_dynamic_check_func(op_name, func): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be better to register this as a decorator so we don't have to call _register_external_dynamic_check_func
everytime. Not required though
Thanks @trevor-m , let me fix these |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @rohanmukh !
* tests in test_tensorrt.py ported to use relay VM along with graph runtime * updates in tensorrt.py to offload dynamic shapes in TRT to VM * updates in tensorrt.py to run TF-SSD-MN * resolved conflicts
tests in test_tensorrt.py ported to use relay VM
updates in tensorrt.py to offload dynamic shapes in TRT to VM
updates in tensorrt.py to run TF-SSD-MN
Thanks for contributing to TVM! Please refer to guideline https://tvm.apache.org/docs/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from Reviewers by @ them in the pull request thread.