Skip to content

Commit

Permalink
feat(//core/conversion/evaluators): A whole bunch of new evaluators
Browse files Browse the repository at this point in the history
Adds evaluators for:
- aten::eq
- aten::ne
- aten::lt
- aten::gt
- aten::le
- aten::ge
- aten::add
- aten::sub
- aten::mul
- aten::Bool
- aten::Float
- aten::__not__
- aten::__is__
- aten::__isnot__
- aten::numel
- aten::dim
- aten::div
- aten::floordiv
- aten::floor
- aten::warn
- prim::min
- prim::max
- prim::shape
- prim::unchecked_cast
- prim::Uninitalized
- prim::RaiseException

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Jun 11, 2020
1 parent 6cce381 commit 7466b8a
Show file tree
Hide file tree
Showing 5 changed files with 519 additions and 45 deletions.
3 changes: 2 additions & 1 deletion core/conversion/evaluators/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ cc_library(
srcs = [
"NodeEvaluatorRegistry.cpp",
"prim.cpp",
"aten.cpp"
"aten.cpp",
"eval_macros.h"
],
deps = [
"//core/util:prelude",
Expand Down
4 changes: 4 additions & 0 deletions core/conversion/evaluators/NodeEvaluatorRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class NodeEvaluatorRegistry {
public:
void RegisterEvaluator(torch::jit::NodeKind node_kind, EvalRegistration eval_reg) {
LOG_DEBUG("Registering evaluator for " << node_kind.toQualString());
auto iter = evaluator_lut_.find(node_kind);
if (iter != evaluator_lut_.end()) {
TRTORCH_THROW_ERROR("Attempting to override already registered evaluator " << node_kind.toQualString() << ", merge implementations instead");
}
evaluator_lut_[node_kind] = std::move(eval_reg);
}

Expand Down
Loading

0 comments on commit 7466b8a

Please sign in to comment.