Skip to content

Commit

Permalink
feat(//core/partitioning): Refactor top level partitioning API, fix a…
Browse files Browse the repository at this point in the history
… bug with

lowering linear to addmm. Add python tests

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Apr 30, 2021
1 parent 7be368f commit abc63f6
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 11 deletions.
2 changes: 1 addition & 1 deletion core/lowering/passes/linear_to_addmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
%weight = aten::t(%weight_t)
%mm: Tensor = aten::matmul(%input, %weight)
%b_f: Tensor = trt::const(%bias)
%out: Tensor = aten::add_(%b_f, %mm, %1)
%out: Tensor = aten::add(%b_f, %mm, %1)
return (%out))IR";
std::string fused_linear_bias_none = R"IR(
graph(%input, %weight_t):
Expand Down
2 changes: 1 addition & 1 deletion cpp/api/include/trtorch/trtorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ struct TRTORCH_API CompileSpec {
uint64_t min_block_size = 1;

/// A list of names of operations that will explicitly run in PyTorch
std::vector<std::string> forced_fallback_operators;
std::vector<std::string> forced_fallback_ops;

/**
* @brief Construct a default Torch Fallback object, fallback will be off
Expand Down
2 changes: 1 addition & 1 deletion cpp/api/src/compile_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
internal.convert_info.engine_settings.max_batch_size = external.max_batch_size;
internal.partition_info.enabled = external.torch_fallback.enabled;
internal.partition_info.min_block_size = external.torch_fallback.min_block_size;
internal.partition_info.forced_fallback_operators = external.torch_fallback.forced_fallback_operators;
internal.partition_info.forced_fallback_operators = external.torch_fallback.forced_fallback_ops;

switch (external.device.device_type) {
case CompileSpec::Device::DeviceType::kDLA:
Expand Down
6 changes: 3 additions & 3 deletions py/trtorch/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> trtorch._C.TorchFall
assert isinstance(fallback_info["min_block_size"], int)
info.min_block_size = fallback_info["min_block_size"]

if "forced_fallback_operators" in fallback_info:
assert isinstance(fallback_info["forced_fallback_operators"], list)
info.forced_fallback_operators = fallback_info["forced_fallback_operators"]
if "forced_fallback_ops" in fallback_info:
assert isinstance(fallback_info["forced_fallback_ops"], list)
info.forced_fallback_operators = fallback_info["forced_fallback_ops"]

return info

Expand Down
7 changes: 7 additions & 0 deletions py/trtorch/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ def compile(module: torch.jit.ScriptModule, compile_spec: Any) -> torch.jit.Scri
"num_avg_timing_iters": 1, # Number of averaging timing iterations used to select kernels
"workspace_size": 0, # Maximum size of workspace given to TensorRT
"max_batch_size": 0, # Maximum batch size (must be >= 1 to be set, 0 means not set)
"torch_fallback": {
"enabled": True,
"force_fallback_ops": [
"aten::max_pool2d"
],
"min_block_size": 1
}
}
Input Sizes can be specified as torch sizes, tuples or lists. Op precisions can be specified using
Expand Down
5 changes: 3 additions & 2 deletions py/trtorch/csrc/tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value) {
}

core::CompileSpec CompileSpec::toInternalCompileSpec() {
std::vector<core::conversion::InputRange> internal_input_ranges;
std::vector<core::ir::InputRange> internal_input_ranges;
for (auto i : input_ranges) {
internal_input_ranges.push_back(i.toInternalInputRange());
}
Expand Down Expand Up @@ -132,6 +132,7 @@ std::string CompileSpec::stringify() {
for (auto i : input_ranges) {
ss << to_str(i);
}
std::string enabled = torch_fallback.enabled ? "True" : "False";
ss << " ]" << std::endl;
ss << " \"Op Precision\": " << to_str(op_precision) << std::endl;
ss << " \"TF32 Disabled\": " << disable_tf32 << std::endl;
Expand All @@ -149,7 +150,7 @@ std::string CompileSpec::stringify() {
ss << " \"Max Batch Size\": " << max_batch_size << std::endl;
ss << " \"Truncate long and double\": " << truncate_long_and_double << std::endl;
ss << " \"Torch Fallback: {" << std::endl;
ss << " \"enabled\": " << torch_fallback.enabled ? "True" : "False" << std::endl;
ss << " \"enabled\": " << enabled << std::endl;
ss << " \"min_block_size\": " << torch_fallback.min_block_size << std::endl;
ss << " \"forced_fallback_operators\": [" << std::endl;
for (auto i : torch_fallback.forced_fallback_operators) {
Expand Down
4 changes: 2 additions & 2 deletions py/trtorch/csrc/tensorrt_classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ struct InputRange : torch::CustomClassHolder {
std::vector<int64_t> opt;
std::vector<int64_t> max;

core::conversion::InputRange toInternalInputRange() {
return core::conversion::InputRange(min, opt, max);
core::ir::InputRange toInternalInputRange() {
return core::ir::InputRange(min, opt, max);
}

ADD_FIELD_GET_SET(min, std::vector<int64_t>);
Expand Down
2 changes: 1 addition & 1 deletion tests/core/lowering/test_linear_to_addmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ TEST(LoweringPasses, LinearToAddMM) {
%weight = aten::t(%weight_t)
%mm: Tensor = aten::matmul(%flat, %weight)
%b_f: Tensor = trt::const(%bias)
%out: Tensor = aten::add_(%b_f, %mm, %1)
%out: Tensor = aten::add(%b_f, %mm, %1)
return (%out))IR";

trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
Expand Down
29 changes: 29 additions & 0 deletions tests/py/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,34 @@ def test_compile_script(self):
self.assertTrue(same < 2e-3)


class TestFallbackToTorch(ModelTestCase):

def setUp(self):
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
self.scripted_model = torch.jit.script(self.model)

def test_compile_script(self):
compile_spec = {
"input_shapes": [self.input.shape],
"device": {
"device_type": trtorch.DeviceType.GPU,
"gpu_id": 0,
"dla_core": 0,
"allow_gpu_fallback": False,
"disable_tf32": False
},
"torch_fallback": {
"enabled": True,
"forced_fallback_ops": ["aten::max_pool2d"],
"min_block_size": 1
}
}

trt_mod = trtorch.compile(self.scripted_model, compile_spec)
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
self.assertTrue(same < 2e-3)


class TestPTtoTRTtoPT(ModelTestCase):

def setUp(self):
Expand Down Expand Up @@ -106,6 +134,7 @@ def test_suite():
suite.addTest(TestCompile.parametrize(TestCompile, model=models.resnet18(pretrained=True)))
suite.addTest(TestCompile.parametrize(TestCompile, model=models.mobilenet_v2(pretrained=True)))
suite.addTest(TestPTtoTRTtoPT.parametrize(TestPTtoTRTtoPT, model=models.mobilenet_v2(pretrained=True)))
suite.addTest(TestFallbackToTorch.parametrize(TestFallbackToTorch, model=models.resnet18(pretrained=True)))
suite.addTest(unittest.makeSuite(TestCheckMethodOpSupport))

return suite
Expand Down

0 comments on commit abc63f6

Please sign in to comment.