diff --git a/README.md b/README.md index cbffdc8..0aaf5bb 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,42 @@ -Implementation of "CLIF: Complementary Leaky Integrate-and-Fire Neuron for Spiking Neural Networks" - -Paper: https://arxiv.org/pdf/2402.04663.pdf +# CLIF + +## Dependencies +- Python 3 +- PyTorch, torchvision +- spikingjelly 0.0.0.0.12 +- Python packages: `pip install tqdm progress torchtoolbox thop` + + +## Training +We use single GTX4090 GPU for running all the experiments. Multi-GPU training is not supported in the current codes. + + +### Setup +CIFAR-10, CIFAR-100, Tiny-Imagenet, DVS-CIFAR10, and DVS-Gesture: + + # CIFAR-10 + python train_BPTT.py -data_dir ./data_dir -dataset cifar10 -model spiking_resnet18 -T_max 200 -epochs 200 -weight_decay 5e-5 -neuron CLIF + + # CIFAR-100 + python train_BPTT.py -data_dir ./data_dir -dataset cifar100 -model spiking_resnet18 -T_max 200 -epochs 200 -neuron CLIF + + # Tiny-Imagenet + python train_BPTT.py -data_dir ./data_dir -dataset tiny_imagenet -model spiking_vgg13_bn -neuron CLIF + + # DVS-CIFAR10 + python train_BPTT.py -data_dir ./data_dir -dataset DVSCIFAR10 -T 10 -drop_rate 0.3 -model spiking_vgg11_bn -lr=0.05 -mse_n_reg -neuron CLIF + + # DVS-Gesture + python train_BPTT.py -data_dir ./data_dir -dataset dvsgesture -model spiking_vgg11_bn -T 20 -b 16 -drop_rate 0.4 -neuron CLIF + +If changing neuron, you can change hyperparameters to ``LIF`` or ``PLIF`` directly after ``-neuron``. + +For example to setup LIF neuron for CIFAR-10 task: + + # LIF neuron for CIFAR-10 + python train_BPTT.py -data_dir ./data_dir -dataset cifar10 -model spiking_resnet18 -amp -T_max 200 -epochs 200 -weight_decay 5e-5 -neuron LIF + + + +## Inference +The inference setup could refer file: ``run_inference_script`` diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..1c88514 --- /dev/null +++ b/inference.py @@ -0,0 +1,538 @@ +import argparse +import collections +import datetime +import os +import random +import time + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data as data +import torchvision +import torchvision.datasets as datasets +import torchvision.transforms as transforms +from spikingjelly.clock_driven import functional, surrogate as surrogate_sj +from spikingjelly.datasets.dvs128_gesture import DVS128Gesture +from torch.cuda import amp +from torch.utils.data import DataLoader +from torch.utils.data.dataloader import default_collate +from torch.utils.tensorboard import SummaryWriter +from torchtoolbox.transform import Cutout +from torchvision.transforms import autoaugment, transforms +from torchvision.transforms.functional import InterpolationMode + +from models import spiking_resnet, vgg_model, spiking_vgg_bn +from modules import neuron +from modules import surrogate as surrogate_self +from utils import Bar, AverageMeter, accuracy, static_cifar_util, augmentation +from utils.augmentation import ToPILImage, Resize, ToTensor +from utils.cifar10_dvs import CIFAR10DVS, DVSCifar10 +from utils.data_loaders import TinyImageNet +from thop import profile + +# from torchtoolbox.transform import Cutout + +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + + +def main(): + parser = argparse.ArgumentParser(description='SNN training') + parser.add_argument('-seed', default=2022, type=int) + parser.add_argument('-name', default='', type=str, help='specify a name for the checkpoint and log files') + parser.add_argument('-T', default=6, type=int, help='simulating time-steps') + parser.add_argument('-tau', default=2.0, type=float, help='a hyperparameter for the LIF model') + parser.add_argument('-b', default=128, type=int, help='batch size') + + parser.add_argument('-j', default=0, type=int, metavar='N', help='number of data loading workers (default: 4)') + parser.add_argument('-data_dir', type=str, default='./data', help='directory of the used dataset') + parser.add_argument('-dataset', default='cifar10', type=str, + help='should be cifar10, cifar100, DVSCIFAR10, dvsgesture, or imagenet') + parser.add_argument('-out_dir', type=str, default='./logs_infer', help='root dir for saving logs and checkpoint') + parser.add_argument('-surrogate', default='rectangle', type=str, + help='used surrogate function. should be sigmoid, rectangle, or triangle') + parser.add_argument('-resume', type=str, help='resume from the checkpoint path') + parser.add_argument('-pre_train', type=str, help='load a pretrained model. used for imagenet') + parser.add_argument('-amp', default=True, type=bool, help='automatic mixed precision training') + + parser.add_argument('-model', type=str, default='spiking_vgg11_bn', help='use which SNN model') + parser.add_argument('-drop_rate', type=float, default=0.0, help='dropout rate. used for DVSCIFAR10') + parser.add_argument('-weight_decay', type=float, default=5e-4) + parser.add_argument('-loss_lambda', type=float, default=0.05, help='the scaling factor for the MSE term in the loss') + parser.add_argument('-mse_n_reg', action='store_true', help='loss function setting') + parser.add_argument('-loss_means', type=float, default=1.0, help='used in the loss function when mse_n_reg=False') + parser.add_argument('-save_init', action='store_true', help='save the initialization of parameters') + parser.add_argument('-neuron_model', type=str, default='LIF', help='save the initialization of parameters') + parser.add_argument('-multiple_step', type=bool, default=False, help='whether multiple steps') + parser.add_argument('-cutupmix_auto', action='store_true', help='cutupmix autoaugmentation for cifar and tinyimagenet') + parser.add_argument('-label_smoothing', type=float, default=0.0, help='label_smoothing for cross entropy') + + args = parser.parse_args() + print(args) + + _seed_ = args.seed + random.seed(_seed_) + torch.manual_seed(_seed_) # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA) + torch.cuda.manual_seed_all(_seed_) + np.random.seed(_seed_) + + ########################################################## + # data loading + ########################################################## + in_dim = None + c_in = None + if args.dataset == 'cifar10' or args.dataset == 'cifar100': + + c_in = 3 + if args.dataset == 'cifar10': + dataloader = torchvision.datasets.CIFAR10 + num_classes = 10 + normalization_mean = (0.4914, 0.4822, 0.4465) + normalization_std = (0.2023, 0.1994, 0.2010) + elif args.dataset == 'cifar100': + dataloader = torchvision.datasets.CIFAR100 + num_classes = 100 + normalization_mean = (0.5071, 0.4867, 0.4408) + normalization_std = (0.2675, 0.2565, 0.2761) + else: + raise NotImplementedError + + if args.cutupmix_auto: + mixup_transforms = [] + mixup_transforms.append(static_cifar_util.RandomMixup(num_classes, p=1.0, alpha=0.2)) + mixup_transforms.append(static_cifar_util.RandomCutmix(num_classes, p=1.0, alpha=1.)) + mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms) + collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731 + + transform_train = static_cifar_util.ClassificationPresetTrain(mean=normalization_mean, + std=normalization_std, + interpolation=InterpolationMode('bilinear'), + auto_augment_policy='ta_wide', + random_erase_prob=0.1) + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(normalization_mean, normalization_std), + ]) + + train_set = dataloader( + root=args.data_dir, + train=True, + transform=transform_train, + download=False, ) + + test_set = dataloader( + root=args.data_dir, + train=False, + transform=transform_test, + download=False) + + train_data_loader = torch.utils.data.DataLoader( + dataset=train_set, + batch_size=args.b, + collate_fn=collate_fn, + shuffle=True, + drop_last=True, + num_workers=args.j, + pin_memory=True + ) + + test_data_loader = torch.utils.data.DataLoader( + dataset=test_set, + batch_size=args.b, + shuffle=False, + drop_last=False, + num_workers=args.j, + pin_memory=True + ) + else: + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + Cutout(), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(normalization_mean, normalization_std), + ]) + + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(normalization_mean, normalization_std), + ]) + + trainset = dataloader(root=args.data_dir, train=True, download=True, transform=transform_train) + train_data_loader = DataLoader(trainset, batch_size=args.b, shuffle=True, num_workers=args.j) + + testset = dataloader(root=args.data_dir, train=False, download=False, transform=transform_test) + test_data_loader = DataLoader(testset, batch_size=args.b, shuffle=False, num_workers=args.j) + + elif args.dataset == 'DVSCIFAR10': + c_in = 2 + num_classes = 10 + + transform_train = transforms.Compose([ + ToPILImage(), + Resize(48), + # augmentation.Cutout(), + # augmentation.RandomSizedCrop(48), + augmentation.RandomHorizontalFlip(), + augmentation.RandomRotation(), + ToTensor(), + + ]) + + transform_test = transforms.Compose([ + ToPILImage(), + Resize(48), + ToTensor(), + ]) + + trainset = CIFAR10DVS(args.data_dir, train=True, use_frame=True, frames_num=args.T, split_by='number', + normalization=None, transform=transform_train) + testset = CIFAR10DVS(args.data_dir, train=False, use_frame=True, frames_num=args.T, split_by='number', + normalization=None, transform=transform_test) + + train_data_loader = DataLoader(trainset, batch_size=args.b, shuffle=True, num_workers=args.j) + test_data_loader = DataLoader(testset, batch_size=args.b, shuffle=False, num_workers=args.j) + + elif args.dataset == 'DVSCIFAR10-pt': + c_in = 2 + num_classes = 10 + in_dim = 48 + train_path = args.data_dir + '/train' + val_path = args.data_dir + '/test' + trainset = DVSCifar10(root=train_path, transform=True) + testset = DVSCifar10(root=val_path, transform=False) + train_data_loader = DataLoader(trainset, batch_size=args.b, shuffle=True, num_workers=args.j) + test_data_loader = DataLoader(testset, batch_size=args.b, shuffle=False, num_workers=args.j, drop_last=False, + pin_memory=True) + + elif args.dataset == 'dvsgesture': + c_in = 2 + num_classes = 11 + in_dim = 128 + + trainset = DVS128Gesture(root=args.data_dir, train=True, data_type='frame', frames_number=args.T, + split_by='number') + train_data_loader = DataLoader(trainset, batch_size=args.b, shuffle=True, num_workers=args.j, drop_last=True, + pin_memory=True) + + testset = DVS128Gesture(root=args.data_dir, train=False, data_type='frame', frames_number=args.T, + split_by='number') + test_data_loader = DataLoader(testset, batch_size=args.b, shuffle=False, num_workers=args.j, drop_last=False, + pin_memory=True) + + elif args.dataset == 'tiny_imagenet': + data_dir = args.data_dir + c_in = 3 + num_classes = 200 + normalize = transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]) + + transoform_list = [ + transforms.RandomCrop(64), + transforms.RandomHorizontalFlip(0.5), + ] + + if args.cutupmix_auto: + transoform_list.append(autoaugment.AutoAugment()) + + transoform_list += [transforms.ToTensor(), normalize] + + train_transforms = transforms.Compose(transoform_list) + val_transforms = transforms.Compose([transforms.ToTensor(), normalize, ]) + + train_data = TinyImageNet(data_dir, train=True, transform=train_transforms) + test_data = TinyImageNet(data_dir, train=False, transform=val_transforms) + + train_data_loader = torch.utils.data.DataLoader( + train_data, + batch_size=args.b, shuffle=True, + num_workers=args.j, pin_memory=True) + + test_data_loader = torch.utils.data.DataLoader( + test_data, + batch_size=args.b, shuffle=False, + num_workers=args.j, pin_memory=True) + else: + raise NotImplementedError + + ########################################################## + # model preparing + ########################################################## + if args.surrogate == 'sigmoid': + surrogate_function = surrogate_sj.Sigmoid() + elif args.surrogate == 'rectangle': + surrogate_function = surrogate_self.Rectangle() + elif args.surrogate == 'triangle': + surrogate_function = surrogate_sj.PiecewiseQuadratic() + else: + raise NotImplementedError + + if args.neuron_model == 'LIF': + neuron_model = neuron.BPTTNeuron + elif args.neuron_model == 'CLIF': + neuron_model = neuron.ComplementaryLIFNeuron + elif args.neuron_model == 'PLIF': + neuron_model = neuron.PLIFNeuron + elif args.neuron_model == 'relu': + neuron_model = neuron.ReLU + args.T = 1 + else: + raise NotImplementedError + + if args.model in ['spiking_resnet18', 'spiking_resnet34', 'spiking_resnet50', 'spiking_resnet101', + 'spiking_resnet152']: + net = spiking_resnet.__dict__[args.model](neuron=neuron_model, num_classes=num_classes, + neuron_dropout=args.drop_rate, + tau=args.tau, surrogate_function=surrogate_function, c_in=c_in, + fc_hw=1) + print('using Resnet model.') + elif args.model in ['spiking_vgg11_bn', 'spiking_vgg13_bn', 'spiking_vgg16_bn', 'spiking_vgg19_bn']: + net = spiking_vgg_bn.__dict__[args.model](neuron=neuron_model, num_classes=num_classes, + neuron_dropout=args.drop_rate, + tau=args.tau, surrogate_function=surrogate_function, c_in=c_in, + fc_hw=in_dim if in_dim else None) + print('using Spiking VGG model.') + elif args.model in ['vggsnn', 'snn5_noAP']: # snn5_noAP use for statistical experiment + net = vgg_model.__dict__[args.model](neuron=neuron_model, num_classes=num_classes, + neuron_dropout=args.drop_rate, + tau=args.tau, surrogate_function=surrogate_function, c_in=c_in, + fc_hw=in_dim if in_dim else None) + print('using Spiking VGG model.') + else: + raise NotImplementedError + + print('Total Parameters: %.2fM' % (sum(p.numel() for p in net.parameters()) / 1000000.0)) + net.cuda() + # + # ########################################################## + # # optimizer preparing + # ########################################################## + # if args.opt == 'SGD': + # optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, + # weight_decay=args.weight_decay) + # elif args.opt == 'AdamW': + # optimizer = torch.optim.AdamW(net.parameters(), lr=args.lr, weight_decay=args.weight_decay) + # else: + # raise NotImplementedError(args.opt) + # + # if args.lr_scheduler == 'StepLR': + # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) + # elif args.lr_scheduler == 'CosALR': + # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.T_max) + # else: + # raise NotImplementedError(args.lr_scheduler) + # + # scaler = None + # if args.amp: + # scaler = amp.GradScaler() + + ########################################################## + # loading models from checkpoint + ########################################################## + + max_test_acc = 0 + + if args.resume: + print('resuming...') + checkpoint = torch.load(args.resume, map_location='cpu') + net.load_state_dict(checkpoint['net']) + # optimizer.load_state_dict(checkpoint['optimizer']) + # lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + start_epoch = checkpoint['epoch'] + 1 + max_test_acc = checkpoint['max_test_acc'] + print('start epoch:', start_epoch, ', max test acc:', max_test_acc) + + if args.pre_train: + checkpoint = torch.load(args.pre_train, map_location='cpu') + state_dict2 = collections.OrderedDict([(k, v) for k, v in checkpoint['net'].items()]) + net.load_state_dict(state_dict2) + print('use pre-trained model, max test acc:', checkpoint['max_test_acc']) + + ########################################################## + # output setting + ########################################################## + out_dir = os.path.join(args.out_dir, + f'inference_{args.dataset}_{args.model}_{args.name}_T{args.T}_tau{args.tau}_bs{args.b}') + + if args.neuron_model != 'LIF': + out_dir += f'_{args.neuron_model}_' + + # if args.lr_scheduler == 'CosALR': + # out_dir += f'CosALR_{args.T_max}' + # elif args.lr_scheduler == 'StepLR': + # out_dir += f'StepLR_{args.step_size}_{args.gamma}' + # else: + # raise NotImplementedError(args.lr_scheduler) + + if args.amp: + out_dir += '_amp' + + if args.cutupmix_auto: + out_dir += '_cutupmix_auto' + + if not os.path.exists(out_dir): + os.makedirs(out_dir) + print(f'Mkdir {out_dir}.') + else: + print(out_dir) + + # save the initialization of parameters + if args.save_init: + checkpoint = { + 'net': net.state_dict(), + 'epoch': 0, + 'max_test_acc': 0.0 + } + torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_0.pth')) + + with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt: + args_txt.write(str(args)) + + ########################################################## + # testing + ########################################################## + criterion_mse = nn.MSELoss() + + start_time = time.time() + net.eval() + + batch_time = AverageMeter() + data_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + end = time.time() + bar = Bar('Processing', max=len(test_data_loader)) + + test_loss = 0 + test_acc = 0 + test_samples = 0 + batch_idx = 0 + with torch.no_grad(): + for data in test_data_loader: + if args.dataset == 'SHD': + frame, label, _ = data + else: + frame, label = data + + batch_idx += 1 + if (args.dataset != 'DVSCIFAR10'): + frame = frame.float().cuda() + + if args.dataset == 'dvsgesture' or args.dataset == "SHD" or args.dataset == "DVSCIFAR10-pt": + frame = frame.transpose(0, 1) + label = label.cuda() + + t_step = args.T + if args.dataset == 'SHD': + t_step = len(frame) + + label_real = torch.cat([label for _ in range(t_step)], 0) + # print(t_step) + + out_all = [] + for t in range(t_step): + + if (args.dataset == 'DVSCIFAR10'): + input_frame = frame[t].float().cuda() + elif args.dataset == 'dvsgesture' or args.dataset == "SHD" or args.dataset == "DVSCIFAR10-pt": + input_frame = frame[t] + else: + input_frame = frame + if t == 0: + out_fr = net(input_frame) + total_fr = out_fr.clone().detach() + out_all.append(out_fr) + else: + out_fr = net(input_frame) + total_fr += out_fr.clone().detach() + out_all.append(out_fr) + + out_all = torch.cat(out_all, 0) + # Calculate the loss + if args.loss_lambda > 0.0: # the loss is a cross entropy term plus a mse term + if args.mse_n_reg: # the mse term is not treated as a regularizer + label_one_hot = F.one_hot(label_real, num_classes).float() + else: + label_one_hot = torch.zeros_like(out_all).fill_(args.loss_means).to(out_all.device) + mse_loss = criterion_mse(out_all, label_one_hot) + loss = ((1 - args.loss_lambda) * F.cross_entropy(out_all, label_real, + label_smoothing=args.label_smoothing) + args.loss_lambda * mse_loss) + else: # the loss is just a cross entropy term + loss = F.cross_entropy(out_all, label_real, label_smoothing=args.label_smoothing) + total_loss = loss + + test_samples += label.numel() + test_loss += total_loss.item() * label.numel() + test_acc += (total_fr.argmax(1) == label).float().sum().item() + functional.reset_net(net) + + # measure accuracy and record loss + prec1, prec5 = accuracy(total_fr.data, label.data, topk=(1, 5)) + losses.update(total_loss, input_frame.size(0)) + top1.update(prec1.item(), input_frame.size(0)) + top5.update(prec5.item(), input_frame.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + # plot progress + bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( + batch=batch_idx, + size=len(test_data_loader), + data=data_time.avg, + bt=batch_time.avg, + total=bar.elapsed_td, + eta=bar.eta_td, + loss=losses.avg, + top1=top1.avg, + top5=top5.avg, + ) + bar.next() + bar.finish() + + test_loss /= test_samples + test_acc /= test_samples + + ############### calculation ############### + + total_time = time.time() - start_time + info = f'test_loss={test_loss}, test_acc={test_acc}, max_test_acc={max_test_acc}, total_time={total_time}' + print(info) + mem_cost = "after one epoch: %fGB" % (torch.cuda.max_memory_cached(0) / 1024 / 1024 / 1024) + print(mem_cost) + + + + B, C, H, W = input_frame.shape + optimal_batch_size = B + dummy_input = torch.randn(optimal_batch_size, C, H, W, dtype=torch.float).cuda() + + repetitions = 100 + total_time = 0 + with torch.no_grad(): + for rep in range(repetitions): + starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + starter.record() + _ = net(dummy_input) + ender.record() + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) / 1000 + total_time += curr_time + Throughput = (repetitions * optimal_batch_size) / total_time + + print("Final Throughput:", Throughput) + with open(os.path.join(out_dir, 'args.txt'), 'a+', encoding='utf-8') as args_txt: + args_txt.write("\n") + args_txt.write(info + "\n") + args_txt.write(mem_cost + "\n") + + args_txt.write("Throughput" + "\n") + args_txt.write(str(Throughput) + "\n") + + + +if __name__ == '__main__': + main() diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/spiking_resnet.py b/models/spiking_resnet.py new file mode 100644 index 0000000..1251b62 --- /dev/null +++ b/models/spiking_resnet.py @@ -0,0 +1,231 @@ +import torch.nn as nn +from spikingjelly.clock_driven import layer + +__all__ = [ + 'PreActResNet', 'spiking_resnet18', 'spiking_resnet34', 'spiking_resnet50', 'spiking_resnet101', 'spiking_resnet152' +] + + +class PreActBlock(nn.Module): + '''Pre-activation version of the BasicBlock.''' + expansion = 1 + + def __init__(self, in_channels, out_channels, stride, dropout, neuron: callable = None, **kwargs): + super(PreActBlock, self).__init__() + whether_bias = True + self.bn1 = nn.BatchNorm2d(in_channels) + + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=whether_bias) + self.bn2 = nn.BatchNorm2d(out_channels) + + self.dropout = layer.Dropout(dropout) + self.conv2 = nn.Conv2d(out_channels, self.expansion * out_channels, kernel_size=3, stride=1, padding=1, + bias=whether_bias) + + if stride != 1 or in_channels != self.expansion * out_channels: + self.shortcut = nn.Conv2d(in_channels, self.expansion * out_channels, kernel_size=1, stride=stride, + padding=0, bias=whether_bias) + else: + self.shortcut = nn.Sequential() + + self.relu1 = neuron(**kwargs) + self.relu2 = neuron(**kwargs) + + def forward(self, x): + x = self.relu1(self.bn1(x)) + out = self.conv1(x) + out = self.conv2(self.dropout(self.relu2(self.bn2(out)))) + out = out + self.shortcut(x) + return out + + +class PreActBottleneck(nn.Module): + '''Pre-activation version of the original Bottleneck module.''' + expansion = 4 + + def __init__(self, in_channels, out_channels, stride, dropout, neuron: callable = None, **kwargs): + super(PreActBottleneck, self).__init__() + whether_bias = True + + self.bn1 = nn.BatchNorm2d(in_channels) + + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=whether_bias) + self.bn2 = nn.BatchNorm2d(out_channels) + + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=whether_bias) + self.bn3 = nn.BatchNorm2d(out_channels) + self.dropout = layer.Dropout(dropout) + self.conv3 = nn.Conv2d(out_channels, self.expansion * out_channels, kernel_size=1, stride=1, padding=0, + bias=whether_bias) + + if stride != 1 or in_channels != self.expansion * out_channels: + self.shortcut = nn.Conv2d(in_channels, self.expansion * out_channels, kernel_size=1, stride=stride, + padding=0, bias=whether_bias) + else: + self.shortcut = nn.Sequential() + + self.relu1 = neuron(**kwargs) + self.relu2 = neuron(**kwargs) + self.relu3 = neuron(**kwargs) + + def forward(self, x): + x = self.relu1(self.bn1(x)) + + out = self.conv1(x) + out = self.conv2(self.relu2(self.bn2(out))) + out = self.conv3(self.dropout(self.relu3(self.bn3(out)))) + + out = out + self.shortcut(x) + + return out + + +class PreActResNet(nn.Module): + + def __init__(self, block, num_blocks, num_classes, dropout, neuron: callable = None, **kwargs): + super(PreActResNet, self).__init__() + self.num_blocks = num_blocks + + self.data_channels = kwargs.get('c_in', 3) + self.init_channels = 64 + self.conv1 = nn.Conv2d(self.data_channels, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.layer1 = self._make_layer(block, 64, num_blocks[0], 1, dropout, neuron, **kwargs) + self.layer2 = self._make_layer(block, 128, num_blocks[1], 2, dropout, neuron, **kwargs) + self.layer3 = self._make_layer(block, 256, num_blocks[2], 2, dropout, neuron, **kwargs) + self.layer4 = self._make_layer(block, 512, num_blocks[3], 2, dropout, neuron, **kwargs) + + self.bn1 = nn.BatchNorm2d(512 * block.expansion) + self.pool = nn.AvgPool2d(4) + self.flat = nn.Flatten() + self.drop = layer.Dropout(dropout) + self.linear = nn.Linear(512 * block.expansion, num_classes) + + self.relu1 = neuron(**kwargs) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, val=1) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.zeros_(m.bias) + + def _make_layer(self, block, out_channels, num_blocks, stride, dropout, neuron, **kwargs): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.init_channels, out_channels, stride, dropout, neuron, **kwargs)) + self.init_channels = out_channels * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = self.conv1(x) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = self.pool(self.relu1(self.bn1(out))) + out = self.drop(self.flat(out)) + out = self.linear(out) + return out + + +# class Bottleneck(nn.Module): +# expansion = 4 +# +# def __init__(self, in_planes, planes, stride=1, bn_type='', **kwargs_spikes): +# super(Bottleneck, self).__init__() +# self.kwargs_spikes = kwargs_spikes +# self.nb_steps = kwargs_spikes['nb_steps'] +# self.conv1 = tdLayer(nn.Conv2d(in_planes, planes, kernel_size=1, bias=False), self.nb_steps) +# self.bn1 = warpBN(planes, bn_type, self.nb_steps) +# self.spike1 = LIFLayer(**kwargs_spikes) +# self.conv2 = tdLayer(nn.Conv2d(planes, planes, kernel_size=3, +# stride=stride, padding=1, bias=False), self.nb_steps) +# self.bn2 = warpBN(planes, bn_type, self.nb_steps) +# self.spike2 = LIFLayer(**kwargs_spikes) +# self.conv3 = tdLayer(nn.Conv2d(planes, self.expansion * +# planes, kernel_size=1, bias=False), self.nb_steps) +# self.bn3 = warpBN(self.expansion * +# planes, bn_type, self.nb_steps) +# +# self.shortcut = nn.Sequential() +# if stride != 1 or in_planes != self.expansion * planes: +# self.shortcut = nn.Sequential( +# tdLayer(nn.Conv2d(in_planes, self.expansion * planes, +# kernel_size=1, stride=stride, bias=False), self.nb_steps), +# warpBN(self.expansion * planes, bn_type, self.nb_steps) +# ) +# self.spike3 = LIFLayer(**kwargs_spikes) +# +# def forward(self, x): +# out = self.spike1(self.bn1(self.conv1(x))) +# out = self.spike2(self.bn2(self.conv2(out))) +# out = self.bn3(self.conv3(out)) +# out += self.shortcut(x) +# out = self.spike3(out) +# return out +# +# +# class ResNet19(nn.Module): +# def __init__(self, block, num_block_layers, num_classes=10, in_channel=3, bn_type='', **kwargs_spikes): +# super(ResNet19, self).__init__() +# self.in_planes = 128 +# self.bn_type = bn_type +# self.kwargs_spikes = kwargs_spikes +# self.nb_steps = kwargs_spikes['nb_steps'] +# self.conv0 = nn.Sequential( +# tdLayer(nn.Conv2d(in_channel, self.in_planes, kernel_size=3, padding=1, stride=1, bias=False), +# nb_steps=self.nb_steps), +# warpBN(self.in_planes, bn_type, self.nb_steps), +# LIFLayer(**kwargs_spikes) +# ) +# self.layer1 = self._make_layer(block, 128, num_block_layers[0], stride=1) +# self.layer2 = self._make_layer(block, 256, num_block_layers[1], stride=2) +# self.layer3 = self._make_layer(block, 512, num_block_layers[2], stride=2) +# self.avg_pool = tdLayer(nn.AdaptiveAvgPool2d((1, 1)), nb_steps=self.nb_steps) +# self.classifier = nn.Sequential( +# tdLayer(nn.Linear(512 * block.expansion, 256, bias=False), nb_steps=self.nb_steps), +# LIFLayer(**kwargs_spikes), +# tdLayer(nn.Linear(256, num_classes, bias=False), nb_steps=self.nb_steps), +# Readout() +# ) +# +# def _make_layer(self, block, planes, num_blocks, stride): +# strides = [stride] + [1] * (num_blocks - 1) +# layers = [] +# for stride in strides: +# layers.append(block(self.in_planes, planes, stride, self.bn_type, **self.kwargs_spikes)) +# self.in_planes = planes * block.expansion +# return nn.Sequential(*layers) +# +# def forward(self, x): +# out, _ = torch.broadcast_tensors(x, torch.zeros((self.nb_steps,) + x.shape)) +# out = self.conv0(out) +# out = self.layer1(out) +# out = self.layer2(out) +# out = self.layer3(out) +# out = self.avg_pool(out) +# out = out.view(out.shape[0], out.shape[1], -1) +# out = self.classifier(out) +# return out + +def spiking_resnet18(neuron: callable = None, num_classes=10, neuron_dropout=0, **kwargs): + return PreActResNet(PreActBlock, [2, 2, 2, 2], num_classes, neuron_dropout, neuron=neuron, **kwargs) + + +def spiking_resnet34(neuron: callable = None, num_classes=10, neuron_dropout=0, **kwargs): + return PreActResNet(PreActBlock, [3, 4, 6, 3], num_classes, neuron_dropout, neuron=neuron, **kwargs) + + +def spiking_resnet50(neuron: callable = None, num_classes=10, neuron_dropout=0, **kwargs): + return PreActResNet(PreActBottleneck, [3, 4, 6, 3], num_classes, neuron_dropout, neuron=neuron, **kwargs) + + +def spiking_resnet101(neuron: callable = None, num_classes=10, neuron_dropout=0, **kwargs): + return PreActResNet(PreActBottleneck, [3, 4, 23, 3], num_classes, neuron_dropout, neuron=neuron, **kwargs) + + +def spiking_resnet152(neuron: callable = None, num_classes=10, neuron_dropout=0, **kwargs): + return PreActResNet(PreActBottleneck, [3, 8, 36, 3], num_classes, neuron_dropout, neuron=neuron, **kwargs) diff --git a/models/spiking_vgg_bn.py b/models/spiking_vgg_bn.py new file mode 100644 index 0000000..571bbef --- /dev/null +++ b/models/spiking_vgg_bn.py @@ -0,0 +1,115 @@ +import torch.nn as nn +from spikingjelly.clock_driven import layer + +__all__ = [ + 'SpikingVGGBN', 'spiking_vgg11_bn', 'spiking_vgg13_bn', 'spiking_vgg16_bn', 'spiking_vgg19_bn' +] + +cfg = { + + 'VGG11': [ + [64, 'M'], + [128, 'M'], + [256, 256, 'M'], + [512, 512, 'M'], + [512, 512, 'M'] + ], + 'VGG13': [ + [64, 64, 'M'], + [128, 128, 'M'], + [256, 256, 'M'], + [512, 512, 'M'], + [512, 512, 'M'] + ], + 'VGG16': [ + [64, 64, 'M'], + [128, 128, 'M'], + [256, 256, 256, 'M'], + [512, 512, 512, 'M'], + [512, 512, 512, 'M'] + ], + 'VGG19': [ + [64, 64, 'M'], + [128, 128, 'M'], + [256, 256, 256, 256, 'M'], + [512, 512, 512, 512, 'M'], + [512, 512, 512, 512, 'M'] + ] +} + + +class SpikingVGGBN(nn.Module): + def __init__(self, vgg_name, neuron: callable = None, dropout=0.0, num_classes=10, **kwargs): + super(SpikingVGGBN, self).__init__() + self.whether_bias = True + self.init_channels = kwargs.get('c_in', 2) + + self.layer1 = self._make_layers(cfg[vgg_name][0], dropout, neuron, **kwargs) + self.layer2 = self._make_layers(cfg[vgg_name][1], dropout, neuron, **kwargs) + self.layer3 = self._make_layers(cfg[vgg_name][2], dropout, neuron, **kwargs) + self.layer4 = self._make_layers(cfg[vgg_name][3], dropout, neuron, **kwargs) + self.layer5 = self._make_layers(cfg[vgg_name][4], dropout, neuron, **kwargs) + + self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) + + self.classifier = nn.Sequential( + nn.Flatten(), + nn.Linear(512 * 7 * 7, num_classes), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + def _make_layers(self, cfg, dropout, neuron, **kwargs): + layers = [] + for x in cfg: + if x == 'M': + layers.append(nn.AvgPool2d(kernel_size=2, stride=2)) + else: + layers.append(nn.Conv2d(self.init_channels, x, kernel_size=3, padding=1, bias=self.whether_bias)) + layers.append(nn.BatchNorm2d(x)) + # kwargs["l_i"] += 1 + layers.append(neuron(**kwargs)) + layers.append(layer.Dropout(dropout)) + self.init_channels = x + return nn.Sequential(*layers) + + def forward(self, x): + out = self.layer1(x) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = self.layer5(out) + out = self.avgpool(out) + out = self.classifier(out) + + return out + + +def spiking_vgg9_bn(neuron: callable = None, num_classes=10, neuron_dropout=0.0, **kwargs): + return SpikingVGGBN('VGG9', neuron=neuron, dropout=neuron_dropout, num_classes=num_classes, **kwargs) + + +def spiking_vgg11_bn(neuron: callable = None, num_classes=10, neuron_dropout=0.0, **kwargs): + return SpikingVGGBN('VGG11', neuron=neuron, dropout=neuron_dropout, num_classes=num_classes, **kwargs) + + +def spiking_vgg13_bn(neuron: callable = None, num_classes=10, neuron_dropout=0.0, **kwargs): + return SpikingVGGBN('VGG13', neuron=neuron, dropout=neuron_dropout, num_classes=num_classes, **kwargs) + + +def spiking_vgg16_bn(neuron: callable = None, num_classes=10, neuron_dropout=0.0, **kwargs): + return SpikingVGGBN('VGG16', neuron=neuron, dropout=neuron_dropout, num_classes=num_classes, **kwargs) + + +def spiking_vgg19_bn(neuron: callable = None, num_classes=10, neuron_dropout=0.0, **kwargs): + return SpikingVGGBN('VGG19', neuron=neuron, dropout=neuron_dropout, num_classes=num_classes, **kwargs) diff --git a/models/vgg_model.py b/models/vgg_model.py new file mode 100644 index 0000000..274c82c --- /dev/null +++ b/models/vgg_model.py @@ -0,0 +1,185 @@ +import torch +from spikingjelly.clock_driven import layer + +__all__ = [ + 'vggsnn', 'snn5', 'snn5_noAP' +] + +from torch import nn + + +class SNN5(nn.Module): + def __init__(self, neuron, num_classes=10, dropout=0.0, **kwargs): + super(SNN5, self).__init__() + pool = nn.Sequential(nn.AvgPool2d(2)) + self.features = nn.Sequential( + Layer(3, 16, 3, 1, 1, neuron, **kwargs), + Layer(16, 64, 5, 1, 1, neuron, **kwargs), + pool, + Layer(64, 128, 5, 1, 1, neuron, **kwargs), + pool, + Layer(128, 256, 5, 1, 1, neuron, **kwargs), + pool, + Layer(256, 512, 3, 1, 1, neuron, **kwargs), + pool, + ) + W = int(32 / 2 / 2 / 2 / 2 / 2) + + self.classifier = nn.Linear(512 * W * W, num_classes) + self.drop = layer.Dropout(dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + + def forward(self, input): + x = self.features(input) + # print(x.shape) + x = self.drop(torch.flatten(x, start_dim=-3, end_dim=-1)) + x = self.classifier(x) + return x + + +# use for Figure.2 +class SNN5_noAP(nn.Module): + def __init__(self, neuron, num_classes=10, dropout=0.0, **kwargs): + super(SNN5_noAP, self).__init__() + pool = nn.Sequential(nn.AvgPool2d(2)) + # pool = APLayer(2) + self.features = nn.Sequential( + Layer(3, 16, 3, 1, 1, neuron, **kwargs), + Layer(16, 64, 5, 2, 1, neuron, **kwargs), + Layer(64, 128, 5, 2, 1, neuron, **kwargs), + Layer(128, 256, 5, 4, 1, neuron, **kwargs), + Layer(256, 256, 3, 2, 1, neuron, **kwargs), + ) + # W = int(32 / 2 / 2 / 2 / 4 / 2) + # if "fc_hw" in kwargs: + # W = int(kwargs["fc_hw"] / 2 / 2 / 2 / 2 / 2) + + self.classifier = nn.Linear(256, num_classes) + self.drop = layer.Dropout(dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + + def forward(self, input): + x = self.features(input) + x = self.drop(torch.flatten(x, start_dim=-3, end_dim=-1)) + x = self.classifier(x) + return x + + +def snn5(neuron: callable = None, num_classes=10, neuron_dropout=0.0, **kwargs): + return SNN5(neuron=neuron, num_classes=num_classes, dropout=neuron_dropout, **kwargs) + + +def snn5_noAP(neuron: callable = None, num_classes=10, neuron_dropout=0.0, **kwargs): + return SNN5_noAP(neuron=neuron, num_classes=num_classes, dropout=neuron_dropout, **kwargs) + + +class Layer(nn.Module): + def __init__(self, in_plane, out_plane, kernel_size, stride, padding, neuron, **kwargs): + super(Layer, self).__init__() + self.fwd = nn.Sequential( + nn.Conv2d(in_plane, out_plane, kernel_size, stride, padding), + nn.BatchNorm2d(out_plane) + ) + self.act = neuron(**kwargs) + + def forward(self, x): + x = self.fwd(x) + x = self.act(x) + # print(x.shape) + return x + + +class VGGSNN(nn.Module): + def __init__(self, neuron, num_classes=10, neuron_dropout=0.0, **kwargs): + super(VGGSNN, self).__init__() + pool = nn.Sequential(nn.AvgPool2d(2)) + # pool = APLayer(2) + self.features = nn.Sequential( + Layer(2, 64, 3, 1, 1, neuron, **kwargs), + Layer(64, 128, 3, 1, 1, neuron, **kwargs), + pool, + Layer(128, 256, 3, 1, 1, neuron, **kwargs), + Layer(256, 256, 3, 1, 1, neuron, **kwargs), + pool, + Layer(256, 512, 3, 1, 1, neuron, **kwargs), + Layer(512, 512, 3, 1, 1, neuron, **kwargs), + pool, + Layer(512, 512, 3, 1, 1, neuron, **kwargs), + Layer(512, 512, 3, 1, 1, neuron, **kwargs), + pool, + ) + W = int(48 / 2 / 2 / 2 / 2) + if "fc_hw" in kwargs: + W = int(kwargs["fc_hw"] / 2 / 2 / 2 / 2) + # self.T = 4 + # self.classifier = SeqToANNContainer(nn.Linear(512 * W * W, 10)) + self.classifier = nn.Linear(512 * W * W, num_classes) + self.drop = layer.Dropout(neuron_dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + + def forward(self, input): + x = self.features(input) + # x = torch.flatten(x, 2) + x = self.drop(torch.flatten(x, start_dim=-3, end_dim=-1)) + x = self.classifier(x) + return x + + +class VGGSNNwoAP(nn.Module): + def __init__(self, neuron, num_classes=10, neuron_dropout=0.0, **kwargs): + super(VGGSNNwoAP, self).__init__() + self.features = nn.Sequential( + Layer(2, 64, 3, 1, 1, neuron, **kwargs), + Layer(64, 128, 3, 2, 1, neuron, **kwargs), + Layer(128, 256, 3, 1, 1, neuron, **kwargs), + Layer(256, 256, 3, 2, 1, neuron, **kwargs), + Layer(256, 512, 3, 1, 1, neuron, **kwargs), + Layer(512, 512, 3, 2, 1, neuron, **kwargs), + Layer(512, 512, 3, 1, 1, neuron, **kwargs), + Layer(512, 512, 3, 2, 1, neuron, **kwargs), + ) + W = int(48 / 2 / 2 / 2 / 2) + if "fc_hw" in kwargs: + W = int(kwargs["fc_hw"] / 2 / 2 / 2 / 2) + + # self.T = 4 + # self.classifier = SeqToANNContainer(nn.Linear(512 * W * W, 10)) + self.classifier = nn.Linear(512 * W * W, num_classes) + self.drop = layer.Dropout(neuron_dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + + def forward(self, input): + # print(input.shape) + x = self.features(input) + # print(x.shape) + x = self.drop(torch.flatten(x, start_dim=-3, end_dim=-1)) + + x = self.classifier(x) + return x + + +def vggsnn(neuron: callable = None, num_classes=10, neuron_dropout=0.0, **kwargs): + return VGGSNN(neuron=neuron, num_classes=num_classes, dropout=neuron_dropout, **kwargs) + + +if __name__ == '__main__': + # model = VGGSNNwoAP() + from modules.neuron import ComplementaryLIFNeuron + from thop import profile + + model = snn5_noAP(neuron=ComplementaryLIFNeuron) + input = torch.randn(1, 3, 32, 32) + flops, params = profile(model, inputs=(input,)) + print(model) diff --git a/modules/neuron.py b/modules/neuron.py index 8b4712d..730952b 100644 --- a/modules/neuron.py +++ b/modules/neuron.py @@ -2,24 +2,54 @@ import torch from spikingjelly.clock_driven.neuron import LIFNode as LIFNode_sj +from spikingjelly.clock_driven.neuron import ParametricLIFNode as PLIFNode_sj +from torch import nn from modules.surrogate import Rectangle +# multistep torch version +class CLIFSpike(nn.Module): + def __init__(self, tau: float): + super(CLIFSpike, self).__init__() + # the symbol is corresponding to the paper + # self.spike_func = surrogate_function + self.spike_func = Rectangle() + + self.v_th = 1. + self.gamma = 1 - 1. / tau + + def forward(self, x_seq): + # x_seq.shape should be [T, N, *] + _spike = [] + u = 0 + m = 0 + T = x_seq.shape[0] + for t in range(T): + u = self.gamma * u + x_seq[t, ...] + spike = self.spike_func(u - self.v_th) + _spike.append(spike) + m = m * torch.sigmoid_((1. - self.gamma) * u) + spike + u = u - spike * (self.v_th + torch.sigmoid_(m)) + # self.pre_spike_mem = torch.stack(_mem) + return torch.stack(_spike, dim=0) + + +# spikingjelly single step version class ComplementaryLIFNeuron(LIFNode_sj): def __init__(self, tau: float = 2., decay_input: bool = False, v_threshold: float = 1., v_reset: float = None, surrogate_function: Callable = Rectangle(), detach_reset: bool = False, cupy_fp32_inference=False, **kwargs): super().__init__(tau, decay_input, v_threshold, v_reset, surrogate_function, detach_reset, cupy_fp32_inference) - self.register_memory('c', 0.) # Complementary memory + self.register_memory('m', 0.) # Complementary memory def forward(self, x: torch.Tensor): - self.neuronal_charge(x) # LIF charging - self.c = self.c * torch.sigmoid(self.v / self.tau) # Forming - spike = self.neuronal_fire() # LIF fire - self.c += spike # Strengthen - self.neuronal_reset(spike) # LIF reset - self.v = self.v - spike * torch.sigmoid(self.c) # Reset + self.neuronal_charge(x) # LIF charging + self.m = self.m * torch.sigmoid(self.v / self.tau) # Forming + spike = self.neuronal_fire() # LIF fire + self.m += spike # Strengthen + self.neuronal_reset(spike) # LIF reset + self.v = self.v - spike * torch.sigmoid(self.m) # Reset return spike def neuronal_charge(self, x: torch.Tensor): @@ -52,6 +82,7 @@ def _reset(self, spike): self.v = (1. - spike) * self.v + spike * self.v_reset +# spikingjelly multiple step version class MultiStepCLIFNeuron(ComplementaryLIFNeuron): def __init__(self, tau: float = 2., decay_input: bool = False, v_threshold: float = 1., v_reset: float = None, surrogate_function: Callable = Rectangle(), @@ -71,9 +102,30 @@ def forward(self, x_seq: torch.Tensor): return spike_seq +class ReLU(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + def forward(self, x): + return torch.relu(x) + + +class BPTTNeuron(LIFNode_sj): + def __init__(self, tau: float = 2., decay_input: bool = False, v_threshold: float = 1., + v_reset: float = None, surrogate_function: Callable = Rectangle(), + detach_reset: bool = False, cupy_fp32_inference=False, **kwargs): + super().__init__(tau, decay_input, v_threshold, v_reset, surrogate_function, detach_reset, cupy_fp32_inference) + + +class PLIFNeuron(PLIFNode_sj): + def __init__(self, tau: float = 2., decay_input: bool = False, v_threshold: float = 1., + v_reset: float = None, surrogate_function: Callable = None, + detach_reset: bool = False, cupy_fp32_inference=False, **kwargs): + super().__init__(tau, decay_input, v_threshold, v_reset, surrogate_function, detach_reset) + if __name__ == '__main__': - T= 8 + T = 8 x_input = torch.rand((T, 3, 32, 32)) * 1.2 clif = ComplementaryLIFNeuron() clif_m = MultiStepCLIFNeuron() @@ -88,4 +140,4 @@ def forward(self, x_seq: torch.Tensor): print(s_list.mean()) print(s_output.mean()) - assert torch.sum(s_output - torch.Tensor(s_list)) == 0 \ No newline at end of file + assert torch.sum(s_output - torch.Tensor(s_list)) == 0 diff --git a/run_inference_script b/run_inference_script new file mode 100644 index 0000000..9f5228b --- /dev/null +++ b/run_inference_script @@ -0,0 +1,59 @@ +# Cifar10 +python inference.py +-data_dir ./data_dir -dataset cifar10 -model spiking_resnet18 -neuron CLIF +-name without_auto_aug +-resume ./save_logs/logs_cifar10/BPTT_cifar10_spiking_resnet18__T6_tau2.0_e200_bs128_SGD_lr0.1_wd5e-05_SG_rectangle_drop0.0_losslamb0.05_LTLIF_CosALR_200_amp/checkpoint_max.pth + +python inference.py +-data_dir ./data_dir -dataset cifar10 -model spiking_resnet18 -neuron LIF +-name without_auto_aug +-resume ./save_logs/logs_cifar10/BPTT_cifar10_spiking_resnet18__T6_tau2.0_e200_bs128_SGD_lr0.1_wd5e-05_SG_rectangle_drop0.0_losslamb0.05_CosALR_200_amp/checkpoint_max.pth + + +# Cifar100 +python inference.py +-data_dir ./data_dir -dataset cifar100 -model spiking_resnet18 -neuron CLIF +-name without_auto_aug +-resume ./save_logs/logs_cifar100/BPTT_cifar100_spiking_resnet18__T6_tau2.0_e200_bs128_SGD_lr0.1_wd0.0005_SG_rectangle_drop0.0_losslamb0.05_LTLIF_CosALR_200_amp/checkpoint_max.pth + +python inference.py +-data_dir ./data_dir -dataset cifar100 -model spiking_resnet18 -neuron LIF +-name without_auto_aug +-resume ./save_logs/logs_cifar100/BPTT_cifar100_spiking_resnet18__T6_tau2.0_e200_bs128_SGD_lr0.1_wd0.0005_SG_rectangle_drop0.0_losslamb0.05_CosALR_200_amp/checkpoint_max.pth + + +# TinyImagenet +python inference.py +-data_dir ./data_dir -dataset tiny_imagenet -model spiking_vgg13_bn -b 256 -neuron CLIF +-name without_auto_aug +-resume ./save_logs/logs_tiny_imagenet/without_auto_aug/test_t_BPTT_tiny_imagenet_spiking_vgg13_bn__T6_tau2.0_e200_bs256_SGD_lr0.1_wd0.0001_SG_rectangle_drop0.0_losslamb0.05_LTLIF_CosALR_200_amp/checkpoint_max.pth + +python inference.py +-data_dir ./data_dir -dataset tiny_imagenet -model spiking_vgg13_bn -b 256 -neuron LIF +-name without_aug +-resume ./save_logs/logs_tiny_imagenet/without_auto_aug/test_t_BPTT_tiny_imagenet_spiking_vgg13_bn__T6_tau2.0_e200_bs256_SGD_lr0.1_wd0.0001_SG_rectangle_drop0.0_losslamb0.05CosALR_200_amp/checkpoint_max.pth + + +# DVSCifar +python inference.py +-data_dir ./data_dir -dataset DVSCIFAR10 -model spiking_vgg11_bn -T 10 -neuron CLIF +-name without_auto_aug +-resume ./save_logs/logs_dvscifar/BPTT_DVSCIFAR10_spiking_vgg11_bn__T10_tau2.0_e300_bs128_SGD_lr0.05_wd0.0005_SG_rectangle_drop0.3_losslamb0.05_LTLIF_CosALR_300_amp/checkpoint_max.pth + +python inference.py +-data_dir ./data_dir -dataset DVSCIFAR10 -model spiking_vgg11_bn -T 10 -neuron LIF +-name without_auto_aug +-resume ./save_logs/logs_dvscifar/BPTT_DVSCIFAR10_spiking_vgg11_bn__T10_tau2.0_e300_bs128_SGD_lr0.05_wd0.0005_SG_rectangle_drop0.3_losslamb0.05_CosALR_300_amp/checkpoint_max.pth + + +# DVSGesture +python inference.py ` +-data_dir ./data_dir -dataset dvsgesture -model spiking_vgg11_bn -T 20 -b 16 -neuron CLIF ` +-name without_auto_aug ` +-resume ./save_logs/logs_dvsgesture/BPTT_dvsgesture_spiking_vgg11_bn__T20_tau2.0_e300_bs16_SGD_lr0.1_wd0.0005_SG_rectangle_drop0.4_losslamb0.05_LTLIF_CosALR_300_amp/checkpoint_max.pth + +python inference.py ` +-data_dir ./data_dir -dataset dvsgesture -model spiking_vgg11_bn -T 20 -b 16 -neuron LIF ` +-name without_auto_aug ` +-resume ./save_logs/logs_dvsgesture/BPTT_dvsgesture_spiking_vgg11_bn__T20_tau2.0_e300_bs16_SGD_lr0.1_wd0.0005_SG_rectangle_drop0.4_losslamb0.05CosALR_300_amp/checkpoint_max.pth + diff --git a/train.py b/train.py new file mode 100644 index 0000000..5352bf4 --- /dev/null +++ b/train.py @@ -0,0 +1,667 @@ +import argparse +import collections +import datetime +import os +import random +import time + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data as data +import torchvision +import torchvision.datasets as datasets +import torchvision.transforms as transforms +from spikingjelly.clock_driven import functional, surrogate as surrogate_sj +from spikingjelly.datasets.dvs128_gesture import DVS128Gesture +from torch.cuda import amp +from torch.utils.data import DataLoader +from torch.utils.data.dataloader import default_collate +from torch.utils.tensorboard import SummaryWriter +from torchtoolbox.transform import Cutout +from torchvision.transforms import autoaugment, transforms +from torchvision.transforms.functional import InterpolationMode + +from models import spiking_resnet, vgg_model, spiking_vgg_bn +from modules import neuron +from modules import surrogate as surrogate_self +from utils import Bar, AverageMeter, accuracy, static_cifar_util, augmentation +from utils.augmentation import ToPILImage, Resize, ToTensor +from utils.cifar10_dvs import CIFAR10DVS, DVSCifar10 +from utils.data_loaders import TinyImageNet + +# from torchtoolbox.transform import Cutout + +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + + +def main(): + parser = argparse.ArgumentParser(description='SNN training') + parser.add_argument('-seed', default=2022, type=int) + parser.add_argument('-name', default='', type=str, help='specify a name for the checkpoint and log files') + parser.add_argument('-T', default=6, type=int, help='simulating time-steps') + parser.add_argument('-tau', default=2.0, type=float, help='a hyperparameter for the LIF model') + parser.add_argument('-b', default=128, type=int, help='batch size') + parser.add_argument('-epochs', default=300, type=int, metavar='N', help='number of total epochs to run') + parser.add_argument('-j', default=0, type=int, metavar='N', help='number of data loading workers (default: 4)') + parser.add_argument('-data_dir', type=str, default='./data', help='directory of the used dataset') + parser.add_argument('-dataset', default='cifar10', type=str, help='should be cifar10, cifar100, DVSCIFAR10, dvsgesture, or imagenet') + parser.add_argument('-out_dir', type=str, default='./logs', help='root dir for saving logs and checkpoint') + parser.add_argument('-surrogate', default='rectangle', type=str, help='used surrogate function. should be sigmoid, rectangle, or triangle') + parser.add_argument('-resume', type=str, help='resume from the checkpoint path') + parser.add_argument('-pre_train', type=str, help='load a pretrained model. used for imagenet') + parser.add_argument('-amp', default=True, type=bool, help='automatic mixed precision training') + parser.add_argument('-opt', type=str, help='use which optimizer. SGD or AdamW', default='SGD') + parser.add_argument('-lr', default=0.1, type=float, help='learning rate') + parser.add_argument('-momentum', default=0.9, type=float, help='momentum for SGD') + parser.add_argument('-lr_scheduler', default='CosALR', type=str, help='use which schedule. StepLR or CosALR') + parser.add_argument('-step_size', default=100, type=float, help='step_size for StepLR') + parser.add_argument('-gamma', default=0.1, type=float, help='gamma for StepLR') + parser.add_argument('-T_max', default=300, type=int, help='T_max for CosineAnnealingLR') + parser.add_argument('-model', type=str, default='spiking_vgg11_bn', help='use which SNN model') + parser.add_argument('-drop_rate', type=float, default=0.0, help='dropout rate') + parser.add_argument('-weight_decay', type=float, default=5e-4) + parser.add_argument('-loss_lambda', type=float, default=0.05, help='the scaling factor for the MSE term in the loss') + parser.add_argument('-mse_n_reg', action='store_true', help='loss function setting') + parser.add_argument('-loss_means', type=float, default=1.0, help='used in the loss function when mse_n_reg=False') + parser.add_argument('-save_init', action='store_true', help='save the initialization of parameters') + parser.add_argument('-neuron_model', type=str, default='LIF', help='save the initialization of parameters') + parser.add_argument('-multiple_step', type=bool, default=False, help='whether multiple steps') + parser.add_argument('-cutupmix_auto', action='store_true', help='cutupmix autoaugmentation for cifar and tinyimagenet') + parser.add_argument('-label_smoothing', type=float, default=0.0, help='label_smoothing for cross entropy') + + args = parser.parse_args() + print(args) + + _seed_ = args.seed + random.seed(_seed_) + torch.manual_seed(_seed_) # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA) + torch.cuda.manual_seed_all(_seed_) + np.random.seed(_seed_) + + ########################################################## + # data loading + ########################################################## + in_dim = None + c_in = None + if args.dataset == 'cifar10' or args.dataset == 'cifar100': + + c_in = 3 + if args.dataset == 'cifar10': + dataloader = torchvision.datasets.CIFAR10 + num_classes = 10 + normalization_mean = (0.4914, 0.4822, 0.4465) + normalization_std = (0.2023, 0.1994, 0.2010) + elif args.dataset == 'cifar100': + dataloader = torchvision.datasets.CIFAR100 + num_classes = 100 + normalization_mean = (0.5071, 0.4867, 0.4408) + normalization_std = (0.2675, 0.2565, 0.2761) + else: + raise NotImplementedError + + if args.cutupmix_auto: + mixup_transforms = [] + mixup_transforms.append(static_cifar_util.RandomMixup(num_classes, p=1.0, alpha=0.2)) + mixup_transforms.append(static_cifar_util.RandomCutmix(num_classes, p=1.0, alpha=1.)) + mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms) + collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731 + + transform_train = static_cifar_util.ClassificationPresetTrain(mean=normalization_mean, + std=normalization_std, + interpolation=InterpolationMode('bilinear'), + auto_augment_policy='ta_wide', + random_erase_prob=0.1) + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(normalization_mean, normalization_std), + ]) + + train_set = dataloader( + root=args.data_dir, + train=True, + transform=transform_train, + download=False, ) + + test_set = dataloader( + root=args.data_dir, + train=False, + transform=transform_test, + download=False) + + train_data_loader = torch.utils.data.DataLoader( + dataset=train_set, + batch_size=args.b, + collate_fn=collate_fn, + shuffle=True, + drop_last=True, + num_workers=args.j, + pin_memory=True + ) + + test_data_loader = torch.utils.data.DataLoader( + dataset=test_set, + batch_size=args.b, + shuffle=False, + drop_last=False, + num_workers=args.j, + pin_memory=True + ) + else: + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + Cutout(), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(normalization_mean, normalization_std), + ]) + + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(normalization_mean, normalization_std), + ]) + + trainset = dataloader(root=args.data_dir, train=True, download=True, transform=transform_train) + train_data_loader = DataLoader(trainset, batch_size=args.b, shuffle=True, num_workers=args.j) + + testset = dataloader(root=args.data_dir, train=False, download=False, transform=transform_test) + test_data_loader = DataLoader(testset, batch_size=args.b, shuffle=False, num_workers=args.j) + + elif args.dataset == 'DVSCIFAR10': + c_in = 2 + num_classes = 10 + + transform_train = transforms.Compose([ + ToPILImage(), + Resize(48), + # augmentation.Cutout(), + # augmentation.RandomSizedCrop(48), + augmentation.RandomHorizontalFlip(), + augmentation.RandomRotation(), + ToTensor(), + + ]) + + transform_test = transforms.Compose([ + ToPILImage(), + Resize(48), + ToTensor(), + ]) + + trainset = CIFAR10DVS(args.data_dir, train=True, use_frame=True, frames_num=args.T, split_by='number', normalization=None, transform=transform_train) + testset = CIFAR10DVS(args.data_dir, train=False, use_frame=True, frames_num=args.T, split_by='number', normalization=None, transform=transform_test) + + train_data_loader = DataLoader(trainset, batch_size=args.b, shuffle=True, num_workers=args.j) + test_data_loader = DataLoader(testset, batch_size=args.b, shuffle=False, num_workers=args.j) + + elif args.dataset == 'DVSCIFAR10-pt': + c_in = 2 + num_classes = 10 + in_dim = 48 + train_path = args.data_dir + '/train' + val_path = args.data_dir + '/test' + trainset = DVSCifar10(root=train_path, transform=True) + testset = DVSCifar10(root=val_path, transform=False) + train_data_loader = DataLoader(trainset, batch_size=args.b, shuffle=True, num_workers=args.j) + test_data_loader = DataLoader(testset, batch_size=args.b, shuffle=False, num_workers=args.j, drop_last=False, pin_memory=True) + + elif args.dataset == 'dvsgesture': + c_in = 2 + num_classes = 11 + in_dim = 128 + + trainset = DVS128Gesture(root=args.data_dir, train=True, data_type='frame', frames_number=args.T, split_by='number') + train_data_loader = DataLoader(trainset, batch_size=args.b, shuffle=True, num_workers=args.j, drop_last=True, pin_memory=True) + + testset = DVS128Gesture(root=args.data_dir, train=False, data_type='frame', frames_number=args.T, split_by='number') + test_data_loader = DataLoader(testset, batch_size=args.b, shuffle=False, num_workers=args.j, drop_last=False, pin_memory=True) + + elif args.dataset == 'tiny_imagenet': + data_dir = args.data_dir + c_in = 3 + num_classes = 200 + normalize = transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]) + + transoform_list = [ + transforms.RandomCrop(64), + transforms.RandomHorizontalFlip(0.5), + ] + + if args.cutupmix_auto: + transoform_list.append(autoaugment.AutoAugment()) + + transoform_list += [transforms.ToTensor(), normalize] + + train_transforms = transforms.Compose(transoform_list) + val_transforms = transforms.Compose([transforms.ToTensor(), normalize,]) + + train_data = TinyImageNet(data_dir, train=True, transform=train_transforms) + test_data = TinyImageNet(data_dir, train=False, transform=val_transforms) + + train_data_loader = torch.utils.data.DataLoader( + train_data, + batch_size=args.b, shuffle=True, + num_workers=args.j, pin_memory=True) + + test_data_loader = torch.utils.data.DataLoader( + test_data, + batch_size=args.b, shuffle=False, + num_workers=args.j, pin_memory=True) + else: + raise NotImplementedError + + ########################################################## + # model preparing + ########################################################## + if args.surrogate == 'sigmoid': + surrogate_function = surrogate_sj.Sigmoid() + elif args.surrogate == 'rectangle': + surrogate_function = surrogate_self.Rectangle() + elif args.surrogate == 'triangle': + surrogate_function = surrogate_sj.PiecewiseQuadratic() + else: + raise NotImplementedError + + if args.neuron_model == 'LIF': + neuron_model = neuron.BPTTNeuron + elif args.neuron_model == 'CLIF': + neuron_model = neuron.ComplementaryLIFNeuron + elif args.neuron_model == 'PLIF': + neuron_model = neuron.PLIFNeuron + elif args.neuron_model == 'relu': + neuron_model = neuron.ReLU + args.T = 1 + else: + raise NotImplementedError + + if args.model in ['spiking_resnet18', 'spiking_resnet34', 'spiking_resnet50', 'spiking_resnet101', 'spiking_resnet152']: + net = spiking_resnet.__dict__[args.model](neuron=neuron_model, num_classes=num_classes, + neuron_dropout=args.drop_rate, + tau=args.tau, surrogate_function=surrogate_function, c_in=c_in, + fc_hw=1) + print('using Resnet model.') + elif args.model in ['spiking_vgg11_bn', 'spiking_vgg13_bn', 'spiking_vgg16_bn', 'spiking_vgg19_bn']: + net = spiking_vgg_bn.__dict__[args.model](neuron=neuron_model, num_classes=num_classes, + neuron_dropout=args.drop_rate, + tau=args.tau, surrogate_function=surrogate_function, c_in=c_in, + fc_hw=in_dim if in_dim else None) + print('using Spiking VGG model.') + elif args.model in ['vggsnn', 'snn5_noAP']: # snn5_noAP use for statistical experiment + net = vgg_model.__dict__[args.model](neuron=neuron_model, num_classes=num_classes, + neuron_dropout=args.drop_rate, + tau=args.tau, surrogate_function=surrogate_function, c_in=c_in, + fc_hw=in_dim if in_dim else None) + print('using Spiking VGG model.') + else: + raise NotImplementedError + + print('Total Parameters: %.2fM' % (sum(p.numel() for p in net.parameters()) / 1000000.0)) + net.cuda() + + ########################################################## + # optimizer preparing + ########################################################## + if args.opt == 'SGD': + optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, + weight_decay=args.weight_decay) + elif args.opt == 'AdamW': + optimizer = torch.optim.AdamW(net.parameters(), lr=args.lr, weight_decay=args.weight_decay) + else: + raise NotImplementedError(args.opt) + + if args.lr_scheduler == 'StepLR': + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) + elif args.lr_scheduler == 'CosALR': + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.T_max) + else: + raise NotImplementedError(args.lr_scheduler) + + scaler = None + if args.amp: + scaler = amp.GradScaler() + + ########################################################## + # loading models from checkpoint + ########################################################## + start_epoch = 0 + max_test_acc = 0 + + if args.resume: + print('resuming...') + checkpoint = torch.load(args.resume, map_location='cpu') + net.load_state_dict(checkpoint['net']) + optimizer.load_state_dict(checkpoint['optimizer']) + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + start_epoch = checkpoint['epoch'] + 1 + max_test_acc = checkpoint['max_test_acc'] + print('start epoch:', start_epoch, ', max test acc:', max_test_acc) + + if args.pre_train: + checkpoint = torch.load(args.pre_train, map_location='cpu') + state_dict2 = collections.OrderedDict([(k, v) for k, v in checkpoint['net'].items()]) + net.load_state_dict(state_dict2) + print('use pre-trained model, max test acc:', checkpoint['max_test_acc']) + + ########################################################## + # output setting + ########################################################## + out_dir = os.path.join(args.out_dir, f'train_{args.dataset}_{args.model}_{args.name}_T{args.T}_tau{args.tau}_e{args.epochs}_bs{args.b}_{args.opt}_lr{args.lr}_wd{args.weight_decay}_SG_{args.surrogate}_drop{args.drop_rate}_losslamb{args.loss_lambda}_labelsmoothing{args.label_smoothing}') + + if args.neuron_model != 'LIF': + out_dir += f'_{args.neuron_model}_' + + if args.lr_scheduler == 'CosALR': + out_dir += f'CosALR_{args.T_max}' + elif args.lr_scheduler == 'StepLR': + out_dir += f'StepLR_{args.step_size}_{args.gamma}' + else: + raise NotImplementedError(args.lr_scheduler) + + if args.amp: + out_dir += '_amp' + + if args.cutupmix_auto: + out_dir += '_cutupmix_auto' + + if not os.path.exists(out_dir): + os.makedirs(out_dir) + print(f'Mkdir {out_dir}.') + else: + print(out_dir) + + # save the initialization of parameters + if args.save_init: + checkpoint = { + 'net': net.state_dict(), + 'epoch': 0, + 'max_test_acc': 0.0 + } + torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_0.pth')) + + with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt: + args_txt.write(str(args)) + + writer = SummaryWriter(os.path.join(out_dir, 'logs'), purge_step=start_epoch) + + ########################################################## + # training and testing + ########################################################## + criterion_mse = nn.MSELoss() + + for epoch in range(start_epoch, args.epochs): + ############### training ############### + start_time = time.time() + net.train() + + batch_time = AverageMeter() + data_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + end = time.time() + + bar = Bar('Processing', max=len(train_data_loader)) + + train_loss = 0 + train_acc = 0 + train_samples = 0 + batch_idx = 0 + for data in train_data_loader: + + if args.dataset == 'SHD': + frame, label, _ = data + # print(frame.shape) + # print(label.shape) + else: + frame, label = data + + batch_idx += 1 + if (args.dataset != 'DVSCIFAR10'): + frame = frame.float().cuda() + if args.dataset == 'dvsgesture' or args.dataset == 'SHD' or args.dataset == "DVSCIFAR10-pt": + frame = frame.transpose(0, 1) + + t_step = args.T + if args.dataset == 'SHD': + t_step = len(frame) + # print(t_step) + + label = label.cuda() + label_real = torch.cat([label for _ in range(t_step)], 0) + + optimizer.zero_grad() + out_all = [] + for t in range(t_step): + if (args.dataset == 'DVSCIFAR10'): + input_frame = frame[t].float().cuda() + # print(input_frame.shape) + elif args.dataset == 'dvsgesture' or args.dataset == "SHD" or args.dataset == "DVSCIFAR10-pt": + input_frame = frame[t] + else: + input_frame = frame + + if args.amp: + with amp.autocast(): + if t == 0: + out_fr = net(input_frame) + total_fr = out_fr.clone().detach() + out_all.append(out_fr) + else: + out_fr = net(input_frame) + total_fr += out_fr.clone().detach() + out_all.append(out_fr) + else: + raise NotImplementedError + + out_all = torch.cat(out_all, 0) + # print(out_all.shape) + # print(label_real.shape) + # print(out_all) + if args.amp: + with amp.autocast(): + # print(label.shape) + # print(out_all.shape) + # Calculate the loss + if args.loss_lambda > 0.0: # the loss is a cross entropy term plus a mse term + if args.mse_n_reg: # the mse term is not treated as a regularizer + label_one_hot = F.one_hot(label_real, num_classes).float() + else: + label_one_hot = torch.zeros_like(out_all).fill_(args.loss_means).to(out_all.device) + mse_loss = criterion_mse(out_all, label_one_hot) + loss = ((1 - args.loss_lambda) * F.cross_entropy(out_all, label_real, + label_smoothing=args.label_smoothing) + args.loss_lambda * mse_loss) + else: # the loss is just a cross entropy term + loss = F.cross_entropy(out_all, label_real, label_smoothing=args.label_smoothing) + scaler.scale(loss).backward() + + else: + raise NotImplementedError( + 'We do not implement mixed precision training for BPTT. Please implement it by yourself.') + + batch_loss = loss.item() + train_loss += loss.item() * label.numel() + + if args.amp: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() + + # measure accuracy and record loss + if args.cutupmix_auto: + label = label.argmax(dim=-1) + prec1, prec5 = accuracy(total_fr.data, label.data, topk=(1, 5)) + losses.update(batch_loss, input_frame.size(0)) + top1.update(prec1.item(), input_frame.size(0)) + top5.update(prec5.item(), input_frame.size(0)) + + train_samples += label.numel() + train_acc += (total_fr.argmax(1) == label).float().sum().item() + + functional.reset_net(net) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + # plot progress + bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( + batch=batch_idx, + size=len(train_data_loader), + data=data_time.avg, + bt=batch_time.avg, + total=bar.elapsed_td, + eta=bar.eta_td, + loss=losses.avg, + top1=top1.avg, + top5=top5.avg, + ) + bar.next() + bar.finish() + + train_loss /= train_samples + train_acc /= train_samples + + writer.add_scalar('train_loss', train_loss, epoch) + writer.add_scalar('train_acc', train_acc, epoch) + lr_scheduler.step() + + ############### testing ############### + net.eval() + + batch_time = AverageMeter() + data_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + end = time.time() + bar = Bar('Processing', max=len(test_data_loader)) + + test_loss = 0 + test_acc = 0 + test_samples = 0 + batch_idx = 0 + with torch.no_grad(): + for data in test_data_loader: + if args.dataset == 'SHD': + frame, label, _ = data + else: + frame, label = data + + batch_idx += 1 + if (args.dataset != 'DVSCIFAR10'): + frame = frame.float().cuda() + + if args.dataset == 'dvsgesture' or args.dataset == "SHD" or args.dataset == "DVSCIFAR10-pt": + frame = frame.transpose(0, 1) + label = label.cuda() + + t_step = args.T + if args.dataset == 'SHD': + t_step = len(frame) + + label_real = torch.cat([label for _ in range(t_step)], 0) + # print(t_step) + + out_all = [] + for t in range(t_step): + + if (args.dataset == 'DVSCIFAR10'): + input_frame = frame[t].float().cuda() + elif args.dataset == 'dvsgesture' or args.dataset == "SHD" or args.dataset == "DVSCIFAR10-pt": + input_frame = frame[t] + else: + input_frame = frame + if t == 0: + out_fr = net(input_frame) + total_fr = out_fr.clone().detach() + out_all.append(out_fr) + else: + out_fr = net(input_frame) + total_fr += out_fr.clone().detach() + out_all.append(out_fr) + + out_all = torch.cat(out_all, 0) + # Calculate the loss + if args.loss_lambda > 0.0: # the loss is a cross entropy term plus a mse term + if args.mse_n_reg: # the mse term is not treated as a regularizer + label_one_hot = F.one_hot(label_real, num_classes).float() + else: + label_one_hot = torch.zeros_like(out_all).fill_(args.loss_means).to(out_all.device) + mse_loss = criterion_mse(out_all, label_one_hot) + loss = ((1 - args.loss_lambda) * F.cross_entropy(out_all, label_real, + label_smoothing=args.label_smoothing) + args.loss_lambda * mse_loss) + else: # the loss is just a cross entropy term + loss = F.cross_entropy(out_all, label_real, label_smoothing=args.label_smoothing) + total_loss = loss + + test_samples += label.numel() + test_loss += total_loss.item() * label.numel() + test_acc += (total_fr.argmax(1) == label).float().sum().item() + functional.reset_net(net) + + # measure accuracy and record loss + prec1, prec5 = accuracy(total_fr.data, label.data, topk=(1, 5)) + losses.update(total_loss, input_frame.size(0)) + top1.update(prec1.item(), input_frame.size(0)) + top5.update(prec5.item(), input_frame.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + # plot progress + bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( + batch=batch_idx, + size=len(test_data_loader), + data=data_time.avg, + bt=batch_time.avg, + total=bar.elapsed_td, + eta=bar.eta_td, + loss=losses.avg, + top1=top1.avg, + top5=top5.avg, + ) + bar.next() + bar.finish() + + test_loss /= test_samples + test_acc /= test_samples + writer.add_scalar('test_loss', test_loss, epoch) + writer.add_scalar('test_acc', test_acc, epoch) + + ############### saving checkpoint ############### + save_max = False + if test_acc > max_test_acc: + max_test_acc = test_acc + save_max = True + + checkpoint = { + 'net': net.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'epoch': epoch, + 'max_test_acc': max_test_acc + } + + if save_max: + torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_max.pth')) + + torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_latest.pth')) + + total_time = time.time() - start_time + info = f'epoch={epoch}, train_loss={train_loss}, train_acc={train_acc}, test_loss={test_loss}, test_acc={test_acc}, max_test_acc={max_test_acc}, total_time={total_time}, escape_time={(datetime.datetime.now() + datetime.timedelta(seconds=total_time * (args.epochs - epoch))).strftime("%Y-%m-%d %H:%M:%S")}' + print(info) + mem_cost = "after one epoch: %fGB" % (torch.cuda.max_memory_cached(0) / 1024 / 1024 / 1024) + print(mem_cost) + + with open(os.path.join(out_dir, 'args.txt'), 'a+', encoding='utf-8') as args_txt: + args_txt.write("\n") + args_txt.write(info + "\n") + args_txt.write(mem_cost + "\n") + + +if __name__ == '__main__': + main() diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..3161d40 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,12 @@ +"""Useful utils +""" +from .misc import * +from .logger import * +from .visualize import * +from .eval import * + +# progress bar +import os, sys + +sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) +from progress.bar import Bar as Bar diff --git a/utils/augmentation.py b/utils/augmentation.py new file mode 100644 index 0000000..1622df8 --- /dev/null +++ b/utils/augmentation.py @@ -0,0 +1,456 @@ +import collections +import math +import numbers +import random + +import numpy as np +import torch +import torchtoolbox.transform +import torchvision +import torchvision.transforms.functional as F +from PIL import Image, ImageOps +from torchvision import transforms + + +class Padding: + def __init__(self, pad): + self.pad = pad + + def __call__(self, imgmap): + return [ImageOps.expand(img, border=self.pad, fill=0) for img in imgmap] + + +class Scale: + def __init__(self, size, interpolation=Image.NEAREST): + assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) + self.size = size + self.interpolation = interpolation + + def __call__(self, imgmap): + # assert len(imgmap) > 1 # list of images + img1 = imgmap[0] + if isinstance(self.size, int): + w, h = img1.size + if (w <= h and w == self.size) or (h <= w and h == self.size): + return imgmap + if w < h: + ow = self.size + oh = int(self.size * h / w) + return [i.resize((ow, oh), self.interpolation) for i in imgmap] + else: + oh = self.size + ow = int(self.size * w / h) + return [i.resize((ow, oh), self.interpolation) for i in imgmap] + else: + return [i.resize(self.size, self.interpolation) for i in imgmap] + + +class CenterCrop: + def __init__(self, size, consistent=True): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, imgmap): + img1 = imgmap[0] + w, h = img1.size + th, tw = self.size + x1 = int(round((w - tw) / 2.)) + y1 = int(round((h - th) / 2.)) + return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap] + + +class RandomCropWithProb: + def __init__(self, size, p=0.8, consistent=True): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + self.consistent = consistent + self.threshold = p + + def __call__(self, imgmap): + img1 = imgmap[0] + w, h = img1.size + if self.size is not None: + th, tw = self.size + if w == tw and h == th: + return imgmap + if self.consistent: + if random.random() < self.threshold: + x1 = random.randint(0, w - tw) + y1 = random.randint(0, h - th) + else: + x1 = int(round((w - tw) / 2.)) + y1 = int(round((h - th) / 2.)) + return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap] + else: + result = [] + for i in imgmap: + if random.random() < self.threshold: + x1 = random.randint(0, w - tw) + y1 = random.randint(0, h - th) + else: + x1 = int(round((w - tw) / 2.)) + y1 = int(round((h - th) / 2.)) + result.append(i.crop((x1, y1, x1 + tw, y1 + th))) + return result + else: + return imgmap + + +class RandomCrop: + def __init__(self, size, consistent=True): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + self.consistent = consistent + + def __call__(self, imgmap, flowmap=None): + img1 = imgmap[0] + w, h = img1.size + if self.size is not None: + th, tw = self.size + if w == tw and h == th: + return imgmap + if not flowmap: + if self.consistent: + x1 = random.randint(0, w - tw) + y1 = random.randint(0, h - th) + return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap] + else: + result = [] + for i in imgmap: + x1 = random.randint(0, w - tw) + y1 = random.randint(0, h - th) + result.append(i.crop((x1, y1, x1 + tw, y1 + th))) + return result + elif flowmap is not None: + assert (not self.consistent) + result = [] + for idx, i in enumerate(imgmap): + proposal = [] + for j in range(3): # number of proposal: use the one with largest optical flow + x = random.randint(0, w - tw) + y = random.randint(0, h - th) + proposal.append([x, y, abs(np.mean(flowmap[idx, y:y + th, x:x + tw, :]))]) + [x1, y1, _] = max(proposal, key=lambda x: x[-1]) + result.append(i.crop((x1, y1, x1 + tw, y1 + th))) + return result + else: + raise ValueError('wrong case') + else: + return imgmap + + +class RandomSizedCrop: + def __init__(self, size, interpolation=Image.BILINEAR, consistent=True, p=1.0): + self.size = size + self.interpolation = interpolation + self.consistent = consistent + self.threshold = p + + def __call__(self, imgmap): + img1 = imgmap[0] + if random.random() < self.threshold: # do RandomSizedCrop + for attempt in range(10): + area = img1.size[0] * img1.size[1] + target_area = random.uniform(0.5, 1) * area + aspect_ratio = random.uniform(3. / 4, 4. / 3) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if self.consistent: + if random.random() < 0.5: + w, h = h, w + if w <= img1.size[0] and h <= img1.size[1]: + x1 = random.randint(0, img1.size[0] - w) + y1 = random.randint(0, img1.size[1] - h) + + imgmap = [i.crop((x1, y1, x1 + w, y1 + h)) for i in imgmap] + for i in imgmap: assert (i.size == (w, h)) + + return [i.resize((self.size, self.size), self.interpolation) for i in imgmap] + else: + result = [] + for i in imgmap: + if random.random() < 0.5: + w, h = h, w + if w <= img1.size[0] and h <= img1.size[1]: + x1 = random.randint(0, img1.size[0] - w) + y1 = random.randint(0, img1.size[1] - h) + result.append(i.crop((x1, y1, x1 + w, y1 + h))) + assert (result[-1].size == (w, h)) + else: + result.append(i) + + assert len(result) == len(imgmap) + return [i.resize((self.size, self.size), self.interpolation) for i in result] + + # Fallback + scale = Scale(self.size, interpolation=self.interpolation) + crop = CenterCrop(self.size) + return crop(scale(imgmap)) + else: # don't do RandomSizedCrop, do CenterCrop + crop = CenterCrop(self.size) + return crop(imgmap) + + +class RandomHorizontalFlip: + def __init__(self, consistent=True, command=None): + self.consistent = consistent + if command == 'left': + self.threshold = 0 + elif command == 'right': + self.threshold = 1 + else: + self.threshold = 0.5 + + def __call__(self, imgmap): + if self.consistent: + if random.random() < self.threshold: + return [i.transpose(Image.FLIP_LEFT_RIGHT) for i in imgmap] + else: + return imgmap + else: + result = [] + for i in imgmap: + if random.random() < self.threshold: + result.append(i.transpose(Image.FLIP_LEFT_RIGHT)) + else: + result.append(i) + assert len(result) == len(imgmap) + return result + + +class RandomGray: + '''Actually it is a channel splitting, not strictly grayscale images''' + + def __init__(self, consistent=True, p=0.5): + self.consistent = consistent + self.p = p # probability to apply grayscale + + def __call__(self, imgmap): + if self.consistent: + if random.random() < self.p: + return [self.grayscale(i) for i in imgmap] + else: + return imgmap + else: + result = [] + for i in imgmap: + if random.random() < self.p: + result.append(self.grayscale(i)) + else: + result.append(i) + assert len(result) == len(imgmap) + return result + + def grayscale(self, img): + channel = np.random.choice(3) + np_img = np.array(img)[:, :, channel] + np_img = np.dstack([np_img, np_img, np_img]) + img = Image.fromarray(np_img, 'RGB') + return img + + +class ColorJitter(object): + """Randomly change the brightness, contrast and saturation of an image. --modified from pytorch source code + Args: + brightness (float or tuple of float (min, max)): How much to jitter brightness. + brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] + or the given [min, max]. Should be non negative numbers. + contrast (float or tuple of float (min, max)): How much to jitter contrast. + contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] + or the given [min, max]. Should be non negative numbers. + saturation (float or tuple of float (min, max)): How much to jitter saturation. + saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] + or the given [min, max]. Should be non negative numbers. + hue (float or tuple of float (min, max)): How much to jitter hue. + hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. + Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. + """ + + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, consistent=False, p=1.0): + self.brightness = self._check_input(brightness, 'brightness') + self.contrast = self._check_input(contrast, 'contrast') + self.saturation = self._check_input(saturation, 'saturation') + self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), + clip_first_on_zero=False) + self.consistent = consistent + self.threshold = p + + def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): + if isinstance(value, numbers.Number): + if value < 0: + raise ValueError("If {} is a single number, it must be non negative.".format(name)) + value = [center - value, center + value] + if clip_first_on_zero: + value[0] = max(value[0], 0) + elif isinstance(value, (tuple, list)) and len(value) == 2: + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError("{} values should be between {}".format(name, bound)) + else: + raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) + + # if value is 0 or (1., 1.) for brightness/contrast/saturation + # or (0., 0.) for hue, do nothing + if value[0] == value[1] == center: + value = None + return value + + @staticmethod + def get_params(brightness, contrast, saturation, hue): + """Get a randomized transform to be applied on image. + Arguments are same as that of __init__. + Returns: + Transform which randomly adjusts brightness, contrast and + saturation in a random order. + """ + transforms = [] + + if brightness is not None: + brightness_factor = random.uniform(brightness[0], brightness[1]) + transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) + + if contrast is not None: + contrast_factor = random.uniform(contrast[0], contrast[1]) + transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) + + if saturation is not None: + saturation_factor = random.uniform(saturation[0], saturation[1]) + transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) + + if hue is not None: + hue_factor = random.uniform(hue[0], hue[1]) + transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_hue(img, hue_factor))) + + random.shuffle(transforms) + transform = torchvision.transforms.Compose(transforms) + + return transform + + def __call__(self, imgmap): + if random.random() < self.threshold: # do ColorJitter + if self.consistent: + transform = self.get_params(self.brightness, self.contrast, + self.saturation, self.hue) + return [transform(i) for i in imgmap] + else: + result = [] + for img in imgmap: + transform = self.get_params(self.brightness, self.contrast, + self.saturation, self.hue) + result.append(transform(img)) + return result + else: # don't do ColorJitter, do nothing + return imgmap + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + format_string += 'brightness={0}'.format(self.brightness) + format_string += ', contrast={0}'.format(self.contrast) + format_string += ', saturation={0}'.format(self.saturation) + format_string += ', hue={0})'.format(self.hue) + return format_string + + +class RandomRotation: + def __init__(self, consistent=True, degree=15, p=1.0): + self.consistent = consistent + self.degree = degree + self.threshold = p + + def __call__(self, imgmap): + if random.random() < self.threshold: # do RandomRotation + if self.consistent: + deg = np.random.randint(-self.degree, self.degree, 1)[0] + return [i.rotate(deg, expand=True) for i in imgmap] + else: + return [i.rotate(np.random.randint(-self.degree, self.degree, 1)[0], expand=True) for i in imgmap] + else: # don't do RandomRotation, do nothing + return imgmap + + +class ToTensor: + def __call__(self, imgmap): + totensor = transforms.ToTensor() + return [totensor(i) for i in imgmap] + + +class ToPILImage: + def __call__(self, imgmap): + topilimage = transforms.ToPILImage() + return [topilimage(i) for i in imgmap] + + +class Resize: + def __init__(self, size): + self.size = size + + def __call__(self, imgmap): + resize = transforms.Resize(self.size) + return [resize(i) for i in imgmap] + + +class Cutout: + def __call__(self, imgmap): + if random.random() < 0.5: + cutout = torchtoolbox.transform.Cutout(p=1) + return [cutout(i) for i in imgmap] + else: + return imgmap + + +class Normalize: + def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): + self.mean = mean + self.std = std + + def __call__(self, imgmap): + normalize = transforms.Normalize(mean=self.mean, std=self.std) + return [normalize(i) for i in imgmap] + + +class Roll: + def __init__(self): + self.off1 = random.randint(-5, 5) + self.off2 = random.randint(-5, 5) + + def __call__(self, imgmap): + return [torch.roll(i, shifts=(self.off1, self.off2), dims=(1, 2)) for i in imgmap] + + +if __name__ == '__main__': + from utils.cifar10_dvs import CIFAR10DVS + + c_in = 2 + num_classes = 10 + + transform_train = transforms.Compose([ + # CIFAR10_DVS_Aug(), # it has been resize 48 + ToPILImage(), + Resize(48), + RandomSizedCrop(48), + RandomHorizontalFlip(), + ToTensor(), + # Roll(), + ]) + + transform_test = transforms.Compose([ + ToPILImage(), + Resize(48), + ToTensor(), + ]) + data_dir = "./data_dir" + trainset = CIFAR10DVS(data_dir, train=True, use_frame=True, frames_num=10, split_by='number', + normalization=None, transform=transform_train) + testset = CIFAR10DVS(data_dir, train=False, use_frame=True, frames_num=10, split_by='number', + normalization=None, transform=transform_test) + + # train_data_loader = data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=0) + # test_data_loader = data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=0) + print(trainset[0]) diff --git a/utils/cifar10_dvs.py b/utils/cifar10_dvs.py new file mode 100644 index 0000000..5ebc360 --- /dev/null +++ b/utils/cifar10_dvs.py @@ -0,0 +1,765 @@ +import random +import threading +import zipfile + +import numpy as np +import torch +import torchvision +import torchvision.transforms as transforms +from torch.utils.data import Dataset + + +class FunctionThread(threading.Thread): + def __init__(self, f, *args, **kwargs): + super().__init__() + self.f = f + self.args = args + self.kwargs = kwargs + + def run(self): + self.f(*self.args, **self.kwargs) + + +def integrate_events_to_frames(events, height, width, frames_num=10, split_by='time', normalization=None): + ''' + * :ref:`API in English ` + + .. _integrate_events_to_frames.__init__-cn: + + :param events: 键是{'t', 'x', 'y', 'p'},值是np数组的的字典 + :param height: 脉冲数据的高度,例如对于CIFAR10-DVS是128 + :param width: 脉冲数据的宽度,例如对于CIFAR10-DVS是128 + :param frames_num: 转换后数据的帧数 + :param split_by: 脉冲数据转换成帧数据的累计方式,允许的取值为 ``'number', 'time'`` + :param normalization: 归一化方法,允许的取值为 ``None, 'frequency', 'max', 'norm', 'sum'`` + :return: 转化后的frames数据,是一个 ``shape = [frames_num, 2, height, width]`` 的np数组 + + 记脉冲数据为 :math:`E_{i} = (t_{i}, x_{i}, y_{i}, p_{i}), i=0,1,...,N-1`,转换为帧数据 :math:`F(j, p, x, y), j=0,1,...,M-1`。 + + 若划分方式 ``split_by`` 为 ``'time'``,则 + + .. math:: + + \\Delta T & = [\\frac{t_{N-1} - t_{0}}{M}] \\\\ + j_{l} & = \\mathop{\\arg\\min}\\limits_{k} \\{t_{k} | t_{k} \\geq t_{0} + \\Delta T \\cdot j\\} \\\\ + j_{r} & = \\begin{cases} \\mathop{\\arg\\max}\\limits_{k} \\{t_{k} | t_{k} < t_{0} + \\Delta T \\cdot (j + 1)\\} + 1, & j < M - 1 \\cr N, & j = M - 1 \\end{cases} \\\\ + F(j, p, x, y) & = \\sum_{i = j_{l}}^{j_{r} - 1} \\mathcal{I_{p, x, y}(p_{i}, x_{i}, y_{i})} + + 若划分方式 ``split_by`` 为 ``'number'``,则 + + .. math:: + + j_{l} & = [\\frac{N}{M}] \\cdot j \\\\ + j_{r} & = \\begin{cases} [\\frac{N}{M}] \\cdot (j + 1), & j < M - 1 \\cr N, & j = M - 1 \\end{cases}\\\\ + F(j, p, x, y) & = \\sum_{i = j_{l}}^{j_{r} - 1} \\mathcal{I_{p, x, y}(p_{i}, x_{i}, y_{i})} + + 其中 :math:`\\mathcal{I}` 为示性函数,当且仅当 :math:`(p, x, y) = (p_{i}, x_{i}, y_{i})` 时为1,否则为0。 + + 若 ``normalization`` 为 ``'frequency'``, + + 若 ``split_by`` 为 ``time`` 则 + + .. math:: + F_{norm}(j, p, x, y) = \\begin{cases} \\frac{F(j, p, x, y)}{\\Delta T}, & j < M - 1 + \\cr \\frac{F(j, p, x, y)}{\\Delta T + (t_{N-1} - t_{0}) \\bmod M}, & j = M - 1 \\end{cases} + + 若 ``split_by`` 为 ``number`` 则 + + .. math:: + F_{norm}(j, p, x, y) = \\frac{F(j, p, x, y)}{t_{j_{r}} - t_{j_{l}}} + + + 若 ``normalization`` 为 ``'max'`` 则 + + .. math:: + F_{norm}(j, p, x, y) = \\frac{F(j, p, x, y)}{\\mathrm{max} F(j, p)} + + 若 ``normalization`` 为 ``'norm'`` 则 + + .. math:: + F_{norm}(j, p, x, y) = \\frac{F(j, p, x, y) - \\mathrm{E}(F(j, p))}{\\sqrt{\\mathrm{Var}(F(j, p))}} + + 若 ``normalization`` 为 ``'sum'`` 则 + + .. math:: + F_{norm}(j, p, x, y) = \\frac{F(j, p, x, y)}{\\sum_{a, b} F(j, p, a, b)} + + * :ref:`中文API ` + + .. _integrate_events_to_frames.__init__-en: + + :param events: a dict with keys are {'t', 'x', 'y', 'p'} and values are numpy arrays + :param height: the height of events data, e.g., 128 for CIFAR10-DVS + :param width: the width of events data, e.g., 128 for CIFAR10-DVS + :param frames_num: frames number + :param split_by: how to split the events, can be ``'number', 'time'`` + :param normalization: how to normalize frames, can be ``None, 'frequency', 'max', 'norm', 'sum'`` + :return: the frames data with ``shape = [frames_num, 2, height, width]`` + + The events data are denoted by :math:`E_{i} = (t_{i}, x_{i}, y_{i}, p_{i}), i=0,1,...,N-1`, and the converted frames + data are denoted by :math:`F(j, p, x, y), j=0,1,...,M-1`. + + If ``split_by`` is ``'time'``, then + + .. math:: + + \\Delta T & = [\\frac{t_{N-1} - t_{0}}{M}] \\\\ + j_{l} & = \\mathop{\\arg\\min}\\limits_{k} \\{t_{k} | t_{k} \\geq t_{0} + \\Delta T \\cdot j\\} \\\\ + j_{r} & = \\begin{cases} \\mathop{\\arg\\max}\\limits_{k} \\{t_{k} | t_{k} < t_{0} + \\Delta T \\cdot (j + 1)\\} + 1, & j < M - 1 \\cr N, & j = M - 1 \\end{cases} \\\\ + F(j, p, x, y) & = \\sum_{i = j_{l}}^{j_{r} - 1} \\mathcal{I_{p, x, y}(p_{i}, x_{i}, y_{i})} + + If ``split_by`` is ``'number'``, then + + .. math:: + + j_{l} & = [\\frac{N}{M}] \\cdot j \\\\ + j_{r} & = \\begin{cases} [\\frac{N}{M}] \\cdot (j + 1), & j < M - 1 \\cr N, & j = M - 1 \\end{cases}\\\\ + F(j, p, x, y) & = \\sum_{i = j_{l}}^{j_{r} - 1} \\mathcal{I_{p, x, y}(p_{i}, x_{i}, y_{i})} + + where :math:`\\mathcal{I}` is the characteristic function,if and only if :math:`(p, x, y) = (p_{i}, x_{i}, y_{i})`, + this function is identically 1 else 0. + + If ``normalization`` is ``'frequency'``, + + if ``split_by`` is ``time``, + + .. math:: + F_{norm}(j, p, x, y) = \\begin{cases} \\frac{F(j, p, x, y)}{\\Delta T}, & j < M - 1 + \\cr \\frac{F(j, p, x, y)}{\\Delta T + (t_{N-1} - t_{0}) \\bmod M}, & j = M - 1 \\end{cases} + + if ``split_by`` is ``number``, + + .. math:: + F_{norm}(j, p, x, y) = \\frac{F(j, p, x, y)}{t_{j_{r}} - t_{j_{l}}} + + If ``normalization`` is ``'max'``, then + + .. math:: + F_{norm}(j, p, x, y) = \\frac{F(j, p, x, y)}{\\mathrm{max} F(j, p)} + + If ``normalization`` is ``'norm'``, then + + .. math:: + F_{norm}(j, p, x, y) = \\frac{F(j, p, x, y) - \\mathrm{E}(F(j, p))}{\\sqrt{\\mathrm{Var}(F(j, p))}} + + If ``normalization`` is ``'sum'``, then + + .. math:: + F_{norm}(j, p, x, y) = \\frac{F(j, p, x, y)}{\\sum_{a, b} F(j, p, a, b)} + ''' + frames = np.zeros(shape=[frames_num, 2, height * width]) + + # 创建j_{l}和j_{r} + j_l = np.zeros(shape=[frames_num], dtype=int) + j_r = np.zeros(shape=[frames_num], dtype=int) + if split_by == 'time': + events['t'] -= events['t'][0] # 时间从0开始 + assert events['t'][-1] > frames_num + dt = events['t'][-1] // frames_num # 每一段的持续时间 + idx = np.arange(events['t'].size) + for i in range(frames_num): + t_l = dt * i + t_r = t_l + dt + mask = np.logical_and(events['t'] >= t_l, events['t'] < t_r) + idx_masked = idx[mask] + j_l[i] = idx_masked[0] + j_r[i] = idx_masked[-1] + 1 if i < frames_num - 1 else events['t'].size + + elif split_by == 'number': + di = events['t'].size // frames_num + for i in range(frames_num): + j_l[i] = i * di + j_r[i] = j_l[i] + di if i < frames_num - 1 else events['t'].size + else: + raise NotImplementedError + + # 开始累计脉冲 + # 累计脉冲需要用bitcount而不能直接相加,原因可参考下面的示例代码,以及 + # https://stackoverflow.com/questions/15973827/handling-of-duplicate-indices-in-numpy-assignments + # height = 3 + # width = 3 + # frames = np.zeros(shape=[2, height, width]) + # events = { + # 'x': np.asarray([1, 2, 1, 1]), + # 'y': np.asarray([1, 1, 1, 2]), + # 'p': np.asarray([0, 1, 0, 1]) + # } + # + # frames[0, events['y'], events['x']] += (1 - events['p']) + # frames[1, events['y'], events['x']] += events['p'] + # print('wrong accumulation\n', frames) + # + # frames = np.zeros(shape=[2, height, width]) + # for i in range(events['p'].__len__()): + # frames[events['p'][i], events['y'][i], events['x'][i]] += 1 + # print('correct accumulation\n', frames) + # + # frames = np.zeros(shape=[2, height, width]) + # frames = frames.reshape(2, -1) + # + # mask = [events['p'] == 0] + # mask.append(np.logical_not(mask[0])) + # for i in range(2): + # position = events['y'][mask[i]] * height + events['x'][mask[i]] + # events_number_per_pos = np.bincount(position) + # idx = np.arange(events_number_per_pos.size) + # frames[i][idx] += events_number_per_pos + # frames = frames.reshape(2, height, width) + # print('correct accumulation by bincount\n', frames) + + for i in range(frames_num): + x = events['x'][j_l[i]:j_r[i]] + y = events['y'][j_l[i]:j_r[i]] + p = events['p'][j_l[i]:j_r[i]] + mask = [] + mask.append(p == 0) + mask.append(np.logical_not(mask[0])) + for j in range(2): + position = y[mask[j]] * height + x[mask[j]] + events_number_per_pos = np.bincount(position) + frames[i][j][np.arange(events_number_per_pos.size)] += events_number_per_pos + + if normalization == 'frequency': + if split_by == 'time': + if i < frames_num - 1: + frames[i] /= dt + else: + frames[i] /= (dt + events['t'][-1] % frames_num) + elif split_by == 'number': + frames[i] /= (events['t'][j_r[i]] - events['t'][j_l[i]]) # 表示脉冲发放的频率 + + else: + raise NotImplementedError + + # 其他的normalization方法,在数据集类读取数据的时候进行通过调用normalize_frame(frames: np.ndarray, normalization: str) + # 函数操作,而不是在转换数据的时候进行 + return frames.reshape((frames_num, 2, height, width)) + + +def normalize_frame(frames: np.ndarray or torch.Tensor, normalization: str): + eps = 1e-5 # 涉及到除法的地方,被除数加上eps,防止出现除以0 + for i in range(frames.shape[0]): + if normalization == 'max': + frames[i][0] /= max(frames[i][0].max(), eps) + frames[i][1] /= max(frames[i][1].max(), eps) + + elif normalization == 'norm': + frames[i][0] = (frames[i][0] - frames[i][0].mean()) / np.sqrt(max(frames[i][0].var(), eps)) + frames[i][1] = (frames[i][1] - frames[i][1].mean()) / np.sqrt(max(frames[i][1].var(), eps)) + + elif normalization == 'sum': + frames[i][0] /= max(frames[i][0].sum(), eps) + frames[i][1] /= max(frames[i][1].sum(), eps) + + else: + raise NotImplementedError + return frames + + +def convert_events_dir_to_frames_dir(events_data_dir, frames_data_dir, suffix, read_function, height, width, + frames_num=10, split_by='time', normalization=None, thread_num=1, compress=False): + # 遍历events_data_dir目录下的所有脉冲数据文件,在frames_data_dir目录下生成帧数据文件 + def cvt_fun(events_file_list): + for events_file in events_file_list: + frames = integrate_events_to_frames(read_function(events_file), height, width, frames_num, split_by, + normalization) + if compress: + frames_file = os.path.join(frames_data_dir, + os.path.basename(events_file)[0: -suffix.__len__()] + '.npz') + np.savez_compressed(frames_file, frames) + else: + frames_file = os.path.join(frames_data_dir, + os.path.basename(events_file)[0: -suffix.__len__()] + '.npy') + np.save(frames_file, frames) + + events_file_list = utils.list_files(events_data_dir, suffix, True) + if thread_num == 1: + cvt_fun(events_file_list) + else: + # 多线程加速 + thread_list = [] + block = events_file_list.__len__() // thread_num + for i in range(thread_num - 1): + thread_list.append(FunctionThread(cvt_fun, events_file_list[i * block: (i + 1) * block])) + thread_list[-1].start() + print(f'thread {i} start, processing files index: {i * block} : {(i + 1) * block}.') + thread_list.append(FunctionThread(cvt_fun, events_file_list[(thread_num - 1) * block:])) + thread_list[-1].start() + print( + f'thread {thread_num} start, processing files index: {(thread_num - 1) * block} : {events_file_list.__len__()}.') + for i in range(thread_num): + thread_list[i].join() + print(f'thread {i} finished.') + + +def extract_zip_in_dir(source_dir, target_dir): + ''' + :param source_dir: 保存有zip文件的文件夹 + :param target_dir: 保存zip解压后数据的文件夹 + :return: None + + 将 ``source_dir`` 目录下的所有*.zip文件,解压到 ``target_dir`` 目录下的对应文件夹内 + ''' + + for file_name in os.listdir(source_dir): + if file_name[-3:] == 'zip': + with zipfile.ZipFile(os.path.join(source_dir, file_name), 'r') as zip_file: + zip_file.extractall(os.path.join(target_dir, file_name[:-4])) + + +class EventsFramesDatasetBase(Dataset): + @staticmethod + def get_wh(): + ''' + :return: (width, height) + width: int + events或frames图像的宽度 + height: int + events或frames图像的高度 + :rtype: tuple + ''' + raise NotImplementedError + + @staticmethod + def read_bin(file_name: str): + ''' + :param file_name: 脉冲数据的文件名 + :type file_name: str + :return: events + 键是{'t', 'x', 'y', 'p'},值是np数组的的字典 + :rtype: dict + ''' + raise NotImplementedError + + @staticmethod + def get_events_item(file_name): + ''' + :param file_name: 脉冲数据的文件名 + :type file_name: str + :return: (events, label) + events: dict + 键是{'t', 'x', 'y', 'p'},值是np数组的的字典 + label: int + 数据的标签 + :rtype: tuple + ''' + raise NotImplementedError + + @staticmethod + def get_frames_item(file_name): + ''' + :param file_name: 帧数据的文件名 + :type file_name: str + :return: (frames, label) + frames: np.ndarray + ``shape = [frames_num, 2, height, width]`` 的np数组 + label: int + 数据的标签 + :rtype: tuple + ''' + raise NotImplementedError + + @staticmethod + def download_and_extract(download_root: str, extract_root: str): + ''' + :param download_root: 保存下载文件的文件夹 + :type download_root: str + :param extract_root: 保存解压后文件的文件夹 + :type extract_root: str + + 下载数据集到 ``download_root``,然后解压到 ``extract_root``。 + ''' + raise NotImplementedError + + @staticmethod + def create_frames_dataset(events_data_dir: str, frames_data_dir: str, frames_num: int, split_by: str, + normalization: str or None): + ''' + :param events_data_dir: 保存脉冲数据的文件夹,文件夹的文件全部是脉冲数据 + :type events_data_dir: str + :param frames_data_dir: 保存帧数据的文件夹 + :type frames_data_dir: str + :param frames_num: 转换后数据的帧数 + :type frames_num: int + :param split_by: 脉冲数据转换成帧数据的累计方式 + :type split_by: str + :param normalization: 归一化方法 + :type normalization: str or None + + 将 ``events_data_dir`` 文件夹下的脉冲数据全部转换成帧数据,并保存在 ``frames_data_dir``。 + 转换参数的详细含义,参见 ``integrate_events_to_frames`` 函数。 + ''' + raise NotImplementedError + + +import numpy as np +import os +from torchvision.datasets import utils +import torch + +labels_dict = { + 'airplane': 0, + 'automobile': 1, + 'bird': 2, + 'cat': 3, + 'deer': 4, + 'dog': 5, + 'frog': 6, + 'horse': 7, + 'ship': 8, + 'truck': 9 +} +# https://figshare.com/articles/dataset/CIFAR10-DVS_New/4724671 +resource = { + 'airplane': ('https://ndownloader.figshare.com/files/7712788', '0afd5c4bf9ae06af762a77b180354fdd'), + 'automobile': ('https://ndownloader.figshare.com/files/7712791', '8438dfeba3bc970c94962d995b1b9bdd'), + 'bird': ('https://ndownloader.figshare.com/files/7712794', 'a9c207c91c55b9dc2002dc21c684d785'), + 'cat': ('https://ndownloader.figshare.com/files/7712812', '52c63c677c2b15fa5146a8daf4d56687'), + 'deer': ('https://ndownloader.figshare.com/files/7712815', 'b6bf21f6c04d21ba4e23fc3e36c8a4a3'), + 'dog': ('https://ndownloader.figshare.com/files/7712818', 'f379ebdf6703d16e0a690782e62639c3'), + 'frog': ('https://ndownloader.figshare.com/files/7712842', 'cad6ed91214b1c7388a5f6ee56d08803'), + 'horse': ('https://ndownloader.figshare.com/files/7712851', 'e7cbbf77bec584ffbf913f00e682782a'), + 'ship': ('https://ndownloader.figshare.com/files/7712836', '41c7bd7d6b251be82557c6cce9a7d5c9'), + 'truck': ('https://ndownloader.figshare.com/files/7712839', '89f3922fd147d9aeff89e76a2b0b70a7') +} +# https://github.com/jackd/events-tfds/blob/master/events_tfds/data_io/aedat.py + + +EVT_DVS = 0 # DVS event type +EVT_APS = 1 # APS event + + +def read_bits(arr, mask=None, shift=None): + if mask is not None: + arr = arr & mask + if shift is not None: + arr = arr >> shift + return arr + + +y_mask = 0x7FC00000 +y_shift = 22 + +x_mask = 0x003FF000 +x_shift = 12 + +polarity_mask = 0x800 +polarity_shift = 11 + +valid_mask = 0x80000000 +valid_shift = 31 + + +def skip_header(fp): + p = 0 + lt = fp.readline() + ltd = lt.decode().strip() + while ltd and ltd[0] == "#": + p += len(lt) + lt = fp.readline() + try: + ltd = lt.decode().strip() + except UnicodeDecodeError: + break + return p + + +def load_raw_events(fp, + bytes_skip=0, + bytes_trim=0, + filter_dvs=False, + times_first=False): + p = skip_header(fp) + fp.seek(p + bytes_skip) + data = fp.read() + if bytes_trim > 0: + data = data[:-bytes_trim] + data = np.fromstring(data, dtype='>u4') + if len(data) % 2 != 0: + print(data[:20:2]) + print('---') + print(data[1:21:2]) + raise ValueError('odd number of data elements') + raw_addr = data[::2] + timestamp = data[1::2] + if times_first: + timestamp, raw_addr = raw_addr, timestamp + if filter_dvs: + valid = read_bits(raw_addr, valid_mask, valid_shift) == EVT_DVS + timestamp = timestamp[valid] + raw_addr = raw_addr[valid] + return timestamp, raw_addr + + +def parse_raw_address(addr, + x_mask=x_mask, + x_shift=x_shift, + y_mask=y_mask, + y_shift=y_shift, + polarity_mask=polarity_mask, + polarity_shift=polarity_shift): + polarity = read_bits(addr, polarity_mask, polarity_shift).astype(np.bool) + x = read_bits(addr, x_mask, x_shift) + y = read_bits(addr, y_mask, y_shift) + return x, y, polarity + + +def load_events( + fp, + filter_dvs=False, + # bytes_skip=0, + # bytes_trim=0, + # times_first=False, + **kwargs): + timestamp, addr = load_raw_events( + fp, + filter_dvs=filter_dvs, + # bytes_skip=bytes_skip, + # bytes_trim=bytes_trim, + # times_first=times_first + ) + x, y, polarity = parse_raw_address(addr, **kwargs) + return timestamp, x, y, polarity + + +class CIFAR10DVS(EventsFramesDatasetBase): + @staticmethod + def get_wh(): + return 128, 128 + + @staticmethod + def download_and_extract(download_root: str, extract_root: str): + for key in resource.keys(): + file_name = os.path.join(os.path.join(download_root, "download"), key + '.zip') + if os.path.exists(file_name): + if utils.check_md5(file_name, resource[key][1]): + print(f'extract {file_name} to {extract_root}') + utils.extract_archive(file_name, extract_root) + else: + print(f'{file_name} corrupted, re-download...') + utils.download_and_extract_archive(resource[key][0], download_root, extract_root, + filename=key + '.zip', + md5=resource[key][1]) + else: + utils.download_and_extract_archive(resource[key][0], download_root, extract_root, filename=key + '.zip', + md5=resource[key][1]) + + @staticmethod + def read_bin(file_name: str): + with open(file_name, 'rb') as fp: + t, x, y, p = load_events(fp, + x_mask=0xfE, + x_shift=1, + y_mask=0x7f00, + y_shift=8, + polarity_mask=1, + polarity_shift=None) + return {'t': t, 'x': 127 - x, 'y': y, 'p': 1 - p.astype(int)} + # 原作者的代码可能有一点问题,因此不是直接返回 t x y p + + @staticmethod + def create_frames_dataset(events_data_dir: str, frames_data_dir: str, frames_num: int, split_by: str, + normalization: str or None): + width, height = CIFAR10DVS.get_wh() + thread_list = [] + for key in resource.keys(): + source_dir = os.path.join(events_data_dir, key) + target_dir = os.path.join(frames_data_dir, key) + os.mkdir(target_dir) + print(f'mkdir {target_dir}') + print(f'convert {source_dir} to {target_dir}') + thread_list.append(FunctionThread( + convert_events_dir_to_frames_dir, + source_dir, target_dir, '.aedat', + CIFAR10DVS.read_bin, height, width, frames_num, split_by, normalization, 1, True)) + thread_list[-1].start() + print(f'thread {thread_list.__len__() - 1} start') + + for i in range(thread_list.__len__()): + thread_list[i].join() + print(f'thread {i} finished') + + @staticmethod + def get_frames_item(file_name): + return torch.from_numpy(np.load(file_name)['arr_0']).float(), labels_dict[file_name.split('_')[-2]] + + @staticmethod + def get_events_item(file_name): + return CIFAR10DVS.read_bin(file_name), labels_dict[file_name.split('_')[-2]] + + def __init__(self, root: str, train: bool, split_ratio=0.9, use_frame=True, frames_num=10, split_by='number', + normalization='max', transform=None): + ''' + :param root: 保存数据集的根目录 + :type root: str + :param train: 是否使用训练集 + :type train: bool + :param split_ratio: 分割比例。每一类中前split_ratio的数据会被用作训练集,剩下的数据为测试集 + :type split_ratio: float + :param use_frame: 是否将事件数据转换成帧数据 + :type use_frame: bool + :param frames_num: 转换后数据的帧数 + :type frames_num: int + :param split_by: 脉冲数据转换成帧数据的累计方式。``'time'`` 或 ``'number'`` + :type split_by: str + :param normalization: 归一化方法,为 ``None`` 表示不进行归一化; + 为 ``'frequency'`` 则每一帧的数据除以每一帧的累加的原始数据数量; + 为 ``'max'`` 则每一帧的数据除以每一帧中数据的最大值; + 为 ``norm`` 则每一帧的数据减去每一帧中的均值,然后除以标准差 + :type normalization: str or None + + CIFAR10 DVS数据集,出自 `CIFAR10-DVS: An Event-Stream Dataset for Object Classification `_, + 数据来源于DVS相机拍摄的显示器上的CIFAR10图片。原始数据的下载地址为 https://figshare.com/articles/dataset/CIFAR10-DVS_New/4724671。 + + 关于转换成帧数据的细节,参见 :func:`~spikingjelly.datasets.utils.integrate_events_to_frames`。 + ''' + super().__init__() + self.transform = transform + self.train = train + events_root = os.path.join(root, 'events') + if os.path.exists(events_root): + print(f'{events_root} already exists') + else: + self.download_and_extract(root, events_root) + + self.use_frame = use_frame + if use_frame: + self.normalization = normalization + if normalization == 'frequency': + dir_suffix = normalization + else: + dir_suffix = None + frames_root = os.path.join(root, f'frames_num_{frames_num}_split_by_{split_by}_normalization_{dir_suffix}') + if os.path.exists(frames_root): + print(f'{frames_root} already exists') + else: + os.mkdir(frames_root) + print(f'mkdir {frames_root}') + self.create_frames_dataset(events_root, frames_root, frames_num, split_by, normalization) + self.data_dir = frames_root if use_frame else events_root + + self.file_name = [] + if train: + index = np.arange(0, int(split_ratio * 1000)) + else: + index = np.arange(int(split_ratio * 1000), 1000) + + for class_name in labels_dict.keys(): + class_dir = os.path.join(self.data_dir, class_name) + for i in index: + if self.use_frame: + self.file_name.append(os.path.join(class_dir, 'cifar10_' + class_name + '_' + str(i) + '.npz')) + else: + self.file_name.append(os.path.join(class_dir, 'cifar10_' + class_name + '_' + str(i) + '.aedat')) + + def __len__(self): + return self.file_name.__len__() + + def __getitem__(self, index): + if self.use_frame: + frames, labels = self.get_frames_item(self.file_name[index]) + if self.normalization is not None and self.normalization != 'frequency': + frames = normalize_frame(frames, self.normalization) + if self.transform is not None: + frames = self.transform(frames) + return frames, labels + else: + return self.get_events_item(self.file_name[index]) + + +class CIFAR10_DVS_Aug(): + def __init__(self): + self.resize = transforms.Resize(size=(48, 48), interpolation=torchvision.transforms.InterpolationMode.NEAREST) + self.rotate = transforms.RandomRotation(degrees=30) + self.shearx = transforms.RandomAffine(degrees=0, shear=(-30, 30)) + + def __call__(self, data): + """ + Args: + index (int): Index + Returns: + tuple: (image, target) where target is index of the target class. + """ + data = self.resize(data.permute([3, 0, 1, 2])) + choices = ['roll', 'rotate', 'shear'] + aug = np.random.choice(choices) + if aug == 'roll': + off1 = random.randint(-5, 5) + off2 = random.randint(-5, 5) + data = torch.roll(data, shifts=(off1, off2), dims=(2, 3)) + if aug == 'rotate': + data = self.rotate(data) + if aug == 'shear': + data = self.shearx(data) + return data + + +class DVSCifar10(Dataset): + def __init__(self, root, train=True, transform=None, target_transform=None): + self.root = os.path.expanduser(root) + self.transform = transform + self.target_transform = target_transform + self.train = train + self.resize = transforms.Resize(size=(48, 48), + interpolation=torchvision.transforms.InterpolationMode.NEAREST) # 48 48 + self.rotate = transforms.RandomRotation(degrees=30) + self.shearx = transforms.RandomAffine(degrees=0, shear=(-30, 30)) + self.tensorx = transforms.ToTensor() + self.imgx = transforms.ToPILImage() + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: (image, target) where target is index of the target class. + """ + # print(index) + data, target = torch.load(self.root + '/{}.pt'.format(index)) + # if self.train: + new_data = [] + for t in range(data.shape[0]): + data_t = self.imgx(data[t, ...]) + data_t = self.resize(data_t) + data_t = self.tensorx(data_t) + # print(d_t.shape) + new_data.append(data_t) + + data = torch.stack(new_data, dim=0) + if self.transform is not None: + + choices = ['roll', 'rotate', 'shear'] + aug = np.random.choice(choices) + flip = random.random() > 0.5 + if flip: + data = torch.flip(data, dims=(3,)) + # off1 = random.randint(-5, 5) + # off2 = random.randint(-5, 5) + # data = torch.roll(data, shifts=(off1, off2), dims=(2, 3)) + + # if aug == 'roll': + off1 = random.randint(-5, 5) + off2 = random.randint(-5, 5) + data = torch.roll(data, shifts=(off1, off2), dims=(2, 3)) + + if aug == 'rotate': + data = self.rotate(data) + if aug == 'shear': + data = self.shearx(data) + + # print(target.shape) + if self.target_transform is not None: + target = self.target_transform(target) + # return data, target.long().squeeze(-1) + target = target.long()[0] + # target = target[0] if len(target) >= 1 else target + return data, target + + def __len__(self): + num = len(os.listdir(self.root)) + # print("data number is: ", num) + return num + + +if __name__ == '__main__': + pass diff --git a/utils/data_loaders.py b/utils/data_loaders.py new file mode 100644 index 0000000..219abd8 --- /dev/null +++ b/utils/data_loaders.py @@ -0,0 +1,354 @@ +import sys + +import torch +import torchvision +import random +import torchvision.transforms as transforms +from torch.utils.data import Dataset, DataLoader +import warnings +import os +import numpy as np +from os.path import isfile, join +from PIL import Image + +warnings.filterwarnings('ignore') + + +class DVSCutout(object): + """Randomly mask out one or more patches from an image. + Args: + n_holes (int): Number of patches to cut out of each image. + length (int): The length (in pixels) of each square patch. + """ + + def __init__(self, length): + self.length = length + + def __call__(self, img): + h = img.size(2) + w = img.size(3) + mask = np.ones((h, w), np.float32) + y = np.random.randint(h) + x = np.random.randint(w) + y1 = np.clip(y - self.length // 2, 0, h) + y2 = np.clip(y + self.length // 2, 0, h) + x1 = np.clip(x - self.length // 2, 0, w) + x2 = np.clip(x + self.length // 2, 0, w) + mask[y1: y2, x1: x2] = 0. + mask = torch.from_numpy(mask) + mask = mask.expand_as(img) + img = img * mask + return img + + +class NCaltech101(Dataset): + def __init__(self, data_path='data/n-caltech/frames_number_10_split_by_number', + data_type='train', transform=False): + + self.filepath = os.path.join(data_path) + self.clslist = os.listdir(self.filepath) + self.clslist.sort() + + self.dvs_filelist = [] + self.targets = [] + self.resize = transforms.Resize(size=(48, 48), interpolation=torchvision.transforms.InterpolationMode.NEAREST) + + for i, cls in enumerate(self.clslist): + # print (i, cls) + file_list = os.listdir(os.path.join(self.filepath, cls)) + num_file = len(file_list) + + cut_idx = int(num_file * 0.9) + train_file_list = file_list[:cut_idx] + test_split_list = file_list[cut_idx:] + for file in file_list: + if data_type == 'train': + if file in train_file_list: + self.dvs_filelist.append(os.path.join(self.filepath, cls, file)) + self.targets.append(i) + else: + if file in test_split_list: + self.dvs_filelist.append(os.path.join(self.filepath, cls, file)) + self.targets.append(i) + + self.data_num = len(self.dvs_filelist) + self.data_type = data_type + if data_type != 'train': + counts = np.unique(np.array(self.targets), return_counts=True)[1] + class_weights = counts.sum() / (counts * len(counts)) + self.class_weights = torch.Tensor(class_weights) + self.classes = range(101) + self.transform = transform + self.rotate = transforms.RandomRotation(degrees=15) + self.shearx = transforms.RandomAffine(degrees=0, shear=(-15, 15)) + + def __getitem__(self, index): + file_pth = self.dvs_filelist[index] + label = self.targets[index] + data = torch.from_numpy(np.load(file_pth)['frames']).float() + data = self.resize(data) + + if self.transform: + + choices = ['roll', 'rotate', 'shear'] + aug = np.random.choice(choices) + if aug == 'roll': + off1 = random.randint(-3, 3) + off2 = random.randint(-3, 3) + data = torch.roll(data, shifts=(off1, off2), dims=(2, 3)) + if aug == 'rotate': + data = self.rotate(data) + if aug == 'shear': + data = self.shearx(data) + + return data, label + + def __len__(self): + return self.data_num + + +def build_ncaltech(transform=False): + train_dataset = NCaltech101(transform=transform) + val_dataset = NCaltech101(data_type='test', transform=False) + + return train_dataset, val_dataset + + +class DVSCifar10(Dataset): + def __init__(self, data_path='data/dvscifar/frames_number_10_split_by_number', + data_type='train', transform=False): + + self.filepath = os.path.join(data_path) + self.clslist = os.listdir(self.filepath) + self.clslist.sort() + + self.dvs_filelist = [] + self.targets = [] + self.resize = transforms.Resize(size=(48, 48), interpolation=torchvision.transforms.InterpolationMode.NEAREST) + + for i, cls in enumerate(self.clslist): + # print (i, cls) + file_list = os.listdir(os.path.join(self.filepath, cls)) + num_file = len(file_list) + + cut_idx = int(num_file * 0.9) + train_file_list = file_list[:cut_idx] + test_split_list = file_list[cut_idx:] + for file in file_list: + if data_type == 'train': + if file in train_file_list: + self.dvs_filelist.append(os.path.join(self.filepath, cls, file)) + self.targets.append(i) + else: + if file in test_split_list: + self.dvs_filelist.append(os.path.join(self.filepath, cls, file)) + self.targets.append(i) + + self.data_num = len(self.dvs_filelist) + self.data_type = data_type + if data_type != 'train': + counts = np.unique(np.array(self.targets), return_counts=True)[1] + class_weights = counts.sum() / (counts * len(counts)) + self.class_weights = torch.Tensor(class_weights) + self.classes = range(10) + self.transform = transform + self.rotate = transforms.RandomRotation(degrees=15) + self.shearx = transforms.RandomAffine(degrees=0, shear=(-15, 15)) + + def __getitem__(self, index): + file_pth = self.dvs_filelist[index] + label = self.targets[index] + data = torch.from_numpy(np.load(file_pth)['frames']).float() + data = self.resize(data) + + if self.transform: + + choices = ['roll', 'rotate', 'shear'] + aug = np.random.choice(choices) + if aug == 'roll': + off1 = random.randint(-3, 3) + off2 = random.randint(-3, 3) + data = torch.roll(data, shifts=(off1, off2), dims=(2, 3)) + if aug == 'rotate': + data = self.rotate(data) + if aug == 'shear': + data = self.shearx(data) + + return data, label + + def __len__(self): + return self.data_num + + +def transformPolicy(): + def __init__(self, ): + self.resize = transforms.Resize(size=(48, 48), interpolation=torchvision.transforms.InterpolationMode.NEAREST) + + +def build_dvscifar(path='data/cifar-dvs', transform=False): + train_dataset = DVSCifar10(data_path=path, data_type='train', transform=transform) + val_dataset = DVSCifar10(data_path=path, transform=False) + + return train_dataset, val_dataset + + +def mixup_criterion(criterion, pred, y_a, y_b, lam): + return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) + + +def rand_bbox(size, lam): + W = size[3] + H = size[4] + cut_rat = np.sqrt(1. - lam) + cut_w = np.int(W * cut_rat) + cut_h = np.int(H * cut_rat) + + # uniform + cx = np.random.randint(W) + cy = np.random.randint(H) + + bbx1 = np.clip(cx - cut_w // 2, 0, W) + bby1 = np.clip(cy - cut_h // 2, 0, H) + bbx2 = np.clip(cx + cut_w // 2, 0, W) + bby2 = np.clip(cy + cut_h // 2, 0, H) + + return bbx1, bby1, bbx2, bby2 + + +def cutmix_data(input, target, alpha=1.0): + lam = np.random.beta(alpha, alpha) + rand_index = torch.randperm(input.size()[0]).cuda() + + target_a = target + target_b = target[rand_index] + + # generate mixed sample + bbx1, bby1, bbx2, bby2 = rand_bbox(input.size(), lam) + input[:, :, :, bbx1:bbx2, bby1:bby2] = input[rand_index, :, :, bbx1:bbx2, bby1:bby2] + # adjust lambda to exactly match pixel ratio + lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2])) + return input, target_a, target_b, lam + + +class TinyImageNet(Dataset): + def __init__(self, root, train=True, transform=None): + self.Train = train + self.root_dir = root + self.transform = transform + self.train_dir = os.path.join(self.root_dir, "train") + self.val_dir = os.path.join(self.root_dir, "val") + + if (self.Train): + self._create_class_idx_dict_train() + else: + self._create_class_idx_dict_val() + + self._make_dataset(self.Train) + + words_file = os.path.join(self.root_dir, "words.txt") + wnids_file = os.path.join(self.root_dir, "wnids.txt") + + self.set_nids = set() + + with open(wnids_file, 'r') as fo: + data = fo.readlines() + for entry in data: + self.set_nids.add(entry.strip("\n")) + + self.class_to_label = {} + with open(words_file, 'r') as fo: + data = fo.readlines() + for entry in data: + words = entry.split("\t") + if words[0] in self.set_nids: + self.class_to_label[words[0]] = (words[1].strip("\n").split(","))[0] + + def _create_class_idx_dict_train(self): + if sys.version_info >= (3, 5): + classes = [d.name for d in os.scandir(self.train_dir) if d.is_dir()] + else: + classes = [d for d in os.listdir(self.train_dir) if os.path.isdir(os.path.join(self.train_dir, d))] + classes = sorted(classes) + num_images = 0 + for root, dirs, files in os.walk(self.train_dir): + for f in files: + if f.endswith(".JPEG"): + num_images = num_images + 1 + + self.len_dataset = num_images + + self.tgt_idx_to_class = {i: classes[i] for i in range(len(classes))} + self.class_to_tgt_idx = {classes[i]: i for i in range(len(classes))} + + def _create_class_idx_dict_val(self): + val_image_dir = os.path.join(self.val_dir, "images") + if sys.version_info >= (3, 5): + images = [d.name for d in os.scandir(val_image_dir) if d.is_file()] + else: + images = [d for d in os.listdir(val_image_dir) if os.path.isfile(os.path.join(self.val_dir, d))] + val_annotations_file = os.path.join(self.val_dir, "val_annotations.txt") + self.val_img_to_class = {} + set_of_classes = set() + with open(val_annotations_file, 'r') as fo: + entry = fo.readlines() + for data in entry: + words = data.split("\t") + self.val_img_to_class[words[0]] = words[1] + set_of_classes.add(words[1]) + + self.len_dataset = len(list(self.val_img_to_class.keys())) + classes = sorted(list(set_of_classes)) + # self.idx_to_class = {i:self.val_img_to_class[images[i]] for i in range(len(images))} + self.class_to_tgt_idx = {classes[i]: i for i in range(len(classes))} + self.tgt_idx_to_class = {i: classes[i] for i in range(len(classes))} + + def _make_dataset(self, Train=True): + self.images = [] + if Train: + img_root_dir = self.train_dir + list_of_dirs = [target for target in self.class_to_tgt_idx.keys()] + else: + img_root_dir = self.val_dir + list_of_dirs = ["images"] + + for tgt in list_of_dirs: + dirs = os.path.join(img_root_dir, tgt) + if not os.path.isdir(dirs): + continue + + for root, _, files in sorted(os.walk(dirs)): + for fname in sorted(files): + if (fname.endswith(".JPEG")): + path = os.path.join(root, fname) + if Train: + item = (path, self.class_to_tgt_idx[tgt]) + else: + item = (path, self.class_to_tgt_idx[self.val_img_to_class[fname]]) + self.images.append(item) + + def return_label(self, idx): + return [self.class_to_label[self.tgt_idx_to_class[i.item()]] for i in idx] + + def __len__(self): + return self.len_dataset + + def __getitem__(self, idx): + img_path, tgt = self.images[idx] + with open(img_path, 'rb') as f: + sample = Image.open(img_path) + sample = sample.convert('RGB') + if self.transform is not None: + sample = self.transform(sample) + + return sample, tgt + + +if __name__ == '__main__': + # choices = ['roll', 'rotate', 'shear'] + # aug = np.random.choice(choices) + # print(aug) + # train_dataset, val_dataset = build_dvscifar( + # path='./data_dir', + # transform=True) + # train_dataset[0] + pass diff --git a/utils/eval.py b/utils/eval.py new file mode 100644 index 0000000..7cdb5bf --- /dev/null +++ b/utils/eval.py @@ -0,0 +1,19 @@ +from __future__ import print_function, absolute_import + +__all__ = ['accuracy'] + + +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res diff --git a/utils/image_augment.py b/utils/image_augment.py new file mode 100644 index 0000000..d13b456 --- /dev/null +++ b/utils/image_augment.py @@ -0,0 +1,324 @@ +from PIL import Image, ImageEnhance, ImageOps +import numpy as np +import random + +import torch + +""" +Reference: + xxPolicy: Cubuk, E. D., Zoph, B., Mane, D., Vasudevan, V., & Le, Q. V. (2019). Autoaugment: Learning augmentation strategies from data. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 113-123). + Cutout: DeVries, Terrance, and Graham W. Taylor. "Improved regularization of convolutional neural networks with cutout." arXiv preprint arXiv:1708.04552 (2017). +""" + + +class Cutout(object): + """Randomly mask out one or more patches from an image. + Args: + n_holes (int): Number of patches to cut out of each image. + length (int): The length (in pixels) of each square patch. + """ + + def __init__(self, n_holes, length): + self.n_holes = n_holes + self.length = length + + def __call__(self, img): + """ + Args: + img (Tensor): Tensor image of size (C, H, W). + Returns: + Tensor: Image with n_holes of dimension length x length cut out of it. + """ + h = img.size(1) + w = img.size(2) + + mask = np.ones((h, w), np.float32) + + for n in range(self.n_holes): + y = np.random.randint(h) + x = np.random.randint(w) + + y1 = np.clip(y - self.length // 2, 0, h) + y2 = np.clip(y + self.length // 2, 0, h) + x1 = np.clip(x - self.length // 2, 0, w) + x2 = np.clip(x + self.length // 2, 0, w) + + mask[y1: y2, x1: x2] = 0. + + mask = torch.from_numpy(mask) + mask = mask.expand_as(img) + img = img * mask + + return img + + +class ImageNetPolicy(object): + """ Randomly choose one of the best 24 Sub-policies on ImageNet. + + Example: + >>> policy = ImageNetPolicy() + >>> transformed = policy(image) + + Example as a PyTorch Transform: + >>> transform=transforms.Compose([ + >>> transforms.Resize(256), + >>> ImageNetPolicy(), + >>> transforms.ToTensor()]) + """ + + def __init__(self, fillcolor=(128, 128, 128)): + self.policies = [ + SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), + SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), + SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), + SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), + SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), + + SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), + SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), + SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), + SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), + SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), + + SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), + SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), + SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), + SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), + SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), + + SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), + SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), + SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), + SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), + SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), + + SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), + SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), + SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), + SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor) + ] + + def __call__(self, img): + policy_idx = random.randint(0, len(self.policies) - 1) + return self.policies[policy_idx](img) + + def __repr__(self): + return "AutoAugment ImageNet Policy" + + +class CIFAR10Policy(object): + """ Randomly choose one of the best 25 Sub-policies on CIFAR10. + + Example: + >>> policy = CIFAR10Policy() + >>> transformed = policy(image) + + Example as a PyTorch Transform: + >>> transform=transforms.Compose([ + >>> transforms.Resize(256), + >>> CIFAR10Policy(), + >>> transforms.ToTensor()]) + """ + + def __init__(self, fillcolor=(128, 128, 128)): + self.policies = [ + SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), + SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), + SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), + SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), + SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), + + SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), + SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), + SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), + SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), + SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), + + SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), + SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), + SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), + SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), + SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), + + SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), + SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor), + SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), + SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), + SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), + + SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), + SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), + SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), + SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), + SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) + ] + + def __call__(self, img): + policy_idx = random.randint(0, len(self.policies) - 1) + return self.policies[policy_idx](img) + + def __repr__(self): + return "AutoAugment CIFAR10 Policy" + + +class SVHNPolicy(object): + """ Randomly choose one of the best 25 Sub-policies on SVHN. + + Example: + >>> policy = SVHNPolicy() + >>> transformed = policy(image) + + Example as a PyTorch Transform: + >>> transform=transforms.Compose([ + >>> transforms.Resize(256), + >>> SVHNPolicy(), + >>> transforms.ToTensor()]) + """ + + def __init__(self, fillcolor=(128, 128, 128)): + self.policies = [ + SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), + SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), + SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), + SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), + SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), + + SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), + SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), + SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), + SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), + SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), + + SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), + SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), + SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), + SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), + SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), + + SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), + SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), + SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), + SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), + SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), + + SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), + SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), + SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), + SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), + SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) + ] + + def __call__(self, img): + policy_idx = random.randint(0, len(self.policies) - 1) + return self.policies[policy_idx](img) + + def __repr__(self): + return "AutoAugment SVHN Policy" + + +def rotate_with_fill(img, magnitude): + rot = img.convert("RGBA").rotate(magnitude) + return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) + + +class SubPolicy(object): + def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): + ranges = { + "shearX": np.linspace(0, 0.3, 10), + "shearY": np.linspace(0, 0.3, 10), + "translateX": np.linspace(0, 150 / 331, 10), + "translateY": np.linspace(0, 150 / 331, 10), + "rotate": np.linspace(0, 30, 10), + "color": np.linspace(0.0, 0.9, 10), + "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), + "solarize": np.linspace(256, 0, 10), + "contrast": np.linspace(0.0, 0.9, 10), + "sharpness": np.linspace(0.0, 0.9, 10), + "brightness": np.linspace(0.0, 0.9, 10), + "autocontrast": [0] * 10, + "equalize": [0] * 10, + "invert": [0] * 10 + } + + # func = { + # "shearX": lambda img, magnitude: img.transform( + # img.size, Image.AFFINE, (1, magnitude * + # random.choice([-1, 1]), 0, 0, 1, 0), + # Image.BICUBIC, fillcolor=fillcolor), + # "shearY": lambda img, magnitude: img.transform( + # img.size, Image.AFFINE, (1, 0, 0, magnitude * + # random.choice([-1, 1]), 1, 0), + # Image.BICUBIC, fillcolor=fillcolor), + # "translateX": lambda img, magnitude: img.transform( + # img.size, Image.AFFINE, (1, 0, magnitude * + # img.size[0] * random.choice([-1, 1]), 0, 1, 0), + # fillcolor=fillcolor), + # "translateY": lambda img, magnitude: img.transform( + # img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * + # img.size[1] * random.choice([-1, 1])), + # fillcolor=fillcolor), + # "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), + # # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), + # "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), + # "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), + # "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), + # "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( + # 1 + magnitude * random.choice([-1, 1])), + # "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( + # 1 + magnitude * random.choice([-1, 1])), + # "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( + # 1 + magnitude * random.choice([-1, 1])), + # "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), + # "equalize": lambda img, magnitude: ImageOps.equalize(img), + # "invert": lambda img, magnitude: ImageOps.invert(img) + # } + + func = { + "shearX": lambda img, magnitude: img.transform( + img.size, Image.AFFINE, (1, magnitude * + random.choice([-1, 1]), 0, 0, 1, 0), + Image.BICUBIC, fillcolor=fillcolor), + "shearY": lambda img, magnitude: img.transform( + img.size, Image.AFFINE, (1, 0, 0, magnitude * + random.choice([-1, 1]), 1, 0), + Image.BICUBIC, fillcolor=fillcolor), + "translateX": lambda img, magnitude: img.transform( + img.size, Image.AFFINE, (1, 0, magnitude * + img.size[0] * random.choice([-1, 1]), 0, 1, 0), + fillcolor=fillcolor), + "translateY": lambda img, magnitude: img.transform( + img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * + img.size[1] * random.choice([-1, 1])), + fillcolor=fillcolor), + "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), + # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), + "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), + "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), + "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), + "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( + 1 + magnitude * random.choice([-1, 1])), + "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( + 1 + magnitude * random.choice([-1, 1])), + "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( + 1 + magnitude * random.choice([-1, 1])), + "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), + "equalize": lambda img, magnitude: ImageOps.equalize(img), + "invert": lambda img, magnitude: ImageOps.invert(img) + } + + # self.name = "{}_{:.2f}_and_{}_{:.2f}".format( + # operation1, ranges[operation1][magnitude_idx1], + # operation2, ranges[operation2][magnitude_idx2]) + self.p1 = p1 + self.operation1 = func[operation1] + self.magnitude1 = ranges[operation1][magnitude_idx1] + self.p2 = p2 + self.operation2 = func[operation2] + self.magnitude2 = ranges[operation2][magnitude_idx2] + + def __call__(self, img): + if random.random() < self.p1: + img = self.operation1(img, self.magnitude1) + if random.random() < self.p2: + img = self.operation2(img, self.magnitude2) + return img diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000..f431aa4 --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,126 @@ +from __future__ import absolute_import + +import matplotlib.pyplot as plt +import numpy as np + +__all__ = ['Logger', 'LoggerMonitor', 'savefig'] + + +def savefig(fname, dpi=None): + dpi = 150 if dpi == None else dpi + plt.savefig(fname, dpi=dpi) + + +def plot_overlap(logger, names=None): + names = logger.names if names == None else names + numbers = logger.numbers + for _, name in enumerate(names): + x = np.arange(len(numbers[name])) + plt.plot(x, np.asarray(numbers[name])) + return [logger.title + '(' + name + ')' for name in names] + + +class Logger(object): + '''Save training process to log file with simple plot function.''' + + def __init__(self, fpath, title=None, resume=False): + self.file = None + self.resume = resume + self.title = '' if title == None else title + if fpath is not None: + if resume: + self.file = open(fpath, 'r') + name = self.file.readline() + self.names = name.rstrip().split('\t') + self.numbers = {} + for _, name in enumerate(self.names): + self.numbers[name] = [] + + for numbers in self.file: + numbers = numbers.rstrip().split('\t') + for i in range(0, len(numbers)): + self.numbers[self.names[i]].append(numbers[i]) + self.file.close() + self.file = open(fpath, 'a') + else: + self.file = open(fpath, 'w') + + def set_names(self, names): + if self.resume: + pass + # initialize numbers as empty list + self.numbers = {} + self.names = names + for _, name in enumerate(self.names): + self.file.write(name) + self.file.write('\t') + self.numbers[name] = [] + self.file.write('\n') + self.file.flush() + + def append(self, numbers): + assert len(self.names) == len(numbers), 'Numbers do not match names' + for index, num in enumerate(numbers): + self.file.write("{0:.6f}".format(num)) + self.file.write('\t') + self.numbers[self.names[index]].append(num) + self.file.write('\n') + self.file.flush() + + def plot(self, names=None): + names = self.names if names == None else names + numbers = self.numbers + for _, name in enumerate(names): + x = np.arange(len(numbers[name])) + plt.plot(x, np.asarray(numbers[name])) + plt.legend([self.title + '(' + name + ')' for name in names]) + plt.grid(True) + + def close(self): + if self.file is not None: + self.file.close() + + +class LoggerMonitor(object): + '''Load and visualize multiple logs.''' + + def __init__(self, paths): + '''paths is a distionary with {name:filepath} pair''' + self.loggers = [] + for title, path in paths.items(): + logger = Logger(path, title=title, resume=True) + self.loggers.append(logger) + + def plot(self, names=None): + plt.figure() + plt.subplot(121) + legend_text = [] + for logger in self.loggers: + legend_text += plot_overlap(logger, names) + plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) + plt.grid(True) + + +if __name__ == '__main__': + # # Example + # logger = Logger('test.txt') + # logger.set_names(['Train loss', 'Valid loss','Test loss']) + + # length = 100 + # t = np.arange(length) + # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 + # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 + # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 + + # for i in range(0, length): + # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) + # logger.plot() + + # Example: logger monitor + paths = "./path" + + field = ['Valid Acc.'] + + monitor = LoggerMonitor(paths) + monitor.plot(names=field) + savefig('test.eps') diff --git a/utils/misc.py b/utils/misc.py new file mode 100644 index 0000000..b3897ab --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,77 @@ +'''Some helper functions for PyTorch, including: + - get_mean_and_std: calculate the mean and std value of dataset. + - msr_init: net parameter initialization. + - progress_bar: progress bar mimic xlua.progress. +''' +import errno +import os + +import torch +import torch.nn as nn +import torch.nn.init as init + +__all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter'] + + +def get_mean_and_std(dataset): + '''Compute the mean and std value of dataset.''' + dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) + + mean = torch.zeros(3) + std = torch.zeros(3) + print('==> Computing mean and std..') + for inputs, targets in dataloader: + for i in range(3): + mean[i] += inputs[:, i, :, :].mean() + std[i] += inputs[:, i, :, :].std() + mean.div_(len(dataset)) + std.div_(len(dataset)) + return mean, std + + +def init_params(net): + '''Init layer parameters.''' + for m in net.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal(m.weight, mode='fan_out') + if m.bias: + init.constant(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + init.constant(m.weight, 1) + init.constant(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal(m.weight, std=1e-3) + if m.bias: + init.constant(m.bias, 0) + + +def mkdir_p(path): + '''make dir if not exist''' + try: + os.makedirs(path) + except OSError as exc: # Python >2.5 + if exc.errno == errno.EEXIST and os.path.isdir(path): + pass + else: + raise + + +class AverageMeter(object): + """Computes and stores the average and current value + Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 + """ + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count diff --git a/utils/read_log_cal_metrics.py b/utils/read_log_cal_metrics.py new file mode 100644 index 0000000..65b8094 --- /dev/null +++ b/utils/read_log_cal_metrics.py @@ -0,0 +1,94 @@ +import os + + +def extract_gb_data(line): + gb_index = line.find("after one epoch") + len("after one epoch: ") + gb_str = line[gb_index:].strip() + return gb_str + + +def seconds_to_hours(seconds): + hours = seconds / 3600 + return round(hours, 2) + + +def extract_total_time(filename): + total_time_sum = 0.0 + max_test_acc = 0.0 + min_test_loss = None + gb_data = 0.00 + epoch = 0. + with open(filename, 'r') as file: + + lines = file.readlines() + for line in lines: + if "Namespace" in line: + continue + + if "total_time" in line: + time_index = line.find("total_time") + len("total_time=") + time_str = line[time_index:].strip() + time_str = time_str.split(',')[0] + total_time_sum += float(time_str) + epoch += 1 + + if "max_test_acc" in line: + max_test_acc_index = line.find("max_test_acc") + len("max_test_acc=") + max_test_acc_str = line[max_test_acc_index:].strip() + max_test_acc_str = max_test_acc_str.split(',')[0] + max_test_acc = float(max_test_acc_str) + + if "test_loss" in line: + min_test_loss_index = line.find("test_loss") + len("test_loss=") + min_test_loss_str = line[min_test_loss_index:].strip() + min_test_loss_str = min_test_loss_str.split(',')[0] + min_test_loss = float(min_test_loss_str) + + if "after one epoch" in line: + gb_data = extract_gb_data(line) + gb_data = float(gb_data[:-2]) + + extract_info_file = "info.txt" + total_time = f"Total Time Sum: {seconds_to_hours(total_time_sum)} h" + time_pre_epoch = f"Time / epoch: {round(total_time_sum / epoch, 2)} s / epoch" if epoch != 0 else f"Time / epoch: None" + max_test_acc = f"Max Test Acc: {round(max_test_acc * 100, 2)} %" + min_test_loss = f"Min Test Loss: {round(min_test_loss, 4)}" if min_test_loss != None else f"Min Test Loss: None" + memory_use = f"Memory use / Epoch: {round(gb_data, 3)} GB" + print(total_time) + print(time_pre_epoch) + print(max_test_acc) + print(min_test_loss) + print(memory_use) + + new_file = os.path.join(os.path.dirname(filename), extract_info_file) + with open(new_file, "w") as file: + file.write(total_time + "\n") + file.write(time_pre_epoch + "\n") + file.write(max_test_acc + "\n") + file.write(min_test_loss + "\n") + file.write(memory_use + "\n") + + +if __name__ == '__main__': + + current_path = os.path.dirname(os.path.abspath(__file__)) + parent_path = os.path.dirname(current_path) + print("Parent Path:", parent_path) + print("Current File's Path:", current_path) + + log_dir = os.path.join(parent_path, "logs") + print(log_dir) + + # directory_folder = "/BPTT_DVSCIFAR10_spiking_vgg11_bn__T10_tau1.1_e300_bs128_SGD_lr0.05_wd0.0005_SG_triangle_drop0.3_losslamb0.05_CosALR_300_amp" + + # file = log_dir + directory_folder + "/args.txt" + # print(directory_folder) + # extract_total_time(file) + + for folder_name in os.listdir(log_dir): + folder_path = os.path.join(log_dir, folder_name) + file_path = folder_path + "/args.txt" + print(file_path) + if os.path.exists(file_path): + extract_total_time(file_path) + print("\n") diff --git a/utils/static_cifar_util.py b/utils/static_cifar_util.py new file mode 100644 index 0000000..83f49e2 --- /dev/null +++ b/utils/static_cifar_util.py @@ -0,0 +1,215 @@ +import math +from typing import Tuple + +import torch +import torchvision +from torch import Tensor +from torchvision import transforms +from torchvision.transforms import autoaugment, transforms +from torchvision.transforms.functional import InterpolationMode + + +class ClassificationPresetTrain: + def __init__( + self, + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + interpolation=InterpolationMode.BILINEAR, + hflip_prob=0.5, + auto_augment_policy=None, + random_erase_prob=0.0, + ): + trans = [] + if hflip_prob > 0: + trans.append(transforms.RandomHorizontalFlip(hflip_prob)) + if auto_augment_policy is not None: + if auto_augment_policy == "ra": + trans.append(autoaugment.RandAugment(interpolation=interpolation)) + elif auto_augment_policy == "ta_wide": + trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation)) + else: + aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) + trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation)) + trans.extend( + [ + transforms.PILToTensor(), + transforms.ConvertImageDtype(torch.float), + transforms.Normalize(mean=mean, std=std), + ] + ) + if random_erase_prob > 0: + trans.append(transforms.RandomErasing(p=random_erase_prob)) + + self.transforms = transforms.Compose(trans) + + def __call__(self, img): + return self.transforms(img) + + +class RandomMixup(torch.nn.Module): + """Randomly apply Mixup to the provided batch and targets. + The class implements the data augmentations as described in the paper + `"mixup: Beyond Empirical Risk Minimization" `_. + + Args: + num_classes (int): number of classes used for one-hot encoding. + p (float): probability of the batch being transformed. Default value is 0.5. + alpha (float): hyperparameter of the Beta distribution used for mixup. + Default value is 1.0. + inplace (bool): boolean to make this transform inplace. Default set to False. + """ + + def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None: + super().__init__() + assert num_classes > 0, "Please provide a valid positive value for the num_classes." + assert alpha > 0, "Alpha param can't be zero." + + self.num_classes = num_classes + self.p = p + self.alpha = alpha + self.inplace = inplace + + def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: + """ + Args: + batch (Tensor): Float tensor of size (B, C, H, W) + target (Tensor): Integer tensor of size (B, ) + + Returns: + Tensor: Randomly transformed batch. + """ + if batch.ndim != 4: + raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") + if target.ndim != 1: + raise ValueError(f"Target ndim should be 1. Got {target.ndim}") + if not batch.is_floating_point(): + raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") + if target.dtype != torch.int64: + raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") + + if not self.inplace: + batch = batch.clone() + target = target.clone() + + if target.ndim == 1: + target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype) + + if torch.rand(1).item() >= self.p: + return batch, target + + # It's faster to roll the batch by one instead of shuffling it to create image pairs + batch_rolled = batch.roll(1, 0) + target_rolled = target.roll(1, 0) + + # Implemented as on mixup paper, page 3. + lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) + batch_rolled.mul_(1.0 - lambda_param) + batch.mul_(lambda_param).add_(batch_rolled) + + target_rolled.mul_(1.0 - lambda_param) + target.mul_(lambda_param).add_(target_rolled) + + return batch, target + + def __repr__(self) -> str: + s = ( + f"{self.__class__.__name__}(" + f"num_classes={self.num_classes}" + f", p={self.p}" + f", alpha={self.alpha}" + f", inplace={self.inplace}" + f")" + ) + return s + + +class RandomCutmix(torch.nn.Module): + """Randomly apply Cutmix to the provided batch and targets. + The class implements the data augmentations as described in the paper + `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" + `_. + + Args: + num_classes (int): number of classes used for one-hot encoding. + p (float): probability of the batch being transformed. Default value is 0.5. + alpha (float): hyperparameter of the Beta distribution used for cutmix. + Default value is 1.0. + inplace (bool): boolean to make this transform inplace. Default set to False. + """ + + def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None: + super().__init__() + assert num_classes > 0, "Please provide a valid positive value for the num_classes." + assert alpha > 0, "Alpha param can't be zero." + + self.num_classes = num_classes + self.p = p + self.alpha = alpha + self.inplace = inplace + + def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: + """ + Args: + batch (Tensor): Float tensor of size (B, C, H, W) + target (Tensor): Integer tensor of size (B, ) + + Returns: + Tensor: Randomly transformed batch. + """ + if batch.ndim != 4: + raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") + if target.ndim != 1: + raise ValueError(f"Target ndim should be 1. Got {target.ndim}") + if not batch.is_floating_point(): + raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") + if target.dtype != torch.int64: + raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") + + if not self.inplace: + batch = batch.clone() + target = target.clone() + + if target.ndim == 1: + target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype) + + if torch.rand(1).item() >= self.p: + return batch, target + + # It's faster to roll the batch by one instead of shuffling it to create image pairs + batch_rolled = batch.roll(1, 0) + target_rolled = target.roll(1, 0) + + # Implemented as on cutmix paper, page 12 (with minor corrections on typos). + lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) + W, H = torchvision.transforms.functional.get_image_size(batch) + + r_x = torch.randint(W, (1,)) + r_y = torch.randint(H, (1,)) + + r = 0.5 * math.sqrt(1.0 - lambda_param) + r_w_half = int(r * W) + r_h_half = int(r * H) + + x1 = int(torch.clamp(r_x - r_w_half, min=0)) + y1 = int(torch.clamp(r_y - r_h_half, min=0)) + x2 = int(torch.clamp(r_x + r_w_half, max=W)) + y2 = int(torch.clamp(r_y + r_h_half, max=H)) + + batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] + lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) + + target_rolled.mul_(1.0 - lambda_param) + target.mul_(lambda_param).add_(target_rolled) + + return batch, target + + def __repr__(self) -> str: + s = ( + f"{self.__class__.__name__}(" + f"num_classes={self.num_classes}" + f", p={self.p}" + f", alpha={self.alpha}" + f", inplace={self.inplace}" + f")" + ) + return s diff --git a/utils/visualize.py b/utils/visualize.py new file mode 100644 index 0000000..2665718 --- /dev/null +++ b/utils/visualize.py @@ -0,0 +1,110 @@ +import matplotlib.pyplot as plt +import numpy as np +import torch +import torchvision + +__all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single'] + + +# functions to show an image +def make_image(img, mean=(0, 0, 0), std=(1, 1, 1)): + for i in range(0, 3): + img[i] = img[i] * std[i] + mean[i] # unnormalize + npimg = img.numpy() + return np.transpose(npimg, (1, 2, 0)) + + +def gauss(x, a, b, c): + return torch.exp(-torch.pow(torch.add(x, -b), 2).div(2 * c * c)).mul(a) + + +def colorize(x): + ''' Converts a one-channel grayscale image to a color heatmap image ''' + if x.dim() == 2: + torch.unsqueeze(x, 0, out=x) + if x.dim() == 3: + cl = torch.zeros([3, x.size(1), x.size(2)]) + cl[0] = gauss(x, .5, .6, .2) + gauss(x, 1, .8, .3) + cl[1] = gauss(x, 1, .5, .3) + cl[2] = gauss(x, 1, .2, .3) + cl[cl.gt(1)] = 1 + elif x.dim() == 4: + cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)]) + cl[:, 0, :, :] = gauss(x, .5, .6, .2) + gauss(x, 1, .8, .3) + cl[:, 1, :, :] = gauss(x, 1, .5, .3) + cl[:, 2, :, :] = gauss(x, 1, .2, .3) + return cl + + +def show_batch(images, Mean=(2, 2, 2), Std=(0.5, 0.5, 0.5)): + images = make_image(torchvision.utils.make_grid(images), Mean, Std) + plt.imshow(images) + plt.show() + + +def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5, 0.5, 0.5)): + im_size = images.size(2) + + # save for adding mask + im_data = images.clone() + for i in range(0, 3): + im_data[:, i, :, :] = im_data[:, i, :, :] * Std[i] + Mean[i] # unnormalize + + images = make_image(torchvision.utils.make_grid(images), Mean, Std) + plt.subplot(2, 1, 1) + plt.imshow(images) + plt.axis('off') + + # for b in range(mask.size(0)): + # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) + mask_size = mask.size(2) + # print('Max %f Min %f' % (mask.max(), mask.min())) + mask = (upsampling(mask, scale_factor=im_size / mask_size)) + # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) + # for c in range(3): + # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] + + # print(mask.size()) + mask = make_image(torchvision.utils.make_grid(0.3 * im_data + 0.7 * mask.expand_as(im_data))) + # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) + plt.subplot(2, 1, 2) + plt.imshow(mask) + plt.axis('off') + + +def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5, 0.5, 0.5)): + im_size = images.size(2) + + # save for adding mask + im_data = images.clone() + for i in range(0, 3): + im_data[:, i, :, :] = im_data[:, i, :, :] * Std[i] + Mean[i] # unnormalize + + images = make_image(torchvision.utils.make_grid(images), Mean, Std) + plt.subplot(1 + len(masklist), 1, 1) + plt.imshow(images) + plt.axis('off') + + for i in range(len(masklist)): + mask = masklist[i].data.cpu() + # for b in range(mask.size(0)): + # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) + mask_size = mask.size(2) + # print('Max %f Min %f' % (mask.max(), mask.min())) + mask = (upsampling(mask, scale_factor=im_size / mask_size)) + # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) + # for c in range(3): + # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] + + # print(mask.size()) + mask = make_image(torchvision.utils.make_grid(0.3 * im_data + 0.7 * mask.expand_as(im_data))) + # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) + plt.subplot(1 + len(masklist), 1, i + 2) + plt.imshow(mask) + plt.axis('off') + +# x = torch.zeros(1, 3, 3) +# out = colorize(x) +# out_im = make_image(out) +# plt.imshow(out_im) +# plt.show()