Skip to content

Commit

Permalink
Merge pull request #892 from NVIDIA/support_aten_extend
Browse files Browse the repository at this point in the history
Support aten extend
  • Loading branch information
peri044 authored Feb 25, 2022
2 parents 0230cc6 + 97eb4eb commit ef62f6b
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 5 deletions.
25 changes: 25 additions & 0 deletions core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,31 @@ auto aten_registrations TORCHTRT_UNUSED =
EvalOptions().validSchemas({
"aten::append.t(t[](a!) self, t(c -> *) el) -> (t[](a!))",
})})
.evaluator({c10::Symbol::fromQualString("aten::extend"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
if (args.at(n->input(0)).IValue()->isList() && args.at(n->input(1)).IValue()->isList()) {
c10::IValue* self_ptr = args.at(n->input(0)).IValueMut();
auto self = self_ptr->to<c10::List<c10::IValue>>();
auto other = args.at(n->input(1)).IValue()->to<c10::List<c10::IValue>>();
const int64_t other_size = other.size();

// Modify value in place
for (int64_t i = 0; i < other_size; i++) {
self.push_back(other.get(i));
}

*self_ptr = c10::IValue(self);
return {};
} else {
TORCHTRT_THROW_ERROR(
"Unimplemented data type for aten::extend.t evaluator: "
<< args.at(n->input(0)).IValue()->type()->str() << ", "
<< args.at(n->input(1)).IValue()->type()->str());
}
},
EvalOptions().validSchemas({
"aten::extend.t(t[](a!) self, t[] other) -> ()",
})})
.evaluator({c10::Symbol::fromQualString("aten::neg"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto el = args.at(n->input(0)).unwrapToInt();
Expand Down
8 changes: 6 additions & 2 deletions core/conversion/var/Var.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Var::Var() {
type_ = Type::kNone;
}

Var::Var(const torch::jit::IValue* p) : type_(Type::kIValue) {
Var::Var(torch::jit::IValue* p) : type_(Type::kIValue) {
ptr_.ivalue = p;
}

Expand Down Expand Up @@ -56,7 +56,7 @@ Var& Var::operator=(const Var& a) {
return (*this);
}

Var& Var::operator=(const torch::jit::IValue* in) {
Var& Var::operator=(torch::jit::IValue* in) {
ptr_.ivalue = in;
type_ = Type::kIValue;
return (*this);
Expand Down Expand Up @@ -116,6 +116,10 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
}

const torch::jit::IValue* Var::IValue() const {
return IValueMut();
}

torch::jit::IValue* Var::IValueMut() const {
TORCHTRT_CHECK(isIValue(), "Requested IValue from Var, however Var type is " << type_name());
if (type_ == Type::kIValue) {
return ptr_.ivalue;
Expand Down
7 changes: 4 additions & 3 deletions core/conversion/var/Var.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@ class Var : torch::CustomClassHolder {
enum Type { kITensor, kIValue, kNone };

Var();
Var(const torch::jit::IValue* p);
Var(torch::jit::IValue* p);
Var(nvinfer1::ITensor* p);
Var(const Var& a);
Var& operator=(const Var& a);
Var& operator=(const torch::jit::IValue* in);
Var& operator=(torch::jit::IValue* in);
Var& operator=(nvinfer1::ITensor* in);
const torch::jit::IValue* IValue() const;
torch::jit::IValue* IValueMut() const;
nvinfer1::ITensor* ITensor() const;

// TODO: Can we consolidate this in a way that prevents requesting invalid
Expand Down Expand Up @@ -63,7 +64,7 @@ class Var : torch::CustomClassHolder {

private:
union VarContainer {
const torch::jit::IValue* ivalue;
torch::jit::IValue* ivalue;
nvinfer1::ITensor* tensor;
void* none;
};
Expand Down
26 changes: 26 additions & 0 deletions tests/core/conversion/evaluators/test_aten_evaluators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,32 @@ TEST(Evaluators, FloorFloatIntEvaluatesCorrectly) {
ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

TEST(Evaluators, ATenExtendEvaluatesCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor):
%2 : int = prim::Constant[value=0]()
%3 : Tensor[] = prim::ListConstruct(%0)
%4 : Tensor[] = prim::ListConstruct(%1)
aten::extend(%3, %4)
%5 : Tensor = aten::cat(%3, %2)
return (%5))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in0 = at::randint(1, 10, {3, 4}, {at::kCUDA});
auto in1 = at::randint(1, 10, {5, 4}, {at::kCUDA});

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in0, in1});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in0, in1});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Evaluators, ATenAppendWithITensorEvaluatesCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor):
Expand Down

0 comments on commit ef62f6b

Please sign in to comment.