Skip to content

Commit

Permalink
[NNVM][TOPI] Add FTVMCompute for matmul (apache#1239)
Browse files Browse the repository at this point in the history
  • Loading branch information
nhynes authored and sergei-mironov committed Aug 8, 2018
1 parent 6a274ef commit ae580ac
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 4 deletions.
3 changes: 3 additions & 0 deletions nnvm/python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def schedule_dense(_, outs, target):

reg.register_pattern("dense", OpPattern.OUT_ELEMWISE_FUSABLE)

#matmul
reg.register_pattern("matmul", OpPattern.OUT_ELEMWISE_FUSABLE)
reg.register_schedule("matmul", _fschedule_injective)

# conv2d
@reg.register_compute("conv2d")
Expand Down
13 changes: 13 additions & 0 deletions nnvm/src/top/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,20 @@
* \file matrix_op.cc
* \brief Matrix operators
*/
#include <topi/nn.h>
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/top/tensor.h>
#include "../op_common.h"
#include "../elemwise_op_common.h"

namespace nnvm {
namespace top {

using namespace nnvm::compiler;

DMLC_REGISTER_PARAMETER(MatMulParam);

inline bool DotShape(const nnvm::NodeAttrs& attrs,
Expand Down Expand Up @@ -93,6 +97,15 @@ NNVM_REGISTER_OP(matmul)
.set_attr<FInferShape>("FInferShape", DotShape)
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", DotCorrectLayout)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const MatMulParam& param = nnvm::get<MatMulParam>(attrs.parsed);
return Array<Tensor>{
topi::matmul(inputs[0], inputs[1], param.transpose_a, param.transpose_b)
};
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
Expand Down
6 changes: 3 additions & 3 deletions topi/include/topi/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,14 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the matmult operation
* \return A Tensor whose op member is the matmul operation
*/
inline tvm::Tensor matmult(const tvm::Tensor& A,
inline tvm::Tensor matmul(const tvm::Tensor& A,
const tvm::Tensor& B,
bool trans_a = false,
bool trans_b = false,
std::string name = "tensor",
std::string tag = kMatMult) {
std::string tag = kMatMul) {
tvm::Array<tvm::Expr> output_shape{A->shape[trans_a ? 1 : 0],
B->shape[trans_b ? 0 : 1]};
auto k = tvm::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k");
Expand Down
2 changes: 1 addition & 1 deletion topi/include/topi/tags.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ constexpr auto kInjective = "injective";
constexpr auto kCommReduce = "comm_reduce";
constexpr auto kCommReduceIdx = "comm_reduce_idx";
constexpr auto kBroadcast = "broadcast";
constexpr auto kMatMult = "matmult";
constexpr auto kMatMul = "matmul";
constexpr auto kConv2dNCHW = "conv2d_nchw";
constexpr auto kConv2dHWCN = "conv2d_hwcn";
constexpr auto kDepthwiseConv2dNCHW = "depthwise_conv2d_nchw";
Expand Down

0 comments on commit ae580ac

Please sign in to comment.