diff --git a/libs/spandrel/spandrel/architectures/PLKSR/arch/PLKSR.py b/libs/spandrel/spandrel/architectures/PLKSR/arch/PLKSR.py index b030271e..08094412 100644 --- a/libs/spandrel/spandrel/architectures/PLKSR/arch/PLKSR.py +++ b/libs/spandrel/spandrel/architectures/PLKSR/arch/PLKSR.py @@ -64,11 +64,19 @@ def __init__(self, dim, kernel_size, with_idt): self.idx = dim def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.with_idt: - x[:, : self.idx] = x[:, : self.idx] + self.conv(x[:, : self.idx]) + if self.training: + x1, x2 = torch.split(x, [self.idx, x.size(1) - self.idx], dim=1) + if self.with_idt: + x1 = self.conv(x1) + x1 + else: + x1 = self.conv(x1) + return torch.cat([x1, x2], dim=1) else: - x[:, : self.idx] = self.conv(x[:, : self.idx]) - return x + if self.with_idt: + x[:, : self.idx] = x[:, : self.idx] + self.conv(x[:, : self.idx]) + else: + x[:, : self.idx] = self.conv(x[:, : self.idx]) + return x class RectSparsePLKConv2d(nn.Module): diff --git a/libs/spandrel/spandrel/architectures/PLKSR/arch/RealPLKSR.py b/libs/spandrel/spandrel/architectures/PLKSR/arch/RealPLKSR.py index e8b078c1..0501d535 100644 --- a/libs/spandrel/spandrel/architectures/PLKSR/arch/RealPLKSR.py +++ b/libs/spandrel/spandrel/architectures/PLKSR/arch/RealPLKSR.py @@ -29,6 +29,10 @@ def __init__(self, dim: int, kernel_size: int): self.idx = dim def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.training: + x1, x2 = torch.split(x, [self.idx, x.size(1) - self.idx], dim=1) + x1 = self.conv(x1) + return torch.cat([x1, x2], dim=1) x[:, : self.idx] = self.conv(x[:, : self.idx]) return x @@ -109,7 +113,9 @@ def __init__( dropout: float = 0, ): super().__init__() - dropout = 0 + + if not self.training: + dropout = 0 self.feats = nn.Sequential( *[nn.Conv2d(3, dim, 3, 1, 1)] diff --git a/libs/spandrel/spandrel/architectures/SPAN/arch/span.py b/libs/spandrel/spandrel/architectures/SPAN/arch/span.py index e24318f1..503201b1 100644 --- a/libs/spandrel/spandrel/architectures/SPAN/arch/span.py +++ b/libs/spandrel/spandrel/architectures/SPAN/arch/span.py @@ -145,9 +145,11 @@ def __init__( stride=s, bias=bias, ) - self.eval_conv.weight.requires_grad = False - self.eval_conv.bias.requires_grad = False # type: ignore - self.update_params() + + if not self.training: + self.eval_conv.weight.requires_grad = False + self.eval_conv.bias.requires_grad = False # type: ignore + self.update_params() def update_params(self): w1 = self.conv[0].weight.data.clone().detach()