Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
qishen-ha committed Sep 22, 2020
1 parent f428131 commit 7dc2c3b
Showing 1 changed file with 23 additions and 16 deletions.
39 changes: 23 additions & 16 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,13 @@ def main():
dataset_test = MelanomaDataset(df_test, 'test', meta_features, transform=transforms_val)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, num_workers=args.num_workers)

PROBS = []
# load model
models = []
for fold in range(5):

if args.eval == 'best':
model_file = os.path.join(args.model_dir, f'{args.kernel_type}_best_fold{fold}.pth')
model_file =
os.path.join(args.model_dir, f'{args.kernel_type}_best_fold{fold}.pth')
elif args.eval == 'best_20':
model_file = os.path.join(args.model_dir, f'{args.kernel_type}_best_20_fold{fold}.pth')
if args.eval == 'final':
Expand All @@ -96,31 +98,36 @@ def main():
model = torch.nn.DataParallel(model)

model.eval()
models.append(model)

PROBS = []
with torch.no_grad():
for (data) in tqdm(test_loader):

if args.use_meta:
data, meta = data
data, meta = data.to(device), meta.to(device)
probs = torch.zeros((data.shape[0], args.out_dim)).to(device)
# predict
PROBS = []
with torch.no_grad():
for (data) in tqdm(test_loader):
if args.use_meta:
data, meta = data
data, meta = data.to(device), meta.to(device)
probs = torch.zeros((data.shape[0], args.out_dim)).to(device)
for model in models:
for I in range(args.n_test):
l = model(get_trans(data, I), meta)
probs += l.softmax(1)
else:
data = data.to(device)
probs = torch.zeros((data.shape[0], args.out_dim)).to(device)
else:
data = data.to(device)
probs = torch.zeros((data.shape[0], args.out_dim)).to(device)
for model in models:
for I in range(args.n_test):
l = model(get_trans(data, I))
probs += l.softmax(1)

probs /= args.n_test
probs /= args.n_test
probs /= len(models)

PROBS.append(probs.detach().cpu())
PROBS.append(probs.detach().cpu())

PROBS = torch.cat(PROBS).numpy()
PROBS = torch.cat(PROBS).numpy()

# save cvs
df_test['target'] = PROBS[:, mel_idx]
df_test[['image_name', 'target']].to_csv(os.path.join(args.sub_dir, f'sub_{args.kernel_type}_{args.eval}.csv'), index=False)

Expand Down

0 comments on commit 7dc2c3b

Please sign in to comment.