Skip to content

Commit

Permalink
【PIR Dist Op Reg No.4 and No.26】 reg global_scatter and limit_by_capa…
Browse files Browse the repository at this point in the history
…city (#62579)

* feat(pir): reg global_scatter and limit_by_capacity

* feat(pir): reg global_scatter and limit_by_capacity

* feat(pir): reg global_scatter and limit_by_capacity

* feat(pir): reg global_scatter and limit_by_capacity

* feat(pir): reg global_scatter and limit_by_capacity

* feat(pir): reg global_scatter and limit_by_capacity

* feat(pir): reg global_scatter and limit_by_capacity
  • Loading branch information
xiaoyewww authored Mar 20, 2024
1 parent 7def47f commit 4024e45
Show file tree
Hide file tree
Showing 11 changed files with 178 additions and 1 deletion.
2 changes: 1 addition & 1 deletion paddle/fluid/operators/limit_by_capacity_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class LimitByCapacityOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("capacity", "(Tensor) The input capacity.");
AddOutput("Out",
"(Tensor) The output tensor expert count limit by capacity.");
AddAttr<int>("n_worker", "int), The number of works.");
AddAttr<int>("n_worker", "(int), The number of works.");
AddComment(
R"DOC(limit_by_capacity Operator.limit expert count by capacity.)DOC");
}
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@
'partial_allgather_',
'nop',
'nop_',
'limit_by_capacity',
'global_scatter',
]


Expand Down
18 changes: 18 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,15 @@
kernel:
func: get_tensor_from_selected_rows {selected_rows -> dense}

- op : global_scatter
args : (Tensor x, Tensor local_count, Tensor global_count, int ring_id=0, bool use_calc_stream=false)
output : Tensor(out)
infer_meta :
func : GlobalScatterInferMeta
kernel :
func : global_scatter
data_type : x

- op : greater_equal
args : (Tensor x, Tensor y)
output : Tensor(out)
Expand Down Expand Up @@ -919,6 +928,15 @@
inplace: (x -> out)
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : limit_by_capacity
args : (Tensor expert_count, Tensor capacity, int n_worker)
output : Tensor(out)
infer_meta :
func : LimitByCapacityInferMeta
kernel :
func : limit_by_capacity
data_type : expert_count

- op : linspace
args : (Tensor start, Tensor stop, Tensor number, DataType dtype, Place place)
output : Tensor(out)
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1614,6 +1614,12 @@
attrs :
{pre_nms_top_n : pre_nms_topN, post_nms_top_n : post_nms_topN}

- op : global_scatter
inputs :
{x : X}
outputs :
out : Out

- op : grad_add
inputs :
{x : X, y : Y}
Expand Down Expand Up @@ -3769,6 +3775,10 @@
outputs :
{param_out: ParamOut, velocity_out: VelocityOut, master_param_out: MasterParamOut}

- op: limit_by_capacity
outputs :
out : Out

- op: lod_array_length
inputs :
{x: X}
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2167,6 +2167,15 @@ void KronInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
out->set_dtype(x.dtype());
}

void LimitByCapacityInferMeta(const MetaTensor& expert_count,
const MetaTensor& capacity,
int n_worker,
MetaTensor* out) {
out->share_dims(expert_count);
out->share_lod(expert_count);
out->set_dtype(expert_count.dtype());
}

void LogLossInferMeta(const MetaTensor& input,
const MetaTensor& label,
float epsilon,
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,11 @@ void IndexAddInferMeta(const MetaTensor& x,

void KronInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out);

void LimitByCapacityInferMeta(const MetaTensor& expert_count,
const MetaTensor& capacity,
int n_worker,
MetaTensor* out);

void LogicalBinaryInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out);
Expand Down
27 changes: 27 additions & 0 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,33 @@ void InstanceNormInferMeta(const MetaTensor& x,
}
}

void GlobalScatterInferMeta(const MetaTensor& x,
const MetaTensor& local_count,
const MetaTensor& global_count,
int ring_id,
bool use_calc_stream,
MetaTensor* out) {
PADDLE_ENFORCE_GE(
ring_id,
0,
phi::errors::InvalidArgument(
"The ring_id (%d) for global scatter op must be non-negative.",
ring_id));
auto input_dims = x.dims();
auto ndim_input = input_dims.size();
// dim check
PADDLE_ENFORCE_EQ(
ndim_input,
2,
phi::errors::InvalidArgument("The input tensor's dimension must be 2. "
"But received input's dimension = %d.",
ndim_input));

phi::DDim out_dims = common::make_ddim({-1, -1});
out->set_dims(out_dims);
out->set_dtype(x.dtype());
}

void GroupNormInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ void InstanceNormInferMeta(const MetaTensor& x,
MetaTensor* saved_variance,
MetaConfig config = MetaConfig());

void GlobalScatterInferMeta(const MetaTensor& x,
const MetaTensor& local_count,
const MetaTensor& global_count,
int ring_id,
bool use_calc_stream,
MetaTensor* out);

void GroupNormInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
Expand Down
2 changes: 2 additions & 0 deletions test/ir/pir/translator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_partial_recv_translator)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST
test_prune_gate_by_capacity_translator)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_random_routing_translator)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_limit_by_capacity_translator)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_global_scatter_translator)

if(NOT WITH_DISTRIBUTE)
list(REMOVE_ITEM TEST_INTERP_CASES ${DISTRIBUTED_OP_TRANSLATOR_TEST})
Expand Down
50 changes: 50 additions & 0 deletions test/ir/pir/translator/test_global_scatter_translator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import test_op_translator

import paddle
from paddle.base.layer_helper import LayerHelper


class TestDistributedLookupTableOpTranslator(
test_op_translator.TestOpTranslator
):
def append_op(self):
self.op_type = "global_scatter"
x = paddle.ones(shape=(4, 8), dtype='float32')
local_count = paddle.to_tensor([0, 1], dtype='int64')
global_count = paddle.to_tensor([0, 1], dtype='int64')
out = paddle.ones(shape=(2, 8), dtype='float32')
attrs = {'ring_id': 0, 'use_calc_stream': False}
helper = LayerHelper(self.op_type)
helper.append_op(
type=self.op_type,
inputs={
"X": x,
"local_count": local_count,
"global_count": global_count,
},
outputs={"Out": out},
attrs=attrs,
)

def test_translator(self):
self.check()


if __name__ == "__main__":
unittest.main()
47 changes: 47 additions & 0 deletions test/ir/pir/translator/test_limit_by_capacity_translator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import test_op_translator

import paddle
from paddle.base.layer_helper import LayerHelper


class TestDistributedLookupTableOpTranslator(
test_op_translator.TestOpTranslator
):
def append_op(self):
self.op_type = "limit_by_capacity"
expert_count = paddle.ones(shape=(8 * 8192,), dtype='int64')
capacity = paddle.ones(shape=(8,), dtype='int64')
out = paddle.ones(shape=(8,), dtype='int64')
attrs = {
'n_worker': 8192,
}
helper = LayerHelper(self.op_type)
helper.append_op(
type=self.op_type,
inputs={"expert_count": expert_count, "capacity": capacity},
outputs={"Out": out},
attrs=attrs,
)

def test_translator(self):
self.check()


if __name__ == "__main__":
unittest.main()

0 comments on commit 4024e45

Please sign in to comment.