-
Notifications
You must be signed in to change notification settings - Fork 27
/
eval_att.py
80 lines (71 loc) · 2.95 KB
/
eval_att.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import argparse
import logging
import math
import os
import time
import editdistance
import kaldi_io
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
from seq2seq.seq2seq import Seq2seq
from DataLoader import SequentialLoader, TokenAcc, rephone
parser = argparse.ArgumentParser(description='MXNet Autograd RNN/LSTM Acoustic Model on TIMIT.')
parser.add_argument('model', help='trained model filename')
parser.add_argument('--beam', type=int, default=0, help='apply beam search, beam width')
parser.add_argument('--ctc', default=False, action='store_true', help='decode CTC acoustic model')
parser.add_argument('--bi', default=False, action='store_true', help='bidirectional LSTM')
parser.add_argument('--dataset', default='test', help='decoding data set')
parser.add_argument('--out', type=str, default='', help='decoded result output dir')
args = parser.parse_args()
logdir = args.out if args.out else os.path.dirname(args.model) + '/decode.log'
# if args.out: os.makedirs(args.out, exist_ok=True)
logging.basicConfig(format='%(asctime)s: %(message)s', datefmt="%H:%M:%S", filename=logdir, level=logging.INFO)
# Load model
model = Seq2seq(123, 63, 250, 3, 0.5, bidirectional=args.bi)
model.load_state_dict(torch.load(args.model, map_location='cpu'))
# data set
feat = 'ark:copy-feats scp:data/{}/feats.scp ark:- | apply-cmvn --utt2spk=ark:data/{}/utt2spk scp:data/{}/cmvn.scp ark:- ark:- |\
add-deltas --delta-order=2 ark:- ark:- | nnet-forward data/final.feature_transform ark:- ark:- |'.format(args.dataset, args.dataset, args.dataset)
with open('data/'+args.dataset+'/text', 'r') as f:
label = {}
for line in f:
line = line.split()
label[line[0]] = line[1:]
# Phone map
with open('conf/phones.60-48-39.map', 'r') as f:
pmap = {rephone[0]: rephone[0]}
for line in f:
line = line.split()
if len(line) < 3: pmap[line[0]] = rephone[0]
else: pmap[line[0]] = line[2]
print(pmap)
def distance(y, t, blank=rephone[0]):
def remap(y, blank):
prev = blank
seq = []
for i in y:
if i != blank and i != prev: seq.append(i)
prev = i
return seq
y = remap(y, blank)
t = remap(t, blank)
return y, t, editdistance.eval(y, t)
def decode():
logging.info('Decoding Seq-Att model:')
err = cnt = 0
for k, v in kaldi_io.read_mat_ark(feat):
xs = Variable(torch.FloatTensor(v[None, ...]), volatile=True)
if args.beam > 0:
y, nll = model.beam_search(xs, args.beam)
else:
y, nll = model.greedy_decode(xs)
y = [pmap[rephone[i]] for i in y]
t = [pmap[i] for i in label[k]]
y, t, e = distance(y, t)
err += e; cnt += len(t)
logging.info('[{}]: {}'.format(k, ' '.join(t)))
logging.info('[{}]: {}\nlog-likelihood: {:.2f}\n'.format(k, ' '.join(y), nll))
logging.info('{} set Seq-Att PER {:.2f}%\n'.format(args.dataset.capitalize(), 100*err/cnt))
decode()