Skip to content

Commit

Permalink
fix: Fix testcases using old InputRange API
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <[email protected]>
  • Loading branch information
peri044 committed Jul 26, 2021
1 parent da2cbc0 commit ff87956
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 52 deletions.
42 changes: 1 addition & 41 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ GraphAndMapping ConstructFallbackGraph(
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
std::vector<ir::Input> inputs;
for (auto& shape : seg_block.in_shape()) {
inputs.push_back(ir::InputRange(shape));
inputs.push_back(ir::Input(shape));
}
// update the input ranges for each segments
convert_cfg.inputs = inputs;
Expand Down Expand Up @@ -332,46 +332,6 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
return mod;
}

// <<<<<<< HEAD
// =======
// std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
// // add global graph's input to old_to_new_g mapping
// for (auto input : g->inputs()) {
// util::getOrAddInputForValue(input, new_g, old_to_new_g);
// }
// for (auto& seg_block : segmented_blocks) {
// std::string cur_block_target =
// seg_block.target() == partitioning::SegmentedBlock::kTensorRT ? "TensorRT" : "Torch";
// LOG_INFO(*seg_block.g() << "(Sub Graph" << cur_block_target << "Block)\n");
// std::ostringstream trt_engine_id;
// trt_engine_id << reinterpret_cast<const int*>(&seg_block);
// if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
// std::vector<ir::Input> inputs;
// for (auto& shape : seg_block.in_shape()) {
// inputs.push_back(ir::Input(shape));
// }
// // update the input ranges for each segments
// convert_cfg.inputs = inputs;
// auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
// auto temp_g = std::make_shared<torch::jit::Graph>();
// auto device_spec = convert_cfg.engine_settings.device;
// auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
// AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true);
//
// seg_block.update_graph(temp_g);
// AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
// } else {
// AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
// }
// }
//
// for (auto& output : g->outputs()) {
// new_g->registerOutput(old_to_new_g[output]);
// }
//
// LOG_INFO(*new_g << "(FallbackGraph)\n");
//
// >>>>>>> master
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
new_mod.type()->addMethod(new_method);
Expand Down
4 changes: 2 additions & 2 deletions tests/core/partitioning/test_conditionals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ TEST(Partitioning, FallbackOnConditionalsCorrectly) {
return;
}

std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({3, 3, 16, 16})};
trtorch::core::CompileSpec cfg(input_ranges);
std::vector<trtorch::core::ir::Input> inputs{trtorch::core::ir::Input({3, 3, 16, 16})};
trtorch::core::CompileSpec cfg(inputs);
cfg.partition_info.enabled = true;
torch::jit::script::Module new_mod = trtorch::core::CompileGraph(mod, cfg);
auto g = new_mod.get_method("forward").graph();
Expand Down
2 changes: 1 addition & 1 deletion tests/core/partitioning/test_shape_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ TEST(Partitioning, InferBranchModelSegmentedBlockShapeCorrectly) {
inputs.push_back(trtorch::core::ir::Input({16, 32, 3, 3}));
inputs.push_back(trtorch::core::ir::Input({16}));

std::unordered_map<torch::jit::Value*, trtorch::core::ir::InputRange> inputs_map;
std::unordered_map<torch::jit::Value*, trtorch::core::ir::Input> inputs_map;
for (size_t i = 0; i < g->inputs().size(); ++i) {
inputs_map.insert({g->inputs()[i], inputs[i]});
}
Expand Down
8 changes: 0 additions & 8 deletions tests/modules/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,10 @@
"model": torch.hub.load('pytorch/vision:v0.9.0', 'resnet50', pretrained=True),
"path": "both"
},
"fcn_resnet101": {
"model": torch.hub.load('pytorch/vision:v0.9.0', 'fcn_resnet101', pretrained=True),
"path": "script"
},
"ssd": {
"model": torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd', model_math="fp32"),
"path": "trace"
},
"faster_rcnn": {
"model": models.detection.fasterrcnn_resnet50_fpn(pretrained=True),
"path": "script"
},
"efficientnet_b0": {
"model": timm.create_model('efficientnet_b0', pretrained=True),
"path": "script"
Expand Down

0 comments on commit ff87956

Please sign in to comment.