Skip to content

Commit

Permalink
Refine xpu callback
Browse files Browse the repository at this point in the history
  • Loading branch information
sovrasov committed Nov 30, 2023
1 parent 98ffd5d commit 3f3c4d3
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions src/otx/algorithms/anomaly/adapters/anomalib/callbacks/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import torch
from pytorch_lightning import Callback

from otx.algorithms.common.utils.utils import is_xpu_available


class XPUCallback(Callback):
"""XPU device callback.
Expand All @@ -20,11 +18,10 @@ def __init__(self, device_idx=0):

def on_fit_start(self, trainer, pl_module):
"""Applies IPEX optimization before training."""
if is_xpu_available():
pl_module.to(self.device)
model, optimizer = torch.xpu.optimize(trainer.model, optimizer=trainer.optimizers[0], dtype=torch.float32)
trainer.optimizers = [optimizer]
trainer.model = model
pl_module.to(self.device)
model, optimizer = torch.xpu.optimize(trainer.model, optimizer=trainer.optimizers[0])
trainer.optimizers = [optimizer]
trainer.model = model

def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
"""Moves train batch tensors to XPU."""
Expand Down

0 comments on commit 3f3c4d3

Please sign in to comment.