Skip to content

Commit

Permalink
fix(eval): Rollback 1.11a0 change + namespace issues
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Nov 3, 2021
1 parent 540e135 commit ba743f5
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion core/conversion/evaluators/eval_macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
auto b = args.at(n->input(1)).unwrapToString(); \
return operation; \
} else { \
TRTORCH_THROW_ERROR( \
TORCHTRT_THROW_ERROR( \
"Unimplemented data type for " \
<< node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \
return {}; \
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/evaluators/eval_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ at::Tensor createTensorFromList(
/// Gets shape of tensor to be created
auto sizes = compute_sizes(data);
checkListInputType(elem_type, sizes.size() == 1 && sizes[0] == 0);
at::ScalarType initial_scalar_type = c10::scalarTypeFromJitType(*elem_type);
at::ScalarType initial_scalar_type = c10::scalarTypeFromJitType(elem_type);
if (initial_scalar_type == at::ScalarType::Double) {
initial_scalar_type = at::typeMetaToScalarType(c10::get_default_dtype());
}
Expand Down
16 changes: 8 additions & 8 deletions tests/core/conversion/evaluators/test_aten_evaluators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,8 +524,8 @@ TEST(Evaluators, EqStrResultIsTrueEvaluatesCorrectly) {
auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}
Expand All @@ -541,8 +541,8 @@ TEST(Evaluators, EqStrResultIsFalseEvaluatesCorrectly) {
auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}
Expand All @@ -558,8 +558,8 @@ TEST(Evaluators, AndBoolResultIsTrueEvaluatesCorrectly) {
auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}
Expand All @@ -575,8 +575,8 @@ TEST(Evaluators, AndBoolResultIsFalseEvaluatesCorrectly) {
auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

0 comments on commit ba743f5

Please sign in to comment.