Skip to content

Commit

Permalink
Merge pull request #110 from dzenanz/windowsFixes
Browse files Browse the repository at this point in the history
Windows-related fixes
  • Loading branch information
WuJunde authored May 21, 2024
2 parents ac6d4f8 + 5ca0eed commit 721c4f0
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 185 deletions.
2 changes: 1 addition & 1 deletion conf/global_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
#INIT_LR = 0.1

#time of we run the script
TIME_NOW = datetime.now().isoformat()
TIME_NOW = datetime.now().strftime("%F_%H-%M-%S.%f")

#tensorboard log dir
LOG_DIR = 'runs'
Expand Down
293 changes: 109 additions & 184 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,188 +35,113 @@
from dataset import *
from utils import *

args = cfg.parse_args()

GPUdevice = torch.device('cuda', args.gpu_device)

net = get_network(args, args.net, use_gpu=args.gpu, gpu_device=GPUdevice, distribution = args.distributed)
if args.pretrain:
weights = torch.load(args.pretrain)
net.load_state_dict(weights,strict=False)

optimizer = optim.Adam(net.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) #learning rate decay

'''load pretrained model'''
if args.weights != 0:
print(f'=> resuming from {args.weights}')
assert os.path.exists(args.weights)
checkpoint_file = os.path.join(args.weights)
assert os.path.exists(checkpoint_file)
loc = 'cuda:{}'.format(args.gpu_device)
checkpoint = torch.load(checkpoint_file, map_location=loc)
start_epoch = checkpoint['epoch']
best_tol = checkpoint['best_tol']

net.load_state_dict(checkpoint['state_dict'],strict=False)
# optimizer.load_state_dict(checkpoint['optimizer'], strict=False)

args.path_helper = checkpoint['path_helper']
def main():

args = cfg.parse_args()

GPUdevice = torch.device('cuda', args.gpu_device)

net = get_network(args, args.net, use_gpu=args.gpu, gpu_device=GPUdevice, distribution = args.distributed)
if args.pretrain:
weights = torch.load(args.pretrain)
net.load_state_dict(weights,strict=False)

optimizer = optim.Adam(net.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) #learning rate decay

'''load pretrained model'''
if args.weights != 0:
print(f'=> resuming from {args.weights}')
assert os.path.exists(args.weights)
checkpoint_file = os.path.join(args.weights)
assert os.path.exists(checkpoint_file)
loc = 'cuda:{}'.format(args.gpu_device)
checkpoint = torch.load(checkpoint_file, map_location=loc)
start_epoch = checkpoint['epoch']
best_tol = checkpoint['best_tol']

net.load_state_dict(checkpoint['state_dict'],strict=False)
# optimizer.load_state_dict(checkpoint['optimizer'], strict=False)

args.path_helper = checkpoint['path_helper']
logger = create_logger(args.path_helper['log_path'])
print(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})')

args.path_helper = set_log_dir('logs', args.exp_name)
logger = create_logger(args.path_helper['log_path'])
print(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})')

args.path_helper = set_log_dir('logs', args.exp_name)
logger = create_logger(args.path_helper['log_path'])
logger.info(args)


# '''segmentation data'''
# transform_train = transforms.Compose([
# transforms.Resize((args.image_size,args.image_size)),
# transforms.ToTensor(),
# ])

# transform_train_seg = transforms.Compose([
# transforms.Resize((args.out_size,args.out_size)),
# transforms.ToTensor(),
# ])

# transform_test = transforms.Compose([
# transforms.Resize((args.image_size, args.image_size)),
# transforms.ToTensor(),
# ])

# transform_test_seg = transforms.Compose([
# transforms.Resize((args.out_size,args.out_size)),
# transforms.ToTensor(),
# ])

# transform_3d_seg = transforms.Compose([
# transforms.ToTensor(),
# ])

# if args.dataset == 'isic':
# '''isic data'''
# isic_train_dataset = ISIC2016(args, args.data_path, transform = transform_train, transform_msk= transform_train_seg, mode = 'Training')
# isic_test_dataset = ISIC2016(args, args.data_path, transform = transform_test, transform_msk= transform_test_seg, mode = 'Test')

# nice_train_loader = DataLoader(isic_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True)
# nice_test_loader = DataLoader(isic_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True)
# '''end'''

# elif args.dataset == 'decathlon':
# nice_train_loader, nice_test_loader, transform_train, transform_val, train_list, val_list = get_decath_loader(args)


# elif args.dataset == 'REFUGE':
# '''REFUGE data'''
# refuge_train_dataset = REFUGE(args, args.data_path, transform = transform_3d_seg, transform_msk= transform_3d_seg, mode = 'Training')
# refuge_test_dataset = REFUGE(args, args.data_path, transform = transform_3d_seg, transform_msk= transform_3d_seg, mode = 'Test')

# nice_train_loader = DataLoader(refuge_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True)
# nice_test_loader = DataLoader(refuge_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True)
# '''end'''

# elif args.dataset == 'LIDC':
# '''LIDC data'''
# # dataset = LIDC(data_path = args.data_path)
# dataset = MyLIDC(args, data_path = args.data_path,transform = transform_train, transform_msk= transform_train_seg)

# dataset_size = len(dataset)
# indices = list(range(dataset_size))
# split = int(np.floor(0.2 * dataset_size))
# np.random.shuffle(indices)
# train_sampler = SubsetRandomSampler(indices[split:])
# test_sampler = SubsetRandomSampler(indices[:split])

# nice_train_loader = DataLoader(dataset, batch_size=args.b, sampler=train_sampler, num_workers=8, pin_memory=True)
# nice_test_loader = DataLoader(dataset, batch_size=args.b, sampler=test_sampler, num_workers=8, pin_memory=True)
# '''end'''

# elif args.dataset == 'DDTI':
# '''REFUGE data'''
# refuge_train_dataset = DDTI(args, args.data_path, transform = transform_train, transform_msk= transform_train_seg, mode = 'Training')
# refuge_test_dataset = DDTI(args, args.data_path, transform = transform_test, transform_msk= transform_test_seg, mode = 'Test')

# nice_train_loader = DataLoader(refuge_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True)
# nice_test_loader = DataLoader(refuge_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True)
# '''end'''

# elif args.dataset == 'Brat':
# '''REFUGE data'''
# refuge_train_dataset = Brat(args, args.data_path, transform = transform_train, transform_msk= transform_train_seg, mode = 'Training')
# refuge_test_dataset = Brat(args, args.data_path, transform = transform_test, transform_msk= transform_test_seg, mode = 'Test')

# nice_train_loader = DataLoader(refuge_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True)
# nice_test_loader = DataLoader(refuge_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True)
# '''end'''
nice_train_loader, nice_test_loader = get_dataloader(args)

'''checkpoint path and tensorboard'''
# iter_per_epoch = len(Glaucoma_training_loader)
checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.net, settings.TIME_NOW)
#use tensorboard
if not os.path.exists(settings.LOG_DIR):
os.mkdir(settings.LOG_DIR)
writer = SummaryWriter(log_dir=os.path.join(
settings.LOG_DIR, args.net, settings.TIME_NOW))
# input_tensor = torch.Tensor(args.b, 3, 256, 256).cuda(device = GPUdevice)
# writer.add_graph(net, Variable(input_tensor, requires_grad=True))

#create checkpoint folder to save model
if not os.path.exists(checkpoint_path):
os.makedirs(checkpoint_path)
checkpoint_path = os.path.join(checkpoint_path, '{net}-{epoch}-{type}.pth')

'''begain training'''
best_acc = 0.0
best_tol = 1e4
best_dice = 0.0

for epoch in range(settings.EPOCH):

if epoch and epoch < 5:
if args.dataset != 'REFUGE':
tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, epoch, net, writer)
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.')
else:
tol, (eiou_cup, eiou_disc, edice_cup, edice_disc) = function.validation_sam(args, nice_test_loader, epoch, net, writer)
logger.info(f'Total score: {tol}, IOU_CUP: {eiou_cup}, IOU_DISC: {eiou_disc}, DICE_CUP: {edice_cup}, DICE_DISC: {edice_disc} || @ epoch {epoch}.')

net.train()
time_start = time.time()
loss = function.train_sam(args, net, optimizer, nice_train_loader, epoch, writer, vis = args.vis)
logger.info(f'Train loss: {loss} || @ epoch {epoch}.')
time_end = time.time()
print('time_for_training ', time_end - time_start)

net.eval()
if epoch and epoch % args.val_freq == 0 or epoch == settings.EPOCH-1:
if args.dataset != 'REFUGE':
tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, epoch, net, writer)
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.')
else:
tol, (eiou_cup, eiou_disc, edice_cup, edice_disc) = function.validation_sam(args, nice_test_loader, epoch, net, writer)
logger.info(f'Total score: {tol}, IOU_CUP: {eiou_cup}, IOU_DISC: {eiou_disc}, DICE_CUP: {edice_cup}, DICE_DISC: {edice_disc} || @ epoch {epoch}.')

if args.distributed != 'none':
sd = net.module.state_dict()
else:
sd = net.state_dict()

if edice > best_dice:
best_tol = tol
is_best = True

save_checkpoint({
'epoch': epoch + 1,
'model': args.net,
'state_dict': sd,
'optimizer': optimizer.state_dict(),
'best_tol': best_dice,
'path_helper': args.path_helper,
}, is_best, args.path_helper['ckpt_path'], filename="best_dice_checkpoint.pth")
else:
is_best = False

writer.close()
logger.info(args)

nice_train_loader, nice_test_loader = get_dataloader(args)

'''checkpoint path and tensorboard'''
# iter_per_epoch = len(Glaucoma_training_loader)
checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.net, settings.TIME_NOW)
#use tensorboard
if not os.path.exists(settings.LOG_DIR):
os.mkdir(settings.LOG_DIR)
writer = SummaryWriter(log_dir=os.path.join(
settings.LOG_DIR, args.net, settings.TIME_NOW))
# input_tensor = torch.Tensor(args.b, 3, 256, 256).cuda(device = GPUdevice)
# writer.add_graph(net, Variable(input_tensor, requires_grad=True))

#create checkpoint folder to save model
if not os.path.exists(checkpoint_path):
os.makedirs(checkpoint_path)
checkpoint_path = os.path.join(checkpoint_path, '{net}-{epoch}-{type}.pth')

'''begain training'''
best_acc = 0.0
best_tol = 1e4
best_dice = 0.0

for epoch in range(settings.EPOCH):

if epoch and epoch < 5:
if args.dataset != 'REFUGE':
tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, epoch, net, writer)
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.')
else:
tol, (eiou_cup, eiou_disc, edice_cup, edice_disc) = function.validation_sam(args, nice_test_loader, epoch, net, writer)
logger.info(f'Total score: {tol}, IOU_CUP: {eiou_cup}, IOU_DISC: {eiou_disc}, DICE_CUP: {edice_cup}, DICE_DISC: {edice_disc} || @ epoch {epoch}.')

net.train()
time_start = time.time()
loss = function.train_sam(args, net, optimizer, nice_train_loader, epoch, writer, vis = args.vis)
logger.info(f'Train loss: {loss} || @ epoch {epoch}.')
time_end = time.time()
print('time_for_training ', time_end - time_start)

net.eval()
if epoch and epoch % args.val_freq == 0 or epoch == settings.EPOCH-1:
if args.dataset != 'REFUGE':
tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, epoch, net, writer)
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.')
else:
tol, (eiou_cup, eiou_disc, edice_cup, edice_disc) = function.validation_sam(args, nice_test_loader, epoch, net, writer)
logger.info(f'Total score: {tol}, IOU_CUP: {eiou_cup}, IOU_DISC: {eiou_disc}, DICE_CUP: {edice_cup}, DICE_DISC: {edice_disc} || @ epoch {epoch}.')

if args.distributed != 'none':
sd = net.module.state_dict()
else:
sd = net.state_dict()

if edice > best_dice:
best_tol = tol
is_best = True

save_checkpoint({
'epoch': epoch + 1,
'model': args.net,
'state_dict': sd,
'optimizer': optimizer.state_dict(),
'best_tol': best_dice,
'path_helper': args.path_helper,
}, is_best, args.path_helper['ckpt_path'], filename="best_dice_checkpoint.pth")
else:
is_best = False

writer.close()


if __name__ == '__main__':
main()

0 comments on commit 721c4f0

Please sign in to comment.