Skip to content

Commit

Permalink
test SyncBatchNorm
Browse files Browse the repository at this point in the history
  • Loading branch information
zkh2016 committed Jun 14, 2022
1 parent f8c44d2 commit 0b27ba5
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions python/paddle/fluid/tests/unittests/test_sparse_norm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import unittest
import numpy as np
import paddle
from paddle.incubate.sparse import nn
import paddle.fluid as fluid
from paddle.fluid.framework import _test_eager_guard
import copy
Expand Down Expand Up @@ -56,11 +57,10 @@ def test(self):

# test backward
sparse_y.backward(sparse_y)
assert np.allclose(
dense_x.grad.flatten().numpy(),
sparse_x.grad.values().flatten().numpy(),
atol=1e-5,
rtol=1e-5)
assert np.allclose(dense_x.grad.flatten().numpy(),
sparse_x.grad.values().flatten().numpy(),
atol=1e-5,
rtol=1e-5)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False})

def test_error_layout(self):
Expand All @@ -86,5 +86,22 @@ def test2(self):
# [1, 6, 6, 6, 3]


class TestConvertSyncBatchNorm(unittest.TestCase):

def test_convert(self):
base_model = paddle.nn.Sequential(nn.Conv3D(3, 5, 3), nn.BatchNorm(5),
nn.BatchNorm(5))

model = paddle.nn.Sequential(
nn.Conv3D(3, 5, 3), nn.BatchNorm(5),
nn.BatchNorm(5,
weight_attr=fluid.ParamAttr(name='bn.scale'),
bias_attr=fluid.ParamAttr(name='bn.bias')))
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
for idx, sublayer in enumerate(base_model.sublayers()):
if isinstance(sublayer, nn.BatchNorm):
self.assertEqual(isinstance(model[idx], nn.SyncBatchNorm), True)


if __name__ == "__main__":
unittest.main()

1 comment on commit 0b27ba5

@paddle-bot-old
Copy link

@paddle-bot-old paddle-bot-old bot commented on 0b27ba5 Jun 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕵️ CI failures summary

🔍 PR: #43520 Commit ID: 0b27ba5 contains failed CI.

🔹 Failed: PR-CI-Kunlun-KP-Build

Unknown Failed
Unknown Failed

🔹 Failed: PR-CI-Static-Check

Unknown Failed
Unknown Failed

🔹 Failed: PR-CE-Framework

Unknown Failed
Unknown Failed

Please sign in to comment.