Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: full test coverage #53

Merged
merged 13 commits into from
Aug 12, 2020
20 changes: 11 additions & 9 deletions byteps/common/compressor/impl/dithering.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ CompressorRegistry::Register reg(
auto seed = HyperParamFinder<unsigned>(kwargs, "seed", true,
[](unsigned x) { return x != 0; });

auto ptype_int = HyperParamFinder<int>(
kwargs, "partition", true, [](int x) { return x == 0 || x == 1; });
auto ptype_int =
HyperParamFinder<int>(kwargs, "dithering_partition", true,
[](int x) { return x == 0 || x == 1; });
auto ptype = static_cast<DitheringCompressor::PartitionType>(ptype_int);

auto ntype_int = HyperParamFinder<int>(
kwargs, "normalize", true, [](int x) { return x == 0 || x == 1; });
auto ntype_int =
HyperParamFinder<int>(kwargs, "dithering_normalize", true,
[](int x) { return x == 0 || x == 1; });
auto ntype = static_cast<DitheringCompressor::NomalizeType>(ntype_int);

return std::unique_ptr<Compressor>(
Expand Down Expand Up @@ -85,11 +87,11 @@ tensor_t DitheringCompressor::CompressImpl(index_t* dst, const scalar_t* src,
const unsigned level = 1 << (_s - 1);
for (size_t i = 0; i < len; ++i) {
float abs_x = std::abs(src[i]);
float normalized = (abs_x / scale) * level;
unsigned low = RoundNextPow2(std::ceil(normalized)) >> 1;
unsigned length = (low != 0) ? low : 1;
unsigned quantized =
low + length * _rng.Bernoulli((normalized - low) / length);
double normalized = (abs_x / scale) * level;
unsigned floor = RoundNextPow2(std::ceil(normalized)) >> 1;
unsigned length = (floor != 0) ? floor : 1;
double p = (normalized - floor) / length;
unsigned quantized = floor + length * _rng.Bernoulli(p);
if (quantized) {
size_t diff = i - last_non_zero_pos;
last_non_zero_pos = i;
Expand Down
2 changes: 1 addition & 1 deletion byteps/common/compressor/impl/dithering.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class DitheringCompressor : public Compressor {
const index_t* compressed, size_t compressed_size);

/*! \brief number of levels */
unsigned int _s;
const unsigned int _s;

PartitionType _ptype;
NomalizeType _ntype;
Expand Down
2 changes: 1 addition & 1 deletion byteps/common/compressor/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class XorShift128PlusBitShifterRNG {
double Rand() { return double(xorshift128p()) / MAX; }

// Bernoulli Distributation
bool Bernoulli(double p) { return xorshift128p() < uint64_t(p * MAX); }
bool Bernoulli(double p) { return xorshift128p() < p * MAX; }

void set_seed(uint64_t seed) { _state = {seed, seed}; }

Expand Down
22 changes: 22 additions & 0 deletions byteps/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,12 @@ def __init__(self, params, optimizer, optimizer_params=None, root_rank=0, compre
)
byteps_declare_tensor("gradient_" + str(i), **byteps_params)

def __del__(self):
if local_rank() == 0:
self._f.close()
if os.path.exists("lr.s"):
os.remove("lr.s")

def _register_compressor(self, params, optimizer_params, compression_params):
"""Register compressor for BytePS

Expand Down Expand Up @@ -274,6 +280,22 @@ def _register_compressor(self, params, optimizer_params, compression_params):
if compression_params.get("seed", None) is not None:
setattr(param, "byteps_seed",
compression_params["seed"])

if compression_params.get("partition"):
if compression_params["partition"] == "linear":
setattr(param, "byteps_dithering_partition", "0")
elif compression_params["partition"] == "natural":
setattr(param, "byteps_dithering_partition", "1")
else:
raise ValueError("Unsupported partition")

if compression_params.get("normalize"):
if compression_params["normalize"] == "max":
setattr(param, "byteps_dithering_normalize", "0")
elif compression_params["normalize"] == "l2":
setattr(param, "byteps_dithering_normalize", "1")
else:
raise ValueError("Unsupported normalization")

# the following code will delete some items in `optimizer_params`
# to avoid duplication
Expand Down
85 changes: 85 additions & 0 deletions tests/meta_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2020 Amazon Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import copy
import time
import os
import subprocess
import sys
import threading

import byteps.mxnet as bps


class MetaTest(type):
BASE_ENV = {"DMLC_NUM_WORKER": "1",
"DMLC_NUM_SERVER": "1",
"DMLC_PS_ROOT_URI": "127.0.0.1",
"DMLC_PS_ROOT_PORT": "1234",
"BYTEPS_LOG_LEVEL": "INFO",
"BYTEPS_MIN_COMPRESS_BYTES": "0",
"BYTEPS_PARTITION_BYTES": "2147483647"}
for name, value in os.environ.items():
if name not in BASE_ENV:
BASE_ENV[name] = value
SCHEDULER_ENV = copy.copy(BASE_ENV)
SCHEDULER_ENV.update(DMLC_ROLE="scheduler")
SERVER_ENV = copy.copy(BASE_ENV)
SERVER_ENV.update(DMLC_ROLE="server")

def __new__(cls, name, bases, dict):
# decorate all test cases
for k, v in dict.items():
if k.startswith("test_") and hasattr(v, "__call__"):
dict[k] = cls.launch_bps(v)

for k, v in cls.BASE_ENV.items():
os.environ[k] = v
os.environ["NVIDIA_VISIBLE_DEVICES"] = "0"
os.environ["DMLC_WORKER_ID"] = "0"
os.environ["DMLC_ROLE"] = "worker"
os.environ["BYTEPS_THREADPOOL_SIZE"] = "4"
os.environ["BYTEPS_FORCE_DISTRIBUTED"] = "1"
os.environ["BYTEPS_LOCAL_RANK"] = "0"
os.environ["BYTEPS_LOCAL_SIZE"] = "1"
return type(name, bases, dict)

@classmethod
def launch_bps(cls, func):
def wrapper(*args, **kwargs):
def run(env):
subprocess.check_call(args=["bpslaunch"], shell=True,
stdout=sys.stdout, stderr=sys.stderr,
env=env)

print("bps init")
scheduler = threading.Thread(target=run,
args=(cls.SCHEDULER_ENV,))
server = threading.Thread(target=run, args=(cls.SERVER_ENV,))
scheduler.daemon = True
server.daemon = True
scheduler.start()
server.start()

bps.init()
func(*args, **kwargs)
bps.shutdown()

scheduler.join()
server.join()
print("bps shutdown")
time.sleep(2)

return wrapper
12 changes: 2 additions & 10 deletions tests/run_byteps_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

path="$(dirname $0)"

export PATH=~/.local/bin:$PATH
export PATH=~/anaconda3/envs/mxnet_p36/bin:$PATH
export DMLC_NUM_WORKER=1
export DMLC_NUM_SERVER=1
export DMLC_PS_ROOT_URI=127.0.0.1
Expand Down Expand Up @@ -32,17 +32,9 @@ export BYTEPS_THREADPOOL_SIZE=4
export BYTEPS_FORCE_DISTRIBUTED=1
export BYTEPS_LOG_LEVEL=WARNING

if [ "$TEST_TYPE" == "mxnet" ]; then
echo "TEST MXNET ..."
bpslaunch python3 $path/test_mxnet.py $@
elif [ "$TEST_TYPE" == "keras" ]; then
if [ "$TEST_TYPE" == "keras" ]; then
echo "TEST KERAS ..."
python $path/test_tensorflow_keras.py $@
elif [ "$TEST_TYPE" == "onebit" ] || [ "$TEST_TYPE" == "topk" ] || [ "$TEST_TYPE" == "randomk" ] || [ "$TEST_TYPE" == "dithering" ]; then
export BYTEPS_MIN_COMPRESS_BYTES=0
export BYTEPS_PARTITION_BYTES=2147483647
echo "TEST $TEST_TYPE"
bpslaunch python3 test_$TEST_TYPE.py
else
echo "Error: unsupported $TEST_TYPE"
exit 1
Expand Down
73 changes: 36 additions & 37 deletions tests/test_dithering.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
# Copyright 2020 Amazon Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import copy
import itertools
import unittest

import byteps.mxnet as bps
Expand All @@ -10,6 +27,7 @@
from parameterized import parameterized
from tqdm import tqdm

from meta_test import MetaTest
from utils import bernoulli, fake_data


Expand All @@ -25,7 +43,6 @@ def round_next_pow2(v):
return v


# partition: 'linear' or 'natural'
def dithering(x, k, state, partition='linear', norm="max"):
y = x.flatten()
if norm == "max":
Expand All @@ -48,7 +65,7 @@ def dithering(x, k, state, partition='linear', norm="max"):
elif partition == "natural":
y *= 2**(k-1)
low = round_next_pow2((np.ceil(y).astype(np.uint32))) >> 1
length = low.copy()
length = copy.deepcopy(low)
length[length == 0] = 1
p = (y - low) / length
y = low + length * bernoulli(p, state)
Expand All @@ -62,33 +79,27 @@ def dithering(x, k, state, partition='linear', norm="max"):
return y.reshape(x.shape)


class DitheringTestCase(unittest.TestCase):
def setUp(self):
print("init")
bps.init()

@parameterized.expand([(2, "natural", "max"),])
def test_dithering(self, k, ptype, ntype):
class DitheringTestCase(unittest.TestCase, metaclass=MetaTest):
@parameterized.expand(itertools.product([2, 4, 8], ["linear, natural"], ["max", "l2"], np.random.randint(0, 2020, size=3).tolist()))
def test_dithering(self, k, ptype, ntype, seed):
ctx = mx.gpu(0)
net = get_model("resnet18_v2")
net.initialize(mx.init.Xavier(), ctx=ctx)
net.summary(nd.ones((1, 3, 224, 224), ctx=ctx))

# hyper-params
seed = 2020
batch_size = 32
optimizer_params = {'momentum': 0, 'wd': 0,
'learning_rate': 0.01}

compression_params = {
"compressor": "dithering",
# "ef": "vanilla",
# "momentum": "nesterov",
"k": k,
"partition": ptype,
"normalize": ntype,
"seed": seed
}
print(compression_params)

trainer = bps.DistributedTrainer(net.collect_params(
), "sgd", optimizer_params, compression_params=compression_params)
Expand All @@ -98,20 +109,12 @@ def test_dithering(self, k, ptype, ntype):
train_data = fake_data(batch_size=batch_size)

params = {}
errors = {}
errors_s = {}
moms = {}
wd_moms = {}
rngs = {}
rngs_s = {}

for i, param in enumerate(trainer._params):
if param.grad_req != 'null':
params[i] = param._data[0].asnumpy()
errors[i] = np.zeros_like(params[i])
errors_s[i] = np.zeros_like(params[i])
moms[i] = np.zeros_like(params[i])
wd_moms[i] = np.zeros_like(params[i])
rngs[i] = np.array([seed, seed], dtype=np.uint64)
rngs_s[i] = np.array([seed, seed], dtype=np.uint64)

Expand All @@ -138,41 +141,37 @@ def test_dithering(self, k, ptype, ntype):
for i, param in enumerate(trainer._params):
if param.grad_req != "null":
g = gs[i] / (batch_size * bps.size())
# print("norm2", norm2(g.flatten())/k)
# moms[i] *= 0.9
# moms[i] += g
# g += 0.9 * moms[i]
# g += errors[i]
c = dithering(g, k, rngs[i], ptype, ntype)
# errors[i] = g - c

# c += errors_s[i]
cs = dithering(c, k, rngs_s[i], ptype, ntype)
# errors_s[i] = c - cs
c = cs

# c += 1e-4*xs[i]
params[i] -= optimizer_params["learning_rate"] * c

np_g = c.flatten()
mx_g = param._grad[0].asnumpy().flatten()
if not np.allclose(np_g, mx_g, atol=np.finfo(np.float32).eps):
diff = np.abs(np_g - mx_g)
print("np", np_g)
print("mx", mx_g)
print("diff", diff)
print("max diff", np.max(diff))
idx = np.nonzero(diff > 1e-5)
print("idx", idx, np_g[idx], mx_g[idx])
input()

cnt = 0
tot = 0
diffs = []
for i, param in enumerate(trainer._params):
if param.grad_req != "null":
x = param._data[0].asnumpy()
tot += len(x.flatten())
if not np.allclose(params[i], x, atol=np.finfo(np.float32).eps):
diff = np.abs(x.flatten() - params[i].flatten())
diffs.append(np.max(diff))
idx = np.where(diff > np.finfo(np.float32).eps)
cnt += len(idx[0])

print("false=%d tot=%d false / tot = %lf" % (cnt, tot, cnt / tot))
if diffs:
print("max_diff=%f\tmin_diff=%f\tmean_diff=%f" %
(np.max(diffs), np.min(diffs), np.mean(diffs)))

assert cnt == 0
assert cnt == 0, "false/tot=%d/%d=%f" % (cnt, tot, cnt/tot)


if __name__ == '__main__':
Expand Down
Loading