Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Revert "multi-tensor multi-precision lamb update operator"
Browse files Browse the repository at this point in the history
This reverts commit 8a8e166.
  • Loading branch information
Rohit Kumar Srivastava committed Nov 8, 2019
1 parent 8a8e166 commit c0508d3
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 549 deletions.
94 changes: 16 additions & 78 deletions python/mxnet/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
multi_sgd_update, multi_sgd_mom_update, multi_mp_sgd_update,
multi_mp_sgd_mom_update, preloaded_multi_sgd_update,
preloaded_multi_sgd_mom_update, preloaded_multi_mp_sgd_update,
preloaded_multi_mp_sgd_mom_update, lamb_update, multi_mp_lamb_update)
preloaded_multi_mp_sgd_mom_update, lamb_update)
from ..ndarray import sparse
from ..random import normal
from ..util import is_np_array
Expand Down Expand Up @@ -1259,91 +1259,29 @@ def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
self.upper_bound = upper_bound
self.bias_correction = bias_correction

def create_state_multi_precision(self, index, weight):
weight_master_copy = None
if self.multi_precision and weight.dtype == numpy.float16:
weight_master_copy = weight.astype(numpy.float32)
return (self.create_state(index, weight_master_copy), weight_master_copy)
if weight.dtype == numpy.float16 and not self.multi_precision:
warnings.warn("Accumulating with float16 in optimizer can lead to "
"poor accuracy or slow convergence. "
"Consider using multi_precision=True option of the "
"LAMB optimizer")
return self.create_state(index, weight)

def create_state(self, index, weight):
stype = weight.stype
dtype = weight.dtype
return (zeros(weight.shape, weight.context, dtype=dtype, stype=stype),
zeros(weight.shape, weight.context, dtype=dtype, stype=stype))

def _update_impl(self, indices, weights, grads, states, multi_precision=False):
aggregate = True
if not isinstance(indices, (tuple, list)):
indices = [indices]
weights = [weights]
grads = [grads]
states = [states]
for weight, grad in zip(weights, grads):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
aggregate = (aggregate and
weight.stype == 'default' and
grad.stype == 'default')
self._update_count(indices)
lrs = self._get_lrs(indices)
wds = self._get_wds(indices)
print("+++++++++++++++++indices={}+++++++++++++++".format(indices))
for idx in indices:
t = self._index_update_count[idx]

kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
'lower_bound': self.lower_bound, 'upper_bound': self.upper_bound,
'bias_correction': self.bias_correction, 't': t,
'rescale_grad': self.rescale_grad}

if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient

if multi_precision:
multi_mp_lamb_update(*_flatten_list(zip(weights, grads,
list(zip(*states))[1])),
out=weights, num_weights=len(weights),
lrs=lrs, wds=wds, **kwargs)
else:
for weight, grad, state, lr, wd in zip(weights, grads, states, lrs, wds):
mean, var = state
lamb_update(weight, grad, mean, var, out=weight, lr=lr, wd=wd, **kwargs)

def update(self, index, weight, grad, state):
self._update_impl(index, weight, grad, state, multi_precision=False)
def update(self, index, weight,grad, state):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
t = self._index_update_count[index]

def update_multi_precision(self, index, weight, grad, state):
if not isinstance(index, (tuple, list)):
use_multi_precision = self.multi_precision and weight.dtype == numpy.float16
else:
use_multi_precision = self.multi_precision and weight[0].dtype == numpy.float16
print("========================use_multi_precision={}".format(use_multi_precision))
self._update_impl(index, weight, grad, state,
multi_precision=use_multi_precision)
kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
'lower_bound': self.lower_bound, 'upper_bound': self.upper_bound,
'bias_correction': self.bias_correction, 't': t,
'rescale_grad': self.rescale_grad}
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient

# def update(self, index, weight, grad, state):
# assert(isinstance(weight, NDArray))
# assert(isinstance(grad, NDArray))
# self._update_count(index)
# lr = self._get_lr(index)
# wd = self._get_wd(index)
# t = self._index_update_count[index]
#
# kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
# 'lower_bound': self.lower_bound, 'upper_bound': self.upper_bound,
# 'bias_correction': self.bias_correction, 't': t,
# 'rescale_grad': self.rescale_grad}
# if self.clip_gradient:
# kwargs['clip_gradient'] = self.clip_gradient
#
# mean, var = state
# lamb_update(weight, grad, mean, var, out=weight, lr=lr, wd=wd, **kwargs)
mean, var = state
lamb_update(weight, grad, mean, var, out=weight, lr=lr, wd=wd, **kwargs)


# pylint: enable=line-too-long
Expand Down
Loading

0 comments on commit c0508d3

Please sign in to comment.