Skip to content

Commit

Permalink
Merge pull request NVIDIA#11 from lcskrishna/cl/fused-optimizers-bfp16
Browse files Browse the repository at this point in the history
[FusedOptimizers] Bug fixes in fused optimizers for fp16/bfp16.
  • Loading branch information
sunway513 authored May 22, 2020
2 parents bdd481d + 9297be6 commit 5cfdc01
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 9 deletions.
2 changes: 1 addition & 1 deletion csrc/multi_tensor_sgd_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ void multi_tensor_sgd_cuda(
scale);
}
// Case 5. bfp16, bfp16, bfp16, No
if(grad_type == at::ScalarType::BFloat16 &&
else if(grad_type == at::ScalarType::BFloat16 &&
weight_type == at::ScalarType::BFloat16 &&
num_tensors == 3)
{
Expand Down
4 changes: 0 additions & 4 deletions tests/L0/run_amp/test_fused_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def tearDown(self):
pass

@unittest.skipIf(disabled, "amp_C is unavailable")
@skipIfRocm
def test_2models2losses1optimizer(self):
model0 = MyModel(1)
model1 = MyModel(2)
Expand Down Expand Up @@ -187,7 +186,6 @@ def test_2models2losses1optimizer(self):
_amp_state.handle._deactivate()

@unittest.skipIf(disabled, "amp_C is unavailable")
@skipIfRocm
def test_3models2losses1optimizer(self):

model0 = MyModel(1)
Expand Down Expand Up @@ -349,7 +347,6 @@ def test_3models2losses1optimizer(self):
_amp_state.handle._deactivate()

@unittest.skipIf(disabled, "amp_C is unavailable")
@skipIfRocm
def test_2models2losses2optimizers(self):
model0 = MyModel(1)
model1 = MyModel(2)
Expand Down Expand Up @@ -545,7 +542,6 @@ def what_got_skipped(which_iter, which_backward):
_amp_state.handle._deactivate()

@unittest.skipIf(disabled, "amp_C is unavailable")
@skipIfRocm
def test_3models2losses2optimizers(self):
model0 = MyModel(1)
model1 = MyModel(2)
Expand Down
4 changes: 0 additions & 4 deletions tests/L0/run_amp/test_multiple_models_optimizers_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def setUp(self):
def tearDown(self):
pass

@skipIfRocm
def test_2models2losses1optimizer(self):
model0 = MyModel(1)
model1 = MyModel(2)
Expand Down Expand Up @@ -170,7 +169,6 @@ def test_2models2losses1optimizer(self):
if opt_level == "O1":
_amp_state.handle._deactivate()

@skipIfRocm
def test_3models2losses1optimizer(self):

model0 = MyModel(1)
Expand Down Expand Up @@ -327,7 +325,6 @@ def test_3models2losses1optimizer(self):
if opt_level == "O1":
_amp_state.handle._deactivate()

@skipIfRocm
def test_2models2losses2optimizers(self):
model0 = MyModel(1)
model1 = MyModel(2)
Expand Down Expand Up @@ -518,7 +515,6 @@ def what_got_skipped(which_iter, which_backward):
if opt_level == "O1":
_amp_state.handle._deactivate()

@skipIfRocm
def test_3models2losses2optimizers(self):
model0 = MyModel(1)
model1 = MyModel(2)
Expand Down

0 comments on commit 5cfdc01

Please sign in to comment.