forked from danieljf24/dual_encoding
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tester_msrvtt.py
122 lines (102 loc) · 5.46 KB
/
tester_msrvtt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# -*- coding: UTF-8 -*-
from __future__ import print_function
import argparse
import json
import logging
import os
import pickle
import sys
import numpy as np
import torch
from tensorboardX import SummaryWriter
import evaluation_vatex
from util.msrvtt_dataloader import Dataset2BertRes, collate_data
import util.data_provider as data
from basic.bigfile import BigFile
from basic.constant import ROOT_PATH
from basic.util import AverageMeter, LogCollector, read_dict
from model_part.model_attention import get_model
from util.text2vec import get_text_encoder
from util.vatex_dataloader import Dataset2BertI3d
from util.vocab import Vocabulary
def parse_args():
# Hyper Parameters
parser = argparse.ArgumentParser()
parser.add_argument('--runpath', type=str, default='/home/fengkai/dataset/')
parser.add_argument('--overwrite', type=int, default=0, choices=[0,1], help='overwrite existed file. (default: 0)')
parser.add_argument('--log_step', default=10, type=int, help='Number of steps to print and record the log.')
parser.add_argument('--batch_size', default=128, type=int, help='Size of a training mini-batch.')
parser.add_argument('--workers', default=5, type=int, help='Number of data loader workers.')
parser.add_argument('--logger_name', default='/home/fengkai/PycharmProjects/dual_encoding/result/fengkai_msrvtt/dual_encoding_concate_full_dp_0.2_measure_cosine/vocab_word_vocab_5_word_dim_768_text_rnn_size_1024_text_norm_True_kernel_sizes_2-3-4_num_512/visual_feat_dim_1024_visual_rnn_size_1024_visual_norm_True_kernel_sizes_2-3-4-5_num_512/mapping_text_0-2048_img_0-2048/loss_func_mrl_margin_0.2_direction_all_max_violation_False_cost_style_sum/optimizer_adam_lr_0.0001_decay_0.99_grad_clip_2.0_val_metric_recall/msrvtt_attention_1', help='Path to save the model and Tensorboard log.')
parser.add_argument('--checkpoint_name', default='model_best.pth.tar', type=str, help='name of checkpoint (default: model_best.pth.tar)')
parser.add_argument('--n_caption', type=int, default=20, help='number of captions of each image/video (default: 1)')
args = parser.parse_args()
return args
def load_config(config_path):
variables = {}
exec(compile(open(config_path, "rb").read(), config_path, 'exec'), variables)
return variables['config']
def main():
opt = parse_args()
print(json.dumps(vars(opt), indent=2))
rootpath = opt.runpath
n_caption = opt.n_caption
resume = os.path.join(opt.logger_name, opt.checkpoint_name)
if not os.path.exists(resume):
logging.info(resume + ' not exists.')
sys.exit(0)
# 模型加载
checkpoint = torch.load(resume)
start_epoch = checkpoint['epoch']
best_rsum = checkpoint['best_rsum']
print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})"
.format(resume, start_epoch, best_rsum))
options = checkpoint['opt']
if not hasattr(options, 'concate'):
setattr(options, "concate", "full")
# 文件名称
visual_feat_path = os.path.join(rootpath, 'msrvtt/msrvtt10ktest/FeatureData')
caption_files = os.path.join(rootpath, 'msrvtt/msrvtt10ktest/TextData/bert_text')
visual_feat_test = BigFile(visual_feat_path)
video_frames_test = read_dict(os.path.join(visual_feat_path, 'video2frames.txt'))
# Construct the model
model = get_model(options.model)(options)
model.load_state_dict(checkpoint['model'])
model.Eiters = checkpoint['Eiters']
model.val_start()
# set data loader
dset = Dataset2BertRes(caption_files, visual_feat_test, video_frames_test, videoEmbed_num = 32)
data_loaders_val = torch.utils.data.DataLoader(dataset=dset,
batch_size=opt.batch_size,
shuffle=False,
pin_memory=True,
num_workers=opt.workers,
collate_fn = collate_data)
video_embs, cap_embs, video_ids = evaluation_vatex.encode_data(model, data_loaders_val, opt.log_step, logging.info)
#embedding 的可视化分析
# tensor_show = torch.cat((video_embs.data, torch.ones(len(video_embs), 1)), 1)
# with SummaryWriter(log_dir='./results', comment='embedding——show') as writer:
# writer.add_embedding(
# video_embs.data,
# label_img=cap_embs.data,
# global_step=1)
c2i_all_errors = evaluation_vatex.cal_error(video_embs, cap_embs, options.measure)
# caption retrieval
(r1i, r5i, r10i, medri, meanri) = evaluation_vatex.t2i(c2i_all_errors, n_caption=n_caption)
t2i_map_score = evaluation_vatex.t2i_map(c2i_all_errors, n_caption=n_caption)
# video retrieval
(r1, r5, r10, medr, meanr) = evaluation_vatex.i2t(c2i_all_errors, n_caption=n_caption)
i2t_map_score = evaluation_vatex.i2t_map(c2i_all_errors, n_caption=n_caption)
print(" * Text to Video:")
print(" * r_1_5_10, medr, meanr: {}".format([round(r1i, 1), round(r5i, 1), round(r10i, 1), round(medri, 1), round(meanri, 1)]))
print(" * recall sum: {}".format(round(r1i+r5i+r10i, 1)))
print(" * mAP: {}".format(round(t2i_map_score, 3)))
print(" * "+'-'*10)
# caption retrieval
print(" * Video to text:")
print(" * r_1_5_10, medr, meanr: {}".format([round(r1, 1), round(r5, 1), round(r10, 1), round(medr, 1), round(meanr, 1)]))
print(" * recall sum: {}".format(round(r1+r5+r10, 1)))
print(" * mAP: {}".format(round(i2t_map_score, 3)))
print(" * "+'-'*10)
if __name__ == '__main__':
main()