diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py index 2432cc84fc2c3..78253bc5bc6bc 100644 --- a/nnvm/python/nnvm/top/nn.py +++ b/nnvm/python/nnvm/top/nn.py @@ -73,6 +73,9 @@ def schedule_dense(_, outs, target): reg.register_pattern("dense", OpPattern.OUT_ELEMWISE_FUSABLE) +#matmul +reg.register_pattern("matmul", OpPattern.OUT_ELEMWISE_FUSABLE) +reg.register_schedule("matmul", _fschedule_injective) # conv2d @reg.register_compute("conv2d") diff --git a/nnvm/src/top/tensor/matrix_op.cc b/nnvm/src/top/tensor/matrix_op.cc index d28097b1090d5..c881e683a6c56 100644 --- a/nnvm/src/top/tensor/matrix_op.cc +++ b/nnvm/src/top/tensor/matrix_op.cc @@ -3,9 +3,11 @@ * \file matrix_op.cc * \brief Matrix operators */ +#include #include #include #include +#include #include #include "../op_common.h" #include "../elemwise_op_common.h" @@ -13,6 +15,8 @@ namespace nnvm { namespace top { +using namespace nnvm::compiler; + DMLC_REGISTER_PARAMETER(MatMulParam); inline bool DotShape(const nnvm::NodeAttrs& attrs, @@ -93,6 +97,15 @@ NNVM_REGISTER_OP(matmul) .set_attr("FInferShape", DotShape) .set_attr("FInferType", ElemwiseType<2, 1>) .set_attr("FCorrectLayout", DotCorrectLayout) +.set_attr( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info) { + const MatMulParam& param = nnvm::get(attrs.parsed); + return Array{ + topi::matmul(inputs[0], inputs[1], param.transpose_a, param.transpose_b) + }; + }) .set_attr( "FGradient", [](const NodePtr& n, const std::vector& ograds) { diff --git a/topi/include/topi/nn.h b/topi/include/topi/nn.h index 2459eb515707c..ee3101c4cc183 100644 --- a/topi/include/topi/nn.h +++ b/topi/include/topi/nn.h @@ -214,14 +214,14 @@ inline tvm::Tensor pad(const tvm::Tensor& t, * \param name The name of the operation * \param tag The tag to mark the operation * - * \return A Tensor whose op member is the matmult operation + * \return A Tensor whose op member is the matmul operation */ -inline tvm::Tensor matmult(const tvm::Tensor& A, +inline tvm::Tensor matmul(const tvm::Tensor& A, const tvm::Tensor& B, bool trans_a = false, bool trans_b = false, std::string name = "tensor", - std::string tag = kMatMult) { + std::string tag = kMatMul) { tvm::Array output_shape{A->shape[trans_a ? 1 : 0], B->shape[trans_b ? 0 : 1]}; auto k = tvm::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k"); diff --git a/topi/include/topi/tags.h b/topi/include/topi/tags.h index 8ba9955be0504..8c92644d96d3c 100644 --- a/topi/include/topi/tags.h +++ b/topi/include/topi/tags.h @@ -15,7 +15,7 @@ constexpr auto kInjective = "injective"; constexpr auto kCommReduce = "comm_reduce"; constexpr auto kCommReduceIdx = "comm_reduce_idx"; constexpr auto kBroadcast = "broadcast"; -constexpr auto kMatMult = "matmult"; +constexpr auto kMatMul = "matmul"; constexpr auto kConv2dNCHW = "conv2d_nchw"; constexpr auto kConv2dHWCN = "conv2d_hwcn"; constexpr auto kDepthwiseConv2dNCHW = "depthwise_conv2d_nchw";