From 0ad51c7264a2f69041a0c56723c0dfab2a6425ff Mon Sep 17 00:00:00 2001 From: moisesh Date: Fri, 21 Feb 2020 02:45:48 -0800 Subject: [PATCH 1/2] Use multi-tensor sumSQ in clip_global_norm --- python/mxnet/gluon/utils.py | 26 ++++++++++++++++++-------- tests/python/gpu/test_gluon_gpu.py | 14 +++++++++----- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index 51f0a5fd4da9..eeacbb9c6576 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -31,6 +31,8 @@ import weakref import requests +import mxnet as mx +from mxnet import nd import numpy as np from .. import ndarray @@ -132,15 +134,23 @@ def clip_global_norm(arrays, max_norm, check_isfinite=True): False. Otherwise a float is returned. """ - def _norm(array): - if array.stype == 'default': - x = array.reshape((-1,)) - return ndarray.dot(x, x) - return array.norm().square() - assert len(arrays) > 0 + # group arrays by ctx + def group_by_ctx(arr_list): + groups = collections.defaultdict(list) + for arr in arr_list: + ctx = arr.context + groups[ctx].append(arr) + return groups + arrays_groups = group_by_ctx(arrays) + all_ctx_sum = [] ctx = arrays[0].context - total_norm = ndarray.add_n(*[_norm(arr).as_in_context(ctx) for arr in arrays]) - total_norm = ndarray.sqrt(total_norm) + for group in arrays_groups: + sum_sq = mx.nd.multi_sum_sq(*arrays_groups[group], + num_arrays=len(arrays_groups[group])) + sum_sq = nd.add_n(*sum_sq) + all_ctx_sum.append(sum_sq.as_in_context(ctx)) + # global reduce + total_norm = nd.add_n(*all_ctx_sum).sqrt() if check_isfinite: if not np.isfinite(total_norm.asscalar()): warnings.warn( diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index aa56eee33dc4..bd9bbe3a2500 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -336,14 +336,18 @@ def test_global_norm_clip_multi_device(): for check_isfinite in [True, False]: x1 = mx.nd.ones((3, 3), ctx=mx.gpu(0)) x2 = mx.nd.ones((4, 4), ctx=mx.cpu(0)) + x3 = mx.nd.ones((7, 4), ctx=mx.gpu(0)) + x4 = mx.nd.ones((7, 4), ctx=mx.cpu(0)) norm = gluon.utils.clip_global_norm( - [x1, x2], 1.0, check_isfinite=check_isfinite) + [x1, x2, x3, x4], 1.0, check_isfinite=check_isfinite) if check_isfinite: - assert norm == 5.0 + assert norm == 9.0 else: - assert norm.asscalar() == 5.0 - assert_almost_equal(x1, np.ones((3, 3)) / 5) - assert_almost_equal(x2, np.ones((4, 4)) / 5) + assert norm.asscalar() == 9.0 + assert_almost_equal(x1, np.ones((3, 3)) / 9) + assert_almost_equal(x2, np.ones((4, 4)) / 9) + assert_almost_equal(x3, np.ones((7, 4)) / 9) + assert_almost_equal(x4, np.ones((7, 4)) / 9) def _check_batchnorm_result(input, num_devices=1, cuda=False): From a2d69d444e5ca5b7fd7b27a55e5268c6fdf72295 Mon Sep 17 00:00:00 2001 From: moisesh Date: Fri, 21 Feb 2020 03:29:22 -0800 Subject: [PATCH 2/2] fix pylint --- python/mxnet/gluon/utils.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index eeacbb9c6576..8e7a5207c522 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -31,8 +31,6 @@ import weakref import requests -import mxnet as mx -from mxnet import nd import numpy as np from .. import ndarray @@ -145,12 +143,12 @@ def group_by_ctx(arr_list): all_ctx_sum = [] ctx = arrays[0].context for group in arrays_groups: - sum_sq = mx.nd.multi_sum_sq(*arrays_groups[group], - num_arrays=len(arrays_groups[group])) - sum_sq = nd.add_n(*sum_sq) + sum_sq = ndarray.multi_sum_sq(*arrays_groups[group], + num_arrays=len(arrays_groups[group])) + sum_sq = ndarray.add_n(*sum_sq) all_ctx_sum.append(sum_sq.as_in_context(ctx)) # global reduce - total_norm = nd.add_n(*all_ctx_sum).sqrt() + total_norm = ndarray.add_n(*all_ctx_sum).sqrt() if check_isfinite: if not np.isfinite(total_norm.asscalar()): warnings.warn(