From 756ce86c184246eed8867436289782074c566805 Mon Sep 17 00:00:00 2001 From: ksaito-ut Date: Tue, 10 Jul 2018 23:26:35 +0900 Subject: [PATCH] add visda classification code --- visda_classification/basenet.py | 895 +++++++++++++++++++++++++ visda_classification/res_train_main.py | 253 +++++++ visda_classification/taskcv_loader.py | 85 +++ visda_classification/utils.py | 36 + 4 files changed, 1269 insertions(+) create mode 100644 visda_classification/basenet.py create mode 100755 visda_classification/res_train_main.py create mode 100644 visda_classification/taskcv_loader.py create mode 100644 visda_classification/utils.py diff --git a/visda_classification/basenet.py b/visda_classification/basenet.py new file mode 100644 index 0000000..82a7113 --- /dev/null +++ b/visda_classification/basenet.py @@ -0,0 +1,895 @@ +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +from torchvision import datasets, models, transforms +import torch.nn.functional as F +from torch.autograd import Variable +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd import Variable +import torch.nn.init as init + +#from resnet200 import Res200 +#from resnext import ResNeXt +from torch.nn.utils.weight_norm import WeightNorm +from torch.nn.utils.weight_norm import weight_norm +from torch.autograd import Function +class GradReverse(Function): + def __init__(self, lambd): + self.lambd = lambd + def forward(self, x): + return x.view_as(x) + def backward(self, grad_output): + return (grad_output*-self.lambd) +def grad_reverse(x,lambd=1.0): + return GradReverse(lambd)(x) +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) +class ResnetBlock(nn.Module): + def __init__(self, dim, padding_type, norm_layer, use_dropout): + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout) + + + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout): + conv_block = [] + p = 0 + # TODO: support padding types + assert(padding_type == 'zero') + p = 1 + + # TODO: InstanceNorm + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), + norm_layer(dim, affine=True), + nn.ReLU(True)] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), + norm_layer(dim, affine=True)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = x + self.conv_block(x) + return out + + + +class L2Norm(nn.Module): + def __init__(self,n_channels, scale): + super(L2Norm,self).__init__() + self.n_channels = n_channels + self.gamma = scale or None + self.eps = 1e-10 + self.weight = nn.Parameter(torch.Tensor(self.n_channels)) + self.reset_parameters() + + def reset_parameters(self): + init.constant(self.weight,self.gamma) + + def forward(self, x): + norm = x.pow(2).sum(1).sqrt()+self.eps + x/=norm.expand_as(x) + out = self.weight.unsqueeze(0).expand_as(x) * x + return out +class BaseNet(nn.Module): + #Model VGG + def __init__(self): + super(BaseNet, self).__init__() + model_ft = models.vgg16(pretrained=True) + mod = list(model_ft.features.children()) + self.features = nn.Sequential(*mod) + mod = list(model_ft.classifier.children()) + mod.pop() + self.classifier = nn.Sequential(*mod) + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0), 512 * 7 * 7) + return x +class AlexNet(nn.Module): + def __init__(self): + super(AlexNet, self).__init__() + model_ft = models.alexnet(pretrained=True) + mod = list(model_ft.features.children()) + self.features = model_ft.features#nn.Sequential(*mod) + print(self.features[0]) + #mod = list(model_ft.classifier.children()) + #mod.pop() + + #self.classifier = nn.Sequential(*mod) + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0),9216) + #x = self.classifier(x) + + return x +class AlexNet_office(nn.Module): + def __init__(self): + super(AlexNet_office, self).__init__() + model_ft = models.alexnet(pretrained=True) + mod = list(model_ft.features.children()) + self.features = model_ft.features#nn.Sequential(*mod) + mod = list(model_ft.classifier.children()) + mod.pop() + print(mod) + self.classifier = nn.Sequential(*mod) + + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0),9216) + x = self.classifier(x) + #x = F.dropout(F.relu(self.top(x)),training=self.training) + + return x +class AlexMiddle_office(nn.Module): + def __init__(self): + super(AlexMiddle_office, self).__init__() + self.top = nn.Linear(4096,256) + def forward(self, x): + x = F.dropout(F.relu(self.top(x)),training=self.training) + return x + + +class AlexClassifier(nn.Module): + # Classifier for VGG + def __init__(self, num_classes=12): + super(AlexClassifier, self).__init__() + mod = [] + mod.append(nn.Dropout()) + mod.append(nn.Linear(4096,256)) + #mod.append(nn.BatchNorm1d(256,affine=True)) + mod.append(nn.ReLU()) + #mod.append(nn.Linear(256,256)) + mod.append(nn.Dropout()) + #mod.append(nn.ReLU()) + mod.append(nn.Dropout()) + #self.top = nn.Linear(256,256) + mod.append(nn.Linear(256,31)) + self.classifier = nn.Sequential(*mod) + def set_lambda(self, lambd): + self.lambd = lambd + def forward(self, x,reverse=False): + if reverse: + x = grad_reverse(x, self.lambd) + x = self.classifier(x) + return x + + + +class Classifier(nn.Module): + # Classifier for VGG + def __init__(self, num_classes=12): + super(Classifier, self).__init__() + model_ft = models.alexnet(pretrained=False) + mod = list(model_ft.classifier.children()) + mod.pop() + mod.append(nn.Linear(4096,num_classes)) + self.classifier = nn.Sequential(*mod) + + def forward(self, x): + + x = self.classifier(x) + return x + +class ClassifierMMD(nn.Module): + def __init__(self, num_classes=12): + super(ClassifierMMD, self).__init__() + model_ft = models.vgg16(pretrained=True) + mod = list(model_ft.classifier.children()) + mod.pop() + self.classifier1 = nn.Sequential(*mod) + self.classifier2 = nn.Sequential( + nn.Dropout(), + nn.Linear(4096, 1000), + nn.ReLU(inplace=True), + ) + self.classifier3 = nn.Sequential( + nn.BatchNorm1d(1000,affine=True), + nn.Dropout(), + nn.ReLU(inplace=True), + ) + self.last = nn.Linear(1000, num_classes) + + def forward(self, x): + x = self.classifier1(x) + x1 = self.classifier2(x) + x2 = self.classifier3(x1) + x3 = self.last(x2) + return x3,x2,x1 +class ResBase(nn.Module): + def __init__(self,option = 'resnet18',pret=True): + super(ResBase, self).__init__() + self.dim = 2048 + if option == 'resnet18': + model_ft = models.resnet18(pretrained=pret) + self.dim = 512 + if option == 'resnet50': + model_ft = models.resnet50(pretrained=pret) + if option == 'resnet101': + model_ft = models.resnet101(pretrained=pret) + if option == 'resnet152': + model_ft = models.resnet152(pretrained=pret) + if option == 'resnet200': + model_ft = Res200() + if option == 'resnetnext': + model_ft = ResNeXt(layer_num=101) + mod = list(model_ft.children()) + mod.pop() + #self.model_ft =model_ft + self.features = nn.Sequential(*mod) + def forward(self, x): + + x = self.features(x) + + x = x.view(x.size(0), self.dim) + return x +class ResBasePlus(nn.Module): + def __init__(self,option = 'resnet18',pret=True): + super(ResBasePlus, self).__init__() + self.dim = 2048 + if option == 'resnet18': + model_ft = models.resnet18(pretrained=pret) + self.dim = 512 + if option == 'resnet50': + model_ft = models.resnet50(pretrained=pret) + if option == 'resnet101': + model_ft = models.resnet101(pretrained=pret) + if option == 'resnet152': + model_ft = models.resnet152(pretrained=pret) + if option == 'resnet200': + model_ft = Res200() + if option == 'resnetnext': + model_ft = ResNeXt(layer_num=101) + mod = list(model_ft.children()) + mod.pop() + #self.model_ft =model_ft + self.layer = nn.Sequential( + nn.Dropout(), + nn.Linear(2048, 1000), + nn.ReLU(inplace=True), + nn.BatchNorm1d(1000,affine=True), + nn.Dropout(), + nn.ReLU(inplace=True), + ) + self.features = nn.Sequential(*mod) + def forward(self, x): + + x = self.features(x) + x = x.view(x.size(0), self.dim) + x = self.layer(x) + return x + +class ResNet_all(nn.Module): + def __init__(self,option = 'resnet18',pret=True): + super(ResNet_all, self).__init__() + self.dim = 2048 + if option == 'resnet18': + model_ft = models.resnet18(pretrained=pret) + self.dim = 512 + if option == 'resnet50': + model_ft = models.resnet50(pretrained=pret) + if option == 'resnet101': + model_ft = models.resnet101(pretrained=pret) + if option == 'resnet152': + model_ft = models.resnet152(pretrained=pret) + if option == 'resnet200': + model_ft = Res200() + if option == 'resnetnext': + model_ft = ResNeXt(layer_num=101) + #mod = list(model_ft.children()) + #mod.pop() + #self.model_ft =model_ft + self.conv1 = model_ft.conv1 + self.bn0 = model_ft.bn1 + self.relu = model_ft.relu + self.maxpool = model_ft.maxpool + self.layer1 = model_ft.layer1 + self.layer2 = model_ft.layer2 + self.layer3 = model_ft.layer3 + self.layer4 = model_ft.layer4 + self.pool = model_ft.avgpool + self.fc = nn.Linear(2048,12) + def forward(self, x,layer_return = False,input_mask=False,mask=None,mask2=None): + if input_mask: + x = self.conv1(x) + x = self.bn0(x) + x = self.relu(x) + conv_x = x + x = self.maxpool(x) + fm1 = mask*self.layer1(x) + fm2 = mask2*self.layer2(fm1) + fm3 = self.layer3(fm2) + fm4 = self.pool(self.layer4(fm3)) + x = fm4.view(fm4.size(0), self.dim) + x = self.fc(x) + return x#,fm1 + else: + x = self.conv1(x) + x = self.bn0(x) + x = self.relu(x) + conv_x = x + x = self.maxpool(x) + fm1 = self.layer1(x) + fm2 = self.layer2(fm1) + fm3 = self.layer3(fm2) + fm4 = self.pool(self.layer4(fm3)) + x = fm4.view(fm4.size(0), self.dim) + x = self.fc(x) + if layer_return: + return x,fm1,fm2 + else: + return x + +class Mask_Generator(nn.Module): + def __init__(self): + super(Mask_Generator, self).__init__() + self.conv1 = nn.Conv2d(256, 256, kernel_size=1,stride=1,padding=0) + self.bn1 = nn.BatchNorm2d(256) + self.conv2 = nn.Conv2d(256, 256, kernel_size=1,stride=1,padding=0) + self.conv1_2 = nn.Conv2d(512, 256, kernel_size=1,stride=1,padding=0) + self.bn1_2 = nn.BatchNorm2d(256) + self.conv2_2 = nn.Conv2d(256, 512, kernel_size=1,stride=1,padding=0) + + def forward(self, x,x2): + x = F.relu(self.bn1(self.conv1(x))) + x = F.sigmoid(self.conv2(x)) + x2 = F.relu(self.bn1_2(self.conv1_2(x2))) + x2 = F.sigmoid(self.conv2_2(x2)) + return x,x2 + + +class ResMiddle_office(nn.Module): + def __init__(self,option = 'resnet18',pret=True): + super(ResMiddle_office, self).__init__() + self.dim = 2048 + layers = [] + layers.append(nn.Linear(self.dim,256)) + layers.append(nn.ReLU(inplace=True)) + layers.append(nn.Dropout()) + self.bottleneck = nn.Sequential(*layers) + #self.features = nn.Sequential(*mod) + def forward(self, x): + x = self.bottleneck(x) + return x + + +class ResBase_office(nn.Module): + def __init__(self,option = 'resnet18',pret=True): + super(ResBase_office, self).__init__() + self.dim = 2048 + if option == 'resnet18': + model_ft = models.resnet18(pretrained=pret) + self.dim = 512 + if option == 'resnet50': + model_ft = models.resnet50(pretrained=pret) + if option == 'resnet101': + model_ft = models.resnet101(pretrained=pret) + if option == 'resnet152': + model_ft = models.resnet152(pretrained=pret) + if option == 'resnet200': + model_ft = Res200() + if option == 'resnetnext': + model_ft = ResNeXt(layer_num=101) + #mod = list(model_ft.children()) + #mod.pop() + #self.model_ft =model_ft + + self.conv1 = model_ft.conv1 + self.bn0 = model_ft.bn1 + self.relu = model_ft.relu + self.maxpool = model_ft.maxpool + + self.layer1 = model_ft.layer1 + self.layer2 = model_ft.layer2 + self.layer3 = model_ft.layer3 + self.layer4 = model_ft.layer4 + self.pool = model_ft.avgpool + #self.bottleneck = nn.Sequential(*layers) + #self.features = nn.Sequential(*mod) + def forward(self, x): + x = self.conv1(x) + x = self.bn0(x) + x = self.relu(x) + conv_x = x + x = self.maxpool(x) + fm1 = self.layer1(x) + fm2 = self.layer2(fm1) + fm3 = self.layer3(fm2) + fm4 = self.pool(self.layer4(fm3)) + x = fm4.view(fm4.size(0), self.dim) + #x = self.bottleneck(x) + return x +class ResBase_D(nn.Module): + def __init__(self,option = 'resnet18',pret=True): + super(ResBase_D, self).__init__() + self.dim = 2048 + if option == 'resnet18': + model_ft = models.resnet18(pretrained=pret) + self.dim = 512 + if option == 'resnet50': + model_ft = models.resnet50(pretrained=pret) + if option == 'resnet101': + model_ft = models.resnet101(pretrained=pret) + if option == 'resnet152': + model_ft = models.resnet152(pretrained=pret) + if option == 'resnet200': + model_ft = Res200() + if option == 'resnetnext': + model_ft = ResNeXt(layer_num=101) + #mod = list(model_ft.children()) + #mod.pop() + #self.model_ft =model_ft + self.conv1 = model_ft.conv1 + self.bn0 = model_ft.bn1 + self.relu = model_ft.relu + self.maxpool = model_ft.maxpool + self.drop0 = nn.Dropout2d() + self.layer1 = model_ft.layer1 + self.drop1 = nn.Dropout2d() + self.layer2 = model_ft.layer2 + self.drop2 = nn.Dropout2d() + self.layer3 = model_ft.layer3 + self.drop3 = nn.Dropout2d() + self.layer4 = model_ft.layer4 + self.drop4 = nn.Dropout2d() + self.pool = model_ft.avgpool + #self.features = nn.Sequential(*mod) + def forward(self, x): + x = self.conv1(x) + x = self.bn0(x) + x = self.relu(x) + x = self.drop0(x) + conv_x = x + x = self.maxpool(x) + fm1 = self.layer1(x) + x = self.drop1(x) + fm2 = self.layer2(fm1) + x = self.drop2(x) + fm3 = self.layer3(fm2) + x = self.drop3(x) + fm4 = self.pool(self.drop4(self.layer4(fm3))) + x = fm4.view(fm4.size(0), self.dim) + return x + +class ResBasePararrel(nn.Module): + def __init__(self,option = 'resnet18',pret=True,gpu_ids=[]): + super(ResBasePararrel, self).__init__() + self.dim = 2048 + if option == 'resnet18': + model_ft = models.resnet18(pretrained=pret) + self.dim = 512 + if option == 'resnet50': + model_ft = models.resnet50(pretrained=pret) + if option == 'resnet101': + model_ft = models.resnet101(pretrained=pret) + if option == 'resnet152': + model_ft = models.resnet152(pretrained=pret) + if option == 'resnet200': + model_ft = Res200() + if option == 'resnetnext': + model_ft = ResNeXt(layer_num=101) + mod = list(model_ft.children()) + mod.pop() + #self.model_ft =model_ft + self.gpu_ids = [0,1] + self.features = nn.Sequential(*mod) + def forward(self, x): + x = x + Variable(torch.randn(x.size()).cuda())*0.05 + x = nn.parallel.data_parallel(self.features, x, self.gpu_ids) + #x = self.features(x) + + x = x.view(x.size(0), self.dim) + return x + +class ResFreeze(nn.Module): + def __init__(self,option = 'resnet18',pret=True): + super(ResFreeze, self).__init__() + self.dim = 2048*2*2 + if option == 'resnet18': + model_ft = models.resnet18(pretrained=pret) + self.dim = 512 + if option == 'resnet50': + model_ft = models.resnet50(pretrained=pret) + if option == 'resnet101': + model_ft = models.resnet101(pretrained=pret) + if option == 'resnet152': + model_ft = models.resnet152(pretrained=pret) + if option == 'resnet200': + model_ft = Res200() + self.conv1 = model_ft.conv1 + self.bn0 = model_ft.bn1 + self.relu = model_ft.relu + self.maxpool = model_ft.maxpool + self.layer1 = model_ft.layer1 + self.layer2 = model_ft.layer2 + self.layer3 = model_ft.layer3 + self.layer4 = model_ft.layer4 + self.avgpool = model_ft.avgpool + + def forward(self, x): + x = self.conv1(x) + x = self.bn0(x) + x = self.relu(x) + conv_x = x + x = self.maxpool(x) + pool_x = x + fm1 = self.layer1(x) + fm2 = self.layer2(fm1) + fm3 = self.layer3(fm2) + fm4 = F.max_pool2d(self.layer4(fm3),kernel_size=3) + #print(fm1) + #print(fm2) + #print(fm3) + #print(fm4) + #x = self.avgpool(fm4) + x = fm4.view(fm4.size(0), self.dim) + return x + + +class DenseBase(nn.Module): + def __init__(self,option = 'densenet201',pret=True): + super(DenseBase, self).__init__() + self.dim = 2048 + if option == 'densenet201': + model_ft = models.densenet201(pretrained=pret) + self.dim = 1920 + if option == 'densenet161': + model_ft = models.densenet161(pretrained=pret) + self.dim = 2208 + mod = list(model_ft.children()) + #mod.pop() + + self.features = nn.Sequential(*mod) + def forward(self, x): + x = self.features(x) + #print x + #x = F.avg_pool2d(x,(7,7)) + #x = x.view(x.size(0), self.dim) + return x + + +class ResClassifier(nn.Module): + def __init__(self, num_classes=13,num_layer = 2,num_unit=2048,prob=0.5,middle=1000): + super(ResClassifier, self).__init__() + layers = [] + # currently 10000 units + layers.append(nn.Dropout(p=prob)) + layers.append(nn.Linear(num_unit,middle)) + layers.append(nn.BatchNorm1d(middle,affine=True)) + layers.append(nn.ReLU(inplace=True)) + + for i in range(num_layer-1): + layers.append(nn.Dropout(p=prob)) + layers.append(nn.Linear(middle,middle)) + layers.append(nn.BatchNorm1d(middle,affine=True)) + layers.append(nn.ReLU(inplace=True)) + layers.append(nn.Linear(middle,num_classes)) + self.classifier = nn.Sequential(*layers) + + #self.classifier = nn.Sequential( + # nn.Dropout(), + # nn.Linear(2048, 1000), + # nn.BatchNorm1d(1000,affine=True), + # nn.ReLU(inplace=True), + # nn.Dropout(), + # nn.Linear(1000, 1000), + # nn.BatchNorm1d(1000,affine=True), + # nn.ReLU(inplace=True), + # nn.Linear(1000, num_classes), + + def set_lambda(self, lambd): + self.lambd = lambd + def forward(self, x,reverse=False): + if reverse: + x = grad_reverse(x, self.lambd) + x = self.classifier(x) + return x +class ResClassifier_office(nn.Module): + def __init__(self, num_classes=12,num_layer = 2,num_unit=2048,prob=0.5,middle=256): + super(ResClassifier_office, self).__init__() + layers = [] + # currently 10000 units + layers.append(nn.Dropout(p=prob)) + layers.append(nn.Linear(num_unit,middle)) + layers.append(nn.BatchNorm1d(middle,affine=True)) + layers.append(nn.ReLU(inplace=True)) + layers.append(nn.Dropout(p=prob)) + layers.append(nn.Linear(middle,num_classes)) + self.classifier = nn.Sequential(*layers) + def set_lambda(self, lambd): + self.lambd = lambd + def forward(self, x,reverse=False): + if reverse: + x = grad_reverse(x, self.lambd) + x = self.classifier(x) + return x +class DenseClassifier(nn.Module): + def __init__(self, num_classes=12,num_layer = 2): + super(DenseClassifier, self).__init__() + layers = [] + # currently 1000 units + layers.append(nn.Dropout()) + layers.append(nn.Linear(1000,500)) + layers.append(nn.BatchNorm1d(500,affine=True)) + layers.append(nn.ReLU(inplace=True)) + + for i in range(num_layer-1): + layers.append(nn.Dropout()) + layers.append(nn.Linear(500,500)) + layers.append(nn.BatchNorm1d(500,affine=True)) + layers.append(nn.ReLU(inplace=True)) + layers.append(nn.Linear(500,num_classes)) + #layers.append(nn.BatchNorm1d(num_classes,affine=True,momentum=0.95)) + self.classifier = nn.Sequential(*layers) + + #self.classifier = nn.Sequential( + # nn.Dropout(), + # nn.Linear(2048, 1000), + # nn.BatchNorm1d(1000,affine=True), + # nn.ReLU(inplace=True), + # nn.Dropout(), + # nn.Linear(1000, 1000), + # nn.BatchNorm1d(1000,affine=True), + # nn.ReLU(inplace=True), + # nn.Linear(1000, num_classes), + + #) + def forward(self, x): + x = self.classifier(x) + #x = self.classifier(x) + return x + + +class AE(nn.Module): + def __init__(self, num_classes=12,num_layer = 2,ngf=32,norm_layer=nn.BatchNorm2d): + super(AE, self).__init__() + layers = [] + layers.append(nn.Dropout()) + layers.append(nn.Linear(512,32*8*8)) + #layers.append(nn.BatchNorm1d(64*8*8,affine=True)) + layers.append(nn.ReLU(inplace=True)) + self.classifier = nn.Sequential(*layers) + n_downsampling=5 + mult = 2**n_downsampling + n_blocks = 3 + model2 = [nn.Conv2d(32, ngf*mult, kernel_size=5, + stride=4, padding=1), + norm_layer(ngf * mult, affine=True), + nn.ReLU()] + + #model2 = [nn.ConvTranspose2d(64, ngf * mult, + # kernel_size=3, stride=2, + # padding=1, output_padding=1), + # norm_layer(ngf * mult, affine=True), + # nn.ReLU()] + for i in range(n_blocks): + model2 += [ResnetBlock(ngf * mult, 'zero', norm_layer=norm_layer, use_dropout=True)] + #print ngf*mult + for i in range(n_downsampling): + mult = 2**(n_downsampling - i) + model2 += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), + kernel_size=3, stride=2, + padding=1, output_padding=1), + norm_layer(int(ngf * mult / 2), affine=True), + nn.ReLU()] + #model2 += [nn.Conv2d(ngf*mult/2, ngf, + # kernel_size=1, padding=1), + # norm_layer(int(ngf), affine=True), + # nn.ReLU()] + #model2 += [nn.Conv2d(int(ngf * mult / 2), 3, kernel_size=, padding=3)] + model2 += [nn.Conv2d(ngf, 3, kernel_size=11, padding=1)] + model2 += [nn.Tanh()] + + self.classifier2 = nn.Sequential(*model2) + def forward(self, x): + x = self.classifier(x) + x = x.view(x.size(0),32,8,8) + x = self.classifier2(x) + return x + + +class InceptionBase(nn.Module): + def __init__(self): + super(InceptionBase, self).__init__() + model_ft = models.inception_v3(pretrained=True) + #mod = list(model_ft.children()) + #mod.pop() + self.features = model_ft#nn.Sequential(*mod) + def forward(self, x): + x = self.features(x) + #x = x.view(x.size(0), 2048) + return x +class InceptionClassifier(nn.Module): + def __init__(self, num_classes=12,num_layer = 2): + super(InceptionClassifier, self).__init__() + layers = [] + layers.append(nn.Dropout()) + layers.append(nn.Linear(1000,1000)) + layers.append(nn.BatchNorm1d(1000,affine=True)) + layers.append(nn.ReLU(inplace=True)) + + for i in range(num_layer-1): + layers.append(nn.Dropout()) + layers.append(nn.Linear(1000,1000)) + layers.append(nn.BatchNorm1d(1000,affine=True)) + layers.append(nn.ReLU(inplace=True)) + layers.append(nn.Linear(1000,num_classes)) + self.classifier = nn.Sequential(*layers) + + def forward(self, x): + x = self.classifier(x) + #x = self.classifier(x) + return x + + +class ResClassifierMMD(nn.Module): + def __init__(self, num_classes=12): + super(ResClassifierMMD, self).__init__() + self.classifier = nn.Sequential( + nn.Linear(512, 1000), + nn.BatchNorm1d(1000,affine=True), + #nn.Dropout(), + nn.ReLU(inplace=True), + ) + self.classifier2 = nn.Sequential( + nn.Linear(1000, 256), + nn.BatchNorm1d(256,affine=True), + #nn.Dropout(), + nn.ReLU(inplace=True), + ) + self.last = nn.Linear(256, num_classes) + def forward(self, x): + x1 = self.classifier(x) + x2 = self.classifier2(x1) + x3 = self.last(x2) + return x3,x2,x1 +class BaseShallow(nn.Module): + def __init__(self,num_classes=12): + super(BaseShallow, self).__init__() + layers = [] + nc = 3 + ndf = 64 + self.features = nn.Sequential( + # input is (nc) x 64 x 64 + nn.Conv2d(nc, ndf, 7, 2, 1, bias=False), + nn.ReLU(inplace=True), + # state size. (ndf) x 32 x 32 + nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), + #nn.BatchNorm2d(ndf * 2), + nn.ReLU(inplace=True), + # state size. (ndf*2) x 16 x 16 + nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), + #nn.BatchNorm2d(ndf * 4), + nn.ReLU(inplace=True), + # state size. (ndf*4) x 8 x 8 + nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), + #nn.BatchNorm2d(ndf * 8), + nn.ReLU(inplace=True), + # state size. (ndf*8) x 4 x 4 + nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), + ) + + self.last = nn.Sequential( + nn.Linear(256, num_classes), + ) + def forward(self, x): + x1 = self.features(x) + x1 = x1.view(-1,100) + #x2 = self.last(x1) + #x3 = self.last(x2) + return x1 + +class ClassifierShallow(nn.Module): + + def __init__(self, num_classes=12): + super(ClassifierShallow, self).__init__() + self.classifier2 = nn.Sequential( + nn.Dropout(), + nn.Linear(100, 1000), + nn.ReLU(inplace=True), + nn.BatchNorm1d(1000,affine=True), + nn.Dropout(), + nn.ReLU(inplace=True), + nn.Linear(1000, num_classes), + ) + def forward(self, x): + #x = self.classifier1(x) + x = self.classifier2(x) + return x + +class Discriminator(nn.Module): + def __init__(self, num_classes=12): + super(Discriminator, self).__init__() + self.classifier = nn.Sequential( + nn.Dropout(), + nn.Linear(2048, 100), + nn.ReLU(inplace=True), + nn.BatchNorm1d(100,affine=True), + nn.Dropout(), + nn.ReLU(inplace=True), + nn.Linear(100, 2), + ) + def forward(self, x): + #x = self.classifier1(x) + #print x + x = self.classifier(x) + return x + + +class EClassifier(nn.Module): + + def __init__(self, num_classes=12): + super(EClassifier, self).__init__() + self.classifier = nn.Sequential( + nn.Dropout(), + nn.Linear(9216+12, 1000), + nn.ReLU(inplace=True), + nn.BatchNorm1d(1000,affine=True), + nn.Dropout(), + nn.ReLU(inplace=True), + nn.Linear(1000, num_classes), + ) + self.classifier2 = nn.Sequential( + nn.Linear(num_classes,12), + ) + def forward(self, x1,x2): + x = torch.cat([x1,x2],1) + x = self.classifier(x) + x_source = self.classifier2(x) + return x,x_source + +class Resbridge(nn.Module): + def __init__(self, num_classes=12,num_layer = 2,num_unit=2048,prob=0.5): + super(Resbridge, self).__init__() + layers = [] + # currently 1000 units + layers.append(nn.Dropout(p=prob)) + layers.append(nn.Linear(num_unit,500)) + layers.append(nn.BatchNorm1d(500,affine=True)) + layers.append(nn.ReLU(inplace=True)) + self.classifier1 = nn.Sequential(*layers) + layers2 = [] + layers2.append(nn.Dropout(p=prob)) + layers2.append(nn.Linear(1000,500)) + layers2.append(nn.BatchNorm1d(500,affine=True)) + layers2.append(nn.ReLU(inplace=True)) + for i in range(num_layer-1): + layers2.append(nn.Dropout(p=prob)) + layers2.append(nn.Linear(500,500)) + layers2.append(nn.BatchNorm1d(500,affine=True)) + layers2.append(nn.ReLU(inplace=True)) + layers2.append(nn.Linear(500,num_classes)) + self.classifier2 = nn.Sequential(*layers2) + + + layers3 = [] + # currently 1000 units + layers3.append(nn.Dropout(p=prob)) + layers3.append(nn.Linear(num_unit,500)) + layers3.append(nn.BatchNorm1d(500,affine=True)) + layers3.append(nn.ReLU(inplace=True)) + self.classifier3 = nn.Sequential(*layers3) + layers4 = [] + layers4.append(nn.Dropout(p=prob)) + layers4.append(nn.Linear(1000,500)) + layers4.append(nn.BatchNorm1d(500,affine=True)) + layers4.append(nn.ReLU(inplace=True)) + for i in range(num_layer-1): + layers4.append(nn.Dropout(p=prob)) + layers4.append(nn.Linear(500,500)) + layers4.append(nn.BatchNorm1d(500,affine=True)) + layers4.append(nn.ReLU(inplace=True)) + layers4.append(nn.Linear(500,num_classes)) + self.classifier4 = nn.Sequential(*layers4) + + def forward(self, x): + x1 = self.classifier1(x) + x3 = self.classifier3(x) + x2 = torch.cat((x1,x3),1) + x2 = self.classifier2(x2) + x4 = torch.cat((x3,x1),1) + x4 = self.classifier4(x4) + + return x2,x4 diff --git a/visda_classification/res_train_main.py b/visda_classification/res_train_main.py new file mode 100755 index 0000000..230ff0e --- /dev/null +++ b/visda_classification/res_train_main.py @@ -0,0 +1,253 @@ +from __future__ import print_function +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +from torch.autograd import Variable +import numpy as np +from utils import * +from taskcv_loader import CVDataLoader +from basenet import * +import torch.nn.functional as F +import os +# Training settings +parser = argparse.ArgumentParser(description='Visda Classification') +parser.add_argument('--batch-size', type=int, default=32, metavar='N', + help='input batch size for training (default: 64)') +parser.add_argument('--test-batch-size', type=int, default=32, metavar='N', + help='input batch size for testing (default: 1000)') +parser.add_argument('--epochs', type=int, default=50, metavar='N', + help='number of epochs to train (default: 10)') +parser.add_argument('--lr', type=float, default=0.001, metavar='LR', + help='learning rate (default: 0.001)') +parser.add_argument('--momentum', type=float, default=0.9, metavar='M', + help='SGD momentum (default: 0.5)') +parser.add_argument('--optimizer', type=str, default='momentum', metavar='OP', + help='the name of optimizer') +parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') +parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') +parser.add_argument('--log-interval', type=int, default=50, metavar='N', + help='how many batches to wait before logging training status') +parser.add_argument('--num_k', type=int, default=4, metavar='K', + help='how many steps to repeat the generator update') +parser.add_argument('--num-layer', type=int, default=2, metavar='K', + help='how many layers for classifier') +parser.add_argument('--name', type=str, default='board', metavar='B', + help='board dir') +parser.add_argument('--save', type=str, default='save/mcd', metavar='B', + help='board dir') +parser.add_argument('--train_path', type=str, default='/data/ugui0/dataset/adaptation/train', metavar='B', + help='directory of source datasets') +parser.add_argument('--val_path', type=str, default='/data/ugui0/dataset/adaptation/validation', metavar='B', + help='directory of target datasets') +parser.add_argument('--resnet', type=str, default='101', metavar='B', + help='which resnet 18,50,101,152,200') + +args = parser.parse_args() +args.cuda = not args.no_cuda and torch.cuda.is_available() +train_path = args.train_path +val_path = args.val_path +num_k = args.num_k +num_layer = args.num_layer +batch_size = args.batch_size +save_path = args.save+'_'+str(args.num_k) + +data_transforms = { + train_path: transforms.Compose([ + transforms.Scale(256), + transforms.RandomHorizontalFlip(), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]), + val_path: transforms.Compose([ + transforms.Scale(256), + transforms.RandomHorizontalFlip(), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]), +} +dsets = {x: datasets.ImageFolder(os.path.join(x), data_transforms[x]) for x in [train_path,val_path]} +dset_sizes = {x: len(dsets[x]) for x in [train_path, val_path]} +dset_classes = dsets[train_path].classes +print ('classes'+str(dset_classes)) +use_gpu = torch.cuda.is_available() +torch.manual_seed(args.seed) +if args.cuda: + torch.cuda.manual_seed(args.seed) +train_loader = CVDataLoader() +train_loader.initialize(dsets[train_path],dsets[val_path],batch_size) +dataset = train_loader.load_data() +test_loader = CVDataLoader() +opt= args +test_loader.initialize(dsets[train_path],dsets[val_path],batch_size,shuffle=True) +dataset_test = test_loader.load_data() +option = 'resnet'+args.resnet +G = ResBase(option) +F1 = ResClassifier(num_layer=num_layer) +F2 = ResClassifier(num_layer=num_layer) +F1.apply(weights_init) +F2.apply(weights_init) +lr = args.lr +if args.cuda: + G.cuda() + F1.cuda() + F2.cuda() +if args.optimizer == 'momentum': + optimizer_g = optim.SGD(list(G.features.parameters()), lr=args.lr,weight_decay=0.0005) + optimizer_f = optim.SGD(list(F1.parameters())+list(F2.parameters()),momentum=0.9,lr=args.lr,weight_decay=0.0005) +elif args.optimizer == 'adam': + optimizer_g = optim.Adam(G.features.parameters(), lr=args.lr,weight_decay=0.0005) + optimizer_f = optim.Adam(list(F1.parameters())+list(F2.parameters()), lr=args.lr,weight_decay=0.0005) +else: + optimizer_g = optim.Adadelta(G.features.parameters(),lr=args.lr,weight_decay=0.0005) + optimizer_f = optim.Adadelta(list(F1.parameters())+list(F2.parameters()),lr=args.lr,weight_decay=0.0005) + +def train(num_epoch): + criterion = nn.CrossEntropyLoss().cuda() + for ep in range(num_epoch): + G.train() + F1.train() + F2.train() + for batch_idx, data in enumerate(dataset): + if batch_idx * batch_size > 30000: + break + if args.cuda: + data1 = data['S'] + target1 = data['S_label'] + data2 = data['T'] + target2 = data['T_label'] + data1, target1 = data1.cuda(), target1.cuda() + data2, target2 = data2.cuda(), target2.cuda() + # when pretraining network source only + eta = 1.0 + data = Variable(torch.cat((data1,data2),0)) + target1 = Variable(target1) + # Step A train all networks to minimize loss on source + optimizer_g.zero_grad() + optimizer_f.zero_grad() + output = G(data) + output1 = F1(output) + output2 = F2(output) + + output_s1 = output1[:batch_size,:] + output_s2 = output2[:batch_size,:] + output_t1 = output1[batch_size:,:] + output_t2 = output2[batch_size:,:] + output_t1 = F.softmax(output_t1) + output_t2 = F.softmax(output_t2) + + entropy_loss = - torch.mean(torch.log(torch.mean(output_t1,0)+1e-6)) + entropy_loss -= torch.mean(torch.log(torch.mean(output_t2,0)+1e-6)) + + loss1 = criterion(output_s1, target1) + loss2 = criterion(output_s2, target1) + all_loss = loss1 + loss2 + 0.01 * entropy_loss + all_loss.backward() + optimizer_g.step() + optimizer_f.step() + + #Step B train classifier to maximize discrepancy + optimizer_g.zero_grad() + optimizer_f.zero_grad() + + output = G(data) + output1 = F1(output) + output2 = F2(output) + output_s1 = output1[:batch_size,:] + output_s2 = output2[:batch_size,:] + output_t1 = output1[batch_size:,:] + output_t2 = output2[batch_size:,:] + output_t1 = F.softmax(output_t1) + output_t2 = F.softmax(output_t2) + loss1 = criterion(output_s1, target1) + loss2 = criterion(output_s2, target1) + entropy_loss = - torch.mean(torch.log(torch.mean(output_t1,0)+1e-6)) + entropy_loss -= torch.mean(torch.log(torch.mean(output_t2,0)+1e-6)) + loss_dis = torch.mean(torch.abs(output_t1-output_t2)) + F_loss = loss1 + loss2 - eta*loss_dis + 0.01 * entropy_loss + F_loss.backward() + optimizer_f.step() + # Step C train genrator to minimize discrepancy + for i in range(num_k): + optimizer_g.zero_grad() + output = G(data) + output1 = F1(output) + output2 = F2(output) + + output_s1 = output1[:batch_size,:] + output_s2 = output2[:batch_size,:] + output_t1 = output1[batch_size:,:] + output_t2 = output2[batch_size:,:] + + loss1 = criterion(output_s1, target1) + loss2 = criterion(output_s2, target1) + output_t1 = F.softmax(output_t1) + output_t2 = F.softmax(output_t2) + loss_dis = torch.mean(torch.abs(output_t1-output_t2)) + entropy_loss = -torch.mean(torch.log(torch.mean(output_t1,0)+1e-6)) + entropy_loss -= torch.mean(torch.log(torch.mean(output_t2,0)+1e-6)) + + loss_dis.backward() + optimizer_g.step() + if batch_idx % args.log_interval == 0: + print('Train Ep: {} [{}/{} ({:.0f}%)]\tLoss1: {:.6f}\tLoss2: {:.6f}\t Dis: {:.6f} Entropy: {:.6f}'.format( + ep, batch_idx * len(data), 70000, + 100. * batch_idx / 70000, loss1.data[0],loss2.data[0],loss_dis.data[0],entropy_loss.data[0])) + if batch_idx == 1 and ep >1: + test(ep) + G.train() + F1.train() + F2.train() + +def test(epoch): + G.eval() + F1.eval() + F2.eval() + test_loss = 0 + correct = 0 + correct2 = 0 + size = 0 + + for batch_idx, data in enumerate(dataset_test): + if batch_idx*batch_size > 5000: + break + if args.cuda: + data2 = data['T'] + target2 = data['T_label'] + if val: + data2 = data['S'] + target2 = data['S_label'] + data2, target2 = data2.cuda(), target2.cuda() + data1, target1 = Variable(data2, volatile=True), Variable(target2) + output = G(data1) + output1 = F1(output) + output2 = F2(output) + test_loss += F.nll_loss(output1, target1).data[0] + pred = output1.data.max(1)[1] # get the index of the max log-probability + correct += pred.eq(target1.data).cpu().sum() + pred = output2.data.max(1)[1] # get the index of the max log-probability + k = target1.data.size()[0] + correct2 += pred.eq(target1.data).cpu().sum() + + size += k + test_loss = test_loss + test_loss /= len(test_loader) # loss function already averages over batch size + print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%) ({:.0f}%)\n'.format( + test_loss, correct, size, + 100. * correct / size,100.*correct2/size)) + #if 100. * correct / size > 67 or 100. * correct2 / size > 67: + value = max(100. * correct / size,100. * correct2 / size) + if not val and value > 60: + torch.save(F1.state_dict(), save_path+'_'+args.resnet+'_'+str(value)+'_'+'F1.pth') + torch.save(F2.state_dict(), save_path+'_'+args.resnet+'_'+str(value)+'_'+'F2.pth') + torch.save(G.state_dict(), save_path+'_'+args.resnet+'_'+str(value)+'_'+'G.pth') + + +#for epoch in range(1, args.epochs + 1): +train(args.epochs+1) diff --git a/visda_classification/taskcv_loader.py b/visda_classification/taskcv_loader.py new file mode 100644 index 0000000..075837e --- /dev/null +++ b/visda_classification/taskcv_loader.py @@ -0,0 +1,85 @@ +import random +import torch.utils.data +import torchvision.transforms as transforms +#import torchnet as tnt +# pip install future --upgrade +from builtins import object +from pdb import set_trace as st +import torch.utils.data as data_utils +class PairedData(object): + def __init__(self, data_loader_A, data_loader_B, max_dataset_size, flip): + self.data_loader_A = data_loader_A + self.data_loader_B = data_loader_B + self.stop_A = False + self.stop_B = False + self.max_dataset_size = max_dataset_size + self.flip = flip + + def __iter__(self): + self.stop_A = False + self.stop_B = False + self.data_loader_A_iter = iter(self.data_loader_A) + self.data_loader_B_iter = iter(self.data_loader_B) + self.iter = 0 + return self + + def __next__(self): + A, A_paths = None, None + B, B_paths = None, None + try: + A, A_paths = next(self.data_loader_A_iter) + except StopIteration: + if A is None or A_paths is None: + self.stop_A = True + self.data_loader_A_iter = iter(self.data_loader_A) + A, A_paths = next(self.data_loader_A_iter) + + try: + B, B_paths = next(self.data_loader_B_iter) + except StopIteration: + if B is None or B_paths is None: + self.stop_B = True + self.data_loader_B_iter = iter(self.data_loader_B) + B, B_paths = next(self.data_loader_B_iter) + + if (self.stop_A and self.stop_B) or self.iter > self.max_dataset_size: + self.stop_A = False + self.stop_B = False + raise StopIteration() + else: + self.iter += 1 + if self.flip and random.random() < 0.5: + idx = [i for i in range(A.size(3) - 1, -1, -1)] + idx = torch.LongTensor(idx) + A = A.index_select(3, idx) + B = B.index_select(3, idx) + return {'S': A, 'S_label': A_paths, + 'T': B, 'T_label': B_paths} + +class CVDataLoader(object): + def initialize(self, dataset_A,dataset_B,batch_size,shuffle=True): + #normalize = transforms.Normalize(mean=mean_im,std=std_im) + self.max_dataset_size = float("inf") + data_loader_A = torch.utils.data.DataLoader( + dataset_A, + batch_size=batch_size, + shuffle=shuffle, + num_workers=4) + data_loader_B = torch.utils.data.DataLoader( + dataset_B, + batch_size=batch_size, + shuffle=shuffle, + num_workers=4) + self.dataset_A = dataset_A + self.dataset_B = dataset_B + flip = False + self.paired_data = PairedData(data_loader_A, data_loader_B, self.max_dataset_size, flip) + + def name(self): + return 'UnalignedDataLoader' + + def load_data(self): + return self.paired_data + + def __len__(self): + return min(max(len(self.dataset_A), len(self.dataset_B)), self.opt.max_dataset_size) diff --git a/visda_classification/utils.py b/visda_classification/utils.py new file mode 100644 index 0000000..8f4fdb3 --- /dev/null +++ b/visda_classification/utils.py @@ -0,0 +1,36 @@ +import numpy as np +import torch +from numpy.random import * +from torch.autograd import Variable +from torch.nn.modules.batchnorm import BatchNorm2d, BatchNorm1d, BatchNorm3d + + +def textread(path): + # if not os.path.exists(path): + # print path, 'does not exist.' + # return False + f = open(path) + lines = f.readlines() + f.close() + for i in range(len(lines)): + lines[i] = lines[i].replace('\n', '') + return lines + +def adjust_learning_rate(optimizer, epoch,lr=0.001): + """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" + lr = lr * 0.99#min(1, 2 - epoch/float(20))#0.95 best + for param_group in optimizer.param_groups: + param_group['lr'] = lr + return lr +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + m.weight.data.normal_(0.0, 0.01) + m.bias.data.normal_(0.0, 0.01) + elif classname.find('BatchNorm') != -1: + m.weight.data.normal_(1.0, 0.01) + m.bias.data.fill_(0) + elif classname.find('Linear') != -1: + m.weight.data.normal_(0.0, 0.01) + m.bias.data.normal_(0.0, 0.01) +