Skip to content

Commit

Permalink
Merge pull request NVIDIA#17 from rohithkrn/fused_opt_bf16
Browse files Browse the repository at this point in the history
Enable bfloat16 for optimizers
  • Loading branch information
sunway513 authored May 28, 2020
2 parents 5cfdc01 + 8554990 commit 38ade0a
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 35 deletions.
4 changes: 2 additions & 2 deletions apex/optimizers/fused_adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def step(self, closure=None):
if len(state) == 0:
# Exponential moving average of gradient values
state['sum'] = torch.zeros_like(p.data)
if p.dtype == torch.float16:
if p.dtype in {torch.float16, torch.bfloat16}:
g_16.append(p.grad.data)
p_16.append(p.data)
h_16.append(state['sum'])
Expand All @@ -100,7 +100,7 @@ def step(self, closure=None):
p_32.append(p.data)
h_32.append(state['sum'])
else:
raise RuntimeError('FusedAdagrad only support fp16 and fp32.')
raise RuntimeError('FusedAdagrad only support fp16, bfloat16 and fp32.')

if(len(g_16) > 0):
multi_tensor_applier(self.multi_tensor_adagrad,
Expand Down
4 changes: 2 additions & 2 deletions apex/optimizers/fused_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)

if p.dtype == torch.float16:
if p.dtype in {torch.float16, torch.bfloat16}:
g_16.append(p.grad.data)
p_16.append(p.data)
m_16.append(state['exp_avg'])
Expand All @@ -141,7 +141,7 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no
m_32.append(state['exp_avg'])
v_32.append(state['exp_avg_sq'])
else:
raise RuntimeError('FusedAdam only support fp16 and fp32.')
raise RuntimeError('FusedAdam only support fp16, bfloat16 and fp32.')

if(len(g_16) > 0):
multi_tensor_applier(self.multi_tensor_adam,
Expand Down
4 changes: 2 additions & 2 deletions apex/optimizers/fused_lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def step(self, closure=None):
# Exponential moving average of gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)

if p.dtype == torch.float16:
if p.dtype in {torch.float16, torch.bfloat16}:
g_16.append(p.grad.data)
p_16.append(p.data)
m_16.append(state['exp_avg'])
Expand All @@ -141,7 +141,7 @@ def step(self, closure=None):
m_32.append(state['exp_avg'])
v_32.append(state['exp_avg_sq'])
else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.')
raise RuntimeError('FusedLAMB only support fp16, bfloat16 and fp32.')

if(len(g_16) > 0):
multi_tensor_applier(self.multi_tensor_lamb,
Expand Down
4 changes: 2 additions & 2 deletions apex/optimizers/fused_novograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def step(self, closure=None):
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)

if p.dtype == torch.float16:
if p.dtype in {torch.float16, torch.bfloat16}:
g_16.append(p.grad.data)
p_16.append(p.data)
m_16.append(state['exp_avg'])
Expand All @@ -151,7 +151,7 @@ def step(self, closure=None):
p_32.append(p.data)
m_32.append(state['exp_avg'])
else:
raise RuntimeError('FusedNovoGrad only support fp16 and fp32.')
raise RuntimeError('FusedNovoGrad only support fp16, bfloat16 and fp32.')

# we store per weight norm as one tensor for one group/precision combination
# different from optim.Adam, we store norm here(not ^2) so we can unify calculation for norm types
Expand Down
4 changes: 2 additions & 2 deletions csrc/layer_norm_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ void cuda_layer_norm(
double epsilon)
{
using namespace at;
DISPATCH_DOUBLE_FLOAT_AND_HALF(input->scalar_type(), 0, "layer_norm_cuda_kernel",
DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16(input->scalar_type(), 0, "layer_norm_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
HostApplyLayerNorm(
output->DATA_PTR<scalar_t_0>(),
Expand Down Expand Up @@ -793,7 +793,7 @@ void cuda_layer_norm_gradient(
at::Tensor* grad_beta)
{
using namespace at;
DISPATCH_FLOAT_AND_HALF(input->scalar_type(), 0, "cuComputeGradInput",
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(input->scalar_type(), 0, "cuComputeGradInput",
using accscalar_t = at::acc_type<scalar_t_0, true>;
HostLayerNormGradient(
dout->DATA_PTR<scalar_t_0>(),
Expand Down
2 changes: 1 addition & 1 deletion csrc/multi_tensor_adagrad.cu
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ void multi_tensor_adagrad_cuda(
using namespace at;

// Assume single type across p,g,h now
DISPATCH_DOUBLE_FLOAT_AND_HALF(
DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16(
tensor_lists[0][0].scalar_type(), 0, "adagrad",
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdagradFunctor<scalar_t_0>(), epsilon, lr,
Expand Down
41 changes: 29 additions & 12 deletions tests/L0/run_optimizers/test_adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,28 @@ def setUp(self, max_abs_diff=1e-6, max_rel_diff=1, iters=7):
def tearDown(self):
pass

def gen_param_optim(self, tensors, adagrad_option):
def gen_param_optim(self, tensors, adagrad_option, apex_only=False):
ref_param = []
tst_param = []
for tensor in tensors:
ref_param.append(torch.nn.Parameter(tensor.clone()))
if apex_only:
ref_param.append(torch.nn.Parameter(tensor.clone().float()))
else:
ref_param.append(torch.nn.Parameter(tensor.clone()))
tst_param.append(torch.nn.Parameter(tensor.clone()))

ref_optim = torch.optim.Adagrad(ref_param, **adagrad_option)
if apex_only:
ref_optim = apex.optimizers.FusedAdagrad(ref_param, **adagrad_option)
else:
ref_optim = torch.optim.Adagrad(ref_param, **adagrad_option)
tst_optim = apex.optimizers.FusedAdagrad(tst_param, **adagrad_option)

return (ref_param, tst_param, ref_optim, tst_optim)

def gen_grad(self, ref_param, tst_param):
def gen_grad(self, ref_param, tst_param, apex_only=False):
for p_ref, p_tst in zip(ref_param, tst_param):
p_ref.grad = torch.rand_like(p_ref)
p_tst.grad = p_ref.grad
p_tst.grad = torch.rand_like(p_tst)
p_ref.grad = p_tst.grad.detach().float() if apex_only else p_tst.grad

def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):
half_grads = []
Expand All @@ -38,9 +44,11 @@ def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):
p_ref.grad = half_grads[-1].float() / scale
return half_grads

def get_max_diff(self, ref_param, tst_param):
def get_max_diff(self, ref_param, tst_param, apex_only=False):
max_abs_diff = max_rel_diff = 0
for p_ref, p_tst in zip(ref_param, tst_param):
if apex_only:
p_tst = p_tst.float()
max_abs_diff_p = (p_ref - p_tst).abs().max().item()
max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item()

Expand All @@ -51,23 +59,24 @@ def get_max_diff(self, ref_param, tst_param):

return max_abs_diff, max_rel_diff

def gen_single_type_test(self, param_type=torch.float):
def gen_single_type_test(self, param_type=torch.float, apex_only=False):
nelem = 278011
adagrad_option = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 1.0e-5}

tensor = torch.rand(nelem, dtype=param_type, device="cuda")
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
[tensor], adagrad_option
[tensor], adagrad_option, apex_only=apex_only
)

for _ in range(self.iters):
self.gen_grad(ref_param, tst_param)
self.gen_grad(ref_param, tst_param, apex_only=apex_only)
ref_optim.step()
tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param, apex_only=apex_only)

self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
if not apex_only:
self.assertLessEqual(max_rel_diff, self.max_rel_diff)

def test_float(self):
self.gen_single_type_test(param_type=torch.float)
Expand All @@ -76,6 +85,14 @@ def test_float(self):
def test_half(self):
self.gen_single_type_test(param_type=torch.float16)

# Compares bfloat16 computation against float32 as gold standard.
# Uses apex optimizers(controlled by apex_only flag) for both types.
# Doesn't use upstream optimizer like other tests as they seem to be
# numerically unstable for half types(see skip note for test above).
def test_bfloat16(self):
self.max_abs_diff = 1e-2
self.gen_single_type_test(param_type=torch.bfloat16, apex_only=True)

def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
adagrad_option = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 0}
Expand Down
41 changes: 29 additions & 12 deletions tests/L0/run_optimizers/test_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,28 @@ def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7):
def tearDown(self):
pass

def gen_param_optim(self, tensors, adam_option):
def gen_param_optim(self, tensors, adam_option, apex_only=False):
ref_param = []
tst_param = []
for tensor in tensors:
ref_param.append(torch.nn.Parameter(tensor.clone()))
if apex_only:
ref_param.append(torch.nn.Parameter(tensor.clone().float()))
else:
ref_param.append(torch.nn.Parameter(tensor.clone()))
tst_param.append(torch.nn.Parameter(tensor.clone()))

ref_optim = torch.optim.Adam(ref_param, **adam_option)
if apex_only:
ref_optim = apex.optimizers.FusedAdam(ref_param, **adam_option)
else:
ref_optim = torch.optim.Adam(ref_param, **adam_option)
tst_optim = apex.optimizers.FusedAdam(tst_param, **adam_option)

return (ref_param, tst_param, ref_optim, tst_optim)

def gen_grad(self, ref_param, tst_param):
def gen_grad(self, ref_param, tst_param, apex_only=False):
for p_ref, p_tst in zip(ref_param, tst_param):
p_ref.grad = torch.rand_like(p_ref)
p_tst.grad = p_ref.grad
p_tst.grad = torch.rand_like(p_tst)
p_ref.grad = p_tst.grad.detach().float() if apex_only else p_tst.grad

def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):
half_grads = []
Expand All @@ -39,9 +45,11 @@ def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):
p_ref.grad = half_grads[-1].float() / scale
return half_grads

def get_max_diff(self, ref_param, tst_param):
def get_max_diff(self, ref_param, tst_param, apex_only=False):
max_abs_diff = max_rel_diff = 0
for p_ref, p_tst in zip(ref_param, tst_param):
if apex_only:
p_tst = p_tst.float()
max_abs_diff_p = (p_ref - p_tst).abs().max().item()
max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item()

Expand All @@ -50,30 +58,39 @@ def get_max_diff(self, ref_param, tst_param):

return max_abs_diff, max_rel_diff

def gen_single_type_test(self, param_type=torch.float):
def gen_single_type_test(self, param_type=torch.float, apex_only=False):
nelem = 278011
adam_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08,
'weight_decay':0, 'amsgrad':False}

tensor = torch.rand(nelem, dtype=param_type, device='cuda')
ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim([tensor], adam_option)
self.gen_param_optim([tensor], adam_option, apex_only=apex_only)

for i in range(self.iters):
self.gen_grad(ref_param, tst_param)
self.gen_grad(ref_param, tst_param, apex_only=apex_only)
ref_optim.step()
tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param, apex_only=apex_only)

self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
if not apex_only:
self.assertLessEqual(max_rel_diff, self.max_rel_diff)

def test_float(self):
self.gen_single_type_test(param_type=torch.float)

def test_half(self):
self.gen_single_type_test(param_type=torch.float16)

# Compares bfloat16 computation against float32 as gold standard.
# Uses apex optimizers(controlled by apex_only flag) for both types.
# Doesn't use upstream optimizer like other tests as they seem to be
# numerically unstable for half types
def test_bfloat16(self):
self.max_abs_diff = 1e-2
self.gen_single_type_test(param_type=torch.bfloat16, apex_only=True)

@unittest.skip('Disable until 8/1/2019 adam/adamw upstream picked')
def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
Expand Down

0 comments on commit 38ade0a

Please sign in to comment.