diff --git a/cinn/common/ir_util.cc b/cinn/common/ir_util.cc old mode 100644 new mode 100755 diff --git a/cinn/hlir/pe/schedule.h b/cinn/hlir/pe/schedule.h old mode 100644 new mode 100755 diff --git a/cinn/ir/tensor.cc b/cinn/ir/tensor.cc index 300e27477cb30..c9317edff375b 100755 --- a/cinn/ir/tensor.cc +++ b/cinn/ir/tensor.cc @@ -236,8 +236,13 @@ ir::Tensor _Tensor_::InitReduction(poly::StageMap stages, const Target &target) if (target.arch == Target::Arch::NVGPU) { if (init_tensor->shape.size() > 1) { stages[init_tensor]->Split(1, 2); - stages[init_tensor]->Bind(0, "blockIdx.x"); - stages[init_tensor]->Bind(1, "threadIdx.x"); + } + stages[init_tensor]->ComputeAt2(stages[this], stages[init_tensor]->axis_names().size() - 1); + auto temp = stages[this]->ctrl_depends(); + for (auto &i : temp) { + if (i->name != init_tensor->name) { + stages[init_tensor]->CtrlDepend(i); + } } } stages[this]->CtrlDepend(init_tensor);