Skip to content

Commit

Permalink
fix(aten::flatten): Fixing flatten converter to handle dynamic batch
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Feb 23, 2021
1 parent 34f84df commit 00f2d78
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 12 deletions.
4 changes: 3 additions & 1 deletion core/conversion/converters/impl/shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ static auto shuffle_registrations TRTORCH_UNUSED =
auto end_dim = args[2].unwrapToInt();
auto in_shape = util::toVec(in->getDimensions());
std::vector<int64_t> out_shape;
if (ctx->input_is_dynamic) {
if (ctx->input_is_dynamic && in_shape[0] != -1) {
out_shape = std::vector<int64_t>({in_shape[0], -1});
} else if (ctx->input_is_dynamic && in_shape[0] == -1) {
out_shape = std::vector<int64_t>({-1, -1 * std::accumulate(std::begin(in_shape), std::end(in_shape), 1, std::multiplies<int64_t>())});
} else {
out_shape = torch::flatten(torch::rand(in_shape), start_dim, end_dim).sizes().vec();
}
Expand Down
2 changes: 1 addition & 1 deletion tests/core/conversion/converters/test_pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ TEST(Converters, ATenAdaptiveAvgPool2DConvertsCorrectlyWithDynamicInput) {

auto trt_in = at::clone(in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in});
auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, false);

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}
26 changes: 25 additions & 1 deletion tests/core/conversion/converters/test_shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,31 @@ TEST(Converters, ATenFlattenConvertsCorrectlyWithDynamicInput) {

in = at::clone(in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {in});
auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {in}, false);
auto trt = trt_results[0].reshape_as(jit_results[0]);

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}


TEST(Converters, ATenFlattenConvertsCorrectlyWithDynamicBatch) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : int = prim::Constant[value=0]()
%2 : int = prim::Constant[value=1]()
%3 : Tensor = aten::flatten(%0, %1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randint(0, 5, {2, 3}, {at::kCUDA});
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});

in = at::clone(in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {in}, true);
auto trt = trt_results[0].reshape_as(jit_results[0]);

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
Expand Down
27 changes: 19 additions & 8 deletions tests/util/run_graph_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,29 @@ std::vector<core::conversion::InputRange> toInputRanges(std::vector<at::Tensor>
return std::move(a);
}

std::vector<core::conversion::InputRange> toInputRangesDynamic(std::vector<at::Tensor> ten) {
std::vector<core::conversion::InputRange> toInputRangesDynamic(std::vector<at::Tensor> ten, bool dynamic_batch) {
std::vector<core::conversion::InputRange> a;

for (auto i : ten) {
auto opt = core::util::toVec(i.sizes());

std::vector<int64_t> min_range(opt);
std::vector<int64_t> max_range(opt);
if (dynamic_batch) {
std::vector<int64_t> min_range(opt);
std::vector<int64_t> max_range(opt);

min_range[1] = ceil(opt[1] / 2.0);
max_range[1] = 2 * opt[1];
min_range[0] = ceil(opt[0] / 2.0);
max_range[0] = 2 * opt[0];

a.push_back(core::conversion::InputRange(min_range, opt, max_range));
a.push_back(core::conversion::InputRange(min_range, opt, max_range));
} else {
std::vector<int64_t> min_range(opt);
std::vector<int64_t> max_range(opt);

min_range[1] = ceil(opt[1] / 2.0);
max_range[1] = 2 * opt[1];

a.push_back(core::conversion::InputRange(min_range, opt, max_range));
}
}

return std::move(a);
Expand Down Expand Up @@ -63,9 +73,10 @@ std::vector<at::Tensor> RunGraphEngine(
std::vector<at::Tensor> RunGraphEngineDynamic(
std::shared_ptr<torch::jit::Graph>& g,
core::conversion::GraphParams& named_params,
std::vector<at::Tensor> inputs) {
std::vector<at::Tensor> inputs,
bool dynamic_batch) {
LOG_DEBUG("Running TRT version");
auto in = toInputRangesDynamic(inputs);
auto in = toInputRangesDynamic(inputs, dynamic_batch);
auto info = core::conversion::ConversionInfo(in);
info.engine_settings.workspace_size = 1 << 20;
std::string eng = core::conversion::ConvertBlockToEngine(g->block(), info, named_params);
Expand Down
3 changes: 2 additions & 1 deletion tests/util/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ std::vector<at::Tensor> RunGraphEngine(
std::vector<at::Tensor> RunGraphEngineDynamic(
std::shared_ptr<torch::jit::Graph>& g,
core::conversion::GraphParams& named_params,
std::vector<at::Tensor> inputs);
std::vector<at::Tensor> inputs,
bool dynamic_batch);

// Run the forward method of a module and return results
torch::jit::IValue RunModuleForward(torch::jit::Module& mod, std::vector<torch::jit::IValue> inputs);
Expand Down

0 comments on commit 00f2d78

Please sign in to comment.