Skip to content

Commit

Permalink
feat(aten::zeros): Implement aten::zeros evaluator
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Jun 1, 2020
1 parent 60df888 commit 670817c
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
1 change: 1 addition & 0 deletions core/conversion/evaluators/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ cc_library(
srcs = [
"NodeEvaluatorRegistry.cpp",
"prim.cpp",
"aten.cpp"
],
deps = [
"//core/util:prelude",
Expand Down
36 changes: 36 additions & 0 deletions core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include "torch/csrc/jit/ir/ir.h"
#include "torch/csrc/jit/ir/constants.h"
#include "ATen/core/functional.h"
#include "ATen/core/ivalue.h"
#include "ATen/core/List.h"
#include "ATen/core/stack.h"
#include "c10/util/intrusive_ptr.h"
#include "torch/torch.h"

#include "core/conversion/evaluators/evaluators.h"

namespace trtorch {
namespace core {
namespace conversion {
namespace evaluators {
namespace {

auto aten_registrations = RegisterNodeEvaluators()
.evaluator({
c10::Symbol::fromQualString("aten::zeros"),
// aten::zeros(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto options = torch::TensorOptions()
.dtype(c10::ScalarType(args.at(&(n->output()[1])).unwrapToInt()))
.layout(torch::kStrided)
.device(torch::kCUDA);

auto out_tensor = torch::zeros(args.at(&(n->output()[0])).unwrapToIntList().vec(), options);
return out_tensor;
}
});
}
} // namespace evaluators
} // namespace conversion
} // namespace core
} // namespace trtorch

0 comments on commit 670817c

Please sign in to comment.