-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PIR Dist Op Reg No.4 and No.26】 reg global_scatter and limit_by_capacity #62579
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- op: limit_by_capacity
inputs:
{expert_count : expert_count, capacity : capacity}
outputs :
out : Out
这里的缩进需要调整一下
@@ -875,6 +884,15 @@ | |||
inplace: (x -> out) | |||
interfaces : paddle::dialect::InferSymbolicShapeInterface | |||
|
|||
- op : limit_by_capacity | |||
args : (Tensor expert_count, Tensor capacity, int n_worker) | |||
output : Tensor(Out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
output : Tensor(Out) | |
output : Tensor(out) |
paddle/phi/api/yaml/op_compat.yaml
Outdated
@@ -1604,6 +1604,12 @@ | |||
attrs : | |||
{pre_nms_top_n : pre_nms_topN, post_nms_top_n : post_nms_topN} | |||
|
|||
- op : global_scatter(global_scatter) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- op : global_scatter(global_scatter) | |
- op : global_scatter |
@xingmingyyj 麻烦请问一下这里是哪里有问题,yaml上哪里填写不对吗 |
可以看一下是不是limit_by_capacity的yaml的input,好像没对齐 |
paddle/phi/api/yaml/op_compat.yaml
Outdated
@@ -3730,6 +3736,12 @@ | |||
outputs : | |||
{param_out: ParamOut, velocity_out: VelocityOut, master_param_out: MasterParamOut} | |||
|
|||
- op: limit_by_capacity | |||
inputs: | |||
{expert_count : expert_count, capacity : capacity} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
{expert_count : expert_count, capacity : capacity} |
完全一致的话不需要在这里处理
paddle/phi/api/yaml/op_compat.yaml
Outdated
@@ -1604,6 +1604,12 @@ | |||
attrs : | |||
{pre_nms_top_n : pre_nms_topN, post_nms_top_n : post_nms_topN} | |||
|
|||
- op : global_scatter | |||
inputs : | |||
{x : X, local_count : local_count, global_count : global_count, ring_id : ring_id, use_calc_stream : use_calc_stream} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
{x : X, local_count : local_count, global_count : global_count, ring_id : ring_id, use_calc_stream : use_calc_stream} | |
{x : X} |
完全一致的话不需要在这里处理
paddle/phi/api/yaml/backward.yaml
Outdated
@@ -1019,6 +1019,15 @@ | |||
func : gelu_grad | |||
composite: gelu_grad(x, out_grad, approximate, x_grad) | |||
|
|||
- backward_op : global_scatter_grad | |||
forward : (Tensor x, Tensor local_count, Tensor global_count, int ring_id=0, bool use_calc_stream=false) -> Tensor(out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
forward : (Tensor x, Tensor local_count, Tensor global_count, int ring_id=0, bool use_calc_stream=false) -> Tensor(out) | |
forward : global_scatter (Tensor x, Tensor local_count, Tensor global_count, int ring_id=0, bool use_calc_stream=false) -> Tensor(out) |
这里应该是目前编译报错的原因,没找到反向对应的前向op
@xingmingyyj @kangguangli 这个麻烦再review一下有没有问题~ |
func : GlobalScatterInferMeta | ||
kernel : | ||
func : global_scatter | ||
data_type : dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data_type : dtype | |
data_type : x |
这里的data_type
对齐的是旧IR下的:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace());
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改,感谢
PR types
Others
PR changes
Others
Description
#60436
注册算子global_scatter和limit_by_capacity