Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR Dist Op Reg No.25】 reg distributed_fused_lamb_init #62050

Merged
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@
{{"{name}", (PyCFunction)(void (*)(void)){name}, METH_VARARGS | METH_KEYWORDS, "C++ interface function for {name}."}},"""

NEED_GEN_STATIC_ONLY_APIS = [
'distributed_fused_lamb_init',
'distributed_fused_lamb_init_',
'fetch',
'fused_bias_dropout_residual_layer_norm',
'fused_embedding_eltwise_layernorm',
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,16 @@
data_type : fpn_rois
optional : rois_num, multi_level_rois_num

- op : distributed_fused_lamb_init
args : (Tensor[] param, Tensor[] grad, float beta1, float beta2, int[] apply_weight_decay, int alignment, int rank, int nranks)
output : Tensor(fp32_fused_param), Tensor(fp32_fused_grad), Tensor(fp16_fused_param), Tensor(fp16_fused_grad), Tensor(moment1), Tensor(moment2), Tensor(beta1_pow), Tensor(beta2_pow), Tensor(fused_param_offsets), Tensor(fp32_shard_fused_param_offsets), Tensor(fp16_shard_fused_param_offsets), Tensor(param_info), Tensor(param_order), Tensor[](param_out){param.size()}, Tensor[](master_param_out){param.size()}, Tensor[](grad_out){grad.size()}, Tensor(global_scale), Tensor(step)
infer_meta :
func : DistributedFusedLambInitInferMeta
kernel :
func : distributed_fused_lamb_init
kangguangli marked this conversation as resolved.
Show resolved Hide resolved
optional : fp32_fused_param, fp32_fused_grad, fp16_fused_param, fp16_fused_grad
kangguangli marked this conversation as resolved.
Show resolved Hide resolved
inplace: (param -> param_out), (grad -> grad_out)

- op : distributed_lookup_table
args : (Tensor[] ids, Tensor w, int table_id = 0, bool is_distributed = false, str lookup_table_version = "lookup_table", int64_t padding_idx = -1, DataType dtype = DataType::FLOAT32, bool is_test = false)
output : Tensor[](outputs){ids.size()}
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3611,6 +3611,12 @@
multi_level_rois_num: MultiLevelRoIsNum
restore_index: RestoreIndex

- op: distributed_fused_lamb_init
inputs:
{param: Param, grad: Grad}
outputs:
{fp32_fused_param: FP32FusedParam, fp32_fused_grad: FP32FusedGrad, fp16_fused_param: FP16FusedParam, fp16_fused_grad: FP16FusedGrad, moment1: Moment1, moment2: Moment2, beta1_pow: Beta1Pow, beta2_pow: Beta2Pow, fused_param_offsets: FusedParamOffsets, fp32_shard_fused_param_offsets: FP32ShardFusedParamOffsets, fp16_shard_fused_param_offsets: FP16ShardFusedParamOffsets, param_info: ParamInfo, param_order: ParamOrder, param_out: ParamOut, master_param_out: MasterParamOut, grad_out: GradOut, global_scale: GlobalScale, step: Step}

- op: distributed_lookup_table
inputs:
{ids: Ids, w: W}
Expand Down
54 changes: 54 additions & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1201,6 +1201,60 @@ void DistributeFpnProposalsInferMeta(
}
}

void DistributedFusedLambInitInferMeta(
kangguangli marked this conversation as resolved.
Show resolved Hide resolved
const std::vector<const MetaTensor*>& param,
const std::vector<const MetaTensor*>& grad,
float beta1,
float beta2,
const std::vector<int>& apply_weight_decay,
int alignment,
int rank,
int nranks,
MetaTensor* fp32_fused_param,
MetaTensor* fp32_fused_grad,
MetaTensor* fp16_fused_param,
MetaTensor* fp16_fused_grad,
MetaTensor* moment1,
MetaTensor* moment2,
MetaTensor* beta1_pow,
MetaTensor* beta2_pow,
MetaTensor* fused_param_offsets,
MetaTensor* fp32_shard_fused_param_offsets,
MetaTensor* fp16_shard_fused_param_offsets,
MetaTensor* param_info,
MetaTensor* param_order,
std::vector<MetaTensor*> param_out,
std::vector<MetaTensor*> master_param_out,
std::vector<MetaTensor*> grad_out,
MetaTensor* global_scale,
MetaTensor* step) {
fp32_fused_param->set_dtype(DataType::FLOAT32);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个算子不需要设置dims信息吗

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

旧IR下该算子的InferShape为空,这些是他新增的。如果要新增dims的话,只能根据kernel的实现去推测,感觉可以等后续对InferMeta有需要再完善,这里我觉得保持现状较好。

fp32_fused_grad->set_dtype(DataType::FLOAT32);
fp16_fused_param->set_dtype(DataType::FLOAT16);
fp16_fused_grad->set_dtype(DataType::FLOAT16);
moment1->set_dtype(DataType::FLOAT32);
moment2->set_dtype(DataType::FLOAT32);
beta1_pow->set_dtype(DataType::FLOAT32);
beta2_pow->set_dtype(DataType::FLOAT32);
fused_param_offsets->set_dtype(DataType::INT32);
fp32_shard_fused_param_offsets->set_dtype(DataType::INT32);
fp16_shard_fused_param_offsets->set_dtype(DataType::INT32);
param_info->set_dtype(DataType::INT32);
param_order->set_dtype(DataType::INT32);

for (size_t i = 0; i < param.size(); ++i) {
param_out[i]->set_dtype(param[i]->dtype());
master_param_out[i]->set_dtype(DataType::FLOAT32);
}

for (size_t i = 0; i < grad.size(); ++i) {
grad_out[i]->set_dtype(grad[i]->dtype());
}

global_scale->set_dtype(DataType::FLOAT32);
step->set_dtype(DataType::INT64);
}

void DropoutInferMeta(const MetaTensor& x,
const MetaTensor& seed_tensor,
const Scalar& p,
Expand Down
28 changes: 28 additions & 0 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,34 @@ void DistributeFpnProposalsInferMeta(
MetaTensor* restore_index,
MetaConfig config = MetaConfig());

void DistributedFusedLambInitInferMeta(
const std::vector<const MetaTensor*>& param,
const std::vector<const MetaTensor*>& grad,
float beta1,
float beta2,
const std::vector<int>& apply_weight_decay,
int alignment,
int rank,
int nranks,
MetaTensor* fp32_fused_param,
MetaTensor* fp32_fused_grad,
MetaTensor* fp16_fused_param,
MetaTensor* fp16_fused_grad,
MetaTensor* moment1,
MetaTensor* moment2,
MetaTensor* beta1_pow,
MetaTensor* beta2_pow,
MetaTensor* fused_param_offsets,
MetaTensor* fp32_shard_fused_param_offsets,
MetaTensor* fp16_shard_fused_param_offsets,
MetaTensor* param_info,
MetaTensor* param_order,
std::vector<MetaTensor*> param_out,
std::vector<MetaTensor*> master_param_out,
std::vector<MetaTensor*> grad_out,
MetaTensor* global_scale,
MetaTensor* step);

void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out);

void DropoutInferMeta(const MetaTensor& x,
Expand Down
1 change: 1 addition & 0 deletions test/ir/pir/translator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_c_allreduce_min_translator)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_c_allreduce_prod_translator)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST
test_distributed_lookup_table_translate)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_distributed_fused_lamb_init)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_partial_send_translator)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_c_reduce_max_translator)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_c_reduce_prod_translator)
Expand Down
152 changes: 152 additions & 0 deletions test/ir/pir/translator/test_distributed_fused_lamb_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import test_op_translator

import paddle
from paddle.base import unique_name
from paddle.base.layer_helper import LayerHelper


class TestDistributedFusedLambInitOpTranslator(
test_op_translator.TestOpTranslator
):
def _create_persistable_var(self, name=None, shape=[-1], dtype='float32'):
startup_block = self.helper.startup_program.global_block()
if name is not None:
name = unique_name.generate(name)
startup_var = startup_block.create_var(
name=name,
shape=shape,
dtype=dtype,
persistable=True,
stop_gradient=True,
)
main_block = self.helper.main_program.global_block()
main_var = main_block.create_var(
name=startup_var.name,
shape=startup_var.shape,
dtype=startup_var.dtype,
persistable=True,
stop_gradient=True,
)
return main_var

def _create_scale_from_constant(self):
name = unique_name.generate('global_scale')
return paddle.static.create_global_var(
name=name,
shape=[1],
dtype='float32',
value=1.0,
persistable=True,
)

def append_op(self):
self.op_type = "distributed_fused_lamb_init"
self.helper = LayerHelper('distributed_fused_lamb')
rank = paddle.distributed.get_rank()
nranks = paddle.distributed.get_world_size()
local_rank = rank % nranks
params = [paddle.ones(shape=(1, 1), dtype='float32')]
grads = [paddle.ones(shape=(1, 1), dtype='float32')]
apply_weight_decay = [1] * len(params)

fp32_fused_param = self._create_persistable_var('fp32_fused_param')
fp32_fused_grad = self._create_persistable_var('fp32_fused_grad')
fp16_fused_param = self._create_persistable_var(
'fp16_fused_param', dtype='float16'
)
fp16_fused_grad = self._create_persistable_var(
'fp16_fused_grad', dtype='float16'
)
moment1 = self._create_persistable_var('moment1')
moment1.is_distributed = True
moment2 = self._create_persistable_var('moment2')
moment2.is_distributed = True
beta1pow = self._create_persistable_var('beta1pow')
beta2pow = self._create_persistable_var('beta2pow')
param_info = self._create_persistable_var('param_info', dtype='int32')
param_info.is_distributed = True

fused_offsets = self._create_persistable_var(
'fused_offsets', dtype='int32'
)

fp32_partial_fused_offsets = self._create_persistable_var(
'fp32_partial_fused_offsets', dtype='int32'
)
fp32_partial_fused_offsets.is_distributed = True

fp16_partial_fused_offsets = self._create_persistable_var(
'fp16_partial_fused_offsets', dtype='int32'
)
fp16_partial_fused_offsets.is_distributed = True

param_order = self._create_persistable_var('param_order', dtype='int32')
param_order.is_distributed = True

scale = self._create_scale_from_constant()
step = self._create_persistable_var('step', dtype='int64')

master_params = []
for p in params:
master_p = self._create_persistable_var('master_weight')
master_params.append(master_p)

attrs = {
'alignment': 128,
'rank': local_rank,
'nranks': nranks,
'apply_weight_decay': apply_weight_decay,
'moment1': 0.0,
'moment2': 0.0,
'beta1': 0.9,
'beta2': 0.999,
}
helper = LayerHelper(self.op_type)
helper.append_op(
type=self.op_type,
inputs={"Param": params, "Grad": grads},
outputs={
'FP32FusedParam': [fp32_fused_param],
'FP32FusedGrad': [fp32_fused_grad],
'FP16FusedParam': [fp16_fused_param],
'FP16FusedGrad': [fp16_fused_grad],
'Moment1': [moment1],
'Moment2': [moment2],
'Beta1Pow': [beta1pow],
'Beta2Pow': [beta2pow],
'GlobalScale': [scale],
'ParamInfo': [param_info],
'ParamOut': params,
'MasterParamOut': master_params,
'GradOut': grads,
'FP32ShardFusedParamOffsets': [fp32_partial_fused_offsets],
'FP16ShardFusedParamOffsets': [fp16_partial_fused_offsets],
'FusedParamOffsets': [fused_offsets],
'ParamOrder': [param_order],
'Step': [step],
},
attrs=attrs,
)

def test_translator(self):
self.check()


if __name__ == "__main__":
unittest.main()