diff --git a/tests/cpp/relay_transform_sequential.cc b/tests/cpp/relay_transform_sequential_test.cc similarity index 71% rename from tests/cpp/relay_transform_sequential.cc rename to tests/cpp/relay_transform_sequential_test.cc index f08d5574d51c..bb4bf928b018 100644 --- a/tests/cpp/relay_transform_sequential.cc +++ b/tests/cpp/relay_transform_sequential_test.cc @@ -18,24 +18,44 @@ */ #include +#include #include #include #include #include #include #include +#include +#include #include #include #include #include #include -TVM_REGISTER_GLOBAL("schedule").set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { - *rv = topi::generic::schedule_injective(args[0], args[1]); -}); +using namespace tvm; + +TVM_REGISTER_GLOBAL("test.seq.strategy") + .set_body_typed([](const Attrs& attrs, const Array& inputs, const Type& out_type, + const Target& target) { + relay::FTVMCompute fcompute = [](const Attrs& attrs, const Array& inputs, + const Type& out_type) -> Array { + CHECK_EQ(inputs.size(), 2U); + return {topi::add(inputs[0], inputs[1])}; + }; + relay::FTVMSchedule fschedule = [](const Attrs& attrs, const Array& outs, + const Target& target) { + With target_scope(target); + return topi::generic::schedule_injective(target, outs); + }; + + auto n = make_object(); + auto strategy = relay::OpStrategy(std::move(n)); + strategy.AddImplementation(fcompute, fschedule, "test.strategy", 10); + return strategy; + }); TEST(Relay, Sequential) { - using namespace tvm; auto tensor_type = relay::TensorType({1, 2, 3}, DataType::Float(32)); auto c_data = tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); @@ -53,14 +73,16 @@ TEST(Relay, Sequential) { auto z3 = relay::Let(a, c, z2); relay::Function func = relay::Function(relay::FreeVars(z3), z3, relay::Type(), {}); - // Get schedule - auto reg = tvm::runtime::Registry::Get("relay.op._Register"); - auto sch = tvm::runtime::Registry::Get("schedule"); - if (!reg || !sch) { - LOG(FATAL) << "Register/schedule is not defined."; + auto reg = tvm::runtime::Registry::Get("ir.RegisterOpAttr"); + if (!reg) { + LOG(FATAL) << "Register is not defined."; } - - (*reg)("add", "FTVMSchedule", *sch, 10); + auto fs = tvm::runtime::Registry::Get("test.seq.strategy"); + if (!fs) { + LOG(FATAL) << "Strategy is not defined."; + } + auto fgeneric = GenericFunc::Get("test.strategy_generic").set_default(*fs); + (*reg)("add", "FTVMStrategy", fgeneric, 10); // Run sequential passes. tvm::Array pass_seqs{