Skip to content

Commit

Permalink
sequential cpp test (apache#5745)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics authored and Trevor Morris committed Jun 9, 2020
1 parent 6790a19 commit 2fab9c1
Showing 1 changed file with 33 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,44 @@
*/

#include <gtest/gtest.h>
#include <topi/broadcast.h>
#include <topi/generic/injective.h>
#include <tvm/driver/driver_api.h>
#include <tvm/ir/module.h>
#include <tvm/node/structural_equal.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/op_strategy.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>

TVM_REGISTER_GLOBAL("schedule").set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) {
*rv = topi::generic::schedule_injective(args[0], args[1]);
});
using namespace tvm;

TVM_REGISTER_GLOBAL("test.seq.strategy")
.set_body_typed([](const Attrs& attrs, const Array<te::Tensor>& inputs, const Type& out_type,
const Target& target) {
relay::FTVMCompute fcompute = [](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) -> Array<te::Tensor> {
CHECK_EQ(inputs.size(), 2U);
return {topi::add(inputs[0], inputs[1])};
};
relay::FTVMSchedule fschedule = [](const Attrs& attrs, const Array<te::Tensor>& outs,
const Target& target) {
With<Target> target_scope(target);
return topi::generic::schedule_injective(target, outs);
};

auto n = make_object<relay::OpStrategyNode>();
auto strategy = relay::OpStrategy(std::move(n));
strategy.AddImplementation(fcompute, fschedule, "test.strategy", 10);
return strategy;
});

TEST(Relay, Sequential) {
using namespace tvm;
auto tensor_type = relay::TensorType({1, 2, 3}, DataType::Float(32));
auto c_data = tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});

Expand All @@ -53,14 +73,16 @@ TEST(Relay, Sequential) {
auto z3 = relay::Let(a, c, z2);
relay::Function func = relay::Function(relay::FreeVars(z3), z3, relay::Type(), {});

// Get schedule
auto reg = tvm::runtime::Registry::Get("relay.op._Register");
auto sch = tvm::runtime::Registry::Get("schedule");
if (!reg || !sch) {
LOG(FATAL) << "Register/schedule is not defined.";
auto reg = tvm::runtime::Registry::Get("ir.RegisterOpAttr");
if (!reg) {
LOG(FATAL) << "Register is not defined.";
}

(*reg)("add", "FTVMSchedule", *sch, 10);
auto fs = tvm::runtime::Registry::Get("test.seq.strategy");
if (!fs) {
LOG(FATAL) << "Strategy is not defined.";
}
auto fgeneric = GenericFunc::Get("test.strategy_generic").set_default(*fs);
(*reg)("add", "FTVMStrategy", fgeneric, 10);

// Run sequential passes.
tvm::Array<relay::transform::Pass> pass_seqs{
Expand Down

0 comments on commit 2fab9c1

Please sign in to comment.