Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
y199387 committed Aug 10, 2022
1 parent ab0c52e commit d08d49f
Showing 1 changed file with 8 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,14 @@ def optimizer_step(self,
closure: Callable[[], Any],
**kwargs: Any) -> Any:
"""Hook to run the optimizer step."""
if type(optimizer) in IPEX_FUSED_OPTIMIZER_LIST:
return super().optimizer_step(model, optimizer, optimizer_idx, closure, **kwargs)

if isinstance(model, pl.LightningModule):
closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)

closure_result = closure()
optimizer.step(closure=None, **kwargs)

return closure_result

# Automatically call closure for optimizer not supported by IPEX
if type(optimizer) not in IPEX_FUSED_OPTIMIZER_LIST:
closure()

return optimizer.step(closure, **kwargs)

0 comments on commit d08d49f

Please sign in to comment.