-
Notifications
You must be signed in to change notification settings - Fork 255
/
few_shot.py
84 lines (61 loc) · 2.28 KB
/
few_shot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from protonets.models import register_model
from .utils import euclidean_dist
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
return x.view(x.size(0), -1)
class Protonet(nn.Module):
def __init__(self, encoder):
super(Protonet, self).__init__()
self.encoder = encoder
def loss(self, sample):
xs = Variable(sample['xs']) # support
xq = Variable(sample['xq']) # query
n_class = xs.size(0)
assert xq.size(0) == n_class
n_support = xs.size(1)
n_query = xq.size(1)
target_inds = torch.arange(0, n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).long()
target_inds = Variable(target_inds, requires_grad=False)
if xq.is_cuda:
target_inds = target_inds.cuda()
x = torch.cat([xs.view(n_class * n_support, *xs.size()[2:]),
xq.view(n_class * n_query, *xq.size()[2:])], 0)
z = self.encoder.forward(x)
z_dim = z.size(-1)
z_proto = z[:n_class*n_support].view(n_class, n_support, z_dim).mean(1)
zq = z[n_class*n_support:]
dists = euclidean_dist(zq, z_proto)
log_p_y = F.log_softmax(-dists, dim=1).view(n_class, n_query, -1)
loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()
_, y_hat = log_p_y.max(2)
acc_val = torch.eq(y_hat, target_inds.squeeze()).float().mean()
return loss_val, {
'loss': loss_val.item(),
'acc': acc_val.item()
}
@register_model('protonet_conv')
def load_protonet_conv(**kwargs):
x_dim = kwargs['x_dim']
hid_dim = kwargs['hid_dim']
z_dim = kwargs['z_dim']
def conv_block(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.MaxPool2d(2)
)
encoder = nn.Sequential(
conv_block(x_dim[0], hid_dim),
conv_block(hid_dim, hid_dim),
conv_block(hid_dim, hid_dim),
conv_block(hid_dim, z_dim),
Flatten()
)
return Protonet(encoder)