Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR Dist Op Reg No.4 and No.26】 reg global_scatter and limit_by_capacity #62579

Merged
merged 7 commits into from
Mar 20, 2024

Conversation

xiaoyewww
Copy link
Contributor

PR types

Others

PR changes

Others

Description

#60436
注册算子global_scatter和limit_by_capacity

Copy link

paddle-bot bot commented Mar 9, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

@xingmingyyj xingmingyyj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

- op: limit_by_capacity
  inputs:
    {expert_count : expert_count, capacity : capacity}
  outputs :
    out : Out

这里的缩进需要调整一下

@@ -875,6 +884,15 @@
inplace: (x -> out)
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : limit_by_capacity
args : (Tensor expert_count, Tensor capacity, int n_worker)
output : Tensor(Out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
output : Tensor(Out)
output : Tensor(out)

paddle/fluid/pir/dialect/operator/ir/ops.yaml Show resolved Hide resolved
@@ -1604,6 +1604,12 @@
attrs :
{pre_nms_top_n : pre_nms_topN, post_nms_top_n : post_nms_topN}

- op : global_scatter(global_scatter)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- op : global_scatter(global_scatter)
- op : global_scatter

@xiaoyewww
Copy link
Contributor Author

@xingmingyyj 麻烦请问一下这里是哪里有问题,yaml上哪里填写不对吗

@xingmingyyj
Copy link
Contributor

xingmingyyj commented Mar 13, 2024

@xingmingyyj 麻烦请问一下这里是哪里有问题,yaml上哪里填写不对吗

可以看一下是不是limit_by_capacity的yaml的input,好像没对齐

@@ -3730,6 +3736,12 @@
outputs :
{param_out: ParamOut, velocity_out: VelocityOut, master_param_out: MasterParamOut}

- op: limit_by_capacity
inputs:
{expert_count : expert_count, capacity : capacity}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
{expert_count : expert_count, capacity : capacity}

完全一致的话不需要在这里处理

@@ -1604,6 +1604,12 @@
attrs :
{pre_nms_top_n : pre_nms_topN, post_nms_top_n : post_nms_topN}

- op : global_scatter
inputs :
{x : X, local_count : local_count, global_count : global_count, ring_id : ring_id, use_calc_stream : use_calc_stream}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
{x : X, local_count : local_count, global_count : global_count, ring_id : ring_id, use_calc_stream : use_calc_stream}
{x : X}

完全一致的话不需要在这里处理

@@ -1019,6 +1019,15 @@
func : gelu_grad
composite: gelu_grad(x, out_grad, approximate, x_grad)

- backward_op : global_scatter_grad
forward : (Tensor x, Tensor local_count, Tensor global_count, int ring_id=0, bool use_calc_stream=false) -> Tensor(out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
forward : (Tensor x, Tensor local_count, Tensor global_count, int ring_id=0, bool use_calc_stream=false) -> Tensor(out)
forward : global_scatter (Tensor x, Tensor local_count, Tensor global_count, int ring_id=0, bool use_calc_stream=false) -> Tensor(out)

这里应该是目前编译报错的原因,没找到反向对应的前向op

@xiaoyewww
Copy link
Contributor Author

@xingmingyyj @kangguangli 这个麻烦再review一下有没有问题~

func : GlobalScatterInferMeta
kernel :
func : global_scatter
data_type : dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
data_type : dtype
data_type : x

这里的data_type对齐的是旧IR下的:

  phi::KernelKey GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
                          ctx.GetPlace());
  }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改,感谢

@kangguangli kangguangli merged commit 4024e45 into PaddlePaddle:develop Mar 20, 2024
30 checks passed
@xiaoyewww xiaoyewww deleted the pir-pr branch May 10, 2024 15:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers HappyOpenSource 快乐开源活动issue与PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants