diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index 51f0a5fd4da9..8e7a5207c522 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -132,15 +132,23 @@ def clip_global_norm(arrays, max_norm, check_isfinite=True): False. Otherwise a float is returned. """ - def _norm(array): - if array.stype == 'default': - x = array.reshape((-1,)) - return ndarray.dot(x, x) - return array.norm().square() - assert len(arrays) > 0 + # group arrays by ctx + def group_by_ctx(arr_list): + groups = collections.defaultdict(list) + for arr in arr_list: + ctx = arr.context + groups[ctx].append(arr) + return groups + arrays_groups = group_by_ctx(arrays) + all_ctx_sum = [] ctx = arrays[0].context - total_norm = ndarray.add_n(*[_norm(arr).as_in_context(ctx) for arr in arrays]) - total_norm = ndarray.sqrt(total_norm) + for group in arrays_groups: + sum_sq = ndarray.multi_sum_sq(*arrays_groups[group], + num_arrays=len(arrays_groups[group])) + sum_sq = ndarray.add_n(*sum_sq) + all_ctx_sum.append(sum_sq.as_in_context(ctx)) + # global reduce + total_norm = ndarray.add_n(*all_ctx_sum).sqrt() if check_isfinite: if not np.isfinite(total_norm.asscalar()): warnings.warn( diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index aa56eee33dc4..bd9bbe3a2500 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -336,14 +336,18 @@ def test_global_norm_clip_multi_device(): for check_isfinite in [True, False]: x1 = mx.nd.ones((3, 3), ctx=mx.gpu(0)) x2 = mx.nd.ones((4, 4), ctx=mx.cpu(0)) + x3 = mx.nd.ones((7, 4), ctx=mx.gpu(0)) + x4 = mx.nd.ones((7, 4), ctx=mx.cpu(0)) norm = gluon.utils.clip_global_norm( - [x1, x2], 1.0, check_isfinite=check_isfinite) + [x1, x2, x3, x4], 1.0, check_isfinite=check_isfinite) if check_isfinite: - assert norm == 5.0 + assert norm == 9.0 else: - assert norm.asscalar() == 5.0 - assert_almost_equal(x1, np.ones((3, 3)) / 5) - assert_almost_equal(x2, np.ones((4, 4)) / 5) + assert norm.asscalar() == 9.0 + assert_almost_equal(x1, np.ones((3, 3)) / 9) + assert_almost_equal(x2, np.ones((4, 4)) / 9) + assert_almost_equal(x3, np.ones((7, 4)) / 9) + assert_almost_equal(x4, np.ones((7, 4)) / 9) def _check_batchnorm_result(input, num_devices=1, cuda=False):