Skip to content

Commit

Permalink
Add dynamic conversion path to aten::mul evaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
mfeliz-cruise committed Apr 7, 2023
1 parent 78b571c commit 02d502c
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 2 deletions.
10 changes: 10 additions & 0 deletions core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,16 @@ auto aten_registrations TORCHTRT_UNUSED =
.evaluator(
{c10::Symbol::fromQualString("aten::mul"),
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
if (!constTypesOnly(args)) {
auto a = args.at(n->input(0)).ITensorOrFreeze(ctx);
auto b = args.at(n->input(1)).ITensorOrFreeze(ctx);
auto mul =
converters::add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, a, b, util::node_info(n));
TORCHTRT_CHECK(mul, "Unable to create mul layer from node: " << *n);
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], mul->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return {};
}
if (args.at(n->input(0)).IValue()->isInt()) {
auto a = args.at(n->input(0)).unwrapToInt();
auto b = args.at(n->input(1)).unwrapToInt();
Expand Down
7 changes: 5 additions & 2 deletions core/conversion/var/Var.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,18 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
}

TORCHTRT_CHECK(
isITensor() || (isIValue() && (ptr_.ivalue->isTensor() || ptr_.ivalue->isCustomClass())),
"Requested either IValue containing a Tensor, or ITensor, however Var type is " << type_name());
isITensor() ||
(isIValue() && (ptr_.ivalue->isTensor() || ptr_.ivalue->isScalar() || ptr_.ivalue->isCustomClass())),
"Requested either IValue containing a Tensor, Scalar or ITensor, however Var type is " << type_name());

nvinfer1::ITensor* out;

if (isIValue()) {
if (ptr_.ivalue->isTensor()) {
auto tensor = ptr_.ivalue->toTensor();
out = converters::tensor_to_const(ctx, tensor);
} else if (ptr_.ivalue->isScalar()) {
out = converters::scalar_to_tensor(ctx, ptr_.ivalue->toScalar());
} else {
// Split converter generates c10::IValue which hold TensorContainer.
auto output_container = ptr_.ivalue->toCustomClass<TensorContainer>();
Expand Down
32 changes: 32 additions & 0 deletions tests/cpp/test_dynamic_size.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,37 @@ TEST(Converters, ATenResizeGetItemDynShapeCorrectly) {

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenResizeGetItemDynShapeMulCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=0]()
%3 : int = prim::Constant[value=-1]()
%4 : int = prim::Constant[value=2]()
%size.1 : int[] = aten::size(%x.1)
%37 : int = aten::__getitem__(%size.1, %2)
%38 : int = aten::mul(%37, %4)
%39 : int[] = prim::ListConstruct(%38, %3)
%7 : Tensor = aten::reshape(%x.1, %39)
return (%7))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 10, {16, 16, 16}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true);

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

0 comments on commit 02d502c

Please sign in to comment.