Skip to content

Commit

Permalink
1bit: rm wd mom
Browse files Browse the repository at this point in the history
  • Loading branch information
jasperzhong committed Jun 20, 2020
1 parent 7f66e90 commit f888c8d
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 45 deletions.
8 changes: 1 addition & 7 deletions byteps/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,12 +276,6 @@ def _register_compressor(self, params, optimizer_params, compression_params):

# change
if compression_params.get("momentum"):
# 1bit compressor use an additional momentum for weight decay
if compressor == "onebit" and "wd" in optimizer_params:
intra_compressor = Compression.wdmom(
intra_compressor, optimizer_params["momentum"], optimizer_params["wd"])
del optimizer_params["wd"]

del optimizer_params['momentum']

return intra_compressor
Expand All @@ -308,7 +302,7 @@ def _allreduce_grads(self):
byteps_push_pull(compressed, is_average=False,
name="gradient_" + str(i), priority=-i)
param._grad[0] = self._intra_compressors[i].decompress(
compressed, ctx, x=param._data[0])
compressed, ctx)

def _init_params(self):
tensors = []
Expand Down
38 changes: 0 additions & 38 deletions byteps/mxnet/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,41 +65,6 @@ def decompress(self, tensor, ctx, *args, **kwargs):
return tensor_decompressed


class WeightDecayMomentum(Compressor):
"""For 1bit compression."""

def __init__(self, compressor, mu, wd, *args, **kwargs):
self.compressor = compressor
self.mom = None
self.cache = None
self.mu = mu
self.wd = wd

def compress(self, tensor, *args, **kwargs):
"""Returns the tensor unmodified."""
return self.compressor.compress(tensor)

def decompress(self, tensor, ctx, *args, **kwargs):
"""Returns the tensor added with additional momentum for wd
m_t = \mu * m_{t-1} + wd * x_t
x_{t+1} = x_t - \eta_t (tensor + \mu m_t + wd * x_t)
"""
if "x" not in kwargs:
return self.compressor.decompress(tensor, ctx)

x = kwargs["x"]

if self.mom is None:
self.mom = nd.zeros_like(tensor)
self.cache = nd.zeros_like(tensor)

nd._internal._mul_scalar(x, self.wd, out=self.cache)
self.mom += self.cache
nd._internal._mul_scalar(self.mom, self.mu, out=self.mom)
tensor += self.mom + self.cache
return self.compressor.decompress(tensor, ctx)


class Compression(object):
"""Optional gradient compression algorithm used during push_pull."""

Expand All @@ -109,9 +74,6 @@ class Compression(object):
"""Compress all floating point gradients to 16-bit."""
fp16 = FP16Compressor()

"""Additional Momentum for weight decay. This is only for 1bit. This is a wrapper."""
wdmom = WeightDecayMomentum


# if __name__ == "__main__":
# x = WeightDecayMomentum(Compression.none, 0.9, 1e-4)
Expand Down

0 comments on commit f888c8d

Please sign in to comment.