Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the new device API, fixing the a nested dict issue in the existing compile phase, adding new lowering pass for bn #288

Merged
merged 8 commits into from
Jan 25, 2021
2 changes: 1 addition & 1 deletion .bazelversion
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.7.0
4.0.0
6 changes: 3 additions & 3 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ http_archive(

http_archive(
name = "tensorrt",
urls = ["https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/7.2.1/tars/TensorRT-7.2.1.6.Ubuntu-18.04.x86_64-gnu.cuda-11.0.cudnn8.0.tar.gz",],
urls = ["https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/7.2.2/tars/TensorRT-7.2.2.3.Ubuntu-18.04.x86_64-gnu.cuda-11.0.cudnn8.0.tar.gz",],
build_file = "@//third_party/tensorrt/archive:BUILD",
sha256 = "8def6b03b0c8c3751f560df21b3e99668ae05aab5140b1d38b8e51e4a0ffbbb8",
strip_prefix = "TensorRT-7.2.1.6"
strip_prefix = "TensorRT-7.2.2.3",
sha256 = "b5c325e38e1d92ce1ce92ca8b54ede9c224bf128c9a53eb0b9022f1ee4313ee0"
)

####################################################################################
Expand Down
1 change: 1 addition & 0 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
passes::Conv2DToConvolution(g);
passes::Conv3DToConvolution(g);
passes::FuseAddMMBranches(g);
passes::RemoveBNDimCheck(g);
torch::jit::EliminateCommonSubexpression(g);
// torch::jit::UnrollLoops(g);
torch::jit::EliminateCommonSubexpression(g);
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ cc_library(
"exception_elimination.cpp",
"fuse_addmm_branches.cpp",
"fuse_flatten_linear.cpp",
"remove_bn_dim_check.cpp",
"remove_contiguous.cpp",
"remove_dropout.cpp",
"remove_to.cpp",
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);
void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph);
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
void RemoveBNDimCheck(std::shared_ptr<torch::jit::Graph> graph);
void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
void RemoveTo(std::shared_ptr<torch::jit::Graph> graph);
Expand Down
89 changes: 89 additions & 0 deletions core/lowering/passes/remove_bn_dim_check.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#include "torch/csrc/jit/ir/alias_analysis.h"
#include "torch/csrc/jit/jit_log.h"
#include "torch/csrc/jit/passes/constant_propagation.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "torch/csrc/jit/passes/guard_elimination.h"
#include "torch/csrc/jit/passes/peephole.h"
#include "torch/csrc/jit/runtime/graph_executor.h"

#include "core/util/prelude.h"

#include <vector>

namespace trtorch {
namespace core {
namespace lowering {
namespace passes {
namespace {
using namespace torch::jit;
struct BNDimCheckRemoval {
BNDimCheckRemoval(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {}

void run() {
findBNDimCheckNodes(graph_->block());
torch::jit::EliminateDeadCode(graph_);
LOG_GRAPH("Post aten::addmm branch fusion: " << *graph_);
}

private:
bool isBNDimCheckNodes(Node* n) {
/// Check if this Node hosts a pattern like so:
/// %290 : bool = aten::ne(%289, %9)
/// = prim::If(%290)
/// block0():
/// %291 : str = aten::format(%10, %289)
/// = prim::RaiseException(%291)
/// -> ()
/// block1():
/// -> ()

if (n->blocks().size() != 2) {
return false;
}
auto arm1 = n->blocks()[0];
auto arm2 = n->blocks()[1];
if (arm1->outputs().size() != 0 || arm2->outputs().size() != 0) {
// Make sure that the node doesn't actually produce any Value that are
// used by other nodes
return false;
}

auto arm1_start = arm1->nodes().begin();

if ((*arm1_start)->kind() != c10::Symbol::fromQualString("aten::format") &&
(*(++arm1_start))->kind() != prim::RaiseException && (*(++arm1_start))->kind() != prim::Return) {
// Make sure that block0 is solely just the exception and the return
return false;
}

if ((*(arm2->nodes().begin()))->kind() != prim::Return) {
// Make sure that block1 is solely the return
return false;
}

return true;
}

void findBNDimCheckNodes(Block* b) {
for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
auto n = *it;
if (n->kind() == prim::If && isBNDimCheckNodes(n)) {
LOG_GRAPH("Found that node " << *n << " is an batch norm dim check node (EliminateChecks)" << std::endl);
it.destroyCurrent();
}
}
}

std::shared_ptr<Graph> graph_;
};
} // namespace

void RemoveBNDimCheck(std::shared_ptr<Graph> graph) {
BNDimCheckRemoval bndcr(std::move(graph));
bndcr.run();
}

} // namespace passes
} // namespace lowering
} // namespace core
} // namespace trtorch
36 changes: 20 additions & 16 deletions docsrc/tutorials/use_from_pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,31 +32,35 @@ at the documentation for the TRTorch ``TensorRTCompileSpec`` API.
.. code-block:: python

spec = {
"forward": trtorch.TensorRTCompileSpec({
"input_shapes": [[1, 3, 300, 300]],
"op_precision": torch.half,
"refit": False,
"debug": False,
"strict_types": False,
"allow_gpu_fallback": True,
"device_type": "gpu",
"capability": trtorch.EngineCapability.default,
"num_min_timing_iters": 2,
"num_avg_timing_iters": 1,
"max_batch_size": 0,
})
}
"forward":
trtorch.TensorRTCompileSpec({
"input_shapes": [[1, 3, 300, 300]],
"op_precision": torch.half,
"refit": False,
"debug": False,
"strict_types": False,
"device": {
"device_type": trtorch.DeviceType.GPU,
"gpu_id": 0,
"allow_gpu_fallback": True
},
"capability": trtorch.EngineCapability.default,
"num_min_timing_iters": 2,
"num_avg_timing_iters": 1,
"max_batch_size": 0,
})
}

Now to compile with TRTorch, provide the target module objects and the spec dictionary to ``torch._C._jit_to_tensorrt``

.. code-block:: python

trt_model = torch._C._jit_to_tensorrt(script_model._c, spec)
trt_model = torch._C._jit_to_backend("tensorrt", script_model, spec)

To run explicitly call the function of the method you want to run (vs. how you can just call on the module itself in standard PyTorch)

.. code-block:: python

input = torch.randn((1, 3, 300, 300).to("cuda").to(torch.half)
input = torch.randn((1, 3, 300, 300)).to("cuda").to(torch.half)
print(trt_model.forward(input))

19 changes: 7 additions & 12 deletions py/trtorch/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,6 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
assert isinstance(compile_spec["strict_types"], bool)
info.strict_types = compile_spec["strict_types"]

if "allow_gpu_fallback" in compile_spec:
assert isinstance(compile_spec["allow_gpu_fallback"], bool)
info.allow_gpu_fallback = compile_spec["allow_gpu_fallback"]

if "device" in compile_spec:
info.device = _parse_device(compile_spec["device"])

Expand All @@ -177,7 +173,7 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
return info


def TensorRTCompileSpec(compile_spec: Dict[str, Any]):
def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.CompileSpec:
"""
Utility to create a formated spec dictionary for using the PyTorch TensorRT backend

Expand Down Expand Up @@ -235,14 +231,13 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]):
ir.set_max(i.max)
backend_spec.append_input_range(ir)

for i in parsed_spec.device:
ir = torch.classes.tensorrt.Device()
ir.set_device_type(i.device_type)
ir.set_gpu_id(i.gpu_id)
ir.set_dla_core(i.dla_core)
ir.set_allow_gpu_fallback(i.allow_gpu_fallback)
backend_spec.set_device(ir)
d = torch.classes.tensorrt.Device()
d.set_device_type(int(parsed_spec.device.device_type))
d.set_gpu_id(parsed_spec.device.gpu_id)
d.set_dla_core(parsed_spec.device.dla_core)
d.set_allow_gpu_fallback(parsed_spec.device.allow_gpu_fallback)

backend_spec.set_device(d)
backend_spec.set_op_precision(int(parsed_spec.op_precision))
backend_spec.set_refit(parsed_spec.refit)
backend_spec.set_debug(parsed_spec.debug)
Expand Down
12 changes: 11 additions & 1 deletion py/trtorch/csrc/register_tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,32 @@
namespace trtorch {
namespace backend {
namespace {
void RegisterTRTCompileSpec() {

#define ADD_FIELD_GET_SET_REGISTRATION(registry, class_name, field_name) \
(registry).def("set_" #field_name, &class_name::set_##field_name); \
(registry).def("get_" #field_name, &class_name::get_##field_name);

void RegisterTRTCompileSpec() {
static auto TRTORCH_UNUSED TRTInputRangeTSRegistration =
torch::class_<trtorch::pyapi::InputRange>("tensorrt", "InputRange").def(torch::init<>());

ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, min);
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, opt);
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, max);

static auto TRTORCH_UNUSED TRTDeviceTSRegistration =
torch::class_<trtorch::pyapi::Device>("tensorrt", "Device").def(torch::init<>());

ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, device_type);
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, gpu_id);
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, dla_core);
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, allow_gpu_fallback);

static auto TRTORCH_UNUSED TRTCompileSpecTSRegistration =
torch::class_<trtorch::pyapi::CompileSpec>("tensorrt", "CompileSpec")
.def(torch::init<>())
.def("append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange)
.def("set_device", &trtorch::pyapi::CompileSpec::setDeviceIntrusive)
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);

ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, op_precision);
Expand Down
2 changes: 1 addition & 1 deletion py/trtorch/csrc/tensorrt_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ c10::impl::GenericDict TensorRTBackend::compile(c10::IValue processed_mod, c10::
auto method = mod.get_method(method_name);
auto g = method.graph();

auto raw_spec = it->value().toGenericDict().at(it->key()).toCustomClass<trtorch::pyapi::CompileSpec>();
auto raw_spec = it->value().toCustomClass<trtorch::pyapi::CompileSpec>();
LOG_DEBUG(raw_spec->stringify());
auto cfg = raw_spec->toInternalCompileSpec();
auto convert_cfg = std::move(cfg.convert_info);
Expand Down
30 changes: 17 additions & 13 deletions py/trtorch/csrc/tensorrt_classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ namespace pyapi {
return field_name; \
}

// TODO: Make this error message more informative
andi4191 marked this conversation as resolved.
Show resolved Hide resolved
#define ADD_ENUM_GET_SET(field_name, type, max_val) \
void set_##field_name(int64_t val) { \
TRTORCH_CHECK(val < max_val, "Invalid enum value for field"); \
field_name = static_cast<type>(val); \
} \
int64_t get_##field_name() { \
return static_cast<int64_t>(field_name); \
}

struct InputRange : torch::CustomClassHolder {
std::vector<int64_t> min;
std::vector<int64_t> opt;
Expand Down Expand Up @@ -59,7 +69,7 @@ struct Device : torch::CustomClassHolder {
allow_gpu_fallback(false) // allow_gpu_fallback
{}

ADD_FIELD_GET_SET(device_type, DeviceType);
ADD_ENUM_GET_SET(device_type, DeviceType, 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is "1" here?
Can we use enumeration?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This macro creates a getter and setter pair that returns integers so that the function is torchbind compatible. the 1 is the max allowable value so that you dont get invalid ones.

ADD_FIELD_GET_SET(gpu_id, int64_t);
ADD_FIELD_GET_SET(dla_core, int64_t);
ADD_FIELD_GET_SET(allow_gpu_fallback, bool);
Expand All @@ -77,28 +87,22 @@ enum class EngineCapability : int8_t {
std::string to_str(EngineCapability value);
nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value);

// TODO: Make this error message more informative
#define ADD_ENUM_GET_SET(field_name, type, max_val) \
void set_##field_name(int64_t val) { \
TRTORCH_CHECK(val < max_val, "Invalid enum value for field"); \
field_name = static_cast<type>(val); \
} \
int64_t get_##field_name() { \
return static_cast<int64_t>(field_name); \
}

struct CompileSpec : torch::CustomClassHolder {
core::CompileSpec toInternalCompileSpec();
std::string stringify();
void appendInputRange(const c10::intrusive_ptr<InputRange>& ir) {
input_ranges.push_back(*ir);
}

ADD_ENUM_GET_SET(op_precision, DataType, 3);
void setDeviceIntrusive(const c10::intrusive_ptr<Device>& d) {
device = *d;
}

ADD_ENUM_GET_SET(op_precision, DataType, 2);
ADD_FIELD_GET_SET(refit, bool);
ADD_FIELD_GET_SET(debug, bool);
ADD_FIELD_GET_SET(strict_types, bool);
ADD_ENUM_GET_SET(capability, EngineCapability, 3);
ADD_ENUM_GET_SET(capability, EngineCapability, 2);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using enumerations instead of hardcoded numbers?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do like a static cast to int of the enums, sounds like a good idea

ADD_FIELD_GET_SET(num_min_timing_iters, int64_t);
ADD_FIELD_GET_SET(num_avg_timing_iters, int64_t);
ADD_FIELD_GET_SET(workspace_size, int64_t);
Expand Down
3 changes: 2 additions & 1 deletion tests/py/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ py_test(
] + select({
":aarch64_linux": [
"test_api_dla.py"
]
],
"//conditions:default" : []
}),
deps = [
requirement("torchvision")
Expand Down
11 changes: 7 additions & 4 deletions tests/py/test_to_backend_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ def setUp(self):
"refit": False,
"debug": False,
"strict_types": False,
"allow_gpu_fallback": True,
"device_type": "gpu",
"device": {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andi4191 This struct looks correct to you right?

Copy link
Contributor

@andi4191 andi4191 Jan 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are missing dla_core.

"device_type": trtorch.DeviceType.GPU,
"gpu_id": 0,
"allow_gpu_fallback": True
},
"capability": trtorch.EngineCapability.default,
"num_min_timing_iters": 2,
"num_avg_timing_iters": 1,
Expand All @@ -29,14 +32,14 @@ def setUp(self):
}

def test_to_backend_lowering(self):
trt_mod = torch._C._jit_to_tensorrt(self.scripted_model._c, {"forward": self.spec})
trt_mod = torch._C._jit_to_backend("tensorrt", self.scripted_model, self.spec)
same = (trt_mod.forward(self.input) - self.scripted_model(self.input)).abs().max()
self.assertTrue(same < 2e-3)


def test_suite():
suite = unittest.TestSuite()
suite.addTest(TestToBackendLowering.parametrize(TestToBackendLowering, model=models.mobilenet_v2(pretrained=True)))
suite.addTest(TestToBackendLowering.parametrize(TestToBackendLowering, model=models.resnet18(pretrained=True)))

return suite

Expand Down