diff --git a/python/nano/src/bigdl/nano/pytorch/strategies/ipex/ipex_strategy.py b/python/nano/src/bigdl/nano/pytorch/strategies/ipex/ipex_strategy.py index 6bea7380552d..51137188d3cd 100644 --- a/python/nano/src/bigdl/nano/pytorch/strategies/ipex/ipex_strategy.py +++ b/python/nano/src/bigdl/nano/pytorch/strategies/ipex/ipex_strategy.py @@ -94,11 +94,14 @@ def optimizer_step(self, closure: Callable[[], Any], **kwargs: Any) -> Any: """Hook to run the optimizer step.""" + if type(optimizer) in IPEX_FUSED_OPTIMIZER_LIST: + return super().optimizer_step(model, optimizer, optimizer_idx, closure, **kwargs) + if isinstance(model, pl.LightningModule): closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure) + + closure_result = closure() + optimizer.step(closure=None, **kwargs) + + return closure_result - # Automatically call closure for optimizer not supported by IPEX - if type(optimizer) not in IPEX_FUSED_OPTIMIZER_LIST: - closure() - - return optimizer.step(closure, **kwargs)