From e35b7fc4bcdcfe008c5dfea60c2297b93dbff99e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Thu, 27 Aug 2020 11:32:40 -0700 Subject: [PATCH 1/3] [Relay][Training] Make AutoDiff thread through global function. (#6336) * save * lint * lint * fix warning * fix test * save --- src/printer/doc.cc | 2 +- src/relay/transforms/gradient.cc | 106 ++++++++++++++++++----- tests/python/relay/test_pass_gradient.py | 41 ++++++++- 3 files changed, 124 insertions(+), 25 deletions(-) diff --git a/src/printer/doc.cc b/src/printer/doc.cc index d487e3e7aa3e..ab1eddbe7d1e 100644 --- a/src/printer/doc.cc +++ b/src/printer/doc.cc @@ -129,7 +129,7 @@ Doc Doc::Indent(int indent, Doc doc) { } Doc Doc::StrLiteral(const std::string& value, std::string quote) { - // TODO(M.K.): add escape. + // TODO(@M.K.): add escape. Doc doc; return doc << quote << value << quote; } diff --git a/src/relay/transforms/gradient.cc b/src/relay/transforms/gradient.cc index 7894c34de55d..9c472542cc91 100644 --- a/src/relay/transforms/gradient.cc +++ b/src/relay/transforms/gradient.cc @@ -72,7 +72,7 @@ Type WithGradientType(const Type&); Expr FirstOrderGradient(const Expr& e, const Optional& mod); Type WithGradientType(const Type& t) { - // TODO(M.K.): stricter checking + // TODO(@M.K.): stricter checking auto ty = t.as(); CHECK(ty) << "input should be a function"; return FuncType(ty->arg_types, TupleType({ty->ret_type, TupleType(ty->arg_types)}), {}, {}); @@ -85,7 +85,7 @@ Expr DeGlobal(const Optional& mod, const Expr& e) { if (mod.defined() && x) { BaseFunc base_func = mod.value()->Lookup(GetRef(x)); if (auto* n = base_func.as()) { - return n->body; + return GetRef(n); } else { return e; } @@ -338,11 +338,22 @@ Expr FirstOrderGradient(const Expr& re, const Optional& mod) { TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient").set_body_typed(FirstOrderGradient); +Type bpt = RelayRefType(FuncType({}, TupleType(Array()), {}, {})); + struct ReverseADType : TypeMutator { Type VisitType_(const TensorTypeNode* ttn) final { Type t = GetRef(ttn); return TupleType({t, RelayRefType(t)}); } + + Type VisitType_(const FuncTypeNode* ftn) final { + std::vector arg_types; + for (const auto& t : ftn->arg_types) { + arg_types.push_back(VisitType(t)); + } + arg_types.push_back(bpt); + return FuncType(arg_types, ftn->ret_type, ftn->type_params, ftn->type_constraints); + } }; Type ReverseType(const Type& t) { return ReverseADType()(t); } @@ -438,12 +449,18 @@ Expr BPEmpty() { struct ReverseAD : ExprMutator { using ADVarMap = std::unordered_map; - + using ADGlobalVarMap = std::unordered_map; + Optional mod; + // TODO(@M.K.) refactor AD to always use mod. Var bp; std::shared_ptr ad_vars; + std::shared_ptr ad_gvars; const OpAttrMap rev_map = Op::GetAttrMap("FPrimalGradient"); - explicit ReverseAD(const Var& bp, std::shared_ptr ad_vars) : bp(bp), ad_vars(ad_vars) {} + explicit ReverseAD(const Optional& mod, const Var& bp, + const std::shared_ptr& ad_vars, + const std::shared_ptr& ad_gvars) + : mod(mod), bp(bp), ad_vars(ad_vars), ad_gvars(ad_gvars) {} Expr VisitExpr_(const OpNode* op) final { LOG(FATAL) << "op should only be inside call"; @@ -481,9 +498,8 @@ struct ReverseAD : ExprMutator { Expr nbp = Function({}, LetList::With([&](LetList* ll) { // we need a new ReverseAD visitor to avoid clobbering the bp local var auto dup_bp = ll->Push(BPEmpty()); - ReverseAD dup_diff(dup_bp, ad_vars); - auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x))); - + auto dup_ad = + ll->Push(ReverseAD(mod, dup_bp, ad_vars, ad_gvars)(DeDup(x))); TransferGrads(call->checked_type(), ret, dup_ad, ll); ll->Push(Call(RefRead(dup_bp), {})); return Call(bpv, {}); @@ -518,22 +534,29 @@ struct ReverseAD : ExprMutator { orig_var->checked_type_ = call->checked_type(); auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll)); auto bpv = ll->Push(RefRead(bp)); - Expr nbp = Function({}, LetList::With([&](LetList* ll) { - tvm::Array rev = - rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll)); - CHECK(args.size() == rev.size()); - for (size_t i = 0; i < args.size(); ++i) { - UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll); - } - return Call(bpv, {}); - }), - TupleType::Empty(), {}); + Expr nbp_body = LetList::With([&](LetList* ll) { + tvm::Array rev = rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll)); + CHECK(args.size() == rev.size()); + for (size_t i = 0; i < args.size(); ++i) { + UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll); + } + return Call(bpv, {}); + }); + Expr nbp = Function({}, nbp_body, TupleType::Empty(), {}); ll->Push(RefWrite(bp, transform::ToANormalForm(nbp))); // TODO(@M.K.): ToANF should be called on rev. Enhance ToANF for that. return ret; }); + } else if (call->op.as()) { + return ExprMutator::VisitExpr_(call); + } else { + std::vector args; + for (const auto& arg : call->args) { + args.push_back(VisitExpr(arg)); + } + args.push_back(bp); + return Call(VisitExpr(call->op), args); } - return ExprMutator::VisitExpr_(call); } Expr VisitExpr_(const ConstantNode* op) final { @@ -559,6 +582,39 @@ struct ReverseAD : ExprMutator { return ad_vars->at(var_ref); } + Expr VisitExpr_(const GlobalVarNode* op) final { + // todo: concatenating string to add attribute seems like a brittle hack. + // maybe get module indexed by a rose tree of string? + CHECK(mod.defined()); + auto orig_gv = GetRef(op); + if (ad_gvars->count(orig_gv) == 0) { + GlobalVar gv(op->name_hint + "_grad"); + (*ad_gvars)[orig_gv] = gv; + Function orig_f = Downcast(DeDup(mod.value()->Lookup(orig_gv))); + std::vector params; + for (const auto& p : orig_f->params) { + params.push_back(Downcast(VisitExpr(p))); + } + params.push_back(bp); + Expr body = VisitExpr(orig_f->body); + Function f(params, body, VisitType(orig_f->ret_type), orig_f->type_params, orig_f->attrs); + std::cout << "gv " << op->name_hint << ": " << AsText(f, false) << std::endl; + mod.value()->Add(gv, f); + } + return ad_gvars->at(orig_gv); + } + + Expr VisitExpr_(const FunctionNode* op) final { + std::vector params; + for (const auto& var : op->params) { + params.push_back(Downcast(VisitExpr(var))); + } + auto new_bp = Var("bp", bpt); + params.push_back(new_bp); + return Function(params, ReverseAD(mod, new_bp, ad_vars, ad_gvars)(op->body), + VisitType(op->ret_type), op->type_params, op->attrs); + } + Type VisitType(const Type& t) final { return t.defined() ? ReverseType(t) : t; } }; @@ -604,12 +660,16 @@ Expr Gradient(const Expr& re, const Optional& mod) { } CHECK(!MissingGrad(e)) << "input has operators with missing gradients"; Expr body = LetList::With([&](LetList* ll) { - Var bp = ll->Push(BPEmpty()); - Expr rev = ReverseAD(bp, std::make_shared())(e); - std::vector args; + Var bp = ll->Push(BPEmpty(), bpt); + Expr rev = ReverseAD(mod, bp, std::make_shared(), + std::make_shared())(e); + std::vector normal_args, args; for (const auto& p : f->params) { - args.push_back(ll->Push(Pair(p, RefCreate(ZerosLike(p))))); + auto x = ll->Push(Pair(p, RefCreate(ZerosLike(p)))); + normal_args.push_back(x); + args.push_back(x); } + args.push_back(bp); auto c = ll->Push(Call(rev, args)); std::function init_grad; init_grad = [&](const Expr& e, const Type& t) { @@ -626,7 +686,7 @@ Expr Gradient(const Expr& re, const Optional& mod) { init_grad(c, f->body->checked_type()); ll->Push(Call(RefRead(bp), {})); std::vector ret; - for (const auto& a : args) { + for (const auto& a : normal_args) { ret.push_back(RefRead(GetField(a, 1))); } std::function get_final_result; diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index 296d3e5e9354..b239ef4fc4a6 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -21,6 +21,7 @@ import tvm from tvm import te from tvm import relay +from tvm.relay import GlobalVar from tvm.relay.analysis import free_vars, free_type_vars from tvm.relay import create_executor, transform from tvm.relay.transform import gradient @@ -29,7 +30,7 @@ import tvm.relay.op as op -def test_id(): +def test_fo_id(): shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) @@ -44,6 +45,21 @@ def test_id(): tvm.testing.assert_allclose(forward.asnumpy(), x.asnumpy()) tvm.testing.assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy())) +def test_id(): + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + x = relay.var("x", t) + func = relay.Function([x], x) + func = run_infer_type(func) + back_func = run_infer_type(gradient(func)) + assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) + ex = create_executor() + x = rand(dtype, *shape) + forward, (grad,) = ex.evaluate(back_func)(x) + tvm.testing.assert_allclose(forward.asnumpy(), x.asnumpy()) + tvm.testing.assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy())) + def test_relu(): shape = (10, 10) @@ -341,5 +357,28 @@ def test_no_duplication(): counts = count_ops(gr) assert counts['nn.dense'] == 3, "We expect 3 dense (1 forward, two backward)" + +def test_global_function(): + m = tvm.IRModule() + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + x = relay.Var('x', t) + d = GlobalVar('double') + m[d] = relay.Function([x], x + x) + y = relay.Var('y', t) + q = GlobalVar('q') + m[q] = relay.Function([y], d(d(y))) + g = GlobalVar('grad') + m[g] = tvm.relay.transform.gradient(q, m) + back_func = m[g] + assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) + ex = create_executor(mod=m) + x = rand(dtype, *shape) + forward, (grad,) = ex.evaluate(back_func)(x) + tvm.testing.assert_allclose(forward.asnumpy(), 4 * x.asnumpy()) + tvm.testing.assert_allclose(grad.asnumpy(), 4 * np.ones_like(x.asnumpy())) + + if __name__ == "__main__": pytest.main([__file__]) From 30cd2302e4078b3a8787e30d70fd79e5b729ec82 Mon Sep 17 00:00:00 2001 From: mbaret <55580676+mbaret@users.noreply.github.com> Date: Thu, 27 Aug 2020 22:03:45 +0100 Subject: [PATCH 2/3] [BYOC][ETHOSN] Add support for quantized convolution (#6335) * [BYOC][ETHOSN] Add support for quantized convolution This PR adds support for quantized convolution. This includes mapping it via a composite function and all the necessary methods to convert from Relay to the APIs in Support Library. Co-authored-by: Leo Blonk Co-authored-by: Tristan O'Connor * Fix padding change Change-Id: I0794b0ac6190478e2d1b858ad0dd90f37fc0207b * Add docs to Tvm2Npu methods Change-Id: Iab865619b449a3d0dd6bb0dbdcb198acd529fc4e * Remove generate tests Change-Id: I51f90499f7ce82a1ce49f0731d3d50627e1d0225 Co-authored-by: Leo Blonk Co-authored-by: Tristan O'Connor --- python/tvm/relay/op/contrib/ethosn.py | 26 +++ src/relay/backend/contrib/ethosn/codegen.cc | 43 +++- .../backend/contrib/ethosn/codegen_ethosn.h | 1 + .../backend/contrib/ethosn/ethosn_api.cc | 190 ++++++++++++++++ src/relay/backend/contrib/ethosn/ethosn_api.h | 22 ++ .../contrib/test_ethosn/infrastructure.py | 2 + .../python/contrib/test_ethosn/test_conv2d.py | 204 ++++++++++++++++++ 7 files changed, 486 insertions(+), 2 deletions(-) create mode 100644 tests/python/contrib/test_ethosn/test_conv2d.py diff --git a/python/tvm/relay/op/contrib/ethosn.py b/python/tvm/relay/op/contrib/ethosn.py index de70297a7889..a93b0e5fc58c 100644 --- a/python/tvm/relay/op/contrib/ethosn.py +++ b/python/tvm/relay/op/contrib/ethosn.py @@ -18,7 +18,9 @@ """Arm(R) Ethos(TM) -N NPU supported operators.""" from enum import Enum import tvm.ir +from ...dataflow_pattern import wildcard, is_op, is_constant from ... import qnn as _qnn +from .register import register_pattern_table from . import _ethosn as support @@ -40,6 +42,30 @@ def ethosn_available(): return Available.SW_AND_HW if hw else Available.SW_ONLY +@register_pattern_table("ethos-n") +def pattern_table(): + """Get the Ethos-N compiler pattern table.""" + def qnn_conv_pattern(): + pattern = is_op('nn.pad')(wildcard()) | wildcard() + pattern = is_op('qnn.conv2d')( + pattern, is_constant(), is_constant(), is_constant(), is_constant(), is_constant()) + pattern = is_op('nn.bias_add')(pattern, is_constant()) + pattern = is_op('qnn.requantize')( + pattern, is_constant(), is_constant(), is_constant(), is_constant()) + return pattern + + def check_conv2d(extract): + """Check if a conv2d is supported by Ethos-N.""" + if not ethosn_available(): + return False + + return support.conv2d(extract) + + return [ + ("ethos-n.qnn_conv2d", qnn_conv_pattern(), check_conv2d), + ] + + @tvm.ir.register_op_attr("qnn.concatenate", "target.ethos-n") def qnn_concatenate(attrs, args): """Check if a concatenate is supported by Ethos-N.""" diff --git a/src/relay/backend/contrib/ethosn/codegen.cc b/src/relay/backend/contrib/ethosn/codegen.cc index f66eb94cb20f..58cd5bf1dd44 100644 --- a/src/relay/backend/contrib/ethosn/codegen.cc +++ b/src/relay/backend/contrib/ethosn/codegen.cc @@ -50,6 +50,16 @@ bool IsEthosnOp(const Call& call, const std::string& op_name) { } } +bool IsEthosnFunc(const Call& call, const std::string& op_name) { + if (call->op->IsInstance()) { + Function func = Downcast(call->op); + CHECK(func.defined()); + auto name_node = func->GetAttr(attr::kComposite); + return name_node.value() == op_name; + } + return false; +} + std::map> InferTensorsVisitor::Infer(const Expr& expr) { tensor_table_.clear(); CHECK(expr->checked_type().defined()); @@ -69,7 +79,11 @@ void InferTensorsVisitor::InferCall(const CallNode* cn) { EthosnError err; Call call = GetRef(cn); // Determine call -> NPU mapping - if (IsEthosnOp(call, "qnn.concatenate")) { + if (IsEthosnFunc(call, "ethos-n.qnn_conv2d")) { + ConvolutionParams params; + err += EthosnAPI::QnnConv2d(cn->op.as()->body, ¶ms); + tensor_table_[cn->args[0]] = {params.activation_info}; + } else if (IsEthosnOp(call, "qnn.concatenate")) { ConcatenateParams params; err = EthosnAPI::Concatenate(call, ¶ms); tensor_table_[cn->args[0]] = params.input_infos; @@ -181,7 +195,10 @@ sl::TensorsAndId ConstructNetworkVisitor::HandleCall(const CallNode* cn) { sl::TensorAndId tensor; sl::TensorsAndId tensors; // Determine call -> NPU mapping - if (IsEthosnOp(call, "qnn.concatenate")) { + if (IsEthosnFunc(call, "ethos-n.qnn_conv2d")) { + if ((err = MakeConvolutionLayer(call, &tensor))) ReportFatalError(call, err); + return MakeOps(tensor); + } else if (IsEthosnOp(call, "qnn.concatenate")) { if ((err = MakeConcatenateLayer(call, &tensor))) ReportFatalError(call, err); return MakeOps(tensor); } else if (IsEthosnOp(call, "split")) { @@ -227,6 +244,28 @@ void ConstructNetworkVisitor::VisitLeaf(const Expr& expr) { if (!expr->IsInstance()) MixedModeVisitor::VisitLeaf(expr); } +EthosnError ConstructNetworkVisitor::MakeConvolutionLayer(const Call& call, + sl::TensorAndId* out) { + ConvolutionParams params; + if (auto err = EthosnAPI::QnnConv2d(call->op.as()->body, ¶ms)) { + return err; + } + + auto activation = operand_table_[call->args[0]][0]; + auto weights = AddConstant(network_, params.weights_info, params.raw_weights).tensor; + auto bias = AddConstant(network_, params.bias_info, params.raw_bias).tensor; + try { + if (params.is_depthwise) { + *out = AddDepthwiseConvolution(network_, *activation, *bias, *weights, params.conv_info); + } else { + *out = AddConvolution(network_, *activation, *bias, *weights, params.conv_info); + } + } catch (const sl::NotSupportedException& e) { + return EthosnError(e.what()); + } + return EthosnError(); +} + EthosnError ConstructNetworkVisitor::MakeConcatenateLayer(const Call& call, sl::TensorAndId* out) { ConcatenateParams params; diff --git a/src/relay/backend/contrib/ethosn/codegen_ethosn.h b/src/relay/backend/contrib/ethosn/codegen_ethosn.h index 714a22d22027..7d1fe9cd5ea9 100644 --- a/src/relay/backend/contrib/ethosn/codegen_ethosn.h +++ b/src/relay/backend/contrib/ethosn/codegen_ethosn.h @@ -197,6 +197,7 @@ class ConstructNetworkVisitor : public MixedModeVisitor, private ErrorReportingP void VisitLeaf(const Expr& expr) final; // Make a support library operand from a Call + EthosnError MakeConvolutionLayer(const Call& call, sl::TensorAndId* out); EthosnError MakeConcatenateLayer(const Call& call, sl::TensorAndId* out); EthosnError MakeSplitLayer(const Call& call, sl::TensorsAndId* outs); diff --git a/src/relay/backend/contrib/ethosn/ethosn_api.cc b/src/relay/backend/contrib/ethosn/ethosn_api.cc index d92e35afeea0..b7cac6504ac6 100644 --- a/src/relay/backend/contrib/ethosn/ethosn_api.cc +++ b/src/relay/backend/contrib/ethosn/ethosn_api.cc @@ -40,6 +40,105 @@ namespace relay { namespace contrib { namespace ethosn { +EthosnError EthosnAPI::QnnConv2d(const Expr& expr, ConvolutionParams* params) { + Call requantize = Downcast(expr); + Call bias_add = Downcast(requantize->args[0]); + Call conv = Downcast(bias_add->args[0]); + Call pad; + if (conv->args[0]->IsInstance() && + Downcast(conv->args[0])->op == Op::Get("nn.pad")) + pad = Downcast(conv->args[0]); + const auto& conv_attr = conv->attrs.as(); + + // Extract the quantization params from the arguments + int input_zero_point; + int kernel_zero_point; + int output_zero_point; + float input_scale; + float kernel_scale; + float output_scale; + EthosnError err = AsConstant(conv->args[2], &input_zero_point); + err += AsConstant(conv->args[3], &kernel_zero_point); + err += AsConstant(requantize->args[4], &output_zero_point); + err += AsConstant(conv->args[4], &input_scale); + err += AsConstant(conv->args[5], &kernel_scale); + err += AsConstant(requantize->args[3], &output_scale); + + // Convert quantization params + sl::QuantizationInfo data_q_info; + sl::QuantizationInfo weights_q_info; + sl::QuantizationInfo bias_q_info; + sl::QuantizationInfo output_q_info; + err += Tvm2Npu(input_zero_point, input_scale, &data_q_info); + err += Tvm2Npu(kernel_zero_point, kernel_scale, &weights_q_info); + err += Tvm2Npu(0, data_q_info.m_Scale * weights_q_info.m_Scale, &bias_q_info); + err += Tvm2Npu(output_zero_point, output_scale, &output_q_info); + + // Convert convolution attributes + sl::Padding padding; + if (pad.defined()) { + Tvm2Npu(conv_attr->padding, &padding); + // Don't support both standalone operator padding and attribute defined padding + if (padding != sl::Padding({0, 0, 0, 0})) { + err += EthosnError( + ErrStrm() << "both op and attr padding exist, must be either op/attr only or no padding"); + } + err += Tvm2Npu(pad->attrs.as()->pad_width, &padding); + } else { + err += Tvm2Npu(conv_attr->padding, &padding); + } + sl::Stride stride; + err += Tvm2Npu(conv_attr->strides, &stride); + // Dilation is not supported + std::array dilation = {1, 1, 1, 1}; + AsArray(conv_attr->dilation, &dilation); + if (conv_attr->dilation.size() != 2 || dilation[0] != 1 || dilation[1] != 1) { + err += + EthosnError(ErrStrm() << "dilation=" << conv_attr->dilation << ", dilation must = [1, 1]"); + } + // Create convolution info + params->conv_info = sl::ConvolutionInfo(padding, stride, output_q_info); + + // Create data info + const TensorTypeNode* data_dtype; + if (pad.defined()) { + data_dtype = pad->args[0]->checked_type().as(); + } else { + data_dtype = conv->args[0]->checked_type().as(); + } + sl::TensorShape activation_tensor_shape; + sl::DataType activation_data_type; + err += Tvm2Npu(data_dtype->shape, &activation_tensor_shape); + err += Tvm2Npu(data_dtype->dtype, &activation_data_type); + params->activation_info = sl::TensorInfo(activation_tensor_shape, activation_data_type, + sl::DataFormat::NHWC, data_q_info); + + // Create weights info + params->is_depthwise = conv_attr->channels.defined() && + tvm::tir::ExprDeepEqual()(conv_attr->channels, conv_attr->groups) && + conv_attr->groups != 1; + + const auto* weights_dtype = conv->args[1]->checked_type().as(); + sl::TensorShape weights_tensor_shape; + sl::DataType weights_data_type; + sl::DataFormat weights_data_format; + // Ignore the error here because weights don't have a batch axis + Tvm2Npu(weights_dtype->shape, &weights_tensor_shape); + err += Tvm2Npu(weights_dtype->dtype, &weights_data_type); + err += Tvm2Npu(params->is_depthwise ? "HWIM" : "HWIO", &weights_data_format); + params->weights_info = + sl::TensorInfo(weights_tensor_shape, weights_data_type, weights_data_format, weights_q_info); + params->raw_weights = conv->args[1].as()->data->data; + + // Create bias info + params->bias_info = sl::TensorInfo( + {1, 1, 1, params->is_depthwise ? weights_tensor_shape[2] : weights_tensor_shape[3]}, + sl::DataType::INT32_QUANTIZED, sl::DataFormat::NHWC, bias_q_info); + params->raw_bias = bias_add->args[1].as()->data->data; + + return err; +} + EthosnError EthosnAPI::Concatenate(const Expr& expr, ConcatenateParams* params) { Call call = Downcast(expr); const auto& attrs = call->attrs.as(); @@ -107,6 +206,60 @@ EthosnError EthosnAPI::Split(const Expr& expr, SplitParams* params) { return err; } +EthosnError EthosnAPI::Tvm2Npu(const Array& padding, sl::Padding* npu_padding) { + std::array dim; + if (EthosnError err = AsArray(padding, &dim)) { + return err; + } + switch (padding.size()) { + case 1: + *npu_padding = sl::Padding(dim[0], dim[0], dim[0], dim[0]); + break; + case 2: + // Height, width -> top, bottom, left, right + *npu_padding = sl::Padding(dim[0], dim[0], dim[1], dim[1]); + break; + case 4: + // Top, left, bottom, right -> top, bottom, left, right + *npu_padding = sl::Padding(dim[0], dim[2], dim[1], dim[3]); + break; + default: + return EthosnError(ErrStrm() << "padding tuple size=" << padding.size() + << ", padding tuple size must be {1, 2, 4}"); + } + return EthosnError(); +} + +EthosnError EthosnAPI::Tvm2Npu(const Array& strides, sl::Stride* npu_stride) { + if (strides.size() != 2) { + return EthosnError(ErrStrm() << "stride size=" << strides.size() << ", stride size must = 2"); + } + std::array dim; + if (EthosnError err = AsArray(strides, &dim)) { + return err; + } + *npu_stride = sl::Stride(dim[1], dim[0]); + return EthosnError(); +} + +EthosnError EthosnAPI::Tvm2Npu(const std::string& dformat, sl::DataFormat* data_format) { + if (dformat == "NCHW") { + *data_format = sl::DataFormat::NCHW; + return EthosnError(); + } else if (dformat == "NHWC") { + *data_format = sl::DataFormat::NHWC; + return EthosnError(); + } else if (dformat == "HWIO") { + *data_format = sl::DataFormat::HWIO; + return EthosnError(); + } else if (dformat == "HWIM") { + *data_format = sl::DataFormat::HWIM; + return EthosnError(); + } + return EthosnError(ErrStrm() << "format=" << dformat + << ", format must be {NCHW, NHWC, HWIO, HWIM}"); +} + EthosnError EthosnAPI::Tvm2Npu(const Array& shape, sl::TensorShape* npu_shape) { EthosnError err = AsArray(shape, npu_shape); if (npu_shape->front() != 1) { @@ -128,6 +281,29 @@ EthosnError EthosnAPI::Tvm2Npu(const tvm::DataType& dtype, sl::DataType* data_ty return EthosnError(ErrStrm() << "dtype=\'" << dtype << "\', dtype must be either uint8 or int32"); } +EthosnError EthosnAPI::Tvm2Npu(int32_t zero_point, float scale, sl::QuantizationInfo* npu_qinfo) { + *npu_qinfo = sl::QuantizationInfo(zero_point, scale); + return EthosnError(); +} + +EthosnError EthosnAPI::Tvm2Npu(const Array>& padding, sl::Padding* npu_padding) { + if (padding.size() != 4) { + return EthosnError(ErrStrm() << "padding tuple size=" << padding.size() + << ", padding tuple size must = 4"); + } + Array reduced_padding; + reduced_padding.push_back(padding[1][0]); + reduced_padding.push_back(padding[1][1]); + reduced_padding.push_back(padding[2][0]); + reduced_padding.push_back(padding[2][1]); + std::array dim; + if (EthosnError err = AsArray(reduced_padding, &dim)) { + return err; + } + *npu_padding = sl::Padding(dim[0], dim[1], dim[2], dim[3]); + return EthosnError(); +} + // Convert an array of IntImmNodes into ValueT // IndexT type of Array indexing variable // ValueT type of resulting value @@ -158,6 +334,20 @@ EthosnError EthosnAPI::AsConstant(const Expr& expr, T* out) { return EthosnError(); } +TVM_REGISTER_GLOBAL("relay.ethos-n.support.conv2d") + .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { + Call call = args[0]; + ConvolutionParams params; + auto err = EthosnAPI::QnnConv2d(call, ¶ms); + if (params.is_depthwise) { + *rv = !err && sl::IsDepthwiseConvolutionSupported(params.bias_info, params.weights_info, + params.conv_info, params.activation_info); + } else { + *rv = !err && sl::IsConvolutionSupported(params.bias_info, params.weights_info, + params.conv_info, params.activation_info); + } + }); + TVM_REGISTER_GLOBAL("relay.ethos-n.support.concatenate") .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { Call call = args[0]; diff --git a/src/relay/backend/contrib/ethosn/ethosn_api.h b/src/relay/backend/contrib/ethosn/ethosn_api.h index 34af7ce0b1d8..20fe8bec03c6 100644 --- a/src/relay/backend/contrib/ethosn/ethosn_api.h +++ b/src/relay/backend/contrib/ethosn/ethosn_api.h @@ -44,6 +44,16 @@ namespace ethosn { namespace sl = ::ethosn::support_library; +struct ConvolutionParams { + sl::ConvolutionInfo conv_info; + sl::TensorInfo activation_info; + sl::TensorInfo weights_info; + sl::TensorInfo bias_info; + void* raw_weights = nullptr; + void* raw_bias = nullptr; + bool is_depthwise = false; +}; + struct ConcatenateParams { sl::QuantizationInfo qInfo; sl::ConcatenationInfo concat_info = sl::ConcatenationInfo(1, qInfo); @@ -115,6 +125,8 @@ class EthosnError { */ class EthosnAPI { public: + /*! \brief Extract the Support Library convolution params from an ethos-n.qnn_conv2d func */ + static EthosnError QnnConv2d(const Expr& expr, ConvolutionParams* params); /*! \brief Extract the Support Library concatenate params from a Relay qnn.concatenate call */ static EthosnError Concatenate(const Expr& expr, ConcatenateParams* params); /*! \brief Extract the Support Library split params from a Relay split call */ @@ -125,6 +137,16 @@ class EthosnAPI { static EthosnError Tvm2Npu(const Array& shape, sl::TensorShape* npu_shape); /*! \brief Convert a TVM data type to a SL data type */ static EthosnError Tvm2Npu(const tvm::DataType& dtype, sl::DataType* data_type); + /*! \brief Convert TVM 1D padding to SL padding */ + static EthosnError Tvm2Npu(const Array& padding, sl::Padding* npu_padding); + /*! \brief Convert TVM 1D striding to SL striding */ + static EthosnError Tvm2Npu(const Array& strides, sl::Stride* npu_stride); + /*! \brief Convert TVM data format to SL data format */ + static EthosnError Tvm2Npu(const std::string& dformat, sl::DataFormat* data_format); + /*! \brief Convert TVM quantization info to SL quantization info */ + static EthosnError Tvm2Npu(int32_t zero_point, float scale, sl::QuantizationInfo* npu_qinfo); + /*! \brief Convert TVM 2D padding to SL padding */ + static EthosnError Tvm2Npu(const Array>& padding, sl::Padding* npu_padding); // Convert an array of IntImmNodes into ValueT // IndexT type of Array indexing variable diff --git a/tests/python/contrib/test_ethosn/infrastructure.py b/tests/python/contrib/test_ethosn/infrastructure.py index c6278334cfec..b43d273553e5 100644 --- a/tests/python/contrib/test_ethosn/infrastructure.py +++ b/tests/python/contrib/test_ethosn/infrastructure.py @@ -94,6 +94,8 @@ def build(mod, params, npu=True, expected_host_ops=0, npu_partitions=1): f = relay.build_module.bind_params_by_name(mod["main"], params) mod = tvm.IRModule() mod["main"] = f + pattern = get_pattern_table("ethos-n") + mod = relay.transform.MergeComposite(pattern)(mod) mod = relay.transform.AnnotateTarget("ethos-n")(mod) mod = relay.transform.MergeCompilerRegions()(mod) mod = relay.transform.PartitionGraph()(mod) diff --git a/tests/python/contrib/test_ethosn/test_conv2d.py b/tests/python/contrib/test_ethosn/test_conv2d.py new file mode 100644 index 000000000000..52e3de94eb4d --- /dev/null +++ b/tests/python/contrib/test_ethosn/test_conv2d.py @@ -0,0 +1,204 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Ethos-N integration conv2d tests""" + +import numpy as np +import math +import tvm +from tvm import relay +from tvm.relay.op.contrib.ethosn import ethosn_available +from . import infrastructure as tei + + +def _get_same_padding(data, kernel, dilation, stride): + dilated_kernel_h = dilation[0] * (kernel[0] - 1) + 1 + dilated_kernel_w = dilation[1] * (kernel[1] - 1) + 1 + out = int(math.ceil(float(data[0]) / float(stride[0]))) + pad = max(0, (out - 1) * stride[0] + dilated_kernel_h - data[0]) + pad_top = pad // 2 + pad_bottom = pad - pad_top + + out = int(math.ceil(float(data[1]) / float(stride[1]))) + pad = max(0, (out - 1) * stride[1] + dilated_kernel_w - data[1]) + pad_left = pad // 2 + pad_right = pad - pad_left + return [pad_top, pad_left, pad_bottom, pad_right] + + +def _get_model(shape, kernel_h, kernel_w, + input_zp, input_sc, + kernel_zp, kernel_sc, + output_zp, output_sc, + pad, strides, dilation, + groups, dtype, + out_channels, weight_format): + """Return a model and any parameters it may have""" + a = relay.var("a", shape=shape, dtype=dtype) + if pad == "op" or pad == "both": + p = _get_same_padding((shape[1], shape[2]), (kernel_h, kernel_w), dilation, strides) + a = relay.nn.pad(a, + pad_width=[(0, 0), (p[0], p[2]), (p[1], p[3]), (0, 0)], + pad_value=input_zp, pad_mode="constant") + shape = (shape[0], shape[1] + p[0] + p[2], shape[2] + p[1] + p[3], shape[3]) + + p = _get_same_padding((shape[1], shape[2]), (kernel_h, kernel_w), dilation, strides) + if weight_format == "HWIO": + weight_shape = (kernel_h, kernel_w, shape[3] // groups, out_channels) + else: + weight_shape = (kernel_h, kernel_w, out_channels, 1) + w = tvm.nd.array(np.random.randint(np.iinfo(dtype).min, high=np.iinfo(dtype).max, size=weight_shape, dtype=dtype)) + weights = relay.const(w, dtype) + conv = relay.qnn.op.conv2d( + a, + weights, + input_zero_point=relay.const(input_zp, "int32"), + kernel_zero_point=relay.const(kernel_zp, "int32"), + input_scale=relay.const(input_sc, "float32"), + kernel_scale=relay.const(kernel_sc, "float32"), + kernel_size=(kernel_h, kernel_w), + data_layout="NHWC", + kernel_layout=weight_format, + dilation=dilation, + strides=strides, + groups=groups, + channels=out_channels, + padding=p if pad == "attr" or pad == "both" else (0, 0, 0, 0), + out_dtype="int32", + ) + b = tvm.nd.array(np.random.randint(0, high=10, size=(out_channels,), dtype="int32")) + biasc = relay.const(b, "int32") + bias = relay.nn.bias_add(conv, biasc, axis=3) + req = relay.qnn.op.requantize( + bias, + relay.const(input_sc * kernel_sc, 'float32'), # input zero scale + relay.const(0, 'int32'), # input zero point + relay.const(output_sc, 'float32'), # output zero scale + relay.const(output_zp, 'int32'), # output zero point + out_dtype="uint8" + ) + params = {"w": w, + "b": b} + return req, params + + +def _get_conv2d_qnn_params(input_zp, input_sc, kernel_zp, kernel_sc, kernel_h, kernel_w, channels): + input_max = input_sc * (255 - input_zp) + input_min = - input_sc * input_zp + kernel_max = kernel_sc * (255 - kernel_zp) + kernel_min = - kernel_sc * kernel_zp + output_limits = [kernel_max * kernel_h * kernel_w * channels * input_max, + kernel_min * kernel_h * kernel_w * channels * input_max, + kernel_min * kernel_h * kernel_w * channels * input_min, + kernel_max * kernel_h * kernel_w * channels * input_min] + output_max = max(output_limits) + output_min = min(output_limits) + output_sc = (output_max - output_min) / 255 + output_zp = - int(output_min / output_sc) + return output_zp, output_sc + + +def test_conv2d(): + if not ethosn_available(): + return + + trials = [ + [(1, 17, 20, 26), 4, 3, 1, 'attr', (2, 2), (1, 1)], + [(1, 30, 27, 30), 5, 5, 3, 'none', (1, 1), (1, 1)], + [(1, 14, 28, 11), 6, 2, 2, 'op', (2, 2), (1, 1)], + [(1, 9, 20, 30), 7, 1, 5, 'none', (1, 1), (1, 1)], + [(1, 21, 21, 22), 8, 5, 1, 'attr', (2, 2), (1, 1)], + [(1, 21, 25, 29), 9, 2, 5, 'op', (1, 1), (1, 1)], + [(1, 31, 28, 15), 10, 1, 2, 'attr', (2, 2), (1, 1)], + [(1, 21, 21, 8), 11, 3, 3, 'none', (1, 1), (1, 1)], + [(1, 5, 11, 6), 12, 5, 2, 'op', (2, 2), (1, 1)], + [(1, 12, 7, 18), 13, 1, 3, 'op', (1, 1), (1, 1)], + [(1, 24, 6, 26), 14, 3, 5, 'none', (2, 2), (1, 1)], + [(1, 19, 24, 16), 15, 2, 1, 'attr', (1, 1), (1, 1)], + ] + + np.random.seed(0) + for depthwise in [False, True]: + for shape, out_channels, kernel_h, kernel_w, pad, stride, dilation in trials: + if depthwise: + out_channels = shape[3] + groups = out_channels + kernel_w = kernel_h + weight_format = "HWOI" + stride = (1, 1) if kernel_w == 1 else (2, 2) + else: + groups = 1 + weight_format = "HWIO" + + outputs = [] + inputs = { + "a": tvm.nd.array(np.random.randint(0, high=255, size=shape, dtype="uint8")), + } + input_zp = np.random.randint(0, 255) + input_sc = np.random.random() * 2 + kernel_zp = np.random.randint(0, 255) + kernel_sc = np.random.random() * 2 + output_zp, output_sc = _get_conv2d_qnn_params(input_zp, input_sc, + kernel_zp, kernel_sc, + kernel_h, kernel_w, shape[3]) + model, params = _get_model(shape, kernel_h, kernel_w, + input_zp, input_sc, + kernel_zp, kernel_sc, + output_zp, output_sc, + pad, stride, dilation, + groups, "uint8", + out_channels, weight_format) + for npu in [False, True]: + mod = tei.make_module(model, params) + outputs.append(tei.build_and_run(mod, inputs, 1, params, npu=npu)) + + tei.verify(outputs, 1) + + +def test_conv2d_failure(): + if not ethosn_available(): + return + + trials = [ + ((1, 4, 4, 4), 1, 1, 0, 1, 0, 1, 0, 1, "none", (1, 1), (1, 1), 1, "uint8", 8, "HWIO", + "Overall scale (of the input * weights / output) should be in the range [0, 1)"), + ((1, 4, 4, 4), 1, 1, 0, 1, 0, 1, 0, 1, "none", (1, 1), (1, 1), 1, "int8", 8, "HWIO", + "dtype='int8', dtype must be either uint8 or int32"), + ((1, 4, 4, 4), 2, 2, 0, 1, 0, 1, 0, 2, "both", (1, 1), (1, 1), 1, "uint8", 8, "HWIO", + "both op and attr padding exist, must be either op/attr only or no padding"), + ((1, 4, 4, 4), 1, 1, 0, 1, 0, 1, 0, 2, "none", (1, 1, 1), (1, 1), 1, "uint8", 8, "HWIO", + "stride size=3, stride size must = 2"), + ((1, 4, 4, 4), 1, 1, 0, 1, 0, 1, 0, 2, "none", (1, 1), (2, 1), 1, "uint8", 8, "HWIO", + "dilation=[2, 1], dilation must = [1, 1]"), + ((2, 4, 4, 4), 1, 1, 0, 1, 0, 1, 0, 2, "none", (1, 1), (1, 1), 1, "uint8", 8, "HWIO", + "batch size=2, batch size must = 1"), + ] + + np.random.seed(0) + for shape, kernel_h, kernel_w, input_zp, input_sc, kernel_zp,\ + kernel_sc, output_zp, output_sc, pad, stride, dilation,\ + groups, dtype, out_channels, weight_format, err_msg in trials: + model, params = _get_model(shape, kernel_h, kernel_w, + input_zp, input_sc, + kernel_zp, kernel_sc, + output_zp, output_sc, + pad, stride, dilation, + groups, dtype, + out_channels, weight_format) + model = tei.make_ethosn_composite(model, "ethos-n.qnn_conv2d") + mod = tei.make_ethosn_partition(model) + tei.test_error(mod, {}, err_msg) From 1899ad82f78976f1711972e6a0cbb7009c1228d6 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 27 Aug 2020 15:11:42 -0700 Subject: [PATCH 3/3] [Ansor][AutoTVM v2.0] Phase 2: Evolutionary Search (#6310) * init commit * Add rest rules * refactor * address comments * improve test * address comments --- python/tvm/auto_scheduler/search_policy.py | 20 ++ .../search_policy/sketch_policy.cc | 166 ++++++++++- .../search_policy/sketch_policy.h | 24 +- .../search_policy/sketch_policy_rules.cc | 279 +++++++++++++++++- .../search_policy/sketch_policy_rules.h | 57 +++- src/auto_scheduler/search_policy/utils.cc | 65 +++- src/auto_scheduler/search_policy/utils.h | 16 +- ...test_auto_scheduler_evolutionary_search.py | 75 +++++ 8 files changed, 674 insertions(+), 28 deletions(-) create mode 100644 tests/python/unittest/test_auto_scheduler_evolutionary_search.py diff --git a/python/tvm/auto_scheduler/search_policy.py b/python/tvm/auto_scheduler/search_policy.py index 278822e2ca04..e2bfca392c1e 100644 --- a/python/tvm/auto_scheduler/search_policy.py +++ b/python/tvm/auto_scheduler/search_policy.py @@ -113,6 +113,8 @@ class SketchPolicy(SearchPolicy): "retry_search_one_round_on_empty": 10, 'evolutionary_search_population': 2048, + 'evolutionary_search_num_iters': 10, + 'evolutionary_search_mutation_prob': 0.85, "evolutionary_search_use_measured_ratio": 0.2, 'cpu_multi_level_tiling_structure': 'SSRSRS', @@ -178,3 +180,21 @@ def sample_initial_population(self, pop_size): """ states = _ffi_api.SketchPolicySampleInitialPopulation(self, pop_size) return states + + def evolutionary_search(self, init_populuations, out_size): + """Evolutionary search. + This python interface is mainly used for debugging and testing. + The actual search is all doen in c++. + Parameters + ---------- + init_populations: List[State] + The initial population states + out_size : int + The size of generated states + Returns + ------- + states: List[State] + The generated states + """ + states = _ffi_api.SketchPolicyEvolutionarySearch(self, init_populuations, out_size) + return states diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc index 51c138be70bb..4f536e829be4 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.cc +++ b/src/auto_scheduler/search_policy/sketch_policy.cc @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -65,6 +66,13 @@ static InitUnroll init_unroll; static InitVectorization init_vectorization; static InitThreadBind init_thread_bind; +/********** Mutation rules **********/ + +static MutateTileSize mutate_tile_size; +static MutateMaxUnrollFactor mutate_max_unroll_factor; +static MutateComputeLocation mutate_compute_location; +static MutateParallel mutate_parallel; + /********** Sketch policy **********/ TVM_REGISTER_NODE_TYPE(SketchPolicyNode); @@ -129,6 +137,12 @@ SketchPolicy::SketchPolicy(SearchTask task, CostModel schedule_cost_model, LOG(FATAL) << "No default init rules for target: " << task->target; } + // The default mutation rules. + node->mutation_rules.push_back(&mutate_tile_size); + node->mutation_rules.push_back(&mutate_max_unroll_factor); + node->mutation_rules.push_back(&mutate_compute_location); + node->mutation_rules.push_back(&mutate_parallel); + data_ = std::move(node); } @@ -336,7 +350,7 @@ Array SketchPolicyNode::SampleInitPopulation(const Array& sketches // Derivation rule based enumeration bool valid = true; for (const auto& rule : init_rules) { - if (rule->Apply(this, &tmp_s) == InitPopulationRule::ResultKind::kInvalid) { + if (rule->Apply(this, &tmp_s) == PopulationGenerationRule::ResultKind::kInvalid) { valid = false; break; } @@ -363,8 +377,148 @@ Array SketchPolicyNode::EvolutionarySearch(const Array& init_popul Array best_states; auto tic_begin = std::chrono::high_resolution_clock::now(); - // TODO(comaniac, merrymercy, jcf94): Since we haven't finished porting the cost model part - // yet, currently delete the implementation of EvolutionarySearch. To be added later. + size_t population = init_population.size(); + int num_iters = GetIntParam(params, SketchParamKey::EvolutionarySearch::num_iters); + double mutation_prob = GetDoubleParam(params, SketchParamKey::EvolutionarySearch::mutation_prob); + + // Two ping pong buffers to avoid copy. + Array states_buf1{init_population}, states_buf2; + states_buf1.reserve(population); + states_buf2.reserve(population); + Array* pnow = &states_buf1; + Array* pnext = &states_buf2; + + // The set of explored states to avoid redundancy. + std::unordered_set explored_set; + + // The heap to maintain the so far best states. + using StateHeapItem = std::pair; + auto cmp = [](const StateHeapItem& left, const StateHeapItem& right) { + return left.second > right.second; + }; + using StateHeap = std::priority_queue, decltype(cmp)>; + StateHeap heap(cmp); + auto update_heap = [&heap, &explored_set](const Array& states, + const std::vector& scores, const int out_size) { + float max_score = 0.0; + for (size_t i = 0; i < states.size(); ++i) { + const State& state = states[i]; + std::string state_str = state.ToStr(); + + // Skip redundant states. + if (explored_set.count(state_str) > 0) { + continue; + } + explored_set.insert(state_str); + + if (static_cast(heap.size()) < out_size) { + // Directly push item if the heap is not full yet. + heap.push({state, scores[i]}); + } else if (scores[i] > heap.top().second) { + // Replace the worst state in the heap with the new state. + heap.pop(); + heap.push({state, scores[i]}); + } + max_score = (scores[i] > max_score) ? scores[i] : max_score; + } + return max_score; + }; + + // Cost model predicted scores. + std::vector scores; + scores.reserve(population); + + // The function to generate prefix sum probabilities based on the given scores. + auto assign_prob = [](const std::vector& scores, std::vector* prefix_sum_probs) { + // Compute selection probabilities. + double sum = 0.0; + prefix_sum_probs->resize(scores.size()); + for (size_t i = 0; i < scores.size(); ++i) { + sum += std::max(scores[i], 0.0f); + (*prefix_sum_probs)[i] = sum; + } + for (size_t i = 0; i < scores.size(); ++i) { + (*prefix_sum_probs)[i] /= sum; + } + }; + + // State selection probabilities. + std::uniform_real_distribution<> uniform_dist(0.0, 1.0); + std::vector state_select_probs; + state_select_probs.reserve(population); + + // Mutation rule selection probabilities. + std::vector rule_select_probs; + rule_select_probs.reserve(mutation_rules.size()); + std::vector rule_levels; + for (const auto& rule : mutation_rules) { + rule_levels.push_back(rule->GetLevel(search_task)); + } + assign_prob(rule_levels, &rule_select_probs); + + // Evaluate the init populations. + *pnow = search_task->compute_dag.InferBound(*pnow); + PruneInvalidState(search_task, pnow); + CHECK_GT(pnow->size(), 0) << "All initial populations are invalid"; + schedule_cost_model->Predict(search_task, *pnow, &scores); + + // Maintain the best states in the heap. + float max_score = update_heap(*pnow, scores, out_size); + + // Genetic algorithm. + for (auto iter_idx = 1; iter_idx <= num_iters; ++iter_idx) { + // Assign the selection probability to each state based on the cost model scores. + assign_prob(scores, &state_select_probs); + + // TODO(@comaniac): Perform cross over. + + // Perform mutations. + size_t fail_ct = 0; + while (pnext->size() < population && fail_ct < population * 2) { + // Select a state to be mutated. + State tmp_s = (*pnow)[RandomChoose(state_select_probs, &rand_gen)]; + if (uniform_dist(rand_gen) < mutation_prob) { + // Select a rule and mutate the state. + const auto& rule = mutation_rules[RandomChoose(rule_select_probs, &rand_gen)]; + if (rule->Apply(this, &tmp_s) == PopulationGenerationRule::ResultKind::kValid) { + pnext->push_back(std::move(tmp_s)); + } else { + fail_ct++; + } + } else { + // Do not mutate this state in this round. + pnext->push_back(std::move(tmp_s)); + } + } + + // Evaluate the new populations. + *pnext = search_task->compute_dag.InferBound(*pnext); + PruneInvalidState(search_task, pnext); + + // Throw away all states generated in this iterations if all new states are invalid. + if (pnext->size() > 0) { + std::swap(pnext, pnow); + schedule_cost_model->Predict(search_task, *pnow, &scores); + + // Maintain the best states in the heap. + float iter_max_score = update_heap(*pnow, scores, out_size); + max_score = (iter_max_score > max_score) ? iter_max_score : max_score; + } + pnext->clear(); + + if (iter_idx % 5 == 0 || iter_idx == num_iters) { + StdCout(verbose) << "GA Iter: " << iter_idx << std::fixed << std::setprecision(4) + << "\tMax Score: " << max_score << "\tPop Size: " << pnow->size() + << std::endl; + } + } + + // Copy best states in the heap to the output. + while (!heap.empty()) { + auto item = heap.top(); + heap.pop(); + best_states.push_back(std::move(item.first)); + } double duration = std::chrono::duration_cast>( std::chrono::high_resolution_clock::now() - tic_begin) @@ -441,5 +595,11 @@ TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicySampleInitialPopulation") return init_population; }); +TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicyEvolutionarySearch") + .set_body_typed([](SketchPolicy policy, Array init_population, int out_size) { + Array states = policy->EvolutionarySearch(init_population, out_size); + return states; + }); + } // namespace auto_scheduler } // namespace tvm diff --git a/src/auto_scheduler/search_policy/sketch_policy.h b/src/auto_scheduler/search_policy/sketch_policy.h index 0c1e6df170f4..2d93d8775c86 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.h +++ b/src/auto_scheduler/search_policy/sketch_policy.h @@ -56,6 +56,10 @@ struct SketchParamKey { struct EvolutionarySearch { /*! \brief The population size for evolutionary search. */ static constexpr const char* population = "evolutionary_search_population"; + /*! \brief The number of iterations performed by generic algorithm.*/ + static constexpr const char* num_iters = "evolutionary_search_num_iters"; + /*! \brief The mutation probability.*/ + static constexpr const char* mutation_prob = "evolutionary_search_mutation_prob"; /*! \brief The maximum percentage of measured states in the initial population for evolutionary * search. */ static constexpr const char* use_measured_ratio = "evolutionary_search_use_measured_ratio"; @@ -90,7 +94,9 @@ class SketchPolicyNode : public SearchPolicyNode { /*! \brief The rules to generate sketches. */ std::vector sketch_rules; /*! \brief The rules to generate initial states. */ - std::vector init_rules; + std::vector init_rules; + /*! \brief The rules to mutate states. */ + std::vector mutation_rules; /*! \brief Random generator. */ std::mt19937 rand_gen; /*! \brief Memorize split space for Split. */ @@ -113,6 +119,14 @@ class SketchPolicyNode : public SearchPolicyNode { */ Array SampleInitPopulation(const Array& sketches, int out_size); + /*! + * \brief Perform evolutionary search. + * \param init_populations The states generated from init population. + * \param out_size The number of expected output states. + * \return The generated states after evolutionary search. + */ + Array EvolutionarySearch(const Array& init_populations, int out_size); + static constexpr const char* _type_key = "auto_scheduler.SketchPolicy"; TVM_DECLARE_FINAL_OBJECT_INFO(SketchPolicyNode, SearchPolicyNode); @@ -127,14 +141,6 @@ class SketchPolicyNode : public SearchPolicyNode { */ Array SearchOneRound(int num_random_states, Array* random_states = nullptr); - /*! - * \brief Perform evolutionary search. - * \param init_populations The states generated from init population. - * \param out_size The number of expected output states. - * \return The generated states after evolutionary search. - */ - Array EvolutionarySearch(const Array& init_populations, int out_size); - /*! * \brief Pick states from best states and random states with eps-greedy policy. * \param best_states States picked by cost model. diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc b/src/auto_scheduler/search_policy/sketch_policy_rules.cc index 92073b68b73a..843301c2bb8f 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc @@ -436,8 +436,8 @@ std::vector> RuleSpecialComputeLocationGPU::Apply( /********** Init Population **********/ -InitPopulationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, - State* state) const { +PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, + State* state) const { StateNode* pstate = state->CopyOnWrite(); // Scan the transformation history and randomly fill tiles size for all SplitStep for (size_t step_id = 0; step_id < (*state)->transform_steps.size(); ++step_id) { @@ -472,10 +472,11 @@ InitPopulationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, return ResultKind::kValid; } -InitPopulationRule::ResultKind InitChangeComputeLocation::Apply(SketchPolicyNode* policy, - State* state) const { +PopulationGenerationRule::ResultKind MutateComputeLocationCommon(SketchPolicyNode* policy, + State* state, + bool infer_bound = true) { if (GetIntParam(policy->params, SketchParamKey::disable_change_compute_location)) { - return ResultKind::kValid; + return PopulationGenerationRule::ResultKind::kValid; } for (int stage_id = static_cast((*state)->stages.size()) - 1; stage_id >= 0; stage_id--) { @@ -584,11 +585,19 @@ InitPopulationRule::ResultKind InitChangeComputeLocation::Apply(SketchPolicyNode } } - *state = policy->search_task->compute_dag.InferBound(*state); - return ResultKind::kValid; + if (infer_bound) { + *state = policy->search_task->compute_dag.InferBound(*state); + } + return PopulationGenerationRule::ResultKind::kValid; } -InitPopulationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy, State* state) const { +PopulationGenerationRule::ResultKind InitChangeComputeLocation::Apply(SketchPolicyNode* policy, + State* state) const { + return MutateComputeLocationCommon(policy, state, false); +} + +PopulationGenerationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy, + State* state) const { std::function annotate_parallel; annotate_parallel = [&annotate_parallel](const SketchPolicyNode& policy, State* state, @@ -652,7 +661,8 @@ InitPopulationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy, Sta return ResultKind::kValid; } -InitPopulationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy, State* state) const { +PopulationGenerationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy, + State* state) const { std::vector auto_unroll_configs = IsGPUTask(policy->search_task) ? std::vector({0, 16, 64, 512, 1024}) : std::vector({0, 16, 64, 512}); @@ -703,8 +713,8 @@ InitPopulationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy, State return ResultKind::kValid; } -InitPopulationRule::ResultKind InitVectorization::Apply(SketchPolicyNode* policy, - State* state) const { +PopulationGenerationRule::ResultKind InitVectorization::Apply(SketchPolicyNode* policy, + State* state) const { for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { const Stage& stage = (*state)->stages[stage_id]; // Skip the inlined stage and placeholder stage @@ -762,7 +772,8 @@ InitPopulationRule::ResultKind InitVectorization::Apply(SketchPolicyNode* policy return ResultKind::kValid; } -InitPopulationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* policy, State* state) const { +PopulationGenerationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* policy, + State* state) const { std::set multi_level_tiling_root_set; for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { if (NeedsMultilevelTiling(policy->search_task, *state, stage_id)) { @@ -908,7 +919,251 @@ InitPopulationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* policy, S state->bind(stage_id, iters1[1], IteratorAnnotation::kThreadX); } } + return ResultKind::kValid; +} + +PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* policy, + State* state) const { + int max_innermost_split_factor = + GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor); + + // Extract all SplitStep + std::vector split_step_ids; + for (size_t i = 0; i < (*state)->transform_steps.size(); ++i) { + if (auto ps = (*state)->transform_steps[i].as()) { + if (!ps->extent.defined() || !ps->extent.value()->IsInstance()) { + continue; + } + auto innermost_factor = ps->lengths.back().value_or(max_innermost_split_factor + 1); + if (GetIntImm(innermost_factor) <= max_innermost_split_factor) { + split_step_ids.push_back(i); + } + } + } + if (split_step_ids.empty()) { + // No tile size could be mutated. + return ResultKind::kInvalid; + } + + // Select a SplitStep with extent larger than one to mutate. + int retry_ct = 0; + int64_t extent = 1; + int step_id; + const SplitStepNode* ps; + + do { + step_id = split_step_ids[(policy->rand_gen)() % split_step_ids.size()]; + ps = (*state)->transform_steps[step_id].as(); + CHECK(ps != nullptr); + extent = GetIntImm(ps->extent.value()); + retry_ct += 1; + } while (retry_ct < static_cast(split_step_ids.size()) << 2 && (extent == 1 || extent == 0)); + + if (extent <= 1) { + // Cannot find a step with extent larger than one. + return ResultKind::kInvalid; + } + + // Fetch the current tile sizes. + std::vector lengths(ps->lengths.size() + 1, 1); + for (int i = 0; i < static_cast(ps->lengths.size()); ++i) { + lengths[i + 1] = GetIntImm(ps->lengths[i].value()); + } + lengths[0] = extent / ElementProduct(lengths); + + // Random permute the tile size order. + std::vector random_perm; + RandomPermutation(lengths.size(), &random_perm, &(policy->rand_gen)); + + // Try to divide a factor from one tile size and multiple it to another. + for (size_t i = 0; i < random_perm.size(); ++i) { + size_t src_idx = random_perm[i]; + int length = lengths[src_idx]; + if (length <= 1) { + continue; + } + + size_t dst_idx = random_perm[(i + 1) % random_perm.size()]; + const std::vector& factors = policy->split_memo.GetFactors(length); + CHECK_GE(factors.size(), 1); + + int divide_factor; + if (dst_idx == lengths.size() - 1) { + // Maintain the restriction of hardware_params.max_innermost_split_factor. + int max_factor_index = static_cast(factors.size()) - 1; + for (; max_factor_index >= 1; max_factor_index--) { + if (factors[max_factor_index] * lengths[dst_idx] <= max_innermost_split_factor) { + break; + } + } + if (max_factor_index == 0) { + // Failed on this dst_idx, try next one. + continue; + } + divide_factor = factors[1 + (policy->rand_gen)() % (max_factor_index)]; + } else { + divide_factor = factors[1 + (policy->rand_gen)() % (factors.size() - 1)]; + } + + // Divide one factor from lengths[src_idx] and multiply it to lengths[dst_idx]. + Array new_lengths; + for (size_t j = 1; j < lengths.size(); ++j) { + if (j == src_idx) { + new_lengths.push_back(Integer(lengths[j] / divide_factor)); + } else if (j == dst_idx) { + new_lengths.push_back(Integer(lengths[j] * divide_factor)); + } else { + new_lengths.push_back(Integer(lengths[j])); + } + } + + StateNode* pstate = state->CopyOnWrite(); + pstate->transform_steps.Set( + step_id, SplitStep(ps->stage_id, ps->iter_id, ps->extent, + Array>(new_lengths.begin(), new_lengths.end()), + ps->inner_to_outer)); + return ResultKind::kValid; + } + return ResultKind::kInvalid; +} + +PopulationGenerationRule::ResultKind MutateMaxUnrollFactor::Apply(SketchPolicyNode* policy, + State* state) const { + // Extract all auto_unroll_max_step pragma steps. + std::vector annotate_steps; + for (size_t i = 0; i < (*state)->transform_steps.size(); ++i) { + if (auto ps = (*state)->transform_steps[i].as()) { + if (StrStartsWith(ps->pragma_type, "auto_unroll_max_step")) { + annotate_steps.push_back(i); + } + } + } + if (annotate_steps.empty()) { + return ResultKind::kInvalid; + } + + // Random pick up one unroll factor candidate. + auto cands = (IsGPUTask(policy->search_task)) ? &gpu_unroll_cands_ : &cpu_unroll_cands_; + auto new_factor = std::to_string((*cands)[(policy->rand_gen)() % cands->size()]); + + // Random pick up and mutate an unroll step. + auto step_id = annotate_steps[(policy->rand_gen)() % annotate_steps.size()]; + auto ps = (*state)->transform_steps[step_id].as(); + CHECK(ps); + StateNode* pstate = state->CopyOnWrite(); + pstate->transform_steps.Set(step_id, + PragmaStep(ps->stage_id, ps->iter_id, + std::string("auto_unroll_max_step") + "$" + new_factor)); + return ResultKind::kValid; +} + +PopulationGenerationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNode* policy, + State* state) const { + return MutateComputeLocationCommon(policy, state, true); +} + +PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* policy, + State* state) const { + // This mutation rule only focuses on a case that parallel was added to + // the outermost loop and the loop is generated by fusing other loops. + // In short, we mutate the fusion step before the parallel step. + + // Extract all parallel steps. + std::vector parallel_steps; + for (size_t s = 0; s < (*state)->transform_steps.size(); ++s) { + auto ps = (*state)->transform_steps[s].as(); + if (!ps || ps->annotation != IteratorAnnotation::kParallel) { + continue; + } + + // Skip non-outermost loop or the parallel step without fusion beforehand. + if (ps->iter_id > 0 || s == 0 || !(*state)->transform_steps[s - 1].as()) { + continue; + } + parallel_steps.push_back(s); + } + if (parallel_steps.empty()) { + return ResultKind::kInvalid; + } + + // Randomly pick one parallel step. + size_t step_id = parallel_steps[(policy->rand_gen)() % parallel_steps.size()]; + auto ps = (*state)->transform_steps[step_id].as(); + CHECK(ps); + size_t stage_id = ps->stage_id; + size_t iter_id = ps->iter_id; + const Stage& stage = (*state)->stages[stage_id]; + const Iterator& it = stage->iters[iter_id]; + + // Replay a new state until the picked fuse step. + State tmp_s = policy->search_task->compute_dag->init_state; + for (size_t s = 0; s < step_id - 1; ++s) { + auto step = (*state)->transform_steps[s]; + tmp_s.CopyOnWrite()->transform_steps.push_back(step); + StepApplyToState(step, &tmp_s, policy->search_task->compute_dag); + } + + // Determine the fusion mutation direction. + // 0: fuse less; 1: fuse more. + auto fuse_step = (*state)->transform_steps[step_id - 1].as(); + auto fused_ids = fuse_step->fused_ids; + std::vector fuse_dir = {0.5, 1.0}; + + // The case that we can only fuse more. This may happen after multiple mutations. + if (fused_ids.size() == 1) { + fuse_dir[0] = 0.0; + } + + // The cases that we cannot fuse the next iters. + if ((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, iter_id)) || + it->iter_kind == IteratorKind::kReduction || it->annotation != IteratorAnnotation::kNone) { + if (fuse_dir[0] == 0.0) { + // No room to mutate this fusion. + return ResultKind::kInvalid; + } + fuse_dir[0] = 1.0; + } + + // Mutate the fusion iters and replay the mutated fused/annotation steps. + int iter_offset = 0; + if (RandomChoose(fuse_dir, &(policy->rand_gen)) == 0) { + fused_ids.pop_back(); + iter_offset = 1; + } else { + auto last_id = fused_ids.back().get()->value; + fused_ids.push_back(last_id + 1); + iter_offset = -1; + } + auto new_fuse_step = FuseStep(stage_id, fused_ids); + tmp_s.CopyOnWrite()->transform_steps.push_back(new_fuse_step); + StepApplyToState(new_fuse_step, &tmp_s, policy->search_task->compute_dag); + tmp_s.CopyOnWrite()->transform_steps.push_back((*state)->transform_steps[step_id]); + StepApplyToState((*state)->transform_steps[step_id], &tmp_s, policy->search_task->compute_dag); + + // Replay the rest steps. + for (size_t s = step_id + 1; s < (*state)->transform_steps.size(); ++s) { + auto step = (*state)->transform_steps[s]; + if (step->stage_id == static_cast(stage_id)) { + // Since we changed the loop structure, iter ID in later steps to the same stage + // has to be adjusted. + auto ps = step.as(); + if (ps) { + if (ps->iter_id == 0) { + step = AnnotationStep(ps->stage_id, 0, ps->annotation); + } else { + CHECK_LE(ps->iter_id + iter_offset, tmp_s->stages[stage_id]->iters.size()); + step = AnnotationStep(ps->stage_id, ps->iter_id + iter_offset, ps->annotation); + } + } else { + // Unexpected step node that we did not process for now. + return ResultKind::kInvalid; + } + } + tmp_s.CopyOnWrite()->transform_steps.push_back(step); + StepApplyToState(step, &tmp_s, policy->search_task->compute_dag); + } + *state = tmp_s; return ResultKind::kValid; } diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.h b/src/auto_scheduler/search_policy/sketch_policy_rules.h index 5ddfd181cc5b..418fbda6a030 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.h +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.h @@ -26,10 +26,13 @@ #define TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_RULES_H_ #include +#include #include #include +#include "utils.h" + namespace tvm { namespace auto_scheduler { @@ -122,7 +125,7 @@ DEFINE_SKETCH_GENERATION_RULE(RuleSpecialComputeLocationGPU); /********** Init Population **********/ /*! \brief The base class for derivation rules used in the initial population. */ -class InitPopulationRule { +class PopulationGenerationRule { public: /*! \brief Result enumeration of the apply function. */ enum class ResultKind : int { kValid = 0, kInvalid = 1 }; @@ -138,7 +141,7 @@ class InitPopulationRule { }; #define DEFINE_INIT_POPULATION_RULE(rule_name) \ - class rule_name : public InitPopulationRule { \ + class rule_name : public PopulationGenerationRule { \ public: \ ResultKind Apply(SketchPolicyNode* policy, State* state) const final; \ }; @@ -162,6 +165,56 @@ DEFINE_INIT_POPULATION_RULE(InitVectorization); /*! \brief The rule that annotates thread binding for GPU. */ DEFINE_INIT_POPULATION_RULE(InitThreadBind); +/********** Mutation **********/ + +/*! \brief The base class for mutation rules used in the evolutionary search. */ +class PopulationMutationRule : public PopulationGenerationRule { + public: + /*! + * \brief Get the priority level of this mutation rule. + * \return The priority level of this mutation rule. Higher the better. + */ + virtual int GetLevel(const SearchTask& task) const = 0; +}; + +// A helper to define mutation rules with a constant rule level. +#define DEFINE_MUTATE_POPULATION_RULE(rule_name, rule_level) \ + class rule_name : public PopulationMutationRule { \ + public: \ + ResultKind Apply(SketchPolicyNode* policy, State* state) const final; \ + int GetLevel(const SearchTask& task) const final { return rule_level; } \ + }; + +/*! \brief The rule that mutates tile size by randomly dividing a tile size by a factor + and multipling it to another tile size. */ +DEFINE_MUTATE_POPULATION_RULE(MutateTileSize, 100); + +/*! \brief The rule that mutates the fusion iterators annotated by parallel. */ +DEFINE_MUTATE_POPULATION_RULE(MutateParallel, 50); + +/*! \brief The rule that mutates the factor of a randomly selected auto max unroll step. */ +class MutateMaxUnrollFactor : public PopulationMutationRule { + public: + ResultKind Apply(SketchPolicyNode* policy, State* state) const final; + int GetLevel(const SearchTask& task) const final { return 10; } + + const std::vector cpu_unroll_cands_ = {0, 16, 64, 512, 1024}; + const std::vector gpu_unroll_cands_ = {0, 16, 64, 512}; +}; + +/*! \brief The rule that randomly changes the computation location for some stages, which do not + * need tiling and are not strictly inlineable(e.g. data padding). */ +class MutateComputeLocation : public PopulationMutationRule { + public: + ResultKind Apply(SketchPolicyNode* policy, State* state) const final; + int GetLevel(const SearchTask& task) const final { + if (IsGPUTask(task)) { + return 0; + } + return 5; + } +}; + } // namespace auto_scheduler } // namespace tvm diff --git a/src/auto_scheduler/search_policy/utils.cc b/src/auto_scheduler/search_policy/utils.cc index b3f07b1c160f..a09ea596984a 100644 --- a/src/auto_scheduler/search_policy/utils.cc +++ b/src/auto_scheduler/search_policy/utils.cc @@ -18,7 +18,7 @@ */ /*! - * \file auto_scheduler/utils.cc + * \file auto_scheduler/search_policy/utils.cc * \brief Common utilities */ @@ -270,6 +270,69 @@ State FollowTiling(const State& state, int stage_id, const std::vector& spl return tmp_s; } +// Return whether a state has nested parallel, which is invalid on CPUs +bool HasNestedParallel(const State& state) { + std::function count_parallel_ct; + + count_parallel_ct = [&state, &count_parallel_ct](int stage_id, size_t* parallel_ct) { + const Stage& stage = state->stages[stage_id]; + + if (stage->compute_at == ComputeAtKind::kInlined) { + return; + } + + for (size_t i = 0; i < stage->iters.size(); ++i) { + if (stage->iters[i]->annotation == IteratorAnnotation::kParallel) { + (*parallel_ct)++; + } + + IterKey iter_key(stage_id, i); + auto pair = state->attach_map->iter_to_attached_stages.find(iter_key); + if (pair != state->attach_map->iter_to_attached_stages.end()) { + for (const auto& attach_stage_id : pair->second) { + count_parallel_ct(attach_stage_id, parallel_ct); + } + } + } + }; + + for (size_t stage_id = 0; stage_id < state->stages.size(); ++stage_id) { + size_t parallel_ct = 0; + + if (state->stages[stage_id]->compute_at == ComputeAtKind::kRoot) { + count_parallel_ct(stage_id, ¶llel_ct); + if (parallel_ct >= 2) { + return true; + } + } + } + + return false; +} + +void PruneInvalidState(const SearchTask& task, Array* states) { + size_t pt = 0; + for (size_t i = 0; i < states->size(); ++i) { + if (!(*states)[i].defined()) { + continue; + } + if (!IsGPUTask(task) && HasNestedParallel((*states)[i])) { + continue; + } + + if (i != pt) { + states->Set(pt, (*states)[i]); + } + pt++; + } + + if (pt == 0) { + LOG(INFO) << "All states are invalid."; + } else { + states->resize(pt); + } +} + const Array>& SplitFactorizationMemo::GetFactorizationSchemes( int extent, int n_lengths, int max_innermost_factor) { QueryKey key = std::make_tuple(extent, n_lengths, max_innermost_factor); diff --git a/src/auto_scheduler/search_policy/utils.h b/src/auto_scheduler/search_policy/utils.h index 2d49ab007c78..792102a2a1ce 100644 --- a/src/auto_scheduler/search_policy/utils.h +++ b/src/auto_scheduler/search_policy/utils.h @@ -18,7 +18,7 @@ */ /*! - * \file auto_scheduler/search_policy/utils.cc + * \file auto_scheduler/search_policy/utils.h * \brief Common utilities for search policies. */ @@ -662,6 +662,20 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo State FollowTiling(const State& state, int stage_id, const std::vector& split_step_ids, int n_split); +// Random choose an index according to a prefix sum probability. +inline int RandomChoose(const std::vector& prefix_sum_probs, std::mt19937* random_gen) { + std::uniform_real_distribution<> dis(0.0, 1.0); + double x = dis(*random_gen); + + CHECK(!prefix_sum_probs.empty()); + + return std::lower_bound(prefix_sum_probs.begin(), prefix_sum_probs.end(), x) - + prefix_sum_probs.begin(); +} + +// Prune invalid states and return the results in-place. +void PruneInvalidState(const SearchTask& task, Array* states); + } // namespace auto_scheduler } // namespace tvm diff --git a/tests/python/unittest/test_auto_scheduler_evolutionary_search.py b/tests/python/unittest/test_auto_scheduler_evolutionary_search.py new file mode 100644 index 000000000000..f06f06ac73c0 --- /dev/null +++ b/tests/python/unittest/test_auto_scheduler_evolutionary_search.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Test evolutionary search. """ + +import tvm +from test_auto_scheduler_common import matmul_auto_scheduler_test +from tvm import auto_scheduler, te +from tvm.auto_scheduler.cost_model.cost_model import PythonBasedModel + + +class MockCostModel(PythonBasedModel): + """A mock cost model that rates 1 only for the states with tile_k=2.""" + def predict(self, task, states): + scores = [] + found = False + for state in states: + for line in str(state).split('\n'): + if line.find('k.1') != -1 and line.find('(0,2)') != -1: + found = True + break + scores.append(1 if found else 0) + return scores + +def test_evo_search(): + """Test evolutionary search. Since we cannot mock random number generator, + we mocked the cost model to manually guide the evo search. If evo search works + as expected, it should find the target state after a sufficient number of iterations. + This unit test has been tested with 1,000 runs with no failures, meaning that + the failure rate is less than 0.1%. + """ + workload_key = auto_scheduler.make_workload_key(matmul_auto_scheduler_test, (10, 10, 4)) + dag = auto_scheduler.ComputeDAG(workload_key) + task = auto_scheduler.SearchTask(dag, workload_key, tvm.target.create('llvm')) + policy = auto_scheduler.SketchPolicy(task, schedule_cost_model=MockCostModel(), verbose=0) + states = policy.sample_initial_population(50) + pruned_states = [] + for state in states: + found = False + for line in str(state).split('\n'): + # Remove all tile_k=2 states and expect evo search will fine them. + if line.find('k.1') != -1 and line.find('(0,2)') != -1: + found = True + break + if not found: + pruned_states.append(state) + + new_states = policy.evolutionary_search(pruned_states, 50) + found = False + for state in new_states: + for line in str(state).split('\n'): + # Check if evo search found at least one state with tile_k=2. + if line.find('k.1') != -1 and line.find('(0,2)') != -1: + found = True + break + if found: + break + assert found + + +if __name__ == "__main__": + test_evo_search()