Skip to content

Commit

Permalink
combine initreduce with computeat (PaddlePaddle#323)
Browse files Browse the repository at this point in the history
  • Loading branch information
haozech authored Jan 11, 2021
1 parent 7fb48cc commit bad12a1
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 2 deletions.
Empty file modified cinn/common/ir_util.cc
100644 → 100755
Empty file.
Empty file modified cinn/hlir/pe/schedule.h
100644 → 100755
Empty file.
9 changes: 7 additions & 2 deletions cinn/ir/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,13 @@ ir::Tensor _Tensor_::InitReduction(poly::StageMap stages, const Target &target)
if (target.arch == Target::Arch::NVGPU) {
if (init_tensor->shape.size() > 1) {
stages[init_tensor]->Split(1, 2);
stages[init_tensor]->Bind(0, "blockIdx.x");
stages[init_tensor]->Bind(1, "threadIdx.x");
}
stages[init_tensor]->ComputeAt2(stages[this], stages[init_tensor]->axis_names().size() - 1);
auto temp = stages[this]->ctrl_depends();
for (auto &i : temp) {
if (i->name != init_tensor->name) {
stages[init_tensor]->CtrlDepend(i);
}
}
}
stages[this]->CtrlDepend(init_tensor);
Expand Down

0 comments on commit bad12a1

Please sign in to comment.