Skip to content

Commit

Permalink
Merge pull request #104 from the-database/dev
Browse files Browse the repository at this point in the history
remove more related to channels last
  • Loading branch information
the-database authored Oct 8, 2024
2 parents bd94813 + cec25ad commit 53fde3f
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 25 deletions.
7 changes: 2 additions & 5 deletions traiNNer/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,7 @@ def model_to_device(self, net: nn.Module) -> nn.Module:
net (nn.Module)
"""
assert isinstance(self.opt.num_gpu, int)
net = net.to(
self.device,
non_blocking=True,
) # pyright: ignore[reportCallIssue] # https://github.com/pytorch/pytorch/issues/131765
net = net.to(self.device) # pyright: ignore[reportCallIssue] # https://github.com/pytorch/pytorch/issues/131765
if self.opt.dist:
find_unused_parameters = self.opt.find_unused_parameters
net = DistributedDataParallel(
Expand Down Expand Up @@ -377,7 +374,7 @@ def save_network(
key = key[7:]
if key == "n_averaged": # ema key, breaks compatibility
continue
new_state_dict[key] = param.to("cpu", non_blocking=True)
new_state_dict[key] = param.cpu()

# avoid occasional writing errors
retry = 3
Expand Down
21 changes: 1 addition & 20 deletions traiNNer/models/sr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,7 @@ def init_training_settings(self) -> None:

init_net_g_ema = build_network(
{**self.opt.network_g, "scale": self.opt.scale}
).to(
self.device,
non_blocking=True,
) # pyright: ignore[reportCallIssue]
).to(self.device) # pyright: ignore[reportCallIssue]

# load pretrained model
if self.opt.path.pretrain_network_g_ema is not None:
Expand Down Expand Up @@ -212,28 +209,24 @@ def init_training_settings(self) -> None:
if train_opt.pixel_opt.get("loss_weight", 0) > 0:
self.cri_pix = build_loss(train_opt.pixel_opt).to(
self.device,
non_blocking=True,
) # pyright: ignore[reportCallIssue] # https://github.com/pytorch/pytorch/issues/131765

if train_opt.mssim_opt:
if train_opt.mssim_opt.get("loss_weight", 0) > 0:
self.cri_mssim = build_loss(train_opt.mssim_opt).to(
self.device,
non_blocking=True,
) # pyright: ignore[reportCallIssue] # https://github.com/pytorch/pytorch/issues/131765

if train_opt.ms_ssim_l1_opt:
if train_opt.ms_ssim_l1_opt.get("loss_weight", 0) > 0:
self.cri_ms_ssim_l1 = build_loss(train_opt.ms_ssim_l1_opt).to(
self.device,
non_blocking=True,
) # pyright: ignore[reportCallIssue] # https://github.com/pytorch/pytorch/issues/131765

if train_opt.ldl_opt:
if train_opt.ldl_opt.get("loss_weight", 0) > 0:
self.cri_ldl = build_loss(train_opt.ldl_opt).to(
self.device,
non_blocking=True,
) # pyright: ignore[reportCallIssue] # https://github.com/pytorch/pytorch/issues/131765

if train_opt.perceptual_opt:
Expand All @@ -243,21 +236,18 @@ def init_training_settings(self) -> None:
):
self.cri_perceptual = build_loss(train_opt.perceptual_opt).to(
self.device,
non_blocking=True,
) # pyright: ignore[reportCallIssue] # https://github.com/pytorch/pytorch/issues/131765

if train_opt.dists_opt:
if train_opt.dists_opt.get("loss_weight", 0) > 0:
self.cri_dists = build_loss(train_opt.dists_opt).to(
self.device,
non_blocking=True,
) # pyright: ignore[reportCallIssue] # https://github.com/pytorch/pytorch/issues/131765

if train_opt.hr_inversion_opt:
if train_opt.hr_inversion_opt.get("loss_weight", 0) > 0:
self.cri_hr_inversion = build_loss(train_opt.hr_inversion_opt).to(
self.device,
non_blocking=True,
) # pyright: ignore[reportCallIssue] # https://github.com/pytorch/pytorch/issues/131765

if train_opt.dinov2_opt:
Expand All @@ -270,49 +260,42 @@ def init_training_settings(self) -> None:
if train_opt.contextual_opt.get("loss_weight", 0) > 0:
self.cri_contextual = build_loss(train_opt.contextual_opt).to(
self.device,
non_blocking=True,
) # pyright: ignore[reportCallIssue] # https://github.com/pytorch/pytorch/issues/131765

if train_opt.color_opt:
if train_opt.color_opt.get("loss_weight", 0) > 0:
self.cri_color = build_loss(train_opt.color_opt).to(
self.device,
non_blocking=True,
) # pyright: ignore[reportCallIssue] # https://github.com/pytorch/pytorch/issues/131765

if train_opt.luma_opt:
if train_opt.luma_opt.get("loss_weight", 0) > 0:
self.cri_luma = build_loss(train_opt.luma_opt).to(
self.device,
non_blocking=True,
) # pyright: ignore[reportCallIssue] # https://github.com/pytorch/pytorch/issues/131765

if train_opt.hsluv_opt:
if train_opt.hsluv_opt.get("loss_weight", 0) > 0:
self.cri_hsluv = build_loss(train_opt.hsluv_opt).to(
self.device,
non_blocking=True,
) # pyright: ignore[reportCallIssue] # https://github.com/pytorch/pytorch/issues/131765

if train_opt.avg_opt:
if train_opt.avg_opt.get("loss_weight", 0) > 0:
self.cri_avg = build_loss(train_opt.avg_opt).to(
self.device,
non_blocking=True,
) # pyright: ignore[reportCallIssue] # https://github.com/pytorch/pytorch/issues/131765

if train_opt.bicubic_opt:
if train_opt.bicubic_opt.get("loss_weight", 0) > 0:
self.cri_bicubic = build_loss(train_opt.bicubic_opt).to(
self.device,
non_blocking=True,
) # pyright: ignore[reportCallIssue] # https://github.com/pytorch/pytorch/issues/131765

if train_opt.gan_opt:
if train_opt.gan_opt.get("loss_weight", 0) > 0:
self.cri_gan = build_loss(train_opt.gan_opt).to(
self.device,
non_blocking=True,
) # pyright: ignore[reportCallIssue] # https://github.com/pytorch/pytorch/issues/131765
else:
self.cri_gan = None
Expand Down Expand Up @@ -370,12 +353,10 @@ def feed_data(self, data: DataFeed) -> None:
assert "lq" in data
self.lq = data["lq"].to(
self.device,
non_blocking=True,
)
if "gt" in data:
self.gt = data["gt"].to(
self.device,
non_blocking=True,
)

# moa
Expand Down

0 comments on commit 53fde3f

Please sign in to comment.