diff --git a/python/nano/src/bigdl/nano/pytorch/strategies/ipex/ipex_strategy.py b/python/nano/src/bigdl/nano/pytorch/strategies/ipex/ipex_strategy.py index e22c5bc69907..6bea7380552d 100644 --- a/python/nano/src/bigdl/nano/pytorch/strategies/ipex/ipex_strategy.py +++ b/python/nano/src/bigdl/nano/pytorch/strategies/ipex/ipex_strategy.py @@ -83,7 +83,8 @@ class IPEXBF16Precision(PrecisionPlugin): @contextmanager def forward_context(self): """AMP for managing model forward/training_step/evaluation_step/predict_step.""" - with torch.cpu.amp.autocast(): + # Manually set the dtype + with torch.cpu.amp.autocast(dtype=torch.bfloat16): yield def optimizer_step(self,