-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
35 lines (29 loc) · 1019 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from ast import iter_child_nodes
import torch
import torch.nn as nn
def init_weight(model):
if isinstance(model, nn.Conv2d):
nn.init.normal_(model.weight, 0, 0.02)
if isinstance(model, nn.BatchNorm2d):
nn.init.constant_(model.bias, 0)
def compute_gradient(disc, img, f_img, esp):
interpolate = esp * img + (1 - esp) * f_img
disc_interpolate = disc(interpolate)
gradient = torch.autograd.grad(
outputs = disc_interpolate,
inputs = interpolate,
grad_outputs = torch.ones_like(disc_interpolate),
retain_graph=True,
create_graph=True
)
return gradient[0]
def gradient_penalty(grad):
grad = grad.view(len(grad), -1)
norm_grad = grad.norm(2, dim=1)
return torch.mean(torch.pow(norm_grad - 1, 2))
def disc_loss_gp(real, fake, grad):
return torch.mean(real - fake + grad)
def gen_loss(fake):
return torch.mean((1.0 - fake) ** 2)
def disc_loss(real, fake):
return torch.mean((1.0 - real)**2) + torch.mean(fake**2)