Skip to content

Commit

Permalink
[CINN][New Hardware Update] standardize CINN_WITH_CUDA (#64506)
Browse files Browse the repository at this point in the history
  • Loading branch information
DongBaiYue authored May 23, 2024
1 parent 7b73a73 commit 8a3dc8b
Show file tree
Hide file tree
Showing 13 changed files with 279 additions and 182 deletions.
144 changes: 73 additions & 71 deletions paddle/cinn/common/cas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1072,77 +1072,79 @@ bool CasSimplifyMutator::SimplifySpecificSumMod(Expr* result, Expr a, Expr b) {
}
}
}
#ifdef CINN_WITH_CUDA
return false;
#else

int const_value = 0;
Expr lower_bound;
Expr upper_bound;
Expr rest_oper;
bool can_simplify = true;
bool has_int = false;
// fold only the expr bound(may contains the var) and try to simplify the var
Expr unfolded_lower_bound, unfolded_upper_bound;
for (Expr& v : a_sum->operands()) {
auto* v_int = v.As<IntImm>();
if (v_int) {
const_value += v_int->value;
has_int = true;
} else if (GetVarBound(&lower_bound, &upper_bound, v, false)) {
AddBaseAndSimplify(&rest_oper, v);
} else {
can_simplify = false;
break;
}
}
can_simplify = can_simplify && has_int &&
std::abs(const_value) % b_i->value == b_i->value - 1 &&
lower_bound.defined() && upper_bound.defined() &&
rest_oper.defined();
// further infer the vars' bound by the intervals infos, try to get the
// constant
if (can_simplify) {
std::vector<Expr> bounds = {lower_bound, upper_bound};
for (int i = 0; i < bounds.size(); ++i) {
Expr bound = bounds[i];
Expr bound_l, bound_r;
GetExprBound(&bound_l, &bound_r, bound);
if (i == 0 && bound_l.defined()) {
lower_bound = bound_l;
}
if (i == 1 && bound_r.defined()) {
upper_bound = bound_r;
}
}
} else {
return false;
}
// case1: (32+(-x))%33 = 32-x%33 (0<=x<=32)
// case2: (x-32)%33 = x%33 - 32%33 (0<=x<=32)
can_simplify = can_simplify && lower_bound.is_constant();
bool case1 = can_simplify && const_value >= 0 &&
lower_bound.get_constant() >= -const_value &&
upper_bound.is_constant() && upper_bound.get_constant() <= 0;
bool case2 = can_simplify && const_value <= 0 &&
lower_bound.get_constant() >= 0 && upper_bound.is_constant() &&
upper_bound.get_constant() <= -const_value;
can_simplify = can_simplify && (case1 || case2);
if (can_simplify) {
Expr const_expr;
if (const_value < 0) {
const_expr = make_const(b->type(), const_value % b_i->value);
} else {
const_expr = make_const(b->type(), const_value % b_i->value);
}
*result = CasSimplify(
Sum::Make(
{const_expr, CasSimplify(Mod::Make(rest_oper, b), var_intervals)}),
var_intervals);
return true;
}
return false;
#endif
return cinn::common::DefaultDeviceTarget().arch.Match(
[&](common::NVGPUArch) { return false; },
[&](std::variant<common::UnknownArch, common::X86Arch, common::ARMArch>) {
int const_value = 0;
Expr lower_bound;
Expr upper_bound;
Expr rest_oper;
bool can_simplify = true;
bool has_int = false;
// fold only the expr bound(may contains the var) and try to simplify
// the var
Expr unfolded_lower_bound, unfolded_upper_bound;
for (Expr& v : a_sum->operands()) {
auto* v_int = v.As<IntImm>();
if (v_int) {
const_value += v_int->value;
has_int = true;
} else if (GetVarBound(&lower_bound, &upper_bound, v, false)) {
AddBaseAndSimplify(&rest_oper, v);
} else {
can_simplify = false;
break;
}
}
can_simplify = can_simplify && has_int &&
std::abs(const_value) % b_i->value == b_i->value - 1 &&
lower_bound.defined() && upper_bound.defined() &&
rest_oper.defined();
// further infer the vars' bound by the intervals infos, try to get the
// constant
if (can_simplify) {
std::vector<Expr> bounds = {lower_bound, upper_bound};
for (int i = 0; i < bounds.size(); ++i) {
Expr bound = bounds[i];
Expr bound_l, bound_r;
GetExprBound(&bound_l, &bound_r, bound);
if (i == 0 && bound_l.defined()) {
lower_bound = bound_l;
}
if (i == 1 && bound_r.defined()) {
upper_bound = bound_r;
}
}
} else {
return false;
}
// case1: (32+(-x))%33 = 32-x%33 (0<=x<=32)
// case2: (x-32)%33 = x%33 - 32%33 (0<=x<=32)
can_simplify = can_simplify && lower_bound.is_constant();
bool case1 = can_simplify && const_value >= 0 &&
lower_bound.get_constant() >= -const_value &&
upper_bound.is_constant() &&
upper_bound.get_constant() <= 0;
bool case2 = can_simplify && const_value <= 0 &&
lower_bound.get_constant() >= 0 &&
upper_bound.is_constant() &&
upper_bound.get_constant() <= -const_value;
can_simplify = can_simplify && (case1 || case2);
if (can_simplify) {
Expr const_expr;
if (const_value < 0) {
const_expr = make_const(b->type(), const_value % b_i->value);
} else {
const_expr = make_const(b->type(), const_value % b_i->value);
}
*result = CasSimplify(
Sum::Make({const_expr,
CasSimplify(Mod::Make(rest_oper, b), var_intervals)}),
var_intervals);
return true;
}
return false;
});
}

// Return if the var's interval is nonnegative.
Expand Down
8 changes: 7 additions & 1 deletion paddle/cinn/hlir/framework/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,15 @@ void Graph::VisualizeGroupedGraph(
for (int idx = 0; idx < groups.size(); ++idx) {
// Create fusion_group_x folder
int device_id = 0;
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch,
common::X86Arch,
common::ARMArch>) {},
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
cudaGetDevice(&device_id);
cudaGetDevice(&device_id);
#endif
});
auto group_path =
utils::StringFormat("%s/device_%d/fusion_group_%d",
FLAGS_cinn_fusion_groups_graphviz_dir.c_str(),
Expand Down
8 changes: 7 additions & 1 deletion paddle/cinn/hlir/framework/graph_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,15 @@ std::vector<ir::LoweredFunc> GetFuncFromImpl(
CHECK_EQ(funcs_after_schedule.size(), expr_pack.size());
std::vector<ir::LoweredFunc> res;
for (int i = 0; i < funcs_after_schedule.size(); i++) {
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch,
common::X86Arch,
common::ARMArch>) {},
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
optim::OptimizeExprGPU(&(funcs_after_schedule[i]->body));
optim::OptimizeExprGPU(&(funcs_after_schedule[i]->body));
#endif
});
auto temp_buffers = lang::GetTempBuffers(
all_arg_tensors, tensor_group, funcs_after_schedule[i]->body);

Expand Down
7 changes: 6 additions & 1 deletion paddle/cinn/hlir/framework/instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -438,9 +438,14 @@ std::string Instruction::DumpInstruction() const {

void Instruction::CheckResults(
const std::map<std::string, cinn_pod_value_t>* name2podargs, void* stream) {
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch, common::X86Arch, common::ARMArch>) {
},
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
cudaStreamSynchronize(static_cast<cudaStream_t>(stream));
cudaStreamSynchronize(static_cast<cudaStream_t>(stream));
#endif
});

if (fn_names_.size() == 1) {
std::unordered_set<std::string> skipped_instr_set = {
Expand Down
8 changes: 7 additions & 1 deletion paddle/cinn/hlir/framework/instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,15 @@ class Instruction {
((lower_func_ptr_t)fn_ptrs_[idx])(static_cast<void*>(pod_args.data()),
pod_args.size());
}
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch,
common::X86Arch,
common::ARMArch>) {},
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
CUDA_CALL(cudaDeviceSynchronize());
CUDA_CALL(cudaDeviceSynchronize());
#endif
});
}
}
if (flag >= 0) {
Expand Down
11 changes: 8 additions & 3 deletions paddle/cinn/hlir/framework/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,16 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
}

auto func_body = ir_sch->GetModule().GetExprs().at(0);
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch, common::X86Arch, common::ARMArch>) {
},
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
if (apply_pass) {
optim::OptimizeExprGPU(&(func_body));
}
if (apply_pass) {
optim::OptimizeExprGPU(&(func_body));
}
#endif
});
// 2.Prepare temp buffers
poly::StageMap stages;
auto temp_buffers =
Expand Down
14 changes: 12 additions & 2 deletions paddle/cinn/hlir/framework/parallel_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,14 @@ void ParallelCompiler::SplitTask() {
context_->graph->fusion_groups.size() ==
context_->lowered_funcs.size());
int device_id = 0;
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch, common::X86Arch, common::ARMArch>) {
},
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
CUDA_CALL(cudaGetDevice(&device_id));
CUDA_CALL(cudaGetDevice(&device_id));
#endif
});
for (int group_id = 0; group_id < context_->graph->fusion_groups.size();
++group_id) {
tasks_.emplace_back(device_id, group_id, this, context_);
Expand Down Expand Up @@ -132,9 +137,14 @@ void ParallelCompiler::RunTask() {

void ParallelCompiler::LaunchTask() {
int device_id = 0;
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch, common::X86Arch, common::ARMArch>) {
},
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
CUDA_CALL(cudaGetDevice(&device_id));
CUDA_CALL(cudaGetDevice(&device_id));
#endif
});
int num_threads = FLAGS_cinn_parallel_compile_thread;
#if defined(PADDLE_WITH_DISTRIBUTE)
if (device_id > 0) {
Expand Down
10 changes: 8 additions & 2 deletions paddle/cinn/hlir/framework/pir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -828,10 +828,16 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
std::vector<ir::LoweredFunc> lowered_funcs;
for (ir::Expr func_body : func_bodies) {
optim::EliminateDeadScheduleBlock(&(func_body), group->output_names());
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch,
common::X86Arch,
common::ARMArch>) {},
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
optim::EliminateCommonGlobalMemoryRead(&(func_body));
optim::OptimizeExprGPU(&(func_body));
optim::EliminateCommonGlobalMemoryRead(&(func_body));
optim::OptimizeExprGPU(&(func_body));
#endif
});

// 2.Prepare temp buffers
auto temp_buffers =
Expand Down
23 changes: 19 additions & 4 deletions paddle/cinn/ir/group_schedule/st_shape_group_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,22 +114,37 @@ void StaticShapeGroupScheduler::Schedule() {
&StaticShapeGroupScheduler::IsKeepGraphDependency);
DoLoopAlignment();
DoComputeInline();
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch, common::X86Arch, common::ARMArch>) {
},
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
OptimizeReduction();
OptimizeReduction();
#endif
});
DoHorizontalLoopFusion();
DoVerticalLoopFusion();
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch, common::X86Arch, common::ARMArch>) {
},
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
BindCudaAxis();
AllocateStorage();
BindCudaAxis();
AllocateStorage();
#endif
});
}

void StaticShapeGroupScheduler::MapExprSchedule() {
DoComputeInline();
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch, common::X86Arch, common::ARMArch>) {
},
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
AllocateStorage();
AllocateStorage();
#endif
});
}

std::vector<std::pair<SymbolicPredicate, ir::Expr>>
Expand Down
Loading

0 comments on commit 8a3dc8b

Please sign in to comment.