Skip to content

Commit

Permalink
fix(aten::zeros): verify zeros produces a tensor correctly
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Feb 20, 2021
1 parent 1c9dfe2 commit 00d2d0c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
7 changes: 6 additions & 1 deletion core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,16 @@ auto aten_registrations TRTORCH_UNUSED =
// Device? device=None, bool? pin_memory=None) -> (Tensor)
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto options = torch::TensorOptions()
.dtype(c10::ScalarType(args.at(n->output(1)).unwrapToInt()))
.layout(torch::kStrided)
.device(torch::kCUDA);

if (!args.at(n->input(1)).isNone() && !args.at(n->input(1)).IValue()->isNone()) {
options = options.dtype(c10::ScalarType(args.at(n->input(1)).unwrapToInt()));
}

auto out_tensor = torch::zeros(args.at(n->input(0)).unwrapToIntList().vec(), options);
std::cout << out_tensor << std::endl;
std::cout << out_tensor.sizes() << std::endl;
return out_tensor;
}})
.evaluator({c10::Symbol::fromQualString("aten::slice"),
Expand Down
19 changes: 19 additions & 0 deletions tests/core/conversion/evaluators/test_aten_evaluators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,23 @@ TEST(Evaluators, DivFloatEvaluatesCorrectly) {
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

TEST(Evaluators, ZerosEvaluatesCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : None = prim::Constant() # :0:0
%3 : int[] = aten::size(%x.1) # <string>:7:9
%z.1 : Tensor = aten::zeros(%3, %2, %2, %2, %2) # experiments/test_zeros.py:8:12
return (%z.1))IR";

auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {in});
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {in});

ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
}

0 comments on commit 00d2d0c

Please sign in to comment.