diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..ba85002 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2018 lyakaap + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..afe347d --- /dev/null +++ b/README.md @@ -0,0 +1,58 @@ +# NetVLAD-pytorch +Pytorch implementation of NetVLAD & Online Hardest Triplet Loss. +In NetVLAD, broadcasting is used to calculate residuals of clusters and it makes whole calculation time much faster. + +NetVLAD: https://arxiv.org/abs/1511.07247 + +In Defense of the Triplet Loss for Person Re-Identification: https://arxiv.org/abs/1703.07737 https://omoindrot.github.io/triplet-loss + +## Usage +``` +import torch +import torch.nn as nn +from torch.autograd import Variable + +from netvlad import NetVLAD +from netvlad import EmbedNet +from hard_triplet_loss import HardTripletLoss +from torchvision.models import resnet18 + + +# Discard layers at the end of base network +encoder = resnet18(pretrained=True) +base_model = nn.Sequential( + encoder.conv1, + encoder.bn1, + encoder.relu, + encoder.maxpool, + encoder.layer1, + encoder.layer2, + encoder.layer3, + encoder.layer4, +]) +dim = list(base_model.parameters())[-1].shape[0] # last channels (512) + +# Define model for embedding +net_vlad = NetVLAD(num_clusters=32, dim=dim, alpha=1.0) +model = EmbedNet(base_model, net_vlad).cuda() + +# Define loss +criterion = HardTripletLoss(margin=0.1).cuda() + +# This is just toy example. Typically, the number of samples in each classes are 4. +labels = torch.randint(0, 10, (40, )).long() +x = torch.rand(40, 3, 128, 128).cuda() +output = model(x) + +triplet_loss = criterion(output, labels) +``` + + +# ghostVlAD +use fc features +contain NetVLAD and ghostVLAD +RUN +``` +python ghostVLAD.py +``` + diff --git a/gostVALD.py b/gostVALD.py new file mode 100644 index 0000000..c6041a2 --- /dev/null +++ b/gostVALD.py @@ -0,0 +1,231 @@ +#codeing=utf-8 +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.models import resnet18 +from torch.autograd import Variable +''' +针对人脸问题,针对同一人多张人脸照片问题,多张人脸特征后进行特征融合,修改VLAD,将FC层替换掉卷积层 + +''' +class netVLAD(nn.Module): + ''' + 参数量:8*128* + ''' + def __init__(self,num_clusters=8,dim=128,normalize_input=True): + super(netVLAD, self).__init__() + self.num_clusters=num_clusters + self.dim=dim + self.normalize_input=normalize_input + self.fc=nn.Linear(dim,num_clusters) + self.centroids=nn.Parameter(torch.rand(num_clusters,dim)) + self._init_params() + def _init_params(self): + nn.init.xavier_normal_(self.fc.weight.data) + nn.init.constant_(self.fc.bias.data, 0.0) + #self.alpha=100. + #self.fc.weight = nn.Parameter( + # (2.0 * self.alpha * self.centroids).unsqueeze(-1).unsqueeze(-1) + #) + #self.fc.bias = nn.Parameter( + # - self.alpha * self.centroids.norm(dim=1) + #) + def forward(self,x): + ''' + x:(10,128) + ''' + N,C=x.shape[:2]#10,128 + assert C==self.dim ,"feature dim not correct" + if self.normalize_input: + x=F.normalize(x,p=2,dim=0) + soft_assign=self.fc(x).unsqueeze(0).permute(0,2,1)#(10,8)->(1,10,8)->(1,8,10) + soft_assign=F.softmax(soft_assign,dim=1) #nn.Softmax(dim=1) + x_flatten=x.view(1,C,-1) + #print(x_flatten.shape) + #print(x_flatten.expand(self.num_clusters, -1, -1, -1).permute(1, 0, 2, 3).shape) + #print(self.centroids.expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0).shape) + residual = x_flatten.expand(self.num_clusters, -1, -1, -1).permute(1, 0, 2, 3) - \ + self.centroids.expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0) + residual *= soft_assign.unsqueeze(2) + vlad = residual.sum(dim=-1)#(1,8,128) + vlad = F.normalize(vlad, p=2, dim=2) + vlad = vlad.view(1, -1) + vlad = F.normalize(vlad, p=2, dim=1) #(1,8*128) + return vlad + +class netVLAD2(nn.Module): + ''' + 参数量:8*128* + ''' + def __init__(self,num_clusters=8,dim=128,normalize_input=True): + super(netVLAD2, self).__init__() + self.num_clusters=num_clusters + self.dim=dim + self.normalize_input=normalize_input + self.fc=nn.Linear(dim,num_clusters) + self.batch_norm = nn.BatchNorm1d(num_clusters, eps=1e-3, momentum=0.01) + self.softmax = nn.Softmax(dim=1) + self.centroids=nn.Parameter(torch.rand(num_clusters,dim)) + self._init_params() + def _init_params(self): + nn.init.xavier_normal_(self.fc.weight.data) + nn.init.constant_(self.fc.bias.data, 0.0) + def forward(self,x): + N,C=x.shape[:2] + if self.normalize_input: + x=F.normalize(x,p=2,dim=1) + soft_assign=self.fc(x) + soft_assign=self.softmax(soft_assign).unsqueeze(0)#(1,10,8) + a_sum = soft_assign.sum(-2).unsqueeze(1)#(1,1,8) + a = torch.mul(a_sum, self.centroids.transpose(1,0).unsqueeze(0))#(1,128,8) + print(soft_assign.size(),a_sum.size(),a.size()) + soft_assign = soft_assign.permute(0, 2, 1).contiguous() + x=x.view([-1, N, self.dim]) + vlad = torch.matmul(soft_assign, x).permute(0, 2, 1).contiguous() + vlad = vlad.sub(a).view([-1, self.num_clusters * self.dim]) + vlad = F.normalize(vlad, p=2, dim=1) + return vlad + def forward2(self,x): + ''' + x:(10,128) + ''' + N,C=x.shape[:2]#10,128 + assert C==self.dim ,"feature dim not correct" + if self.normalize_input: + x=F.normalize(x,p=2,dim=1) + soft_assign=self.fc(x).unsqueeze(0).permute(0,2,1)#(10,8)->(1,10,8)->(1,8,10) + soft_assign=F.softmax(soft_assign,dim=1) #nn.Softmax(dim=1) #(1,8,10) + x_flatten=x.unsqueeze(0).permute(0,2,1)#(1,128,10) + #print(x_flatten.shape) + #print(x_flatten.expand(self.num_clusters, -1, -1, -1).shape)#(8,1,128,40) + #print(self.centroids.expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0).shape) + #[(1,128,10)->(8,1,128,10)->(1,8,128,10)]-[(8,128)->(10,8,128)->(8,128,10)->(1,8,128,10)] + residual = x_flatten.expand(self.num_clusters, -1, -1, -1).permute(1, 0, 2, 3) - \ + self.centroids.expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0) + #print(residual.size())#(1,8,128,10) + residual *= soft_assign.unsqueeze(2) #(1,8,128,10)*(1,8,1,10)->(1,8,128,10) + vlad = residual.sum(dim=-1)#(1,8,128) + vlad = F.normalize(vlad, p=2, dim=2) + vlad = vlad.view(1, -1) + vlad = F.normalize(vlad, p=2, dim=1) #(1,8*128) + return vlad +class gostVLAD(nn.Module): + def __init__(self,num_clusters=8,gost=1,dim=128,normalize_input=True): + super(gostVLAD, self).__init__() + self.num_clusters=num_clusters + self.dim=dim + self.gost=gost + self.normalize_input=normalize_input + self.fc=nn.Linear(dim,num_clusters+gost) + self.centroids=nn.Parameter(torch.rand(num_clusters,dim)) + self._init_params() + def _init_params(self): + nn.init.xavier_normal_(self.fc.weight.data) + nn.init.constant_(self.fc.bias.data, 0.0) + def forward(self,x): + ''' + x:NxD + ''' + N,C=x.shape[:2]#10,128 + assert C==self.dim ,"feature dim not correct" + if self.normalize_input: + x=F.normalize(x,p=2,dim=0) + soft_assign=self.fc(x).unsqueeze(0).permute(0,2,1)#(10,9)->(1,10,9)->(1,9,10) + soft_assign=F.softmax(soft_assign,dim=1) + + soft_assign=soft_assign[:,:self.num_clusters,:]#(1,8,10) + + x_flatten=x.view(1,C,-1) + residual = x_flatten.expand(self.num_clusters, -1, -1, -1).permute(1, 0, 2, 3) - \ + self.centroids.expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0) + residual *= soft_assign.unsqueeze(2) + vlad = residual.sum(dim=-1)#(1,8,128) + vlad = F.normalize(vlad, p=2, dim=2) + vlad = vlad.view(1, -1) + vlad = F.normalize(vlad, p=2, dim=1) #(1,8*128) + return vlad + + +class gostVLAD2(nn.Module): + def __init__(self,num_clusters=8,gost=1,dim=128,normalize_input=True): + super(gostVLAD2, self).__init__() + self.num_clusters=num_clusters + self.dim=dim + self.gost=gost + self.normalize_input=normalize_input + self.fc=nn.Linear(dim,num_clusters+gost) + self.centroids=nn.Parameter(torch.rand(num_clusters+gost,dim)) + self._init_params() + def _init_params(self): + nn.init.xavier_normal_(self.fc.weight.data) + nn.init.constant_(self.fc.bias.data, 0.0) + def forward(self,x): + ''' + x:NxD + ''' + N,C=x.shape[:2]#10,128 + assert C==self.dim ,"feature dim not correct" + if self.normalize_input: + x=F.normalize(x,p=2,dim=0) + soft_assign=self.fc(x).unsqueeze(0).permute(0,2,1)#(10,9)->(1,10,9)->(1,9,10) + soft_assign=F.softmax(soft_assign,dim=1) + + #soft_assign=soft_assign[:,:self.num_clusters,:]#(1,8,10) + + x_flatten=x.unsqueeze(0).permute(0,2,1)#x.view(1,C,-1) + residual = x_flatten.expand(self.num_clusters+self.gost, -1, -1, -1).permute(1, 0, 2, 3) - \ + self.centroids.expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0) + residual *= soft_assign.unsqueeze(2) + vlad = residual.sum(dim=-1)#(1,9,128) + vald=vald[:,:self.num_clusters,:]#(1,8,128) + vlad = F.normalize(vlad, p=2, dim=2) + vlad = vlad.view(1, -1) + vlad = F.normalize(vlad, p=2, dim=1) #(1,8*128) + return vlad + + +class EmbedNet(nn.Module): + def __init__(self, base_model, net_vlad,dim_in=512,dim_out=128): + super(EmbedNet, self).__init__() + self.base_model = base_model + self.net_vlad = net_vlad + self.conv=nn.Conv2d(dim_in,dim_out,kernel_size=(1,1),bias=True) + self.avgp=nn.AdaptiveAvgPool2d(1) + def forward(self, x): + x = self.base_model(x) + x=self.conv(x) # + x=self.avgp(x) + x=x.squeeze() #(N,128) + embedded_x = self.net_vlad.forward(x) + emb2=self.net_vlad.forward2(x) + return embedded_x,emb2 + + +def test(): + encoder = resnet18(pretrained=False) + base_model = nn.Sequential( + encoder.conv1, + encoder.bn1, + encoder.relu, + encoder.maxpool, + encoder.layer1, + encoder.layer2, + encoder.layer3, + encoder.layer4, + ) + dim_in = list(base_model.parameters())[-1].shape[0]#512 + dim_out=128 + net_vlad=netVLAD2(dim=dim_out) + #net_vlad=gostVLAD(dim=dim_out) + model=EmbedNet(base_model,net_vlad,dim_in=dim_in,dim_out=dim_out) + + x=torch.rand(10,3,128,128) + output1,output2=model(x) + print(output1.shape,output2.shape)#(1,8*128) + print(output1) + print(output2.detach().numpy()) + + +test() + + diff --git a/hard_triplet_loss.py b/hard_triplet_loss.py new file mode 100644 index 0000000..6832f42 --- /dev/null +++ b/hard_triplet_loss.py @@ -0,0 +1,143 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class HardTripletLoss(nn.Module): + """Hard/Hardest Triplet Loss + (pytorch implementation of https://omoindrot.github.io/triplet-loss) + + For each anchor, we get the hardest positive and hardest negative to form a triplet. + """ + def __init__(self, margin=0.1, hardest=False, squared=False): + """ + Args: + margin: margin for triplet loss + hardest: If true, loss is considered only hardest triplets. + squared: If true, output is the pairwise squared euclidean distance matrix. + If false, output is the pairwise euclidean distance matrix. + """ + super(HardTripletLoss, self).__init__() + self.margin = margin + self.hardest = hardest + self.squared = squared + + def forward(self, embeddings, labels): + """ + Args: + labels: labels of the batch, of size (batch_size,) + embeddings: tensor of shape (batch_size, embed_dim) + + Returns: + triplet_loss: scalar tensor containing the triplet loss + """ + pairwise_dist = _pairwise_distance(embeddings, squared=self.squared) + + if self.hardest: + # Get the hardest positive pairs + mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float() + valid_positive_dist = pairwise_dist * mask_anchor_positive + hardest_positive_dist, _ = torch.max(valid_positive_dist, dim=1, keepdim=True) + + # Get the hardest negative pairs + mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float() + max_anchor_negative_dist, _ = torch.max(pairwise_dist, dim=1, keepdim=True) + anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * ( + 1.0 - mask_anchor_negative) + hardest_negative_dist, _ = torch.min(anchor_negative_dist, dim=1, keepdim=True) + + # Combine biggest d(a, p) and smallest d(a, n) into final triplet loss + triplet_loss = F.relu(hardest_positive_dist - hardest_negative_dist + 0.1) + triplet_loss = torch.mean(triplet_loss) + else: + anc_pos_dist = pairwise_dist.unsqueeze(dim=2) + anc_neg_dist = pairwise_dist.unsqueeze(dim=1) + + # Compute a 3D tensor of size (batch_size, batch_size, batch_size) + # triplet_loss[i, j, k] will contain the triplet loss of anc=i, pos=j, neg=k + # Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1) + # and the 2nd (batch_size, 1, batch_size) + loss = anc_pos_dist - anc_neg_dist + self.margin + + mask = _get_triplet_mask(labels).float() + triplet_loss = loss * mask + + # Remove negative losses (i.e. the easy triplets) + triplet_loss = F.relu(triplet_loss) + + # Count number of hard triplets (where triplet_loss > 0) + hard_triplets = torch.gt(triplet_loss, 1e-16).float() + num_hard_triplets = torch.sum(hard_triplets) + + triplet_loss = torch.sum(triplet_loss) / (num_hard_triplets + 1e-16) + + return triplet_loss + + +def _pairwise_distance(x, squared=False, eps=1e-16): + # Compute the 2D matrix of distances between all the embeddings. + + cor_mat = torch.matmul(x, x.t()) + norm_mat = cor_mat.diag() + distances = norm_mat.unsqueeze(1) - 2 * cor_mat + norm_mat.unsqueeze(0) + distances = F.relu(distances) + + if not squared: + mask = torch.eq(distances, 0.0).float() + distances = distances + mask * eps + distances = torch.sqrt(distances) + distances = distances * (1.0 - mask) + + return distances + + +def _get_anchor_positive_triplet_mask(labels): + # Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label. + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + indices_not_equal = torch.eye(labels.shape[0]).to(device).byte() ^ 1 + + # Check if labels[i] == labels[j] + labels_equal = torch.unsqueeze(labels, 0) == torch.unsqueeze(labels, 1) + + mask = indices_not_equal * labels_equal + + return mask + + +def _get_anchor_negative_triplet_mask(labels): + # Return a 2D mask where mask[a, n] is True iff a and n have distinct labels. + + # Check if labels[i] != labels[k] + labels_equal = torch.unsqueeze(labels, 0) == torch.unsqueeze(labels, 1) + mask = labels_equal ^ 1 + + return mask + + +def _get_triplet_mask(labels): + """Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid. + + A triplet (i, j, k) is valid if: + - i, j, k are distinct + - labels[i] == labels[j] and labels[i] != labels[k] + """ + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + # Check that i, j and k are distinct + indices_not_same = torch.eye(labels.shape[0]).to(device).byte() ^ 1 + i_not_equal_j = torch.unsqueeze(indices_not_same, 2) + i_not_equal_k = torch.unsqueeze(indices_not_same, 1) + j_not_equal_k = torch.unsqueeze(indices_not_same, 0) + distinct_indices = i_not_equal_j * i_not_equal_k * j_not_equal_k + + # Check if labels[i] == labels[j] and labels[i] != labels[k] + label_equal = torch.eq(torch.unsqueeze(labels, 0), torch.unsqueeze(labels, 1)) + i_equal_j = torch.unsqueeze(label_equal, 2) + i_equal_k = torch.unsqueeze(label_equal, 1) + valid_labels = i_equal_j * (i_equal_k ^ 1) + + mask = distinct_indices * valid_labels # Combine the two masks + + return mask diff --git a/netvlad.py b/netvlad.py new file mode 100644 index 0000000..701a275 --- /dev/null +++ b/netvlad.py @@ -0,0 +1,88 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class NetVLAD(nn.Module): + """NetVLAD layer implementation""" + + def __init__(self, num_clusters=64, dim=128, alpha=100.0, + normalize_input=True): + """ + Args: + num_clusters : int + The number of clusters + dim : int + Dimension of descriptors + alpha : float + Parameter of initialization. Larger value is harder assignment. + normalize_input : bool + If true, descriptor-wise L2 normalization is applied to input. + """ + super(NetVLAD, self).__init__() + self.num_clusters = num_clusters + self.dim = dim + self.alpha = alpha + self.normalize_input = normalize_input + self.conv = nn.Conv2d(dim, num_clusters, kernel_size=(1, 1), bias=True) + self.centroids = nn.Parameter(torch.rand(num_clusters, dim)) + self._init_params() + + def _init_params(self): + self.conv.weight = nn.Parameter( + (2.0 * self.alpha * self.centroids).unsqueeze(-1).unsqueeze(-1) + ) + self.conv.bias = nn.Parameter( + - self.alpha * self.centroids.norm(dim=1) + ) + + def forward(self, x): + N, C = x.shape[:2] + + if self.normalize_input: + x = F.normalize(x, p=2, dim=1) # across descriptor dim + + # soft-assignment + soft_assign = self.conv(x).view(N, self.num_clusters, -1) + soft_assign = F.softmax(soft_assign, dim=1) + + x_flatten = x.view(N, C, -1) + + # calculate residuals to each clusters + residual = x_flatten.expand(self.num_clusters, -1, -1, -1).permute(1, 0, 2, 3) - \ + self.centroids.expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0) + residual *= soft_assign.unsqueeze(2) + vlad = residual.sum(dim=-1) + + vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization + vlad = vlad.view(x.size(0), -1) # flatten + vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize + + return vlad + + +class EmbedNet(nn.Module): + def __init__(self, base_model, net_vlad): + super(EmbedNet, self).__init__() + self.base_model = base_model + self.net_vlad = net_vlad + + def forward(self, x): + x = self.base_model(x) + embedded_x = self.net_vlad(x) + return embedded_x + + +class TripletNet(nn.Module): + def __init__(self, embed_net): + super(TripletNet, self).__init__() + self.embed_net = embed_net + + def forward(self, a, p, n): + embedded_a = self.embed_net(a) + embedded_p = self.embed_net(p) + embedded_n = self.embed_net(n) + return embedded_a, embedded_p, embedded_n + + def feature_extract(self, x): + return self.embed_net(x) diff --git a/test.py b/test.py new file mode 100644 index 0000000..9b523d3 --- /dev/null +++ b/test.py @@ -0,0 +1,39 @@ +import torch +import torch.nn as nn +from torch.autograd import Variable + +from netvlad import NetVLAD +from netvlad import EmbedNet +from hard_triplet_loss import HardTripletLoss +from torchvision.models import resnet18 + + +# Discard layers at the end of base network +encoder = resnet18(pretrained=False) +base_model = nn.Sequential( + encoder.conv1, + encoder.bn1, + encoder.relu, + encoder.maxpool, + encoder.layer1, + encoder.layer2, + encoder.layer3, + encoder.layer4, +) +dim = list(base_model.parameters())[-1].shape[0] # last channels (512) + +# Define model for embedding +net_vlad = NetVLAD(num_clusters=32, dim=dim, alpha=1.0) +model = EmbedNet(base_model, net_vlad).cuda() + +# Define loss +criterion = HardTripletLoss(margin=0.1).cuda() + +# This is just toy example. Typically, the number of samples in each classes are 4. +labels = torch.randint(0, 10, (40, )).long().cuda() +x = torch.rand(40, 3, 128, 128).cuda() +base_model.cuda() +print(base_model(x).shape) +output = model(x) +print(output.shape) +triplet_loss = criterion(output, labels)