Skip to content

Commit

Permalink
[IPU] Decoupling ipu sharding and modeling (PaddlePaddle#43164)
Browse files Browse the repository at this point in the history
* Decoupling ipu sharding and modeling (PaddlePaddle#665)

* feat(shard): decoupling shard setting with modeling.

* fix(shard): split test cases to avoid failure.

* fix(shard): add function docs and fix typo.

* test(shard): add tests.

* test(shard): more test case.

* fix(): change ipu_index/stage default value to -1.

* fix format

Co-authored-by: czr-gc <[email protected]>
  • Loading branch information
2 people authored and sneaxiy committed Jun 27, 2022
1 parent bc0c6a4 commit f6becba
Show file tree
Hide file tree
Showing 3 changed files with 321 additions and 10 deletions.
76 changes: 67 additions & 9 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
'program_guard',
'name_scope',
'ipu_shard_guard',
'set_ipu_shard',
'cuda_places',
'cpu_places',
'xpu_places',
Expand Down Expand Up @@ -252,28 +253,28 @@ def _test_eager_guard(place=None):
_enable_legacy_dygraph()


global_ipu_index = None
global_ipu_stage = None
global_ipu_index = -1
global_ipu_stage = -1
ipu_index_attr_name = 'ipu_index'
ipu_stage_attr_name = 'ipu_stage'


@signature_safe_contextmanager
def ipu_shard_guard(index=None, stage=None):
def ipu_shard_guard(index=-1, stage=-1):
"""
Used to shard the graph on IPUs. Set each Op run on which IPU in the sharding and which stage in the pipelining.
Args:
index(int, optional): Specify which ipu the Tensor is computed on, (such as '0, 1, 2, 3').
The default value is None, which means the Op only run on IPU 0.
The default value is -1, which means the Op only run on IPU 0.
stage(int, optional): Specify the computation order of the sharded model(such as '0, 1, 2, 3').
The sharded model will be computed from small to large. The default value is None,
The sharded model will be computed from small to large. The default value is -1,
which means no pipelining computation order and run Ops in terms of graph.
**Note**:
Only if the enable_manual_shard=True, the 'index' is able to be set not None. Please refer
Only if the enable_manual_shard=True, the 'index' is able to be set not -1. Please refer
to :code:`paddle.static.IpuStrategy` .
Only if the enable_pipelining=True, the 'stage' is able to be set not None. Please refer
Only if the enable_pipelining=True, the 'stage' is able to be set not -1. Please refer
to :code:`paddle.static.IpuStrategy` .
A index is allowed to match none stage or a stage. A stage is only allowed to match a new or
duplicated index.
Expand Down Expand Up @@ -311,6 +312,63 @@ def ipu_shard_guard(index=None, stage=None):
global_ipu_stage = prev_ipu_stage


def set_ipu_shard(call_func, index=-1, stage=-1):
"""
Shard the ipu with the given call function. Set every ops in call function to the given ipu sharding.
Args:
call_func(Layer|function): Specify the call function to be wrapped.
index(int, optional): Specify which ipu the Tensor is computed on, (such as ‘0, 1, 2, 3’).
The default value is -1, which means the Op only run on IPU 0.
stage(int, optional): Specify the computation order of the sharded model(such as ‘0, 1, 2, 3’).
The sharded model will be computed from small to large. The default value is -1,
which means no pipelining computation order and run Ops in terms of graph.
Returns:
The wrapped call function.
Examples:
.. code-block:: python
# required: ipu
import paddle
paddle.enable_static()
a = paddle.static.data(name='data', shape=[None, 1], dtype='float32')
relu = paddle.nn.ReLU()
relu = paddle.static.set_ipu_shard(relu, index=1, stage=1)
relu(a)
"""

def decorate(func):

def wrapper(*args, **kwargs):
with ipu_shard_guard(index=index, stage=stage):
return func(*args, **kwargs)

return wrapper

from .dygraph.layers import Layer
if not isinstance(call_func, Layer):
if callable(call_func):
return decorate(call_func)
else:
raise TypeError(
"Unsupported type. Only accept paddle.nn.Layer or function.")

# patch paddle.nn.Layer
class BlockFn(type(call_func)):

def __call__(self, *args, **kwargs):
with ipu_shard_guard(index=index, stage=stage):
return super().__call__(*args, **kwargs)

BlockFn.__name__ = type(call_func).__name__
call_func.__class__ = BlockFn
return call_func


def require_version(min_version, max_version=None):
"""
Check if the installed version of PaddlePaddle is in [min_version, max_version],
Expand Down Expand Up @@ -2772,10 +2830,10 @@ def find_name(var_list, name):

# proto.attrs doesn't include ipu_index
if core.is_compiled_with_ipu():
if global_ipu_index is not None:
if global_ipu_index >= 0:
self._update_desc_attr(ipu_index_attr_name,
global_ipu_index)
if global_ipu_stage is not None:
if global_ipu_stage >= 0:
self._update_desc_attr(ipu_stage_attr_name,
global_ipu_stage)

Expand Down
252 changes: 252 additions & 0 deletions python/paddle/fluid/tests/unittests/ipu/test_set_ipu_shard_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import paddle
import paddle.nn as nn
from paddle.static import set_ipu_shard

paddle.enable_static()


class SimpleNet(paddle.nn.Layer):

def __init__(self, input_size, output_size):
super(SimpleNet, self).__init__()
self.linear1 = nn.Linear(input_size, output_size)
self.relu1 = nn.ReLU()
self.linear2 = nn.Linear(input_size, output_size)
self.relu2 = nn.ReLU()
self.linear3 = nn.Linear(input_size, output_size)

def forward(self, x):
x = self.linear1(x)
x = self.relu1(x)
x = self.linear_relu2(x)
x = self.linear3(x)
return x

def linear_relu2(self, x):
x = self.linear2(x)
x = self.relu2(x)
return x


@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestSetIpuShard(unittest.TestCase):

def _test(self):
# build graph
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog):
x = paddle.static.data(name='X', shape=[10, 46], dtype='float32')
label = paddle.static.data(name='Y',
shape=[10, 46],
dtype='float32')
model = SimpleNet(46, 46)

set_ipu_shard(model.linear1, index=1)
set_ipu_shard(model.relu1, index=2)
model.linear_relu2 = set_ipu_shard(model.linear_relu2, index=3)
model.linear3 = set_ipu_shard(model.linear3, index=4)
out = model(x)

ipu_index_list = []
for op in main_prog.global_block().ops:
if op.desc.has_attr("ipu_index"):
ipu_index_list.append(op.desc.attr("ipu_index"))

return ipu_index_list

def test_set_ipu_shard(self):
ipu_index_list = self._test()
expected_ipu_index_list = [1, 1, 2, 3, 3, 3, 4, 4]

self.assertTrue(
np.allclose(ipu_index_list, expected_ipu_index_list, atol=0))


@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestSetIpuPipeline(unittest.TestCase):

def _test(self):
# build graph
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog):
x = paddle.static.data(name='X', shape=[10, 46], dtype='float32')
label = paddle.static.data(name='Y',
shape=[10, 46],
dtype='float32')
model = SimpleNet(46, 46)

set_ipu_shard(model.linear1, stage=1)
set_ipu_shard(model.relu1, stage=2)
model.linear_relu2 = set_ipu_shard(model.linear_relu2, stage=3)
model.linear3 = set_ipu_shard(model.linear3, stage=4)
out = model(x)

ipu_index_list = []
for op in main_prog.global_block().ops:
if op.desc.has_attr("ipu_stage"):
ipu_index_list.append(op.desc.attr("ipu_stage"))

return ipu_index_list

def test_set_ipu_shard(self):
ipu_index_list = self._test()
expected_ipu_index_list = [1, 1, 2, 3, 3, 3, 4, 4]

self.assertTrue(
np.allclose(ipu_index_list, expected_ipu_index_list, atol=0))


@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestSetIpuShardAndPipeline(unittest.TestCase):

def _test(self):
# build graph
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog):
x = paddle.static.data(name='X', shape=[10, 46], dtype='float32')
label = paddle.static.data(name='Y',
shape=[10, 46],
dtype='float32')
model = SimpleNet(46, 46)

set_ipu_shard(model.linear1, index=1, stage=2)
set_ipu_shard(model.relu1, index=2, stage=3)
model.linear_relu2 = set_ipu_shard(model.linear_relu2,
index=3,
stage=4)
model.linear3 = set_ipu_shard(model.linear3, index=4, stage=1)
out = model(x)

ipu_index_list = []
ipu_stage_list = []
for op in main_prog.global_block().ops:
if op.desc.has_attr("ipu_index"):
ipu_index_list.append(op.desc.attr("ipu_index"))
if op.desc.has_attr("ipu_stage"):
ipu_stage_list.append(op.desc.attr("ipu_stage"))

return ipu_index_list + ipu_stage_list

def test_set_ipu_shard(self):
ipu_index_list = self._test()
expected_ipu_index_list = [
1, 1, 2, 3, 3, 3, 4, 4, 2, 2, 3, 4, 4, 4, 1, 1
]

self.assertTrue(
np.allclose(ipu_index_list, expected_ipu_index_list, atol=0))


@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestSetIpuForModel(unittest.TestCase):

def _test(self):
# build graph
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog):
x = paddle.static.data(name='X', shape=[10, 46], dtype='float32')
label = paddle.static.data(name='Y',
shape=[10, 46],
dtype='float32')
model = SimpleNet(46, 46)

set_ipu_shard(model, index=1, stage=2)
out = model(x)

ipu_index_list = []
ipu_stage_list = []
for op in main_prog.global_block().ops:
if op.desc.has_attr("ipu_index"):
ipu_index_list.append(op.desc.attr("ipu_index"))
if op.desc.has_attr("ipu_stage"):
ipu_stage_list.append(op.desc.attr("ipu_stage"))

return ipu_index_list + ipu_stage_list

def test_set_ipu_shard(self):
ipu_index_list = self._test()
expected_ipu_index_list = [
1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2
]

self.assertTrue(
np.allclose(ipu_index_list, expected_ipu_index_list, atol=0))


@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestSetIpuMixedModel(unittest.TestCase):

def setUp(self):

def linear_relu2_mixed(self, x):
with paddle.static.ipu_shard_guard(index=2, stage=3):
x = self.linear2(x)
with paddle.static.ipu_shard_guard(index=3, stage=4):
x = self.relu2(x)
return x

self._old_linear = SimpleNet.linear_relu2
SimpleNet.linear_relu2 = linear_relu2_mixed

def tearDown(self):
SimpleNet.linear_relu2 = self._old_linear

def _test(self):
# build graph
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog):
x = paddle.static.data(name='X', shape=[10, 46], dtype='float32')
label = paddle.static.data(name='Y',
shape=[10, 46],
dtype='float32')
model = SimpleNet(46, 46)

set_ipu_shard(model.linear1, index=1, stage=2)
set_ipu_shard(model.relu1, index=2, stage=3)
model.linear3 = set_ipu_shard(model.linear3, index=4, stage=1)
out = model(x)

ipu_index_list = []
ipu_stage_list = []
for op in main_prog.global_block().ops:
if op.desc.has_attr("ipu_index"):
ipu_index_list.append(op.desc.attr("ipu_index"))
if op.desc.has_attr("ipu_stage"):
ipu_stage_list.append(op.desc.attr("ipu_stage"))

return ipu_index_list + ipu_stage_list

def test_set_ipu_shard(self):
ipu_index_list = self._test()
expected_ipu_index_list = [
1, 1, 2, 2, 2, 3, 4, 4, 2, 2, 3, 3, 3, 4, 1, 1
]

self.assertTrue(
np.allclose(ipu_index_list, expected_ipu_index_list, atol=0))


if __name__ == "__main__":
unittest.main()
3 changes: 2 additions & 1 deletion python/paddle/static/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from ..fluid.framework import npu_places # noqa: F401
from ..fluid.framework import Variable # noqa: F401
from ..fluid.framework import ipu_shard_guard # noqa: F401
from ..fluid.framework import set_ipu_shard # noqa: F401
from ..fluid.layers.control_flow import Print # noqa: F401
from ..fluid.layers.nn import py_func # noqa: F401
from ..fluid.parallel_executor import ParallelExecutor # noqa: F401
Expand Down Expand Up @@ -81,5 +82,5 @@
'deserialize_persistables', 'load_from_file', 'normalize_program',
'load_program_state', 'set_program_state', 'cpu_places', 'cuda_places',
'xpu_places', 'npu_places', 'mlu_places', 'Variable', 'create_global_var',
'accuracy', 'auc', 'device_guard', 'create_parameter'
'accuracy', 'auc', 'device_guard', 'create_parameter', 'set_ipu_shard'
]

0 comments on commit f6becba

Please sign in to comment.