Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(//core/conversion): Add support for aten::size with dynamic shaped models for Torchscript backend. #1647

Merged
merged 15 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, const torch::
return {};
}
}
auto eval = evaluators::EvalNode(n, eval_args);
auto eval = evaluators::EvalNode(ctx, n, eval_args);
return eval;
}

Expand Down
36 changes: 24 additions & 12 deletions core/conversion/converters/impl/shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,25 +70,37 @@ static auto shuffle_registrations TORCHTRT_UNUSED =
auto in = args[0].ITensorOrFreeze(ctx);
auto in_shape = util::toVec(in->getDimensions());
std::vector<int64_t> new_shape;
nvinfer1::ITensor* shape_tensor;
if (ctx->input_is_dynamic) {
new_shape = util::toVec(args[1].unwrapToIntList().vec());
int nbDynamicDims = 0;
for (size_t i = 0; i < new_shape.size(); i++) {
if (in_shape[i] == -1)
nbDynamicDims++;
}
if (nbDynamicDims > 1) {
TORCHTRT_THROW_ERROR(
"Resize is currently not supported when target shape contains more than one dynamic dimension");
LOG_DEBUG("Using dynamic version of reshape layer");
if (args[1].isITensorList()) {
LOG_DEBUG("Shape tensor is an ITensorList");
auto new_shape = args[1].unwrapToITensorList();
auto concat_layer = ctx->net->addConcatenation(new_shape.data(), new_shape.size());
TORCHTRT_CHECK(concat_layer, "Unable to create concatenation layer from node: " << *n);
concat_layer->setAxis(static_cast<int32_t>(0));
shape_tensor = concat_layer->getOutput(0);
} else if (args[1].isIntList()) {
LOG_DEBUG("Shape tensor is an IntList");
auto shape_vec = args[1].unwrapToIntList().vec();
shape_tensor = tensor_to_const(ctx, torch::tensor(shape_vec).to(torch::kI32));
} else {
LOG_ERROR(
"Invalid IValue type of " << args[1].IValue()->type()
<< " detected for shape tensor from node: " << *n);
}
} else {
new_shape = torch::reshape(torch::rand(in_shape), args[1].unwrapToIntList().vec()).sizes().vec();
}

auto shuffle = ctx->net->addShuffle(*in);
TORCHTRT_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
shuffle->setReshapeDimensions(util::toDims(new_shape));
shuffle->setName(util::node_info(n).c_str());
TORCHTRT_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);

if (ctx->input_is_dynamic) {
shuffle->setInput(1, *shape_tensor);
} else {
shuffle->setReshapeDimensions(util::toDims(new_shape));
}

auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
Expand Down
4 changes: 2 additions & 2 deletions core/conversion/evaluators/NodeEvaluatorRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ std::vector<std::string> getEvaluatorList() {
return get_evaluator_registry().GetRegisteredEvaluatorList();
}

c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, kwargs& args) {
c10::optional<torch::jit::IValue> EvalNode(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) {
auto evaluator = get_evaluator_registry().GetEvaluator(n);
return evaluator(n, args);
return evaluator(ctx, n, args);
}

void register_node_evaluator(torch::jit::NodeKind node_kind, EvalRegistration eval_reg) {
Expand Down
101 changes: 57 additions & 44 deletions core/conversion/evaluators/aten.cpp

Large diffs are not rendered by default.

252 changes: 126 additions & 126 deletions core/conversion/evaluators/eval_macros.h

Large diffs are not rendered by default.

45 changes: 44 additions & 1 deletion core/conversion/evaluators/eval_util.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "core/conversion/evaluators/eval_util.h"
#include <ATen/ATen.h>
#include "ATen/InitialTensorOptions.h"
#include "ATen/core/List.h"
Expand All @@ -6,12 +7,54 @@
#include "ATen/core/jit_type.h"
#include "c10/util/irange.h"
#include "core/util/prelude.h"
#include "torch/torch.h"

namespace torch_tensorrt {
namespace core {
namespace conversion {
namespace evaluators {

nvinfer1::ITensor* index_layer(
ConversionCtx* ctx,
const torch::jit::Node* n,
nvinfer1::ITensor* input_tensor,
int64_t index) {
// index to access needs to be an at::Tensor
at::Tensor indices = torch::tensor({index}).to(torch::kI32);
auto indices_out = converters::tensor_to_const(ctx, indices);

auto gather_layer = ctx->net->addGather(*input_tensor, *indices_out, 0);
TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
auto indexed_tensor = gather_layer->getOutput(0);
return indexed_tensor;
}

c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) {
LOG_DEBUG("Using dynamic version of aten::size evaluator");
auto in = args.at(n->input(0)).ITensorOrFreeze(ctx);
LOG_DEBUG("Input dimensions: " << in->getDimensions());
auto shape_layer = ctx->net->addShape(*in);
TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n);
auto shape_1d_tensor = shape_layer->getOutput(0);

if (n->inputs().size() != 1) {
auto maxDim = static_cast<int64_t>(in->getDimensions().nbDims);
auto dim = args.at(n->input(1)).unwrapToInt();
// Handle negative axis by refering to nbDims of input Tensor
dim = dim < 0 ? dim + maxDim : dim;
LOG_DEBUG("Dimension to select: " << dim);
shape_1d_tensor = index_layer(ctx, n, shape_1d_tensor, dim);
}

LOG_DEBUG("Output tensor shape: " << shape_1d_tensor->getDimensions());

auto tensor_holder = TensorContainer();
tensor_holder.hold_tensor(shape_1d_tensor);
auto shape_1d_ivalue = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));

return shape_1d_ivalue;
}

int64_t normalizeIndex(int64_t idx, int64_t list_size) {
if (idx < 0) {
// Handle negative indexing
Expand Down Expand Up @@ -128,7 +171,7 @@ void checkSequenceSize(int64_t n, int64_t dim, int64_t seq_size) {
}

// TODO: Conditionally enable truncation based on user setting
at::Tensor scalar_to_tensor(const at::Scalar& s, const at::Device device = at::kCPU) {
at::Tensor scalar_to_tensor(const at::Scalar& s, const at::Device device) {
// This function is basically same with the one in
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/ScalarOps.h, what different here is that Int and Float
// won't be upgraded to kDouble or kLong since we don't support these 2 types in conversion
Expand Down
9 changes: 9 additions & 0 deletions core/conversion/evaluators/eval_util.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
#pragma once

#include "core/conversion/evaluators/evaluators.h"
#include "torch/csrc/jit/ir/ir.h"

namespace torch_tensorrt {
namespace core {
namespace conversion {
namespace evaluators {

nvinfer1::ITensor* index_layer(
ConversionCtx* ctx,
const torch::jit::Node* n,
nvinfer1::ITensor* input_tensor,
int64_t index);

c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args);

c10::optional<torch::jit::IValue> toIValue(const torch::jit::Value* v);
at::Tensor createTensorFromList(
const torch::jit::IValue& data,
Expand Down
7 changes: 5 additions & 2 deletions core/conversion/evaluators/evaluators.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include "torch/csrc/jit/ir/ir.h"

#include "core/conversion/conversionctx/ConversionCtx.h"
#include "core/conversion/converters/converter_util.h"
#include "core/conversion/tensorcontainer/TensorContainer.h"
#include "core/conversion/var/Var.h"

Expand Down Expand Up @@ -33,7 +35,8 @@ inline bool constTypesOnly(kwargs& args) {
// to use the node itself to pull out arguments.
// This means that you should iterate over node inputs vs. the args
// when writing evaluators
typedef std::function<c10::optional<torch::jit::IValue>(const torch::jit::Node*, kwargs&)> NodeEvaluator;
typedef std::function<c10::optional<torch::jit::IValue>(ConversionCtx*, const torch::jit::Node*, kwargs&)>
NodeEvaluator;

struct EvalOptions {
std::set<c10::TypePtr> blacklisted_output_types;
Expand Down Expand Up @@ -72,7 +75,7 @@ struct EvalRegistration {
: kind(_kind), evaluator(_evaluator), options(_options){};
};

c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, kwargs& args);
c10::optional<torch::jit::IValue> EvalNode(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args);
bool shouldEvalAtConversionTime(const torch::jit::Node* n);
std::vector<std::string> getEvaluatorList();
void register_node_evaluator(torch::jit::NodeKind node_kind, NodeEvaluator evaluator);
Expand Down
59 changes: 38 additions & 21 deletions core/conversion/evaluators/prim.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
#include <limits>

#include "torch/csrc/jit/ir/ir.h"
//#include "torch/csrc/jit/ir/constants.h"
#include "ATen/core/List.h"
#include "ATen/core/functional.h"
#include "ATen/core/ivalue.h"
#include "ATen/core/stack.h"
#include "c10/util/intrusive_ptr.h"
#include "torch/csrc/jit/ir/ir.h"
#include "torch/torch.h"

#include "core/conversion/evaluators/eval_macros.h"
Expand All @@ -24,28 +23,28 @@ auto prim_registrations =
RegisterNodeEvaluators()
.evaluator(
{torch::jit::prim::Constant,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
if (n->output()->type()->kind() == at::FunctionType::Kind) {
return {};
}
return evaluators::toIValue(n->output());
}})
.evaluator(
{torch::jit::prim::NumToTensor,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
return evaluators::scalar_to_tensor(args.at(n->input(0)).IValue()->toScalar());
}})
.evaluator(
{torch::jit::prim::ListUnpack,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
// Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
const torch::jit::IValue* outputs = args.at(n->input()).IValue();
auto outputVec = outputs->toList().vec();
return std::move(c10::ivalue::Tuple::create(outputVec));
}})
.evaluator(
{torch::jit::prim::ListConstruct,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
const auto num_inputs = n->inputs().size();
if (constTypesOnly(args)) {
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
Expand Down Expand Up @@ -89,9 +88,8 @@ auto prim_registrations =
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
}
} else {
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
c10::TypePtr elementType = lt->getElementType();
auto list = c10::impl::GenericList(elementType);
// List would be of IValues (with ITensors embedded in them)
auto list = c10::impl::GenericList(c10::AnyType::get());
list.reserve(num_inputs);
for (auto in : n->inputs()) {
if (args.at(in).isITensor()) {
Expand All @@ -103,8 +101,27 @@ auto prim_registrations =
if (args.at(in).IValue()->isNone()) {
auto ival = torch::jit::IValue();
list.emplace_back(std::move(ival));
} else if (args.at(in).IValue()->isInt()) {
auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const(
ctx, torch::tensor({args.at(in).unwrapToInt()}).to(torch::kI32));
auto tensor_holder = TensorContainer();
tensor_holder.hold_tensor(itensor);
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
list.emplace_back(std::move(ival));
} else if (args.at(in).IValue()->isDouble()) {
auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const(
ctx, torch::tensor({args.at(in).unwrapToDouble()}).to(torch::kFloat));
auto tensor_holder = TensorContainer();
tensor_holder.hold_tensor(itensor);
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
list.emplace_back(std::move(ival));
} else {
list.emplace_back(std::move(args.at(in).unwrapToTensor()));
auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const(
ctx, std::move(args.at(in).unwrapToTensor()));
auto tensor_holder = TensorContainer();
tensor_holder.hold_tensor(itensor);
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
list.emplace_back(std::move(ival));
}
}
}
Expand All @@ -113,7 +130,7 @@ auto prim_registrations =
}})
.evaluator(
{c10::Symbol::fromQualString("prim::dtype"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto input = args.at(n->input(0));
if (input.isITensor()) {
auto trt_dtype = input.ITensor()->getType();
Expand All @@ -136,7 +153,7 @@ auto prim_registrations =
})})
.evaluator(
{c10::Symbol::fromQualString("prim::min"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
if (n->inputs().size() == 1) {
auto a = args.at(n->input(0)).unwrapToIntList();
int64_t min = std::numeric_limits<int64_t>::max();
Expand Down Expand Up @@ -198,7 +215,7 @@ auto prim_registrations =
})})
.evaluator(
{c10::Symbol::fromQualString("prim::max"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
if (n->inputs().size() == 1) {
auto a = args.at(n->input(0)).unwrapToIntList();
int64_t max = std::numeric_limits<int64_t>::min();
Expand Down Expand Up @@ -260,7 +277,7 @@ auto prim_registrations =
})})
.evaluator(
{c10::Symbol::fromQualString("prim::shape"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
LOG_WARNING("There may be undefined behavior using dynamic shape and prim::shape");
auto tensor_var = args.at(n->input(0));
if (tensor_var.isITensor()) {
Expand All @@ -274,7 +291,7 @@ auto prim_registrations =
EvalOptions().validSchemas({"prim::shape(Tensor a) -> (int[])"})})
.evaluator(
{torch::jit::prim::TupleConstruct,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
c10::IValue tuple = c10::ivalue::Tuple::create();
std::vector<c10::IValue> elems;
for (auto in : n->inputs()) {
Expand All @@ -292,7 +309,7 @@ auto prim_registrations =
}})
.evaluator(
{torch::jit::prim::TupleIndex,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
// Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
auto tuple = args.at(n->input(0)).IValue()->toTuple();
int64_t idx = args.at(n->input(1)).IValue()->toInt();
Expand All @@ -302,24 +319,24 @@ auto prim_registrations =
EvalOptions().validSchemas({"prim::TupleIndex(Any tup, int i) -> (Any)"})})
.evaluator(
{torch::jit::prim::TupleUnpack,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
// Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
auto output = args.at(n->input()).IValue()->toTuple();
return c10::optional<torch::jit::IValue>(std::move(output));
}})
.evaluator(
{c10::Symbol::fromQualString("prim::unchecked_cast"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
return *(args.at(n->input(0)).IValue());
}})
.evaluator(
{c10::Symbol::fromQualString("prim::Uninitialized"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
return c10::IValue::uninitialized();
}})
.evaluator(
{c10::Symbol::fromQualString("prim::RaiseException"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto exception = args.at(n->input(0)).IValue();
TORCHTRT_THROW_ERROR("Error from TorchScript: " << *exception);
return {};
Expand All @@ -328,4 +345,4 @@ auto prim_registrations =
} // namespace evaluators
} // namespace conversion
} // namespace core
} // namespace torch_tensorrt
} // namespace torch_tensorrt
25 changes: 25 additions & 0 deletions core/conversion/var/Var.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,31 @@ bool Var::isITensor() const {
}
}

bool Var::isITensorList() {
// Unpack the Var as a List and check if each entry is a custom class since
// ITensors are stored in CustomClassHolder
auto ival_list = ptr_.ivalue->toList();
for (int i = 0; i < ival_list.size(); i++) {
if (!ival_list.get(i).isCustomClass()) {
return false;
}
}
return true;
}

std::vector<nvinfer1::ITensor*> Var::unwrapToITensorList() {
TORCHTRT_CHECK(
isIValue(), "Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name());
TORCHTRT_CHECK(isITensorList(), "Expected IValue to be an ITensorList");
auto ivalue_list = ptr_.ivalue->toList();
std::vector<nvinfer1::ITensor*> outputs;
for (int i = 0; i < ivalue_list.size(); i++) {
auto element = ivalue_list.get(i).toCustomClass<TensorContainer>()->tensor();
outputs.push_back(std::move(element));
}
return outputs;
}

bool Var::isIValue() const {
if (type_ == Type::kIValue) {
return true;
Expand Down
Loading