Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/CINN into b
Browse files Browse the repository at this point in the history
  • Loading branch information
MayYouBeProsperous committed Sep 27, 2022
2 parents d9bea0c + 9ff559d commit a43ff90
Show file tree
Hide file tree
Showing 127 changed files with 5,233 additions and 1,217 deletions.
1 change: 1 addition & 0 deletions cinn/auto_schedule/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ add_subdirectory(search_space)
add_subdirectory(search_strategy)
add_subdirectory(task)
add_subdirectory(task_scheduler)
add_subdirectory(tests)

proto_library(auto_schedule_proto SRCS auto_schedule.proto)

Expand Down
12 changes: 9 additions & 3 deletions cinn/auto_schedule/auto_tuner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
#include "cinn/auto_schedule/task/task_creator.h"
#include "cinn/auto_schedule/task/tune_task.h"
#include "cinn/auto_schedule/task_scheduler/task_scheduler.h"
#include "cinn/common/context.h"
#include "cinn/common/type.h"
#include "cinn/hlir/framework/op.h"

namespace cinn {
namespace auto_schedule {
Expand All @@ -43,11 +45,15 @@ void AutoTuner::Initialize(const Config& config, hlir::framework::GraphCompiler*
// create tasks
TaskCreator task_creator;
tasks_ = task_creator.CreateTuneTaskOpLevel(graph_);

const auto& dtype_dict = graph_->GetAttrs<absl::flat_hash_map<std::string, common::Type>>("inferdtype");
const auto& shape_dict = graph_->GetAttrs<absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");

op_lowerer_ = std::make_unique<hlir::framework::OpLowerer>(dtype_dict, shape_dict, target_);
for (TuneTask& task : tasks_) {
task.SetGraphCompiler(graph_compiler);
task.SetOpLowerer(op_lowerer_.get());
task.TaskGraphToUnoptLoweredFunc();
task.SerializeToString(graph_->GetAttrs<absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape"),
graph_->GetAttrs<absl::flat_hash_map<std::string, common::Type>>("inferdtype"));
task.SerializeToString(shape_dict, dtype_dict);
VLOG(3) << "Add a task with serialized_key:\n" << task.serialized_key;
}

Expand Down
2 changes: 2 additions & 0 deletions cinn/auto_schedule/auto_tuner.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "cinn/common/target.h"
#include "cinn/hlir/framework/graph.h"
#include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/hlir/framework/op_lowering.h"

namespace cinn {
namespace auto_schedule {
Expand Down Expand Up @@ -55,6 +56,7 @@ class AutoTuner {
private:
const common::Target& target_;
hlir::framework::Graph* graph_;
std::unique_ptr<hlir::framework::OpLowerer> op_lowerer_;

// Tasks to tune
std::vector<TuneTask> tasks_;
Expand Down
112 changes: 91 additions & 21 deletions cinn/auto_schedule/auto_tuner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,21 @@
#include <glog/logging.h>
#include <gtest/gtest.h>

#include <cstdlib>
#include <iostream>

#include "cinn/common/target.h"
#include "cinn/frontend/net_builder.h"
#include "cinn/frontend/syntax.h"
#include "cinn/hlir/framework/graph.h"
#include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/hlir/framework/node.h"
#include "cinn/hlir/framework/pass.h"
#include "cinn/ir/ir_base.h"
#include "cinn/runtime/flags.h"

DECLARE_bool(auto_schedule_use_cost_model);
DECLARE_bool(cinn_ir_schedule);

namespace cinn {
namespace auto_schedule {
Expand All @@ -35,9 +40,10 @@ using ::cinn::hlir::framework::BuildScope;
using ::cinn::hlir::framework::Graph;
using ::cinn::hlir::framework::GraphCompiler;
using ::cinn::hlir::framework::Instruction;
using ::cinn::hlir::framework::Node;
using ::cinn::hlir::framework::Scope;

class TestAutoTuner : public ::testing::Test {
class TestAutoTunerWithoutFusion : public ::testing::Test {
public:
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
Expand All @@ -50,34 +56,47 @@ class TestAutoTuner : public ::testing::Test {
std::unique_ptr<GraphCompiler> graph_compiler;
std::unique_ptr<AutoTuner> tuner;

static frontend::Program CreateAddReluProgram();
frontend::Program CreateAddReluProgram() {
frontend::NetBuilder builder("test");

auto a = builder.CreateInput(Float(32), {1, 64, 112, 112}, "A");
auto b = builder.CreateInput(Float(32), {64}, "B");
auto c = builder.Add(a, b, 1);
auto d = builder.Relu(c);

return builder.Build();
}

void SetUp() override {
graph = std::make_shared<Graph>(CreateAddReluProgram(), target);
compiled_scope = BuildScope(target, graph);
graph_compiler = std::make_unique<GraphCompiler>(target, compiled_scope, graph);
tuner = std::make_unique<AutoTuner>(target, graph.get());
srand(0);
// AutoTuner is combined with new IR Schedule
FLAGS_cinn_ir_schedule = true;
graph = std::make_shared<Graph>(CreateAddReluProgram(), target);
compiled_scope = BuildScope(target, graph);
graph_compiler = std::make_unique<GraphCompiler>(target, compiled_scope, graph);
tuner = std::make_unique<AutoTuner>(target, graph.get());
}

TuningResult InitializeAndTune(const AutoTuner::Config& config, const TuningOptions& options) {
tuner->Initialize(config, graph_compiler.get());
return tuner->Tune(options);
}

void BasicCheckResult(const TuningResult& result) {
virtual void BasicCheckResult(const TuningResult& result) {
ASSERT_EQ(2, result.tuned_graph.size());
const auto& sub_graph1 = result.tuned_graph.front();
ASSERT_EQ(1, sub_graph1.groups.size());
ASSERT_EQ(sub_graph1.groups[0][0]->op()->name, "elementwise_add");
ASSERT_EQ(sub_graph1.groups[0]->CollectNodes()[0]->op()->name, "elementwise_add");
const auto& sub_graph2 = result.tuned_graph.back();
ASSERT_EQ(1, sub_graph2.groups.size());
ASSERT_EQ(sub_graph2.groups[0][0]->op()->name, "relu");
ASSERT_EQ(sub_graph2.groups[0]->CollectNodes()[0]->op()->name, "relu");

ASSERT_EQ(result.optimized_exprs.size(), 2UL);
ASSERT_EQ(result.optimized_exprs[0].lowered_funcs.size(), 1UL);
ASSERT_EQ(result.optimized_exprs[0].lowered_funcs[0].size(), 1UL);
}

void ApplyTunedAndRun(const TuningResult& result) {
virtual void ApplyTunedAndRun(const TuningResult& result) {
// build runtime program with tuning result
GraphCompiler::CompileOptions compile_options;
compile_options.with_instantiate_variables = true;
Expand Down Expand Up @@ -119,33 +138,84 @@ class TestAutoTuner : public ::testing::Test {
}
};

frontend::Program TestAutoTuner::CreateAddReluProgram() {
frontend::NetBuilder builder("test");
TEST_F(TestAutoTunerWithoutFusion, ZeroMeasure_DisableCostModel) {
FLAGS_auto_schedule_use_cost_model = false;
ZeroMeasure();
}

auto a = builder.CreateInput(Float(32), {1, 64, 112, 112}, "A");
auto b = builder.CreateInput(Float(32), {64}, "B");
auto c = builder.Add(a, b, 1);
auto d = builder.Relu(c);
TEST_F(TestAutoTunerWithoutFusion, ZeroMeasure_EnableCostModel) {
FLAGS_auto_schedule_use_cost_model = true;
ZeroMeasure();
}

TEST_F(TestAutoTunerWithoutFusion, NonZeroMeasure_DisableCostModel) {
FLAGS_auto_schedule_use_cost_model = false;
NonZeroMeasure();
}

return builder.Build();
TEST_F(TestAutoTunerWithoutFusion, NonZeroMeasure_EnableCostModel) {
FLAGS_auto_schedule_use_cost_model = true;
NonZeroMeasure();
}

TEST_F(TestAutoTuner, ZeroMeasure_DisableCostModel) {
class TestAutoTunerWithFusion : public TestAutoTunerWithoutFusion {
public:
void SetUp() override {
srand(0);
// AutoTuner is combined with new IR Schedule
FLAGS_cinn_ir_schedule = true;
graph = std::make_shared<Graph>(CreateAddReluProgram(), target);
ApplyPass(graph.get(), "OpFusionPass");
compiled_scope = BuildScope(target, graph);
graph_compiler = std::make_unique<GraphCompiler>(target, compiled_scope, graph);
tuner = std::make_unique<AutoTuner>(target, graph.get());
}

void BasicCheckResult(const TuningResult& result) override {
ASSERT_EQ(result.tuned_graph.size(), 1UL);
const std::vector<Node*>& nodes = result.tuned_graph[0].groups[0]->CollectNodes();
ASSERT_EQ(nodes.size(), 3UL);
ASSERT_EQ(nodes[0]->op()->name, "broadcast_to");
ASSERT_EQ(nodes[1]->op()->name, "elementwise_add");
ASSERT_EQ(nodes[2]->op()->name, "relu");

ASSERT_EQ(result.optimized_exprs.size(), 1UL);
ASSERT_EQ(result.optimized_exprs[0].lowered_funcs.size(), 1UL);
ASSERT_EQ(result.optimized_exprs[0].lowered_funcs[0].size(), 1UL);
}

void ApplyTunedAndRun(const TuningResult& result) override {
// build runtime program with tuning result
GraphCompiler::CompileOptions compile_options;
compile_options.with_instantiate_variables = true;
compile_options.Apply(result);
ASSERT_EQ(1, compile_options.groups.size());
ASSERT_EQ(1, compile_options.lowered_funcs.size());
ASSERT_EQ(1, compile_options.lowered_funcs[0].size());
VLOG(6) << "Print lowered_funcs before building";
VLOG(6) << compile_options.lowered_funcs[0][0];
auto runtime_program = graph_compiler->Build(compile_options).runtime_program;
ASSERT_EQ(1, runtime_program->size());
runtime_program->Execute();
}
};

TEST_F(TestAutoTunerWithFusion, ZeroMeasure_DisableCostModel) {
FLAGS_auto_schedule_use_cost_model = false;
ZeroMeasure();
}

TEST_F(TestAutoTuner, ZeroMeasure_EnableCostModel) {
TEST_F(TestAutoTunerWithFusion, ZeroMeasure_EnableCostModel) {
FLAGS_auto_schedule_use_cost_model = true;
ZeroMeasure();
}

TEST_F(TestAutoTuner, NonZeroMeasure_DisableCostModel) {
TEST_F(TestAutoTunerWithFusion, NonZeroMeasure_DisableCostModel) {
FLAGS_auto_schedule_use_cost_model = false;
NonZeroMeasure();
}

TEST_F(TestAutoTuner, NonZeroMeasure_EnableCostModel) {
TEST_F(TestAutoTunerWithFusion, NonZeroMeasure_EnableCostModel) {
FLAGS_auto_schedule_use_cost_model = true;
NonZeroMeasure();
}
Expand Down
14 changes: 13 additions & 1 deletion cinn/auto_schedule/measure/measurer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
#include "cinn/frontend/net_builder.h"
#include "cinn/frontend/syntax.h"
#include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/runtime/flags.h"

DECLARE_bool(cinn_ir_schedule);

namespace cinn {
namespace auto_schedule {
Expand Down Expand Up @@ -51,6 +54,7 @@ class TestMeasurer : public ::testing::Test {
std::vector<MeasureInput> inputs;

void SetUp() override {
FLAGS_cinn_ir_schedule = true;
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
#else
Expand All @@ -60,15 +64,23 @@ class TestMeasurer : public ::testing::Test {
auto scope = BuildScope(target, graph);
graph_compiler = std::make_unique<GraphCompiler>(target, scope, graph);
TaskCreator task_creator;
tasks = task_creator.CreateTuneTaskOpLevel(graph.get());
tasks = task_creator.CreateTuneTaskOpLevel(graph.get());
const auto& dtype_dict = graph->GetAttrs<absl::flat_hash_map<std::string, common::Type>>("inferdtype");
const auto& shape_dict = graph->GetAttrs<absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");

auto op_lowerer = std::make_unique<hlir::framework::OpLowerer>(dtype_dict, shape_dict, target);

inputs.reserve(tasks.size());
for (int i = 0; i < tasks.size(); ++i) {
auto* task = &tasks[i];
task->SetOpLowerer(op_lowerer.get());
task->TaskGraphToUnoptLoweredFunc();
MeasureInput input;
input.task = task;
// TODO(CtfGo): currently FusedGraphToLoweredFunc doesn't work well on NVGPU target,
// this setting of lowered_funcs will be enabled once we fix the bug
// input.lowered_funcs = graph_compiler->FusedGraphToLoweredFunc(task->task_graph);
input.lowered_funcs.push_back(task->lowered_funcs);
inputs.emplace_back(input);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ int AutoGenRule::NumberApplicable() const {
return num_applicable_;
}

ir::ModuleExpr AutoGenRule::ApplyRandomly() {
ir::IRSchedule AutoGenRule::ApplyRandomly() {
CHECK_GT(num_applicable_, 0) << "Call " << GetRuleName() << "::ApplyRandomly() with NumberApplicable() == 0";
int index = rand() % num_applicable_;
return Apply(index);
Expand Down
6 changes: 3 additions & 3 deletions cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class AutoGenRule {

// Initailize the AutoGenRule, it must be called before further actions.
// Returns false if the rule cannot be applied on the mod_expr, true otherwise.
virtual RuleApplyType Init(const ir::ModuleExpr& mod_expr) = 0;
virtual RuleApplyType Init(const ir::IRSchedule& init_schedule) = 0;

// CINN IRSchedule can contain many ScheduleBlock(s) and Loop(s), so
// a auto gen rule may be suitable to different number of
Expand All @@ -58,11 +58,11 @@ class AutoGenRule {
virtual int NumberApplicable() const;

// Applies rule on the ir::ModuleExpr for a schedule block randomly
virtual ir::ModuleExpr ApplyRandomly();
virtual ir::IRSchedule ApplyRandomly();

// Applies rule on the ir::ModuleExpr for a schedule block specified by index
// between 0 (inclusive) and NumberApplicable() (exclusive)
virtual ir::ModuleExpr Apply(int index) = 0;
virtual ir::IRSchedule Apply(int index) = 0;

// Returns the name of the rule, used for debug.
virtual std::string GetRuleName() const = 0;
Expand Down
18 changes: 11 additions & 7 deletions cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "cinn/ir/ir_base.h"
#include "cinn/ir/ir_printer.h"
#include "cinn/ir/ir_schedule.h"
#include "cinn/optim/ir_copy.h"

namespace cinn {
namespace auto_schedule {
Expand All @@ -53,7 +54,9 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr) cons
return false;
}

if (no_inline_output_names_.find(tensor->name) != no_inline_output_names_.end()) {
// LoweredFunc output can be tensor name or tensor buffer name
if (no_inline_output_names_.find(tensor->name) != no_inline_output_names_.end() ||
no_inline_output_names_.find(tensor->buffer->name) != no_inline_output_names_.end()) {
return false;
}

Expand Down Expand Up @@ -119,8 +122,8 @@ AutoInlineType AutoInline::AnalyzeInlineType(const Expr& sche_block_realize_expr
return AutoInlineType::kCannotInline;
}

RuleApplyType AutoInline::Init(const ir::ModuleExpr& mod_expr) {
ir_schedule_ = std::make_unique<ir::IRSchedule>(mod_expr);
RuleApplyType AutoInline::Init(const ir::IRSchedule& init_schedule) {
ir_schedule_ = std::make_unique<ir::IRSchedule>(optim::IRCopy(init_schedule));
all_block_realizes_ = ir_schedule_->GetAllBlocks();
apply_indices_and_type_.clear();
num_applicable_ = 0;
Expand All @@ -138,12 +141,13 @@ RuleApplyType AutoInline::Init(const ir::ModuleExpr& mod_expr) {
return num_applicable_ > 0 ? RuleApplyType::kApply : RuleApplyType::kCannotApply;
}

ir::ModuleExpr AutoInline::Apply(int index) {
ir::IRSchedule AutoInline::Apply(int index) {
CHECK(ir_schedule_ != nullptr) << "Run AutoInline::Apply without Init";
CHECK(num_applicable_ > 0 && apply_indices_and_type_.size() == num_applicable_)
<< "AutoInline::Apply pre-condition doesn't meet";
CHECK(num_applicable_ > index)
<< "Invalid index for AutoInline::Apply, the index needs 0 <= index && index < NumberApplicable()";
CHECK(index >= 0 && num_applicable_ > index)
<< "Invalid index for AutoInline::Apply, the index needs 0 <= index && index < NumberApplicable(), "
<< "Currently index = " << index << ", NumberApplicable() = " << num_applicable_;

int apply_index = apply_indices_and_type_[index].first;
AutoInlineType type = apply_indices_and_type_[index].second;
Expand Down Expand Up @@ -171,7 +175,7 @@ ir::ModuleExpr AutoInline::Apply(int index) {
sche_block->write_buffers = {};
AnalyzeScheduleBlockReadWriteBuffer(sche_block);
}
return ir_schedule_->GetModule();
return optim::IRCopy(*ir_schedule_);
}

std::string AutoInline::GetRuleName() const { return "AutoInline"; }
Expand Down
4 changes: 2 additions & 2 deletions cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ class AutoInline : public AutoGenRule {
AutoInline(const common::Target& target, const std::unordered_set<std::string>& no_inline_output_names);
~AutoInline() = default;

RuleApplyType Init(const ir::ModuleExpr& mod_expr) override;
RuleApplyType Init(const ir::IRSchedule& init_schedule) override;

ir::ModuleExpr Apply(int index) override;
ir::IRSchedule Apply(int index) override;

std::string GetRuleName() const override;

Expand Down
Loading

0 comments on commit a43ff90

Please sign in to comment.