Skip to content

Commit

Permalink
Update pd_loss.py
Browse files Browse the repository at this point in the history
  • Loading branch information
the-database committed Oct 25, 2024
1 parent 841b4e6 commit 6a30c8d
Showing 1 changed file with 53 additions and 7 deletions.
60 changes: 53 additions & 7 deletions traiNNer/losses/pd_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,43 @@ def __init__(
layer_weights: dict[str, float] | None = None,
w_lambda: float = 0.01,
loss_weight: float = 1,
alpha: list[float] | None = None,
) -> None:
super().__init__()
if layer_weights is None:
layer_weights = {
"conv1_2": 0.1,
"relu1_2": 0.1,
"conv2_2": 0.1,
"relu2_2": 0.1,
"conv3_4": 1,
"relu3_4": 1,
"conv4_4": 1,
"relu4_4": 1,
"conv5_4": 1,
"relu5_4": 1,
}
if alpha is None:
alpha = []
for k in layer_weights:
if k.startswith("conv"):
alpha.append(0.0)
else:
alpha.append(1.0)

self.vgg = VGG(list(layer_weights.keys())).cuda()
self.loss_weight = loss_weight * w_lambda
self.loss_weight = loss_weight
self.w_lambda = w_lambda
self.layer_weights = layer_weights
self.criterion = self.w_distance
self.alpha = alpha

self.criterion1 = None
self.criterion2 = None

if any(x < 1 for x in self.alpha):
self.criterion1 = nn.L1Loss()
if any(x > 0 for x in self.alpha):
self.criterion2 = self.w_distance

def w_distance(self, x_vgg: Tensor, y_vgg: Tensor) -> Tensor:
x_vgg = x_vgg / (torch.sum(x_vgg, dim=(2, 3), keepdim=True) + 1e-14)
Expand All @@ -91,8 +114,23 @@ def forward_once(self, x: Tensor) -> dict[str, Tensor]:
def forward(self, x: Tensor, gt: Tensor) -> Tensor:
x_vgg, gt_vgg = self.forward_once(x), self.forward_once(gt.detach())
score = torch.tensor(0.0, device=x.device)
for k in x_vgg:
score += self.criterion(x_vgg[k], gt_vgg[k]) * self.layer_weights[k]
for i, k in enumerate(x_vgg):
alpha = self.alpha[i]
s = torch.tensor(0.0, device=score.device)
if alpha < 1:
assert self.criterion1 is not None
temp = self.criterion1(x_vgg[k], gt_vgg[k]) * (1 - alpha)
s += temp
# print("l1", k, temp)
if alpha > 0:
assert self.criterion2 is not None
temp = self.criterion2(x_vgg[k], gt_vgg[k]) * alpha * self.w_lambda
s += temp
# print("pd", k, temp)

s *= self.layer_weights[k]
score += s

return score * self.loss_weight


Expand All @@ -105,6 +143,8 @@ def __init__(self, layer_name_list: list[str]) -> None:
).features
assert isinstance(vgg_pretrained_features, torch.nn.Sequential)

self._disable_inplace_relu(vgg_pretrained_features)

self.stages: nn.ModuleDict = nn.ModuleDict()
stage_breakpoints = {}

Expand Down Expand Up @@ -149,13 +189,19 @@ def _change_padding_mode(conv: nn.Module, padding_mode: str) -> nn.Conv2d:
new_conv.bias.copy_(conv.bias)
return new_conv

@staticmethod
def _disable_inplace_relu(model: nn.Module) -> None:
for module in model.modules():
if isinstance(module, nn.ReLU):
module.inplace = False

@torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda") # pyright: ignore[reportPrivateImportUsage] # https://github.com/pytorch/pytorch/issues/131765
def forward(self, x: Tensor) -> dict[str, Tensor]:
x = (x - self.mean) / self.std
h = (x - self.mean) / self.std

feats = {}
for layer_name, stage in self.stages.items():
x = stage(x)
feats[layer_name] = x.clone()
last_h = h if not feats else feats[next(reversed(feats))]
feats[layer_name] = stage(last_h)

return feats

0 comments on commit 6a30c8d

Please sign in to comment.