Skip to content

Commit

Permalink
fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
gglin001 committed Jun 6, 2022
1 parent 7ae714c commit 3996edc
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 22 deletions.
2 changes: 2 additions & 0 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ def set_ipu_shard(call_func, index=-1, stage=-1):
"""

def decorate(func):

def wrapper(*args, **kwargs):
with ipu_shard_guard(index=index, stage=stage):
return func(*args, **kwargs)
Expand All @@ -358,6 +359,7 @@ def wrapper(*args, **kwargs):

# patch paddle.nn.Layer
class BlockFn(type(call_func)):

def __call__(self, *args, **kwargs):
with ipu_shard_guard(index=index, stage=stage):
return super().__call__(*args, **kwargs)
Expand Down
52 changes: 30 additions & 22 deletions python/paddle/fluid/tests/unittests/ipu/test_set_ipu_shard_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@


class SimpleNet(paddle.nn.Layer):

def __init__(self, input_size, output_size):
super(SimpleNet, self).__init__()
self.linear1 = nn.Linear(input_size, output_size)
Expand All @@ -47,13 +48,15 @@ def linear_relu2(self, x):
@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestSetIpuShard(unittest.TestCase):

def _test(self):
# build graph
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog):
x = paddle.static.data(name='X', shape=[10, 46], dtype='float32')
label = paddle.static.data(
name='Y', shape=[10, 46], dtype='float32')
label = paddle.static.data(name='Y',
shape=[10, 46],
dtype='float32')
model = SimpleNet(46, 46)

set_ipu_shard(model.linear1, index=1)
Expand All @@ -74,20 +77,21 @@ def test_set_ipu_shard(self):
expected_ipu_index_list = [1, 1, 2, 3, 3, 3, 4, 4]

self.assertTrue(
np.allclose(
ipu_index_list, expected_ipu_index_list, atol=0))
np.allclose(ipu_index_list, expected_ipu_index_list, atol=0))


@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestSetIpuPipeline(unittest.TestCase):

def _test(self):
# build graph
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog):
x = paddle.static.data(name='X', shape=[10, 46], dtype='float32')
label = paddle.static.data(
name='Y', shape=[10, 46], dtype='float32')
label = paddle.static.data(name='Y',
shape=[10, 46],
dtype='float32')
model = SimpleNet(46, 46)

set_ipu_shard(model.linear1, stage=1)
Expand All @@ -108,26 +112,28 @@ def test_set_ipu_shard(self):
expected_ipu_index_list = [1, 1, 2, 3, 3, 3, 4, 4]

self.assertTrue(
np.allclose(
ipu_index_list, expected_ipu_index_list, atol=0))
np.allclose(ipu_index_list, expected_ipu_index_list, atol=0))


@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestSetIpuShardAndPipeline(unittest.TestCase):

def _test(self):
# build graph
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog):
x = paddle.static.data(name='X', shape=[10, 46], dtype='float32')
label = paddle.static.data(
name='Y', shape=[10, 46], dtype='float32')
label = paddle.static.data(name='Y',
shape=[10, 46],
dtype='float32')
model = SimpleNet(46, 46)

set_ipu_shard(model.linear1, index=1, stage=2)
set_ipu_shard(model.relu1, index=2, stage=3)
model.linear_relu2 = set_ipu_shard(
model.linear_relu2, index=3, stage=4)
model.linear_relu2 = set_ipu_shard(model.linear_relu2,
index=3,
stage=4)
model.linear3 = set_ipu_shard(model.linear3, index=4, stage=1)
out = model(x)

Expand All @@ -148,20 +154,21 @@ def test_set_ipu_shard(self):
]

self.assertTrue(
np.allclose(
ipu_index_list, expected_ipu_index_list, atol=0))
np.allclose(ipu_index_list, expected_ipu_index_list, atol=0))


@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestSetIpuForModel(unittest.TestCase):

def _test(self):
# build graph
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog):
x = paddle.static.data(name='X', shape=[10, 46], dtype='float32')
label = paddle.static.data(
name='Y', shape=[10, 46], dtype='float32')
label = paddle.static.data(name='Y',
shape=[10, 46],
dtype='float32')
model = SimpleNet(46, 46)

set_ipu_shard(model, index=1, stage=2)
Expand All @@ -184,14 +191,15 @@ def test_set_ipu_shard(self):
]

self.assertTrue(
np.allclose(
ipu_index_list, expected_ipu_index_list, atol=0))
np.allclose(ipu_index_list, expected_ipu_index_list, atol=0))


@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestSetIpuMixedModel(unittest.TestCase):

def setUp(self):

def linear_relu2_mixed(self, x):
with paddle.static.ipu_shard_guard(index=2, stage=3):
x = self.linear2(x)
Expand All @@ -210,8 +218,9 @@ def _test(self):
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog):
x = paddle.static.data(name='X', shape=[10, 46], dtype='float32')
label = paddle.static.data(
name='Y', shape=[10, 46], dtype='float32')
label = paddle.static.data(name='Y',
shape=[10, 46],
dtype='float32')
model = SimpleNet(46, 46)

set_ipu_shard(model.linear1, index=1, stage=2)
Expand All @@ -236,8 +245,7 @@ def test_set_ipu_shard(self):
]

self.assertTrue(
np.allclose(
ipu_index_list, expected_ipu_index_list, atol=0))
np.allclose(ipu_index_list, expected_ipu_index_list, atol=0))


if __name__ == "__main__":
Expand Down

0 comments on commit 3996edc

Please sign in to comment.