diff --git a/predict.py b/predict.py index 49641a1..4717952 100644 --- a/predict.py +++ b/predict.py @@ -67,11 +67,13 @@ def main(): dataset_test = MelanomaDataset(df_test, 'test', meta_features, transform=transforms_val) test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, num_workers=args.num_workers) - PROBS = [] + # load model + models = [] for fold in range(5): if args.eval == 'best': - model_file = os.path.join(args.model_dir, f'{args.kernel_type}_best_fold{fold}.pth') + model_file = + os.path.join(args.model_dir, f'{args.kernel_type}_best_fold{fold}.pth') elif args.eval == 'best_20': model_file = os.path.join(args.model_dir, f'{args.kernel_type}_best_20_fold{fold}.pth') if args.eval == 'final': @@ -96,31 +98,36 @@ def main(): model = torch.nn.DataParallel(model) model.eval() + models.append(model) - PROBS = [] - with torch.no_grad(): - for (data) in tqdm(test_loader): - - if args.use_meta: - data, meta = data - data, meta = data.to(device), meta.to(device) - probs = torch.zeros((data.shape[0], args.out_dim)).to(device) + # predict + PROBS = [] + with torch.no_grad(): + for (data) in tqdm(test_loader): + if args.use_meta: + data, meta = data + data, meta = data.to(device), meta.to(device) + probs = torch.zeros((data.shape[0], args.out_dim)).to(device) + for model in models: for I in range(args.n_test): l = model(get_trans(data, I), meta) probs += l.softmax(1) - else: - data = data.to(device) - probs = torch.zeros((data.shape[0], args.out_dim)).to(device) + else: + data = data.to(device) + probs = torch.zeros((data.shape[0], args.out_dim)).to(device) + for model in models: for I in range(args.n_test): l = model(get_trans(data, I)) probs += l.softmax(1) - probs /= args.n_test + probs /= args.n_test + probs /= len(models) - PROBS.append(probs.detach().cpu()) + PROBS.append(probs.detach().cpu()) - PROBS = torch.cat(PROBS).numpy() + PROBS = torch.cat(PROBS).numpy() + # save cvs df_test['target'] = PROBS[:, mel_idx] df_test[['image_name', 'target']].to_csv(os.path.join(args.sub_dir, f'sub_{args.kernel_type}_{args.eval}.csv'), index=False)