diff --git a/byteps/common/compressor/impl/dithering.cc b/byteps/common/compressor/impl/dithering.cc index e1fef8b25..ca84fcfee 100644 --- a/byteps/common/compressor/impl/dithering.cc +++ b/byteps/common/compressor/impl/dithering.cc @@ -33,12 +33,14 @@ CompressorRegistry::Register reg( auto seed = HyperParamFinder(kwargs, "seed", true, [](unsigned x) { return x != 0; }); - auto ptype_int = HyperParamFinder( - kwargs, "partition", true, [](int x) { return x == 0 || x == 1; }); + auto ptype_int = + HyperParamFinder(kwargs, "dithering_partition", true, + [](int x) { return x == 0 || x == 1; }); auto ptype = static_cast(ptype_int); - auto ntype_int = HyperParamFinder( - kwargs, "normalize", true, [](int x) { return x == 0 || x == 1; }); + auto ntype_int = + HyperParamFinder(kwargs, "dithering_normalize", true, + [](int x) { return x == 0 || x == 1; }); auto ntype = static_cast(ntype_int); return std::unique_ptr( @@ -85,11 +87,11 @@ tensor_t DitheringCompressor::CompressImpl(index_t* dst, const scalar_t* src, const unsigned level = 1 << (_s - 1); for (size_t i = 0; i < len; ++i) { float abs_x = std::abs(src[i]); - float normalized = (abs_x / scale) * level; - unsigned low = RoundNextPow2(std::ceil(normalized)) >> 1; - unsigned length = (low != 0) ? low : 1; - unsigned quantized = - low + length * _rng.Bernoulli((normalized - low) / length); + double normalized = (abs_x / scale) * level; + unsigned floor = RoundNextPow2(std::ceil(normalized)) >> 1; + unsigned length = (floor != 0) ? floor : 1; + double p = (normalized - floor) / length; + unsigned quantized = floor + length * _rng.Bernoulli(p); if (quantized) { size_t diff = i - last_non_zero_pos; last_non_zero_pos = i; diff --git a/byteps/common/compressor/impl/dithering.h b/byteps/common/compressor/impl/dithering.h index a5ffc8fe8..d27c5e5cf 100644 --- a/byteps/common/compressor/impl/dithering.h +++ b/byteps/common/compressor/impl/dithering.h @@ -76,7 +76,7 @@ class DitheringCompressor : public Compressor { const index_t* compressed, size_t compressed_size); /*! \brief number of levels */ - unsigned int _s; + const unsigned int _s; PartitionType _ptype; NomalizeType _ntype; diff --git a/byteps/common/compressor/utils.h b/byteps/common/compressor/utils.h index b5c21a2ee..64683fdc5 100644 --- a/byteps/common/compressor/utils.h +++ b/byteps/common/compressor/utils.h @@ -87,7 +87,7 @@ class XorShift128PlusBitShifterRNG { double Rand() { return double(xorshift128p()) / MAX; } // Bernoulli Distributation - bool Bernoulli(double p) { return xorshift128p() < uint64_t(p * MAX); } + bool Bernoulli(double p) { return xorshift128p() < p * MAX; } void set_seed(uint64_t seed) { _state = {seed, seed}; } diff --git a/byteps/mxnet/__init__.py b/byteps/mxnet/__init__.py index c2771add2..3f0699cb7 100644 --- a/byteps/mxnet/__init__.py +++ b/byteps/mxnet/__init__.py @@ -227,6 +227,12 @@ def __init__(self, params, optimizer, optimizer_params=None, root_rank=0, compre ) byteps_declare_tensor("gradient_" + str(i), **byteps_params) + def __del__(self): + if local_rank() == 0: + self._f.close() + if os.path.exists("lr.s"): + os.remove("lr.s") + def _register_compressor(self, params, optimizer_params, compression_params): """Register compressor for BytePS @@ -274,6 +280,22 @@ def _register_compressor(self, params, optimizer_params, compression_params): if compression_params.get("seed", None) is not None: setattr(param, "byteps_seed", compression_params["seed"]) + + if compression_params.get("partition"): + if compression_params["partition"] == "linear": + setattr(param, "byteps_dithering_partition", "0") + elif compression_params["partition"] == "natural": + setattr(param, "byteps_dithering_partition", "1") + else: + raise ValueError("Unsupported partition") + + if compression_params.get("normalize"): + if compression_params["normalize"] == "max": + setattr(param, "byteps_dithering_normalize", "0") + elif compression_params["normalize"] == "l2": + setattr(param, "byteps_dithering_normalize", "1") + else: + raise ValueError("Unsupported normalization") # the following code will delete some items in `optimizer_params` # to avoid duplication diff --git a/tests/meta_test.py b/tests/meta_test.py new file mode 100644 index 000000000..fbdbfb456 --- /dev/null +++ b/tests/meta_test.py @@ -0,0 +1,85 @@ +# Copyright 2020 Amazon Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import copy +import time +import os +import subprocess +import sys +import threading + +import byteps.mxnet as bps + + +class MetaTest(type): + BASE_ENV = {"DMLC_NUM_WORKER": "1", + "DMLC_NUM_SERVER": "1", + "DMLC_PS_ROOT_URI": "127.0.0.1", + "DMLC_PS_ROOT_PORT": "1234", + "BYTEPS_LOG_LEVEL": "INFO", + "BYTEPS_MIN_COMPRESS_BYTES": "0", + "BYTEPS_PARTITION_BYTES": "2147483647"} + for name, value in os.environ.items(): + if name not in BASE_ENV: + BASE_ENV[name] = value + SCHEDULER_ENV = copy.copy(BASE_ENV) + SCHEDULER_ENV.update(DMLC_ROLE="scheduler") + SERVER_ENV = copy.copy(BASE_ENV) + SERVER_ENV.update(DMLC_ROLE="server") + + def __new__(cls, name, bases, dict): + # decorate all test cases + for k, v in dict.items(): + if k.startswith("test_") and hasattr(v, "__call__"): + dict[k] = cls.launch_bps(v) + + for k, v in cls.BASE_ENV.items(): + os.environ[k] = v + os.environ["NVIDIA_VISIBLE_DEVICES"] = "0" + os.environ["DMLC_WORKER_ID"] = "0" + os.environ["DMLC_ROLE"] = "worker" + os.environ["BYTEPS_THREADPOOL_SIZE"] = "4" + os.environ["BYTEPS_FORCE_DISTRIBUTED"] = "1" + os.environ["BYTEPS_LOCAL_RANK"] = "0" + os.environ["BYTEPS_LOCAL_SIZE"] = "1" + return type(name, bases, dict) + + @classmethod + def launch_bps(cls, func): + def wrapper(*args, **kwargs): + def run(env): + subprocess.check_call(args=["bpslaunch"], shell=True, + stdout=sys.stdout, stderr=sys.stderr, + env=env) + + print("bps init") + scheduler = threading.Thread(target=run, + args=(cls.SCHEDULER_ENV,)) + server = threading.Thread(target=run, args=(cls.SERVER_ENV,)) + scheduler.daemon = True + server.daemon = True + scheduler.start() + server.start() + + bps.init() + func(*args, **kwargs) + bps.shutdown() + + scheduler.join() + server.join() + print("bps shutdown") + time.sleep(2) + + return wrapper diff --git a/tests/run_byteps_test.sh b/tests/run_byteps_test.sh index ea022c3c3..a0d125ca6 100755 --- a/tests/run_byteps_test.sh +++ b/tests/run_byteps_test.sh @@ -2,7 +2,7 @@ path="$(dirname $0)" -export PATH=~/.local/bin:$PATH +export PATH=~/anaconda3/envs/mxnet_p36/bin:$PATH export DMLC_NUM_WORKER=1 export DMLC_NUM_SERVER=1 export DMLC_PS_ROOT_URI=127.0.0.1 @@ -32,17 +32,9 @@ export BYTEPS_THREADPOOL_SIZE=4 export BYTEPS_FORCE_DISTRIBUTED=1 export BYTEPS_LOG_LEVEL=WARNING -if [ "$TEST_TYPE" == "mxnet" ]; then - echo "TEST MXNET ..." - bpslaunch python3 $path/test_mxnet.py $@ -elif [ "$TEST_TYPE" == "keras" ]; then +if [ "$TEST_TYPE" == "keras" ]; then echo "TEST KERAS ..." python $path/test_tensorflow_keras.py $@ -elif [ "$TEST_TYPE" == "onebit" ] || [ "$TEST_TYPE" == "topk" ] || [ "$TEST_TYPE" == "randomk" ] || [ "$TEST_TYPE" == "dithering" ]; then - export BYTEPS_MIN_COMPRESS_BYTES=0 - export BYTEPS_PARTITION_BYTES=2147483647 - echo "TEST $TEST_TYPE" - bpslaunch python3 test_$TEST_TYPE.py else echo "Error: unsupported $TEST_TYPE" exit 1 diff --git a/tests/test_dithering.py b/tests/test_dithering.py index b25b15b71..f1fa159d0 100644 --- a/tests/test_dithering.py +++ b/tests/test_dithering.py @@ -1,3 +1,20 @@ +# Copyright 2020 Amazon Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import copy +import itertools import unittest import byteps.mxnet as bps @@ -10,6 +27,7 @@ from parameterized import parameterized from tqdm import tqdm +from meta_test import MetaTest from utils import bernoulli, fake_data @@ -25,7 +43,6 @@ def round_next_pow2(v): return v -# partition: 'linear' or 'natural' def dithering(x, k, state, partition='linear', norm="max"): y = x.flatten() if norm == "max": @@ -48,7 +65,7 @@ def dithering(x, k, state, partition='linear', norm="max"): elif partition == "natural": y *= 2**(k-1) low = round_next_pow2((np.ceil(y).astype(np.uint32))) >> 1 - length = low.copy() + length = copy.deepcopy(low) length[length == 0] = 1 p = (y - low) / length y = low + length * bernoulli(p, state) @@ -62,33 +79,27 @@ def dithering(x, k, state, partition='linear', norm="max"): return y.reshape(x.shape) -class DitheringTestCase(unittest.TestCase): - def setUp(self): - print("init") - bps.init() - - @parameterized.expand([(2, "natural", "max"),]) - def test_dithering(self, k, ptype, ntype): +class DitheringTestCase(unittest.TestCase, metaclass=MetaTest): + @parameterized.expand(itertools.product([2, 4, 8], ["linear, natural"], ["max", "l2"], np.random.randint(0, 2020, size=3).tolist())) + def test_dithering(self, k, ptype, ntype, seed): ctx = mx.gpu(0) net = get_model("resnet18_v2") net.initialize(mx.init.Xavier(), ctx=ctx) net.summary(nd.ones((1, 3, 224, 224), ctx=ctx)) # hyper-params - seed = 2020 batch_size = 32 optimizer_params = {'momentum': 0, 'wd': 0, 'learning_rate': 0.01} compression_params = { "compressor": "dithering", - # "ef": "vanilla", - # "momentum": "nesterov", "k": k, "partition": ptype, "normalize": ntype, "seed": seed } + print(compression_params) trainer = bps.DistributedTrainer(net.collect_params( ), "sgd", optimizer_params, compression_params=compression_params) @@ -98,20 +109,12 @@ def test_dithering(self, k, ptype, ntype): train_data = fake_data(batch_size=batch_size) params = {} - errors = {} - errors_s = {} - moms = {} - wd_moms = {} rngs = {} rngs_s = {} for i, param in enumerate(trainer._params): if param.grad_req != 'null': params[i] = param._data[0].asnumpy() - errors[i] = np.zeros_like(params[i]) - errors_s[i] = np.zeros_like(params[i]) - moms[i] = np.zeros_like(params[i]) - wd_moms[i] = np.zeros_like(params[i]) rngs[i] = np.array([seed, seed], dtype=np.uint64) rngs_s[i] = np.array([seed, seed], dtype=np.uint64) @@ -138,41 +141,37 @@ def test_dithering(self, k, ptype, ntype): for i, param in enumerate(trainer._params): if param.grad_req != "null": g = gs[i] / (batch_size * bps.size()) - # print("norm2", norm2(g.flatten())/k) - # moms[i] *= 0.9 - # moms[i] += g - # g += 0.9 * moms[i] - # g += errors[i] c = dithering(g, k, rngs[i], ptype, ntype) - # errors[i] = g - c - # c += errors_s[i] cs = dithering(c, k, rngs_s[i], ptype, ntype) - # errors_s[i] = c - cs c = cs - # c += 1e-4*xs[i] params[i] -= optimizer_params["learning_rate"] * c + np_g = c.flatten() + mx_g = param._grad[0].asnumpy().flatten() + if not np.allclose(np_g, mx_g, atol=np.finfo(np.float32).eps): + diff = np.abs(np_g - mx_g) + print("np", np_g) + print("mx", mx_g) + print("diff", diff) + print("max diff", np.max(diff)) + idx = np.nonzero(diff > 1e-5) + print("idx", idx, np_g[idx], mx_g[idx]) + input() + cnt = 0 tot = 0 - diffs = [] for i, param in enumerate(trainer._params): if param.grad_req != "null": x = param._data[0].asnumpy() tot += len(x.flatten()) if not np.allclose(params[i], x, atol=np.finfo(np.float32).eps): diff = np.abs(x.flatten() - params[i].flatten()) - diffs.append(np.max(diff)) idx = np.where(diff > np.finfo(np.float32).eps) cnt += len(idx[0]) - print("false=%d tot=%d false / tot = %lf" % (cnt, tot, cnt / tot)) - if diffs: - print("max_diff=%f\tmin_diff=%f\tmean_diff=%f" % - (np.max(diffs), np.min(diffs), np.mean(diffs))) - - assert cnt == 0 + assert cnt == 0, "false/tot=%d/%d=%f" % (cnt, tot, cnt/tot) if __name__ == '__main__': diff --git a/tests/test_mxnet.py b/tests/test_mxnet.py index 9bb2e5db6..7f192915b 100644 --- a/tests/test_mxnet.py +++ b/tests/test_mxnet.py @@ -1,3 +1,5 @@ +# Copyright 2020 Amazon Technologies, Inc. All Rights Reserved. +# Copyright 2019 ByteDance Technologies, Inc. All Rights Reserved. # Copyright 2018 Uber Technologies, Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,45 +15,29 @@ # limitations under the License. # ============================================================================== -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +import itertools +import unittest import byteps.mxnet as bps -import itertools import mxnet as mx -import os import numpy as np -import unittest -from mxnet.base import MXNetError -from mxnet.test_utils import same + +from meta_test import MetaTest has_gpu = mx.context.num_gpus() > 0 -# MLSL supports only byte, float and double data types -mlsl_supported_types = set(['float32', 'float64']) -class MXTest: +class MXTest(unittest.TestCase, metaclass=MetaTest): """ Tests for ops in byteps.mxnet. """ - def _current_context(self): if has_gpu: return mx.gpu(bps.local_rank()) else: return mx.current_context() - - def filter_supported_types(self, types): - if 'MLSL_ROOT' in os.environ: - types = [t for t in types if t in mlsl_supported_types] - return types - + def test_byteps_trainer_param_order(self): - size = bps.size() - dtypes = self.filter_supported_types(['float32']) - dims = [1] - ctx = self._current_context() net = mx.gluon.nn.Sequential() # layers may be added in a random order for all workers layers = {'ones_': 1, 'zeros_': 0} @@ -65,19 +51,19 @@ def test_byteps_trainer_param_order(self): # check the result of bps_broadcast for name, init in layers.items(): weight = params[name + 'weight'].data()[0].asnumpy() - expected = np.full(shape=weight.shape, fill_value=init, dtype=weight.dtype) + expected = np.full(shape=weight.shape, + fill_value=init, dtype=weight.dtype) assert np.array_equal(weight, expected), (weight, expected) print('test_byteps_trainer_param_order passed') def test_byteps_push_pull(self): """Test that the byteps_push_pull correctly sums 1D, 2D, 3D tensors.""" - size = bps.size() - dtypes = self.filter_supported_types(['float32']) - dims = [1] + dtypes = ['float16', 'float32', 'float64'] + dims = [1, 2, 3] + count = 0 ctx = self._current_context() - count = 100 - shapes = [(), (17)] + shapes = [(), (17), (17, 17), (17, 17, 17)] for dtype, dim in itertools.product(dtypes, dims): # MXNet uses gpu_id as part of the seed, so to get identical seeds # we must set a context. @@ -85,24 +71,24 @@ def test_byteps_push_pull(self): tensor = mx.nd.random.uniform(-100, 100, shape=shapes[dim], ctx=ctx) tensor = tensor.astype(dtype) + input = tensor.asnumpy() - print("tensor before push_pull:", tensor) bps.byteps_declare_tensor("tensor_" + str(count)) bps.byteps_push_pull(tensor, name="tensor_"+str(count)) tensor.wait_to_read() - print("tensor after push_pull:", tensor) + output = tensor.asnumpy() + assert np.allclose(input, output) + count += 1 print('test_byteps_push_pull passed') - def test_byteps_push_pull_inplace(self): """Test that the byteps_push_pull correctly sums 1D, 2D, 3D tensors.""" size = bps.size() - dtypes = self.filter_supported_types(['int32', 'int64', - 'float32', 'float64']) + dtypes = ['float16', 'float32', 'float64'] dims = [1, 2, 3] + count = 0 ctx = self._current_context() - count = 200 shapes = [(), (17), (17, 17), (17, 17, 17)] for dtype, dim in itertools.product(dtypes, dims): mx.random.seed(1234, ctx=ctx) @@ -111,7 +97,7 @@ def test_byteps_push_pull_inplace(self): tensor = tensor.astype(dtype) multiplied = tensor.copy() bps.byteps_declare_tensor("tensor_" + str(count)) - bps.byteps_push_pull(tensor, name= "tensor_" + str(count)) + bps.byteps_push_pull(tensor, name="tensor_" + str(count)) max_difference = mx.nd.max(mx.nd.subtract(tensor, multiplied)) count += 1 @@ -136,54 +122,5 @@ def test_byteps_push_pull_inplace(self): print('test_byteps_push_pull_inplace passed') - def test_byteps_broadcast(self): - """Test that the broadcast correctly broadcasts 1D, 2D, 3D tensors.""" - rank = bps.rank() - size = bps.size() - - # This test does not apply if there is only one worker. - if size == 1: - return - - dtypes = ['int32', 'int64', - 'float32', 'float64'] - dims = [1, 2, 3] - ctx = self._current_context() - count = 300 - shapes = [(), (17), (17, 17), (17, 17, 17)] - root_ranks = list(range(size)) - for dtype, dim, root_rank in itertools.product(dtypes, dims, - root_ranks): - tensor = mx.nd.ones(shapes[dim], ctx=ctx) * rank - root_tensor = mx.nd.ones(shapes[dim], ctx=ctx) * root_rank - tensor = tensor.astype(dtype) - root_tensor = root_tensor.astype(dtype) - - broadcast_tensor = bps.broadcast(tensor, root_rank=root_rank, - name=str(count)) - if rank != root_rank: - if same(tensor.asnumpy(), root_tensor.asnumpy()): - print("broadcast", count, dtype, dim, - mx.nd.max(tensor == root_tensor)) - print("tensor", bps.rank(), tensor) - print("root_tensor", bps.rank(), root_tensor) - print("comparison", bps.rank(), tensor == root_tensor) - assert not same(tensor.asnumpy(), root_tensor.asnumpy()), \ - 'bps.broadcast modifies source tensor' - if not same(broadcast_tensor.asnumpy(), root_tensor.asnumpy()): - print("broadcast", count, dtype, dim) - print("broadcast_tensor", bps.rank(), broadcast_tensor) - print("root_tensor", bps.rank(), root_tensor) - print("comparison", bps.rank(), - broadcast_tensor == root_tensor) - assert same(broadcast_tensor.asnumpy(), root_tensor.asnumpy()), \ - 'bps.broadcast produces incorrect broadcasted tensor' - - if __name__ == '__main__': - mxtest = MXTest() - bps.init() - mxtest.test_byteps_push_pull() - mxtest.test_byteps_trainer_param_order() - #mxtest.test_byteps_broadcast() - mxtest.test_byteps_push_pull_inplace() + unittest.main() diff --git a/tests/test_onebit.py b/tests/test_onebit.py index d995fbfd4..426e803f2 100644 --- a/tests/test_onebit.py +++ b/tests/test_onebit.py @@ -1,3 +1,19 @@ +# Copyright 2020 Amazon Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import itertools import unittest import byteps.mxnet as bps @@ -6,20 +22,27 @@ import numpy as np from gluoncv.model_zoo import get_model from mxnet import autograd, gluon +from parameterized import parameterized from tqdm import tqdm +from meta_test import MetaTest from utils import fake_data -def onebit(x): - l1 = np.linalg.norm(x.flatten(), 1) +def onebit(x, scaling): + if scaling: + l1 = np.linalg.norm(x.flatten(), 1) sign = x < 0 sign = -((sign << 1) - 1) - return l1 / len(x.flatten()) * sign + if scaling: + return l1 / len(x.flatten()) * sign + else: + return sign -class OnebitTestCase(unittest.TestCase): - def test_onebit(self): +class OnebitTestCase(unittest.TestCase, metaclass=MetaTest): + @parameterized.expand(itertools.product([True, False])) + def test_onebit(self, scaling): bps.init() ctx = mx.gpu(0) net = get_model("resnet18_v2") @@ -33,9 +56,7 @@ def test_onebit(self): compression_params = { "compressor": "onebit", - # "ef": "vanilla", - # "momentum": "nesterov", - "scaling": True, + "scaling": scaling, } trainer = bps.DistributedTrainer(net.collect_params( @@ -46,18 +67,10 @@ def test_onebit(self): train_data = fake_data(batch_size=batch_size) params = {} - errors = {} - errors_s = {} - moms = {} - wd_moms = {} for i, param in enumerate(trainer._params): if param.grad_req != 'null': params[i] = param._data[0].asnumpy() - errors[i] = np.zeros_like(params[i]) - errors_s[i] = np.zeros_like(params[i]) - moms[i] = np.zeros_like(params[i]) - wd_moms[i] = np.zeros_like(params[i]) for it, batch in tqdm(enumerate(train_data)): data = batch[0].as_in_context(ctx) @@ -82,41 +95,25 @@ def test_onebit(self): for i, param in enumerate(trainer._params): if param.grad_req != "null": g = gs[i] / (batch_size * bps.size()) - # moms[i] *= 0.9 - # moms[i] += g - # g += 0.9 * moms[i] - # g += errors[i] - c = onebit(g) - # errors[i] = g - c - - # c += errors_s[i] - cs = onebit(c) - # errors_s[i] = c - cs + c = onebit(g, scaling) + + cs = onebit(c, scaling) c = cs - # wd_moms[i] = 0.9 * wd_moms[i] + 1e-4 * xs[i] - # c += 0.9 * wd_moms[i] + 1e-4 * xs[i] params[i] -= optimizer_params["learning_rate"] * c cnt = 0 tot = 0 - diffs = [] for i, param in enumerate(trainer._params): if param.grad_req != "null": x = param._data[0].asnumpy() tot += len(x.flatten()) if not np.allclose(params[i], x, atol=np.finfo(np.float32).eps): diff = np.abs(x.flatten() - params[i].flatten()) - diffs.append(np.max(diff)) idx = np.where(diff > np.finfo(np.float32).eps) cnt += len(idx[0]) - print("false=%d tot=%d false / tot = %lf" % (cnt, tot, cnt / tot)) - if diffs: - print("max_diff=%f\tmin_diff=%f\tmean_diff=%f" % - (np.max(diffs), np.min(diffs), np.mean(diffs))) - - assert cnt == 0 + assert cnt == 0, "false/tot=%d/%d=%f" % (cnt, tot, cnt/tot) if __name__ == '__main__': diff --git a/tests/test_randomk.py b/tests/test_randomk.py index 634a2ab97..9c481e7d8 100644 --- a/tests/test_randomk.py +++ b/tests/test_randomk.py @@ -1,3 +1,19 @@ +# Copyright 2020 Amazon Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import itertools import unittest import byteps.mxnet as bps @@ -10,6 +26,7 @@ from parameterized import parameterized from tqdm import tqdm +from meta_test import MetaTest from utils import fake_data, randint @@ -27,29 +44,21 @@ def randomk(x, k, state): return y.reshape(x.shape) -class RandomkTestCase(unittest.TestCase): - def setUp(self): - print("init") - bps.init() - - @parameterized.expand([(1,)]) - def test_randomk(self, k): +class RandomkTestCase(unittest.TestCase, metaclass=MetaTest): + @parameterized.expand(itertools.product([1, 3, 5], np.random.randint(0, 2020, size=3).tolist())) + def test_randomk(self, k, seed): ctx = mx.gpu(0) - # np.random.seed(2020) net = get_model("resnet18_v2") net.initialize(mx.init.Xavier(), ctx=ctx) net.summary(nd.ones((1, 3, 224, 224), ctx=ctx)) # hyper-params - seed = 2020 batch_size = 32 optimizer_params = {'momentum': 0, 'wd': 0, 'learning_rate': 0.01} compression_params = { "compressor": "randomk", - # "ef": "vanilla", - # "momentum": "nesterov", "k": k, "seed": seed } @@ -62,20 +71,12 @@ def test_randomk(self, k): train_data = fake_data(batch_size=batch_size) params = {} - errors = {} - errors_s = {} - moms = {} - wd_moms = {} rngs = {} rngs_s = {} for i, param in enumerate(trainer._params): if param.grad_req != 'null': params[i] = param._data[0].asnumpy() - errors[i] = np.zeros_like(params[i]) - errors_s[i] = np.zeros_like(params[i]) - moms[i] = np.zeros_like(params[i]) - wd_moms[i] = np.zeros_like(params[i]) rngs[i] = np.array([seed, seed], dtype=np.uint64) rngs_s[i] = np.array([seed, seed], dtype=np.uint64) @@ -102,40 +103,25 @@ def test_randomk(self, k): for i, param in enumerate(trainer._params): if param.grad_req != "null": g = gs[i] / (batch_size * bps.size()) - # moms[i] *= 0.9 - # moms[i] += g - # g += 0.9 * moms[i] - # g += errors[i] c = randomk(g, k, rngs[i]) - # errors[i] = g - c - # c += errors_s[i] cs = randomk(c, k, rngs_s[i]) - # errors_s[i] = c - cs c = cs - # c += 1e-4*xs[i] params[i] -= optimizer_params["learning_rate"] * c cnt = 0 tot = 0 - diffs = [] for i, param in enumerate(trainer._params): if param.grad_req != "null": x = param._data[0].asnumpy() tot += len(x.flatten()) if not np.allclose(params[i], x, atol=np.finfo(np.float32).eps): diff = np.abs(x.flatten() - params[i].flatten()) - diffs.append(np.max(diff)) idx = np.where(diff > np.finfo(np.float32).eps) cnt += len(idx[0]) - print("false=%d tot=%d false / tot = %lf" % (cnt, tot, cnt / tot)) - if diffs: - print("max_diff=%f\tmin_diff=%f\tmean_diff=%f" % - (np.max(diffs), np.min(diffs), np.mean(diffs))) - - assert cnt == 0 + assert cnt == 0, "false/tot=%d/%d=%f" % (cnt, tot, cnt/tot) if __name__ == '__main__': diff --git a/tests/test_topk.py b/tests/test_topk.py index 6504ca1a4..2a5409b89 100644 --- a/tests/test_topk.py +++ b/tests/test_topk.py @@ -1,3 +1,20 @@ +# Copyright 2020 Amazon Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import itertools +import random import unittest import byteps.mxnet as bps @@ -9,6 +26,7 @@ from parameterized import parameterized from tqdm import tqdm +from meta_test import MetaTest from utils import fake_data @@ -22,12 +40,8 @@ def topk(x, k): return y.reshape(x.shape) -class TopkTestCase(unittest.TestCase): - def setUp(self): - print("init") - bps.init() - - @parameterized.expand([(1,)]) +class TopkTestCase(unittest.TestCase, metaclass=MetaTest): + @parameterized.expand(itertools.product([1, 3, 5])) def test_topk(self, k): ctx = mx.gpu(0) net = get_model("resnet18_v2") @@ -41,8 +55,6 @@ def test_topk(self, k): compression_params = { "compressor": "topk", - # "ef": "vanilla", - # "momentum": "nesterov", "k": k, } @@ -54,18 +66,10 @@ def test_topk(self, k): train_data = fake_data(batch_size=batch_size) params = {} - errors = {} - errors_s = {} - moms = {} - wd_moms = {} for i, param in enumerate(trainer._params): if param.grad_req != 'null': params[i] = param._data[0].asnumpy() - errors[i] = np.zeros_like(params[i]) - errors_s[i] = np.zeros_like(params[i]) - moms[i] = np.zeros_like(params[i]) - wd_moms[i] = np.zeros_like(params[i]) for it, batch in tqdm(enumerate(train_data)): data = batch[0].as_in_context(ctx) @@ -90,40 +94,25 @@ def test_topk(self, k): for i, param in enumerate(trainer._params): if param.grad_req != "null": g = gs[i] / (batch_size * bps.size()) - # moms[i] *= 0.9 - # moms[i] += g - # g += 0.9 * moms[i] - # g += errors[i] c = topk(g, k) - # errors[i] = g - c - # c += errors_s[i] cs = topk(c, k) - # errors_s[i] = c - cs c = cs - # c += 1e-4*xs[i] params[i] -= optimizer_params["learning_rate"] * c cnt = 0 tot = 0 - diffs = [] for i, param in enumerate(trainer._params): if param.grad_req != "null": x = param._data[0].asnumpy() tot += len(x.flatten()) if not np.allclose(params[i], x, atol=np.finfo(np.float32).eps): diff = np.abs(x.flatten() - params[i].flatten()) - diffs.append(np.max(diff)) idx = np.where(diff > np.finfo(np.float32).eps) cnt += len(idx[0]) - print("false=%d tot=%d false / tot = %lf" % (cnt, tot, cnt / tot)) - if diffs: - print("max_diff=%f\tmin_diff=%f\tmean_diff=%f" % - (np.max(diffs), np.min(diffs), np.mean(diffs))) - - assert cnt == 0 + assert cnt == 0, "false/tot=%d/%d=%f" % (cnt, tot, cnt/tot) if __name__ == '__main__': diff --git a/tests/utils.py b/tests/utils.py index 59fa44e35..b3ee93fcf 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -43,7 +43,7 @@ def xorshift128p(state): @jit(nopython=True) def bernoulli(p, state): t = p * np.iinfo(np.uint64).max - r = np.array([xorshift128p(state) for _ in range(len(p))], dtype=np.uint64) + r = np.array([xorshift128p(state) for _ in range(len(p))], dtype=np.float32) return r < t