-
Notifications
You must be signed in to change notification settings - Fork 4
/
test.py
57 lines (48 loc) · 1.81 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import random
import logging
import torch
import numpy as np
from model import CARTON
from dataset import CSQADataset
from utils import Predictor, Inference
# import constants
from constants import *
# set logger
logging.basicConfig(format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
datefmt='%d/%m/%Y %I:%M:%S %p',
level=logging.INFO,
handlers=[
logging.FileHandler(f'{args.path_results}/test_{args.question_type}.log', 'w'),
logging.StreamHandler()
])
logger = logging.getLogger(__name__)
# set a seed value
random.seed(args.seed)
np.random.seed(args.seed)
if torch.cuda.is_available():
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
def main():
# load data
dataset = CSQADataset()
vocabs = dataset.get_vocabs()
inference_data = dataset.get_inference_data()
logger.info(f'Inference question type: {args.question_type}')
logger.info('Inference data prepared')
logger.info(f"Num of inference data: {len(inference_data)}")
# load model
model = CARTON(vocabs).to(DEVICE)
logger.info(f"=> loading checkpoint '{args.model_path}'")
if DEVICE.type=='cpu':
checkpoint = torch.load(f'{ROOT_PATH}/{args.model_path}', encoding='latin1', map_location='cpu')
else:
checkpoint = torch.load(f'{ROOT_PATH}/{args.model_path}', encoding='latin1')
args.start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
logger.info(f"=> loaded checkpoint '{args.model_path}' (epoch {checkpoint['epoch']})")
# construct actions
predictor = Predictor(model, vocabs)
Inference().construct_actions(inference_data, predictor)
if __name__ == '__main__':
main()