From dc44a2dc9d08c91b3af4e00e21bc627e63ea1c6c Mon Sep 17 00:00:00 2001 From: Mingzhi Hu <49382651+y199387@users.noreply.github.com> Date: Wed, 24 Aug 2022 08:19:29 +0800 Subject: [PATCH] Nano: update ipex_bf16_inference_model (#5470) * rollback requirement-doc * Update * Update * Update * Update * Fix ipex with amp error * Fix unit test * Update * Remove redundant method * Fix error * remove redundant import * Update ut to conver cases --- .../deps/ipex/ipex_inference_bf16_model.py | 32 ++++++++++++++- .../src/bigdl/nano/pytorch/amp/bfloat16.py | 36 +++++++++++++++++ .../nano/test/pytorch/tests/test_bf16_ipex.py | 40 ++++++++++++++----- 3 files changed, 97 insertions(+), 11 deletions(-) diff --git a/python/nano/src/bigdl/nano/deps/ipex/ipex_inference_bf16_model.py b/python/nano/src/bigdl/nano/deps/ipex/ipex_inference_bf16_model.py index efd2d87b6a9..47ac051449a 100644 --- a/python/nano/src/bigdl/nano/deps/ipex/ipex_inference_bf16_model.py +++ b/python/nano/src/bigdl/nano/deps/ipex/ipex_inference_bf16_model.py @@ -14,6 +14,12 @@ # limitations under the License. # +import contextlib +import subprocess +from logging import info, warning + +from ...utils.log4Error import invalidInputError + from .ipex_inference_model import PytorchIPEXJITModel from bigdl.nano.pytorch.amp.bfloat16 import autocast import torch @@ -38,13 +44,35 @@ def __init__(self, model, input_sample=None, use_ipex=False, the parameter will be ignored if use_ipex is False. :param from_load: this will only be set by _load method. ''' + if use_ipex: + invalidInputError( + self._check_cpu_isa, + errMsg="Applying IPEX BF16 optimization needs the cpu support avx512.", + fixMsg="Please set use_ipex to False or not set precision to bf16." + ) PytorchIPEXJITModel.__init__(self, model, input_sample=input_sample, use_ipex=use_ipex, dtype=torch.bfloat16, use_jit=use_jit, channels_last=channels_last, from_load=from_load) - @autocast() + @property + def _check_cpu_isa(self): + """Indicator to verify if cpu supports avx512""" + msg = subprocess.check_output(["lscpu"]).decode("utf-8") + return 'avx512' in msg or 'amx' in msg + + def autocast_context_manager(self): + """Create autocast context""" + return autocast(enabled=self._check_cpu_isa) + + @contextlib.contextmanager + def forward_context(self): + """Enable autocast context""" + with self.autocast_context_manager(): + yield + def forward_step(self, *inputs): - return super().forward_step(*inputs) + with self.forward_context(): + return super().forward_step(*inputs) @property def status(self): diff --git a/python/nano/src/bigdl/nano/pytorch/amp/bfloat16.py b/python/nano/src/bigdl/nano/pytorch/amp/bfloat16.py index df934a2238a..080e6906809 100644 --- a/python/nano/src/bigdl/nano/pytorch/amp/bfloat16.py +++ b/python/nano/src/bigdl/nano/pytorch/amp/bfloat16.py @@ -17,6 +17,8 @@ import contextlib import io from logging import info, warning +import sys +import fcntl from pytorch_lightning import LightningModule import torch @@ -46,6 +48,40 @@ def __exit__(self, exc_type, exc_val, exc_tb): return super().__exit__(exc_type, exc_val, exc_tb) +class RedirectStream(object): + """Context manager to capture output of shared library""" + def __init__(self, stream=sys.stdout, target=None): + self.origin_stream = stream + self.origin_stream_fileno = stream.fileno() + self.target = io.StringIO() if target is None else target + # Create a pipe to capture the stream + self.pipe_out, self.pipe_in = os.pipe() + + def __enter__(self): + # Save a copy of the original stream + self.origin_stream_fileno_dup = os.dup(self.origin_stream_fileno) + # Replace the original stream with the write pipe + os.dup2(self.pipe_in, self.origin_stream_fileno) + os.close(self.pipe_in) + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.origin_stream.flush() + # Make pipe_out non-blocking + fcntl.fcntl(self.pipe_out, fcntl.F_SETFL, os.O_NONBLOCK) + while True: + try: + buf = os.read(self.pipe_out, 1024) + if not buf: + break + self.target.write(buf.decode('utf-8')) + except OSError as e: + break + os.close(self.pipe_out) + os.dup2(self.origin_stream_fileno_dup, self.origin_stream_fileno) + os.close(self.origin_stream_fileno_dup) + + class BF16Model(LightningModule): """Model of BFloat16 with auto mixed precision.""" diff --git a/python/nano/test/pytorch/tests/test_bf16_ipex.py b/python/nano/test/pytorch/tests/test_bf16_ipex.py index 38e45ce165d..f6ad3c5e2f3 100644 --- a/python/nano/test/pytorch/tests/test_bf16_ipex.py +++ b/python/nano/test/pytorch/tests/test_bf16_ipex.py @@ -19,7 +19,8 @@ from bigdl.nano.pytorch import Trainer from torchvision.models.resnet import resnet18 from unittest.mock import MagicMock, Mock, PropertyMock, patch -from bigdl.nano.pytorch.utils import TORCH_VERSION_LESS_1_10, TORCH_VERSION_LESS_1_12 +from bigdl.nano.pytorch.utils import TORCH_VERSION_LESS_1_10, TORCH_VERSION_LESS_1_11 +from bigdl.nano.common import check_avx512 class Pytorch1_9: @@ -35,11 +36,28 @@ def test_bf16_pytorch_less_1_10(self): trainer.quantize(model, precision='bf16', use_ipex=True) -class Pytorch1_12: - def test_bf16_common(self): - """ - Debug mode. Allow run bf16 forward without bf16 instruction support. - """ +class CaseWithoutAVX512: + def test_unsupported_HW_or_OS(self): + trainer = Trainer(max_epochs=1) + model = resnet18(num_classes=10) + + with pytest.raises(RuntimeError, + match="Applying IPEX BF16 optimization needs the cpu support avx512."): + bf16_model = trainer.quantize(model, precision='bf16', use_ipex=True) + + +class Pytorch1_11: + @patch('bigdl.nano.deps.ipex.ipex_inference_bf16_model.PytorchIPEXJITBF16Model._check_cpu_isa', new_callable=PropertyMock) + def test_unsupported_HW_or_OS(self, mocked_check_cpu_isa): + mocked_check_cpu_isa.return_value = False + trainer = Trainer(max_epochs=1) + model = resnet18(num_classes=10) + + with pytest.raises(RuntimeError, + match="Applying IPEX BF16 optimization needs the cpu support avx512."): + bf16_model = trainer.quantize(model, precision='bf16', use_ipex=True) + + def test_bf16_with_avx512_core(self): trainer = Trainer(max_epochs=1) model = resnet18(num_classes=10) @@ -47,18 +65,22 @@ def test_bf16_common(self): y = torch.ones((10,), dtype=torch.long) bf16_model = trainer.quantize(model, precision='bf16', use_ipex=True) - # Debug mode to test functionality, make sure forward is called sucessfully y_hat = bf16_model(x) + assert y_hat.shape == (10, 10) and y_hat.dtype == torch.bfloat16 -TORCH_VERSION_CLS = Pytorch1_12 +TORCH_VERSION_CLS = Pytorch1_11 if TORCH_VERSION_LESS_1_10: + print("ipex 1.9") TORCH_VERSION_CLS = Pytorch1_9 +elif not check_avx512(): + print("IPEX Inference Model Without AVX512") + TORCH_VERSION_CLS = CaseWithoutAVX512 -class TestBF16(TORCH_VERSION_CLS, TestCase): +class TestIPEXBF16(TORCH_VERSION_CLS, TestCase): pass