diff --git a/conf/global_settings.py b/conf/global_settings.py index 797f57db..daa98ace 100644 --- a/conf/global_settings.py +++ b/conf/global_settings.py @@ -37,7 +37,7 @@ #INIT_LR = 0.1 #time of we run the script -TIME_NOW = datetime.now().isoformat() +TIME_NOW = datetime.now().strftime("%F_%H-%M-%S.%f") #tensorboard log dir LOG_DIR = 'runs' diff --git a/train.py b/train.py index e9f8e2b3..f752e24e 100644 --- a/train.py +++ b/train.py @@ -35,188 +35,113 @@ from dataset import * from utils import * -args = cfg.parse_args() - -GPUdevice = torch.device('cuda', args.gpu_device) - -net = get_network(args, args.net, use_gpu=args.gpu, gpu_device=GPUdevice, distribution = args.distributed) -if args.pretrain: - weights = torch.load(args.pretrain) - net.load_state_dict(weights,strict=False) - -optimizer = optim.Adam(net.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False) -scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) #learning rate decay - -'''load pretrained model''' -if args.weights != 0: - print(f'=> resuming from {args.weights}') - assert os.path.exists(args.weights) - checkpoint_file = os.path.join(args.weights) - assert os.path.exists(checkpoint_file) - loc = 'cuda:{}'.format(args.gpu_device) - checkpoint = torch.load(checkpoint_file, map_location=loc) - start_epoch = checkpoint['epoch'] - best_tol = checkpoint['best_tol'] - - net.load_state_dict(checkpoint['state_dict'],strict=False) - # optimizer.load_state_dict(checkpoint['optimizer'], strict=False) - - args.path_helper = checkpoint['path_helper'] +def main(): + + args = cfg.parse_args() + + GPUdevice = torch.device('cuda', args.gpu_device) + + net = get_network(args, args.net, use_gpu=args.gpu, gpu_device=GPUdevice, distribution = args.distributed) + if args.pretrain: + weights = torch.load(args.pretrain) + net.load_state_dict(weights,strict=False) + + optimizer = optim.Adam(net.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False) + scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) #learning rate decay + + '''load pretrained model''' + if args.weights != 0: + print(f'=> resuming from {args.weights}') + assert os.path.exists(args.weights) + checkpoint_file = os.path.join(args.weights) + assert os.path.exists(checkpoint_file) + loc = 'cuda:{}'.format(args.gpu_device) + checkpoint = torch.load(checkpoint_file, map_location=loc) + start_epoch = checkpoint['epoch'] + best_tol = checkpoint['best_tol'] + + net.load_state_dict(checkpoint['state_dict'],strict=False) + # optimizer.load_state_dict(checkpoint['optimizer'], strict=False) + + args.path_helper = checkpoint['path_helper'] + logger = create_logger(args.path_helper['log_path']) + print(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})') + + args.path_helper = set_log_dir('logs', args.exp_name) logger = create_logger(args.path_helper['log_path']) - print(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})') - -args.path_helper = set_log_dir('logs', args.exp_name) -logger = create_logger(args.path_helper['log_path']) -logger.info(args) - - -# '''segmentation data''' -# transform_train = transforms.Compose([ -# transforms.Resize((args.image_size,args.image_size)), -# transforms.ToTensor(), -# ]) - -# transform_train_seg = transforms.Compose([ -# transforms.Resize((args.out_size,args.out_size)), -# transforms.ToTensor(), -# ]) - -# transform_test = transforms.Compose([ -# transforms.Resize((args.image_size, args.image_size)), -# transforms.ToTensor(), -# ]) - -# transform_test_seg = transforms.Compose([ -# transforms.Resize((args.out_size,args.out_size)), -# transforms.ToTensor(), -# ]) - -# transform_3d_seg = transforms.Compose([ -# transforms.ToTensor(), -# ]) - -# if args.dataset == 'isic': -# '''isic data''' -# isic_train_dataset = ISIC2016(args, args.data_path, transform = transform_train, transform_msk= transform_train_seg, mode = 'Training') -# isic_test_dataset = ISIC2016(args, args.data_path, transform = transform_test, transform_msk= transform_test_seg, mode = 'Test') - -# nice_train_loader = DataLoader(isic_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True) -# nice_test_loader = DataLoader(isic_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True) -# '''end''' - -# elif args.dataset == 'decathlon': -# nice_train_loader, nice_test_loader, transform_train, transform_val, train_list, val_list = get_decath_loader(args) - - -# elif args.dataset == 'REFUGE': -# '''REFUGE data''' -# refuge_train_dataset = REFUGE(args, args.data_path, transform = transform_3d_seg, transform_msk= transform_3d_seg, mode = 'Training') -# refuge_test_dataset = REFUGE(args, args.data_path, transform = transform_3d_seg, transform_msk= transform_3d_seg, mode = 'Test') - -# nice_train_loader = DataLoader(refuge_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True) -# nice_test_loader = DataLoader(refuge_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True) -# '''end''' - -# elif args.dataset == 'LIDC': -# '''LIDC data''' -# # dataset = LIDC(data_path = args.data_path) -# dataset = MyLIDC(args, data_path = args.data_path,transform = transform_train, transform_msk= transform_train_seg) - -# dataset_size = len(dataset) -# indices = list(range(dataset_size)) -# split = int(np.floor(0.2 * dataset_size)) -# np.random.shuffle(indices) -# train_sampler = SubsetRandomSampler(indices[split:]) -# test_sampler = SubsetRandomSampler(indices[:split]) - -# nice_train_loader = DataLoader(dataset, batch_size=args.b, sampler=train_sampler, num_workers=8, pin_memory=True) -# nice_test_loader = DataLoader(dataset, batch_size=args.b, sampler=test_sampler, num_workers=8, pin_memory=True) -# '''end''' - -# elif args.dataset == 'DDTI': -# '''REFUGE data''' -# refuge_train_dataset = DDTI(args, args.data_path, transform = transform_train, transform_msk= transform_train_seg, mode = 'Training') -# refuge_test_dataset = DDTI(args, args.data_path, transform = transform_test, transform_msk= transform_test_seg, mode = 'Test') - -# nice_train_loader = DataLoader(refuge_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True) -# nice_test_loader = DataLoader(refuge_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True) -# '''end''' - -# elif args.dataset == 'Brat': -# '''REFUGE data''' -# refuge_train_dataset = Brat(args, args.data_path, transform = transform_train, transform_msk= transform_train_seg, mode = 'Training') -# refuge_test_dataset = Brat(args, args.data_path, transform = transform_test, transform_msk= transform_test_seg, mode = 'Test') - -# nice_train_loader = DataLoader(refuge_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True) -# nice_test_loader = DataLoader(refuge_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True) -# '''end''' -nice_train_loader, nice_test_loader = get_dataloader(args) - -'''checkpoint path and tensorboard''' -# iter_per_epoch = len(Glaucoma_training_loader) -checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.net, settings.TIME_NOW) -#use tensorboard -if not os.path.exists(settings.LOG_DIR): - os.mkdir(settings.LOG_DIR) -writer = SummaryWriter(log_dir=os.path.join( - settings.LOG_DIR, args.net, settings.TIME_NOW)) -# input_tensor = torch.Tensor(args.b, 3, 256, 256).cuda(device = GPUdevice) -# writer.add_graph(net, Variable(input_tensor, requires_grad=True)) - -#create checkpoint folder to save model -if not os.path.exists(checkpoint_path): - os.makedirs(checkpoint_path) -checkpoint_path = os.path.join(checkpoint_path, '{net}-{epoch}-{type}.pth') - -'''begain training''' -best_acc = 0.0 -best_tol = 1e4 -best_dice = 0.0 - -for epoch in range(settings.EPOCH): - - if epoch and epoch < 5: - if args.dataset != 'REFUGE': - tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, epoch, net, writer) - logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.') - else: - tol, (eiou_cup, eiou_disc, edice_cup, edice_disc) = function.validation_sam(args, nice_test_loader, epoch, net, writer) - logger.info(f'Total score: {tol}, IOU_CUP: {eiou_cup}, IOU_DISC: {eiou_disc}, DICE_CUP: {edice_cup}, DICE_DISC: {edice_disc} || @ epoch {epoch}.') - - net.train() - time_start = time.time() - loss = function.train_sam(args, net, optimizer, nice_train_loader, epoch, writer, vis = args.vis) - logger.info(f'Train loss: {loss} || @ epoch {epoch}.') - time_end = time.time() - print('time_for_training ', time_end - time_start) - - net.eval() - if epoch and epoch % args.val_freq == 0 or epoch == settings.EPOCH-1: - if args.dataset != 'REFUGE': - tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, epoch, net, writer) - logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.') - else: - tol, (eiou_cup, eiou_disc, edice_cup, edice_disc) = function.validation_sam(args, nice_test_loader, epoch, net, writer) - logger.info(f'Total score: {tol}, IOU_CUP: {eiou_cup}, IOU_DISC: {eiou_disc}, DICE_CUP: {edice_cup}, DICE_DISC: {edice_disc} || @ epoch {epoch}.') - - if args.distributed != 'none': - sd = net.module.state_dict() - else: - sd = net.state_dict() - - if edice > best_dice: - best_tol = tol - is_best = True - - save_checkpoint({ - 'epoch': epoch + 1, - 'model': args.net, - 'state_dict': sd, - 'optimizer': optimizer.state_dict(), - 'best_tol': best_dice, - 'path_helper': args.path_helper, - }, is_best, args.path_helper['ckpt_path'], filename="best_dice_checkpoint.pth") - else: - is_best = False - -writer.close() + logger.info(args) + + nice_train_loader, nice_test_loader = get_dataloader(args) + + '''checkpoint path and tensorboard''' + # iter_per_epoch = len(Glaucoma_training_loader) + checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.net, settings.TIME_NOW) + #use tensorboard + if not os.path.exists(settings.LOG_DIR): + os.mkdir(settings.LOG_DIR) + writer = SummaryWriter(log_dir=os.path.join( + settings.LOG_DIR, args.net, settings.TIME_NOW)) + # input_tensor = torch.Tensor(args.b, 3, 256, 256).cuda(device = GPUdevice) + # writer.add_graph(net, Variable(input_tensor, requires_grad=True)) + + #create checkpoint folder to save model + if not os.path.exists(checkpoint_path): + os.makedirs(checkpoint_path) + checkpoint_path = os.path.join(checkpoint_path, '{net}-{epoch}-{type}.pth') + + '''begain training''' + best_acc = 0.0 + best_tol = 1e4 + best_dice = 0.0 + + for epoch in range(settings.EPOCH): + + if epoch and epoch < 5: + if args.dataset != 'REFUGE': + tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, epoch, net, writer) + logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.') + else: + tol, (eiou_cup, eiou_disc, edice_cup, edice_disc) = function.validation_sam(args, nice_test_loader, epoch, net, writer) + logger.info(f'Total score: {tol}, IOU_CUP: {eiou_cup}, IOU_DISC: {eiou_disc}, DICE_CUP: {edice_cup}, DICE_DISC: {edice_disc} || @ epoch {epoch}.') + + net.train() + time_start = time.time() + loss = function.train_sam(args, net, optimizer, nice_train_loader, epoch, writer, vis = args.vis) + logger.info(f'Train loss: {loss} || @ epoch {epoch}.') + time_end = time.time() + print('time_for_training ', time_end - time_start) + + net.eval() + if epoch and epoch % args.val_freq == 0 or epoch == settings.EPOCH-1: + if args.dataset != 'REFUGE': + tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, epoch, net, writer) + logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.') + else: + tol, (eiou_cup, eiou_disc, edice_cup, edice_disc) = function.validation_sam(args, nice_test_loader, epoch, net, writer) + logger.info(f'Total score: {tol}, IOU_CUP: {eiou_cup}, IOU_DISC: {eiou_disc}, DICE_CUP: {edice_cup}, DICE_DISC: {edice_disc} || @ epoch {epoch}.') + + if args.distributed != 'none': + sd = net.module.state_dict() + else: + sd = net.state_dict() + + if edice > best_dice: + best_tol = tol + is_best = True + + save_checkpoint({ + 'epoch': epoch + 1, + 'model': args.net, + 'state_dict': sd, + 'optimizer': optimizer.state_dict(), + 'best_tol': best_dice, + 'path_helper': args.path_helper, + }, is_best, args.path_helper['ckpt_path'], filename="best_dice_checkpoint.pth") + else: + is_best = False + + writer.close() + + +if __name__ == '__main__': + main()