Skip to content

Commit

Permalink
[Relay] enable blocking format in x86 conv2d and fold scale axis (apa…
Browse files Browse the repository at this point in the history
  • Loading branch information
Menooker authored and Trevor Morris committed Jun 18, 2020
1 parent 2770f22 commit d563bdb
Show file tree
Hide file tree
Showing 5 changed files with 507 additions and 247 deletions.
10 changes: 10 additions & 0 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@
# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
import logging

import re
import topi
from tvm.te import SpecializedCondition
from .generic import *
from .. import op as _op

logger = logging.getLogger('strategy')

_NCHWc_matcher = re.compile("^NCHW[0-9]+c$")
_OIHWio_matcher = re.compile("^OIHW[0-9]+i[0-9]+o$")

@schedule_injective.register("cpu")
def schedule_injective_cpu(attrs, outs, target):
"""schedule injective ops for x86"""
Expand Down Expand Up @@ -96,6 +100,9 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
wrap_compute_conv2d(topi.x86.conv2d_nchw),
wrap_topi_schedule(topi.x86.schedule_conv2d_nchw),
name="conv2d_nchw.x86")
elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc
assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio
return conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target)
elif layout == "NHWC":
assert kernel_layout == "HWIO"
logger.warning("For x86 target, NCHW layout is recommended for conv2d.")
Expand Down Expand Up @@ -128,6 +135,9 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.generic")
elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc
assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio
return depthwise_conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target)
elif layout == "NHWC":
assert kernel_layout == "HWOI"
logger.warning("depthwise_conv2d NHWC layout is not optimized for x86.")
Expand Down
2 changes: 2 additions & 0 deletions src/relay/op/tensor/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
namespace tvm {
namespace relay {

extern Expr MakeReshape(Expr data, Array<Integer> newshape);

template <typename AttrType>
bool ConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
Expand Down
151 changes: 118 additions & 33 deletions src/relay/transforms/fold_scale_axis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/relay/transform.h>
#include <tvm/tir/data_layout.h>

#include "../op/tensor/transform.h"
#include "pass_util.h"
#include "pattern_util.h"

Expand All @@ -39,6 +40,7 @@ namespace relay {
*
* Use namespace to reduce potential naming conflict.
*/

namespace fold_scale_axis {

using runtime::TypedPackedFunc;
Expand Down Expand Up @@ -305,6 +307,41 @@ class ForwardPrep : private ExprVisitor {
}
};

static bool IsIntInArray(const Array<Integer>& axis, int v) {
for (size_t i = 0; i < axis.size(); i++) {
if (axis[i] == v) return true;
}
return false;
}

static Expr ReshapeToMatchAxis(Expr scale, const Array<PrimExpr>& shape,
const Array<Integer>& axis) {
Array<Integer> arr;
for (size_t i = 0; i < shape.size(); i++) {
if (IsIntInArray(axis, i)) {
auto node = shape[i].as<IntImmNode>();
if (!node) {
// if the shape is not a constant, use normal transform
return Expr();
}
arr.push_back(node->value);
} else {
arr.push_back(1);
}
}
return MakeReshape(scale, std::move(arr));
}

// if only one axis, use expand dim. Else, use reshape
static Expr ReshapeOrExpandToMatchAxis(Expr scale, const Array<PrimExpr>& shape,
const Array<Integer>& axis) {
if (axis.size() > 1) {
return ReshapeToMatchAxis(scale, shape, axis);
} else {
return ExpandBiasToMatchAxis(scale, shape.size(), axis);
}
}

//----------------------------------------------
// Per operator defs for FScaleAxisForward
//----------------------------------------------
Expand Down Expand Up @@ -365,15 +402,21 @@ Expr AddSubForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
if (slhs != nullptr) {
CHECK(srhs == nullptr);
CHECK(MatchBroadcastToLeftAxes(tlhs, trhs, slhs->axes));
Expr scale = ExpandBiasToMatchAxis(slhs->scale, tlhs->shape.size(), slhs->axes);
Expr scale = ReshapeOrExpandToMatchAxis(slhs->scale, tlhs->shape, slhs->axes);
if (!scale.defined()) {
return Expr();
}
Expr rhs = Divide(new_args[1], scale);
rnode->value = Call(ref_call->op, {slhs->value, rhs}, ref_call->attrs, ref_call->type_args);
rnode->scale = slhs->scale;
rnode->axes = slhs->axes;
} else {
CHECK(srhs != nullptr);
CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes));
Expr scale = ExpandBiasToMatchAxis(srhs->scale, trhs->shape.size(), srhs->axes);
Expr scale = ReshapeOrExpandToMatchAxis(srhs->scale, trhs->shape, srhs->axes);
if (!scale.defined()) {
return Expr();
}
Expr lhs = Divide(new_args[0], scale);
rnode->value = Call(ref_call->op, {lhs, srhs->value}, ref_call->attrs, ref_call->type_args);
rnode->scale = srhs->scale;
Expand Down Expand Up @@ -445,7 +488,6 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {

CHECK_GE(c_big_axis, 0);
Message none = NullValue<Message>();
AxesSet data_axes = NullValue<AxesSet>();
// For now, we only support simple pattern (no folded weight/data)
// More general layout can be supported under the current framework.
// By using a unified layout transformation.
Expand All @@ -454,12 +496,17 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
// only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
if (kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 && c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) {
data_axes = {c_big_axis};
}
if (data_axes.defined()) {
return {Message(data_axes, false), none};
if (param->groups == 1 || is_depthwise_conv2d) {
auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout
(ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) { // blocked layout
Array<Integer> arr{c_big_axis};
if (c_small_axis >= 0) {
arr.push_back(c_small_axis);
}
return {Message(arr, false), none};
}
}
return {none, none};
}
Expand All @@ -478,12 +525,14 @@ Expr Conv2DForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
Layout kernel_layout(param->kernel_layout);
int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C'));
CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data)
// TODO(tvm-team) support general data layout
CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1);
CHECK(sdata->axes.size() == 1 && c_big_axis == sdata->axes[0]->value);
int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
int big_ic_axis = kernel_layout.IndexOf(LayoutAxis::Get('I'));
int small_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
int small_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
int big_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('I'));
int big_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));

bool is_simple = (small_ko_axis < 0 && small_ki_axis < 0 && big_ki_axis >= 0);
bool is_blocking = (small_ko_axis >= 0 && small_ki_axis >= 0 && big_ki_axis >= 0);
CHECK(is_simple || is_blocking);

// Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, kernel_layout);
Expand All @@ -493,11 +542,26 @@ Expr Conv2DForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,

// match the ic_axis
if (is_depthwise_conv2d) {
Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_oc_axis});
weight = Multiply(weight, scale);
if (is_simple) {
Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ko_axis});
weight = Multiply(weight, scale);
} else {
weight = Multiply(weight,
ReshapeToMatchAxis(sdata->scale, weight->type_as<TensorTypeNode>()->shape,
{big_ko_axis, small_ko_axis}));
if (!weight.defined()) return Expr();
}

} else {
Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ic_axis});
weight = Multiply(weight, scale);
if (is_simple) {
Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ki_axis});
weight = Multiply(weight, scale);
} else {
weight = Multiply(weight,
ReshapeToMatchAxis(sdata->scale, weight->type_as<TensorTypeNode>()->shape,
{big_ki_axis, small_ki_axis}));
if (!weight.defined()) return Expr();
}
}
// return transformed conv2d
return Call(ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args);
Expand Down Expand Up @@ -752,14 +816,20 @@ Expr AddSubBackwardTransform(const Call& call, const Message& message, const Exp
CHECK(equal(message->axes, lhs_message->axes));
Expr lhs = transformer->Transform(call->args[0], message, scale);
Expr rhs = transformer->Transform(call->args[1], NullValue<Message>(), NullValue<Expr>());
Expr rhs_scale = ExpandBiasToMatchAxis(scale, tlhs->shape.size(), message->axes);
Expr rhs_scale = ReshapeOrExpandToMatchAxis(scale, tlhs->shape, message->axes);
if (!rhs_scale.defined()) {
return transformer->NormalCallTransform(call.operator->());
}
rhs = Multiply(rhs, rhs_scale);
return Call(call->op, {lhs, rhs}, call->attrs, call->type_args);
} else if (rhs_message.defined()) {
CHECK(equal(message->axes, rhs_message->axes));
Expr lhs = transformer->Transform(call->args[0], NullValue<Message>(), NullValue<Expr>());
Expr rhs = transformer->Transform(call->args[1], message, scale);
Expr lhs_scale = ExpandBiasToMatchAxis(scale, trhs->shape.size(), message->axes);
Expr lhs_scale = ReshapeOrExpandToMatchAxis(scale, trhs->shape, message->axes);
if (!lhs_scale.defined()) {
return transformer->NormalCallTransform(call.operator->());
}
lhs = Multiply(lhs, lhs_scale);
return Call(call->op, {lhs, rhs}, call->attrs, call->type_args);
} else {
Expand Down Expand Up @@ -829,13 +899,19 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
// only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
if (kernel_layout.IndexOf(LayoutAxis::Get('o')) < 0 &&
kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 && c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) {
return Message({c_big_axis}, false);
} else {
return NullValue<Message>();
if (param->groups == 1 || is_depthwise_conv2d) {
auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout
(ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) { // blocked layout
Array<Integer> arr{c_big_axis};
if (c_small_axis >= 0) {
arr.push_back(c_small_axis);
}
return Message(arr, false);
}
}
return NullValue<Message>();
}

// Conv2D consumes the scale axis during transformation.
Expand All @@ -852,19 +928,28 @@ Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Exp
CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data)
// TODO(tvm-team) support general data layout
CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('o')), -1);
CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1);
CHECK(message->axes.size() == 1 && c_big_axis == message->axes[0]->value);

int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
int small_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
int small_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
int big_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('I'));
int big_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
// Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
CHECK(param->groups == 1 || is_depthwise_conv2d);
bool is_simple = (small_ko_axis < 0 && small_ki_axis < 0 && big_ki_axis >= 0);
bool is_blocking = (small_ko_axis >= 0 && small_ki_axis >= 0 && big_ki_axis >= 0);
CHECK(is_simple || is_blocking);

Expr data = transformer->Transform(call->args[0], NullValue<Message>(), NullValue<Expr>());
Expr weight = transformer->Transform(call->args[1], NullValue<Message>(), NullValue<Expr>());
// scale on input for deptwise.
Expr wscale = ExpandBiasToMatchAxis(scale, kernel_layout.ndim(), {big_oc_axis});
Expr wscale;
if (is_simple) {
wscale = ExpandBiasToMatchAxis(scale, kernel_layout.ndim(), {big_ko_axis});
} else {
wscale = ReshapeToMatchAxis(scale, weight->type_as<TensorTypeNode>()->shape,
{big_ko_axis, small_ko_axis});
if (!wscale.defined()) return transformer->NormalCallTransform(call.operator->());
}
weight = Multiply(weight, wscale);
return Call(call->op, {data, weight}, call->attrs, call->type_args);
}
Expand Down
Loading

0 comments on commit d563bdb

Please sign in to comment.