Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR AMP]Adapt more amp uts in PIR #62880

Merged
merged 1 commit into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions test/amp/amp_base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import paddle
from paddle import nn
from paddle.base import core
from paddle.framework import in_dynamic_mode
from paddle.framework import in_dynamic_or_pir_mode


def copy_bits_from_float_to_uint16(f):
Expand Down Expand Up @@ -68,7 +68,7 @@ def _build_optimizer(
grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
else:
grad_clip = None
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
assert model is not None
parameters = model.parameters()
else:
Expand All @@ -82,7 +82,7 @@ def _build_optimizer(
epsilon=1e-4,
weight_decay=0.01,
)
if not in_dynamic_mode() and use_amp:
if not in_dynamic_or_pir_mode() and use_amp:
optimizer = paddle.static.amp.decorate(
optimizer,
amp_lists,
Expand Down Expand Up @@ -178,7 +178,7 @@ def forward(self, x):
def build_conv_model(
use_amp, amp_dtype="float16", amp_level="O1", use_promote=False
):
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
model = SimpleConvNet()
optimizer = _build_optimizer(use_amp=False, model=model)
if use_amp and amp_dtype == "float16":
Expand Down
141 changes: 141 additions & 0 deletions test/amp/test_amp_promote.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,100 @@ def test_o2_promote_off(self):
)


@unittest.skipIf(
not core.is_compiled_with_cuda()
or paddle.device.cuda.get_device_capability()[0] < 7.0,
"run test when gpu's compute capability is at least 7.0.",
)
class TestPirAmpPromoteStats(AmpTestBase):
def check_promote_results(
self, dtype, level, use_promote, expected_op_calls, debug_info
):
with paddle.pir_utils.IrGuard():
startup = paddle.static.Program()
main = paddle.static.Program()
with paddle.static.program_guard(main, startup):
model, optimizer, scaler = build_conv_model(
use_amp=True,
amp_dtype=dtype,
amp_level=level,
use_promote=use_promote,
)
model.train()

with paddle.amp.auto_cast(
enable=True,
dtype=dtype,
level=level,
use_promote=use_promote,
):
x = paddle.static.data(
'x', shape=[1, 1, 6, 6], dtype='float32'
)
out = model(x)
loss = paddle.mean(out)
scaled = scaler.scale(loss)
scaler.minimize(optimizer, scaled)

place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(startup)
paddle.amp.debugging.enable_operator_stats_collection()
exe.run(
main,
feed={
'x': np.random.random([1, 1, 6, 6]).astype('float32'),
},
fetch_list=[loss],
)
paddle.amp.debugging.disable_operator_stats_collection()
op_stats = paddle.base.core.get_low_precision_op_list()

self._check_op_calls(
op_stats,
expected_fp16_calls=expected_op_calls,
debug_info=debug_info,
)

def test_o2_promote_on(self):
paddle.set_flags({"FLAGS_pir_apply_inplace_pass": 0})
expected_fp16_calls = {
"pd_op.conv2d": 1,
"pd_op.add": 2,
"pd_op.relu": 0,
"pd_op.matmul": 1,
"pd_op.softmax": 1,
"pd_op.mean": 1,
"pd_op.adamw_": 4,
}
self.check_promote_results(
'float16',
'O2',
use_promote=True,
expected_op_calls=expected_fp16_calls,
debug_info="TestEagerAmpPromoteStats/test_o2_promote_on",
)

def test_o2_promote_off(self):
paddle.set_flags({"FLAGS_pir_apply_inplace_pass": 0})
expected_fp16_calls = {
"pd_op.conv2d": 1,
"pd_op.add": 2,
"pd_op.relu": 1,
"pd_op.matmul": 1,
"pd_op.softmax": 1,
"pd_op.mean": 1,
"pd_op.adamw_": 4,
}
self.check_promote_results(
'float16',
'O2',
use_promote=False,
expected_op_calls=expected_fp16_calls,
debug_info="TestEagerAmpPromoteStats/test_o2_promote_off",
)


@unittest.skipIf(
not core.is_compiled_with_cuda()
or paddle.device.cuda.get_device_capability()[0] < 7.0,
Expand Down Expand Up @@ -220,5 +314,52 @@ def test_o2_use_promote_off(self):
self.assertEqual(linear_out.dtype, paddle.float16)


@unittest.skipIf(
not core.is_compiled_with_cuda()
or paddle.device.cuda.get_device_capability()[0] < 7.0,
"run test when gpu's compute capability is at least 7.0.",
)
class TestPirAmpPromoteSimple(AmpTestBase):
def init_net(self):
self._conv = paddle.nn.Conv2D(
in_channels=1, out_channels=6, kernel_size=3, bias_attr=False
)
self._linear = paddle.nn.Linear(in_features=4, out_features=4)

def test_o2_use_promote_on(self):
with paddle.pir_utils.IrGuard():
startup = paddle.static.Program()
main = paddle.static.Program()
with paddle.static.program_guard(main, startup):
self.init_net()
with paddle.amp.auto_cast(level='O2'):
x = paddle.rand(shape=[1, 1, 6, 6], dtype='float32')
conv_out = self._conv(x)
y = paddle.rand(shape=conv_out.shape, dtype='float16')
add_out = conv_out + y
linear_out = self._linear(add_out)

self.assertEqual(conv_out.dtype, paddle.float16)
self.assertEqual(add_out.dtype, paddle.float16)
self.assertEqual(linear_out.dtype, paddle.float32)

def test_o2_use_promote_off(self):
with paddle.pir_utils.IrGuard():
startup = paddle.static.Program()
main = paddle.static.Program()
with paddle.static.program_guard(main, startup):
self.init_net()
with paddle.amp.auto_cast(level='O2', use_promote=False):
x = paddle.rand(shape=[1, 1, 6, 6], dtype='float32')
conv_out = self._conv(x)
y = paddle.rand(shape=conv_out.shape, dtype='float16')
add_out = conv_out + y
linear_out = self._linear(add_out)

self.assertEqual(conv_out.dtype, paddle.float16)
self.assertEqual(add_out.dtype, paddle.float16)
self.assertEqual(linear_out.dtype, paddle.float16)


if __name__ == '__main__':
unittest.main()
85 changes: 84 additions & 1 deletion test/amp/test_collect_operator_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import unittest

import numpy as np
from amp_base_models import build_while_model

import paddle
Expand All @@ -38,7 +39,7 @@ def _check_result(self, dtype):
self.assertTrue(conv_num == 1)
self.assertTrue(add_num == 1)

if dtype == "float16":
if dtype == paddle.float16:
self.assertTrue(int(conv2d_called[0]) == 1)
self.assertTrue(int(add_called[0]) == 1)

Expand Down Expand Up @@ -67,6 +68,88 @@ def test_context(self):
self._check_result(dtype=out.dtype)


class TestOpStatsPir(unittest.TestCase):
def _check_result(self, dtype):
# Returned the dict.
op_list = paddle.base.core.get_low_precision_op_list()

self.assertTrue('pd_op.add' in op_list)
self.assertTrue('pd_op.conv2d' in op_list)

conv2d_called = op_list['pd_op.conv2d'].split(',')
add_called = op_list['pd_op.add'].split(',')
add_num = 0
conv_num = 0
for i in range(4):
add_num += int(add_called[i])
conv_num += int(add_called[i])

self.assertTrue(conv_num == 1)
self.assertTrue(add_num == 1)

if dtype == paddle.float16:
self.assertTrue(int(conv2d_called[0]) == 1)
self.assertTrue(int(add_called[0]) == 1)

def test_enable_disable(self):
if not paddle.is_compiled_with_cuda():
return
paddle.set_flags({"FLAGS_pir_apply_inplace_pass": 0})
with paddle.pir_utils.IrGuard():
startup = paddle.static.Program()
main = paddle.static.Program()
with paddle.static.program_guard(main, startup):
conv = paddle.nn.Conv2D(3, 2, 3)
x = paddle.static.data('x', [10, 3, 32, 32], 'float32')

with paddle.amp.auto_cast(enable=True, level='O2'):
out = conv(x)

place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(startup)
paddle.amp.debugging.enable_operator_stats_collection()
exe.run(
main,
feed={
'x': np.random.random([10, 3, 32, 32]).astype(
'float32'
),
},
fetch_list=[out],
)
paddle.amp.debugging.disable_operator_stats_collection()
self._check_result(dtype=out.dtype)

def test_context(self):
if not paddle.is_compiled_with_cuda():
return
paddle.set_flags({"FLAGS_pir_apply_inplace_pass": 0})
with paddle.pir_utils.IrGuard():
startup = paddle.static.Program()
main = paddle.static.Program()
with paddle.static.program_guard(main, startup):
conv = paddle.nn.Conv2D(3, 2, 3)
x = paddle.static.data('x', [10, 3, 32, 32], 'float32')
with paddle.amp.auto_cast(enable=True, level='O2'):
out = conv(x)

place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(startup)
with paddle.amp.debugging.collect_operator_stats():
exe.run(
main,
feed={
'x': np.random.random([10, 3, 32, 32]).astype(
'float32'
),
},
fetch_list=[out],
)
self._check_result(dtype=out.dtype)


class TestOpStatsStatic(unittest.TestCase):
def test_while_op(self):
paddle.enable_static()
Expand Down
80 changes: 78 additions & 2 deletions test/amp/test_compare_accuracy_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@

import unittest

import numpy as np

import paddle
from paddle.base import core


@unittest.skipIf(
not core.is_compiled_with_cuda(), "not support cpu TestCompareAccuracyApi"
not core.is_compiled_with_cuda(),
"not support cpu TestEagerCompareAccuracyApi",
)
class TestCompareAccuracyApi(unittest.TestCase):
class TestEagerCompareAccuracyApi(unittest.TestCase):
def calc(self, path, dtype):
paddle.base.core.set_nan_inf_debug_path(path)
x = paddle.to_tensor(
Expand Down Expand Up @@ -67,5 +70,78 @@ def test2(self):
)


@unittest.skipIf(
not core.is_compiled_with_cuda(),
"not support cpu TestPirCompareAccuracyApi",
)
class TestPirCompareAccuracyApi(unittest.TestCase):
def calc(self, path, dtype):
paddle.base.core.set_nan_inf_debug_path(path)
with paddle.pir_utils.IrGuard():
startup = paddle.static.Program()
main = paddle.static.Program()
with paddle.static.program_guard(main, startup):
x = paddle.static.data(
'x',
[
4,
],
dtype,
)
y = paddle.static.data(
'y',
[
4,
],
dtype,
)
# normal
z1 = x + y
# inf
z2 = x * y
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(startup)
exe.run(
main,
feed={
'x': np.array([2000, 3000, 4, 0]).astype(dtype),
'y': np.array([100, 500, 2, 10000]).astype(dtype),
},
fetch_list=[z2],
)

def test(self):
paddle.set_flags(
{"FLAGS_check_nan_inf": 1, "FLAGS_check_nan_inf_level": 3}
)
fp32_path = "workerlog_fp32_log_dir"
fp16_path = "workerlog_fp16_log_dir"
self.calc(fp32_path, "float32")
self.calc(fp16_path, "float16")

out_excel = "compare_accuracy_out_excel.csv"
paddle.amp.debugging.compare_accuracy(
fp32_path,
fp16_path,
out_excel,
loss_scale=1,
dump_all_tensors=False,
)

def test2(self):
fp32_path = "workerlog_fp32_log_dir"
fp16_path = "workerlog_fp16_null_log_dir"
self.calc(fp32_path, "float32")
out_excel = "compare_accuracy_out_excel_2.csv"
paddle.amp.debugging.compare_accuracy(
fp32_path,
fp16_path,
out_excel,
loss_scale=1,
dump_all_tensors=False,
)


if __name__ == '__main__':
unittest.main()