Skip to content

Commit

Permalink
feat(//core/conversion/evaluators): Allow ITensors to be wrapped in
Browse files Browse the repository at this point in the history
IValues

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed May 26, 2020
1 parent 8c26a1b commit 619e345
Show file tree
Hide file tree
Showing 12 changed files with 159 additions and 48 deletions.
2 changes: 1 addition & 1 deletion core/conversion/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ cc_library(
],
deps = [
"@tensorrt//:nvinfer",
"//core/conversion/arg",
"//core/conversion/var",
"//core/conversion/conversionctx",
"//core/conversion/converters",
"//core/conversion/evaluators",
Expand Down
8 changes: 5 additions & 3 deletions core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <sstream>

#include "core/util/prelude.h"
#include "core/conversion/arg/Arg.h"
#include "core/conversion/var/Var.h"
#include "core/conversion/conversion.h"
#include "core/conversion/converters/converters.h"
#include "core/conversion/evaluators/evaluators.h"
Expand Down Expand Up @@ -35,6 +35,8 @@ c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, const torch::
}
if (ctx->evaluated_value_map.find(eval_in) != ctx->evaluated_value_map.end()) {
eval_args[eval_in] = &(ctx->evaluated_value_map[eval_in]);
} else if (ctx->value_tensor_map.find(eval_in) != ctx->value_tensor_map.end()) {
eval_args[eval_in] = ctx->value_tensor_map[eval_in];
} else if (evaluators::shouldEvalAtConversionTime(eval_in->node())) {
auto result = EvaluateNode(ctx, eval_in->node(), level++, limit);
if (result) {
Expand Down Expand Up @@ -82,8 +84,8 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
ctx->AssociateValueAndIValue(input, eval.value());
node_args.push_back(&(ctx->evaluated_value_map[input]));
} else {
LOG_DEBUG(ctx->logger, "Found the value is None");;
node_args.push_back(Arg());
LOG_DEBUG(ctx->logger, "Found the value is None");
node_args.push_back(Var());
}
} else {
// Node input has not been converted yet or is a prim op
Expand Down
1 change: 0 additions & 1 deletion core/conversion/conversionctx/ConversionCtx.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <unordered_map>
#include <memory>

//#include "ATen/ATen.h"
#include "torch/csrc/jit/ir/ir.h"
#include "NvInfer.h"

Expand Down
4 changes: 2 additions & 2 deletions core/conversion/converters/converters.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
#include "ATen/core/function_schema.h"

#include "core/util/prelude.h"
#include "core/conversion/arg/Arg.h"
#include "core/conversion/var/Var.h"
#include "core/conversion/conversionctx/ConversionCtx.h"

namespace trtorch {
namespace core {
namespace conversion {
namespace converters {

typedef std::vector<Arg> args;
typedef std::vector<Var> args;
typedef std::function<bool(ConversionCtx*, const torch::jit::Node*, args&)> OpConverter;
struct ConversionPattern {
std::string signature;
Expand Down
2 changes: 2 additions & 0 deletions core/conversion/evaluators/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ cc_library(
],
deps = [
"//core/util:prelude",
"//core/conversion/var",
"//core/conversion/tensorcontainer",
] + select({
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
"//conditions:default": ["@libtorch//:libtorch"],
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/evaluators/NodeEvaluatorRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ bool shouldEvalAtConversionTime(const torch::jit::Node* n) {
return get_evaluator_registry().EvalAtConversionTime(n);
}

c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, const kwargs& args) {
c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, kwargs& args) {
auto evaluator = get_evaluator_registry().GetEvaluator(n->kind());
return evaluator(n, args);
}
Expand Down
32 changes: 24 additions & 8 deletions core/conversion/evaluators/evaluators.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,45 @@

#include <string>
#include <map>
#include <set>

#include "torch/csrc/jit/ir/ir.h"

#include "core/conversion/tensorcontainer/TensorContainer.h"
#include "core/conversion/var/Var.h"

namespace trtorch {
namespace core {
namespace conversion {
namespace evaluators {

typedef std::map<const torch::jit::Value*, const torch::jit::IValue*> kwargs;

// NOTE: The input args are a dictionary of Value -> IValue this means
// inputs will not be repeated. We did this so writing encoders
// is similar to converters (where a dictionary makes more sense)
// This mean that you should iterate over node inputs vs. the args
typedef std::map<const torch::jit::Value*, Var> kwargs;

inline bool constTypesOnly(kwargs& args) {
std::set<Var::Type> types;
for (auto a : args) {
if (a.second.type() == Var::kITensor) {
return false;
}
}
return true;
}

// NOTE: The input args are a dictionary of Value -> Var this means
// inputs will not be repeated. We did this because while in the case
// of converters we have the function schema to lay out argument order,
// evaluators dont use the schema, they use node kind as key so it easier
// to use the node itself to pull out arguments.
// This means that you should iterate over node inputs vs. the args
// when writing evaluators
typedef std::function<c10::optional<torch::jit::IValue>(const torch::jit::Node*, const kwargs&)> NodeEvaluator;
typedef std::function<c10::optional<torch::jit::IValue>(const torch::jit::Node*, kwargs&)> NodeEvaluator;

struct EvalRegistration {
torch::jit::NodeKind kind;
NodeEvaluator evaluator;
};

c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, const kwargs& args);
c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, kwargs& args);
bool shouldEvalAtConversionTime(const torch::jit::Node* n);
void register_node_evaluator(torch::jit::NodeKind node_kind, NodeEvaluator evaluator);
void register_node_evaluator(EvalRegistration r);
Expand Down
77 changes: 46 additions & 31 deletions core/conversion/evaluators/prim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "ATen/core/ivalue.h"
#include "ATen/core/List.h"
#include "ATen/core/stack.h"
#include "c10/util/intrusive_ptr.h"

#include "core/conversion/evaluators/evaluators.h"

Expand All @@ -16,51 +17,65 @@ namespace {
auto prim_registrations = RegisterNodeEvaluators()
.evaluator({
torch::jit::prim::Constant,
[](const torch::jit::Node* n, const kwargs& args) -> c10::optional<torch::jit::IValue> {
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
if (n->output()->type()->kind() == at::FunctionType::Kind) {
return {};
}
return torch::jit::toIValue(n->output());
}
}).evaluator({
torch::jit::prim::ListConstruct,
[](const torch::jit::Node* n, const kwargs& args) -> c10::optional<torch::jit::IValue> {
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
const auto num_inputs = n->inputs().size();
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
if (torch::jit::IntType::get() == lt->getElementType()) {
c10::List<int64_t> list;
list.reserve(num_inputs);
for (auto in : n->inputs()) {
list.emplace_back(std::move(args.at(in)->to<int64_t>()));
}
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
} else if (torch::jit::FloatType::get() == lt->getElementType()) {
c10::List<double> list;
list.reserve(num_inputs);
for (auto in : n->inputs()) {
list.emplace_back(std::move(args.at(in)->to<double>()));
if (constTypesOnly(args)) {
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
if (torch::jit::IntType::get() == lt->getElementType()) {
c10::List<int64_t> list;
list.reserve(num_inputs);
for (auto in : n->inputs()) {
list.emplace_back(std::move(args.at(in).unwrapToInt()));
}
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
} else if (torch::jit::FloatType::get() == lt->getElementType()) {
c10::List<double> list;
list.reserve(num_inputs);
for (auto in : n->inputs()) {
list.emplace_back(std::move(args.at(in).unwrapToDouble()));
}
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
} else if (lt->getElementType() == torch::jit::BoolType::get()) {
c10::List<bool> list;
list.reserve(num_inputs);
for (auto in : n->inputs()) {
list.emplace_back(std::move(args.at(in).unwrapToBool()));
}
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
} else if (lt->getElementType()->isSubtypeOf(torch::jit::TensorType::get())) {
c10::List<at::Tensor> list;
list.reserve(num_inputs);
for (auto in : n->inputs()) {
if (args.at(in).isIValue()) {
list.emplace_back(std::move(args.at(in).unwrapToTensor()));
}
}
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
} else {
c10::TypePtr elementType = lt->getElementType();
auto list = c10::impl::GenericList(elementType);
list.reserve(num_inputs);
for (auto in : n->inputs()) {
list.emplace_back(std::move(*(args.at(in).IValue())));
}
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
}
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
} else if (lt->getElementType() == torch::jit::BoolType::get()) {
c10::List<bool> list;
list.reserve(num_inputs);
for (auto in : n->inputs()) {
list.emplace_back(std::move(args.at(in)->to<bool>()));
}
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
} else if (lt->getElementType()->isSubtypeOf(torch::jit::TensorType::get())) {
c10::List<at::Tensor> list;
list.reserve(num_inputs);
for (auto in : n->inputs()) {
list.emplace_back(std::move(args.at(in)->toTensor()));
}
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
} else {
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
c10::TypePtr elementType = lt->getElementType();
auto list = c10::impl::GenericList(elementType);
list.reserve(num_inputs);
for (auto in : n->inputs()) {
list.emplace_back(std::move(*(args.at(in))));
auto x = torch::make_custom_class<TensorContainer>(reinterpret_cast<int64_t>(args.at(in).ITensor()));
list.emplace_back(std::move(x));
}
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
}
Expand Down
36 changes: 36 additions & 0 deletions core/conversion/tensorcontainer/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package(default_visibility = ["//visibility:public"])

config_setting(
name = "use_pre_cxx11_abi",
values = {
"define": "abi=pre_cxx11_abi",
}
)

cc_library(
name = "tensorcontainer",
hdrs = [
"TensorContainer.h",
],
srcs = [
"TensorContainer.cpp",
],
deps = [
"@tensorrt//:nvinfer",
"//core/util:prelude",
] + select({
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
"//conditions:default": ["@libtorch//:libtorch"],
}),
alwayslink = True,
)

load("@rules_pkg//:pkg.bzl", "pkg_tar")

pkg_tar(
name = "include",
package_dir = "core/conversion/tensorcontainer/",
srcs = [
"TensorContainer.h",
],
)
16 changes: 16 additions & 0 deletions core/conversion/tensorcontainer/TensorContainer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include "core/conversion/tensorcontainer/TensorContainer.h"

namespace trtorch {
namespace core {
namespace conversion {
namespace {

static auto tensor_container =
torch::class_<TensorContainer>("_eval_ivalue_types", "TensorContainer")
.def(torch::init<int64_t>())
.def("clone", &TensorContainer::clone);

} // namespace
} // conversion
} // core
} // trtorch
25 changes: 25 additions & 0 deletions core/conversion/tensorcontainer/TensorContainer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include "NvInfer.h"
#include "torch/custom_class.h"

namespace trtorch {
namespace core {
namespace conversion {

struct TensorContainer : torch::CustomClassHolder {
int64_t tensor_;
TensorContainer(int64_t init) : tensor_(init) {}

c10::intrusive_ptr<TensorContainer> clone() const {
return c10::make_intrusive<TensorContainer>(tensor_);
}

nvinfer1::ITensor* tensor() {
return reinterpret_cast<nvinfer1::ITensor*>(tensor_);
}
};

} // conversion
} // core
} // trtorch
2 changes: 1 addition & 1 deletion core/execution/register_trt_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ std::vector<at::Tensor> RunCudaEngine(nvinfer1::IExecutionContext* ctx, std::pai
auto dims = core::util::toDimsPad(inputs[i].sizes(), 1);
auto shape = core::util::toVec(dims);
contig_inputs.push_back(inputs[i].view(shape).contiguous());
LOG_DEBUG("In shape: " << shape);
LOG_DEBUG("Input shape: " << dims);
ctx->setBindingDimensions(i, dims);
gpu_handles.push_back(contig_inputs.back().data_ptr());
}
Expand Down

0 comments on commit 619e345

Please sign in to comment.