Skip to content

Commit

Permalink
Nano: update ipex_bf16_inference_model (#5470)
Browse files Browse the repository at this point in the history
* rollback requirement-doc

* Update

* Update

* Update

* Update

* Fix ipex with amp error

* Fix unit test

* Update

* Remove redundant method

* Fix error

* remove redundant import

* Update ut to conver cases
  • Loading branch information
y199387 authored Aug 24, 2022
1 parent 5552783 commit dc44a2d
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 11 deletions.
32 changes: 30 additions & 2 deletions python/nano/src/bigdl/nano/deps/ipex/ipex_inference_bf16_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
# limitations under the License.
#

import contextlib
import subprocess
from logging import info, warning

from ...utils.log4Error import invalidInputError

from .ipex_inference_model import PytorchIPEXJITModel
from bigdl.nano.pytorch.amp.bfloat16 import autocast
import torch
Expand All @@ -38,13 +44,35 @@ def __init__(self, model, input_sample=None, use_ipex=False,
the parameter will be ignored if use_ipex is False.
:param from_load: this will only be set by _load method.
'''
if use_ipex:
invalidInputError(
self._check_cpu_isa,
errMsg="Applying IPEX BF16 optimization needs the cpu support avx512.",
fixMsg="Please set use_ipex to False or not set precision to bf16."
)
PytorchIPEXJITModel.__init__(self, model, input_sample=input_sample, use_ipex=use_ipex,
dtype=torch.bfloat16, use_jit=use_jit,
channels_last=channels_last, from_load=from_load)

@autocast()
@property
def _check_cpu_isa(self):
"""Indicator to verify if cpu supports avx512"""
msg = subprocess.check_output(["lscpu"]).decode("utf-8")
return 'avx512' in msg or 'amx' in msg

def autocast_context_manager(self):
"""Create autocast context"""
return autocast(enabled=self._check_cpu_isa)

@contextlib.contextmanager
def forward_context(self):
"""Enable autocast context"""
with self.autocast_context_manager():
yield

def forward_step(self, *inputs):
return super().forward_step(*inputs)
with self.forward_context():
return super().forward_step(*inputs)

@property
def status(self):
Expand Down
36 changes: 36 additions & 0 deletions python/nano/src/bigdl/nano/pytorch/amp/bfloat16.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import contextlib
import io
from logging import info, warning
import sys
import fcntl

from pytorch_lightning import LightningModule
import torch
Expand Down Expand Up @@ -46,6 +48,40 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return super().__exit__(exc_type, exc_val, exc_tb)


class RedirectStream(object):
"""Context manager to capture output of shared library"""
def __init__(self, stream=sys.stdout, target=None):
self.origin_stream = stream
self.origin_stream_fileno = stream.fileno()
self.target = io.StringIO() if target is None else target
# Create a pipe to capture the stream
self.pipe_out, self.pipe_in = os.pipe()

def __enter__(self):
# Save a copy of the original stream
self.origin_stream_fileno_dup = os.dup(self.origin_stream_fileno)
# Replace the original stream with the write pipe
os.dup2(self.pipe_in, self.origin_stream_fileno)
os.close(self.pipe_in)
return self

def __exit__(self, exc_type, exc_value, traceback):
self.origin_stream.flush()
# Make pipe_out non-blocking
fcntl.fcntl(self.pipe_out, fcntl.F_SETFL, os.O_NONBLOCK)
while True:
try:
buf = os.read(self.pipe_out, 1024)
if not buf:
break
self.target.write(buf.decode('utf-8'))
except OSError as e:
break
os.close(self.pipe_out)
os.dup2(self.origin_stream_fileno_dup, self.origin_stream_fileno)
os.close(self.origin_stream_fileno_dup)


class BF16Model(LightningModule):
"""Model of BFloat16 with auto mixed precision."""

Expand Down
40 changes: 31 additions & 9 deletions python/nano/test/pytorch/tests/test_bf16_ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from bigdl.nano.pytorch import Trainer
from torchvision.models.resnet import resnet18
from unittest.mock import MagicMock, Mock, PropertyMock, patch
from bigdl.nano.pytorch.utils import TORCH_VERSION_LESS_1_10, TORCH_VERSION_LESS_1_12
from bigdl.nano.pytorch.utils import TORCH_VERSION_LESS_1_10, TORCH_VERSION_LESS_1_11
from bigdl.nano.common import check_avx512


class Pytorch1_9:
Expand All @@ -35,30 +36,51 @@ def test_bf16_pytorch_less_1_10(self):
trainer.quantize(model, precision='bf16', use_ipex=True)


class Pytorch1_12:
def test_bf16_common(self):
"""
Debug mode. Allow run bf16 forward without bf16 instruction support.
"""
class CaseWithoutAVX512:
def test_unsupported_HW_or_OS(self):
trainer = Trainer(max_epochs=1)
model = resnet18(num_classes=10)

with pytest.raises(RuntimeError,
match="Applying IPEX BF16 optimization needs the cpu support avx512."):
bf16_model = trainer.quantize(model, precision='bf16', use_ipex=True)


class Pytorch1_11:
@patch('bigdl.nano.deps.ipex.ipex_inference_bf16_model.PytorchIPEXJITBF16Model._check_cpu_isa', new_callable=PropertyMock)
def test_unsupported_HW_or_OS(self, mocked_check_cpu_isa):
mocked_check_cpu_isa.return_value = False
trainer = Trainer(max_epochs=1)
model = resnet18(num_classes=10)

with pytest.raises(RuntimeError,
match="Applying IPEX BF16 optimization needs the cpu support avx512."):
bf16_model = trainer.quantize(model, precision='bf16', use_ipex=True)

def test_bf16_with_avx512_core(self):
trainer = Trainer(max_epochs=1)
model = resnet18(num_classes=10)

x = torch.rand((10, 3, 256, 256))
y = torch.ones((10,), dtype=torch.long)

bf16_model = trainer.quantize(model, precision='bf16', use_ipex=True)
# Debug mode to test functionality, make sure forward is called sucessfully
y_hat = bf16_model(x)

assert y_hat.shape == (10, 10) and y_hat.dtype == torch.bfloat16


TORCH_VERSION_CLS = Pytorch1_12
TORCH_VERSION_CLS = Pytorch1_11

if TORCH_VERSION_LESS_1_10:
print("ipex 1.9")
TORCH_VERSION_CLS = Pytorch1_9
elif not check_avx512():
print("IPEX Inference Model Without AVX512")
TORCH_VERSION_CLS = CaseWithoutAVX512


class TestBF16(TORCH_VERSION_CLS, TestCase):
class TestIPEXBF16(TORCH_VERSION_CLS, TestCase):
pass


Expand Down

0 comments on commit dc44a2d

Please sign in to comment.