-
Notifications
You must be signed in to change notification settings - Fork 328
/
Copy pathmfcc_solver.py
145 lines (121 loc) · 3.98 KB
/
mfcc_solver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
Performs offline solving against data using a set amount of MFCCS (defined by variable CONTROL_SIZE),
and prints out accuracy.
"""
import librosa
import os
import pickle
import random
import time
import matplotlib.pyplot as plt
from fastdtw import fastdtw
from operator import itemgetter
MFCC_PATH = "mfccs_all.pickle"
CONTROL_SIZE = 250
RUN_LIMIT = 10
def detect_leading_silence(sound, silence_threshold=-40.0, chunk_size=10):
trim_ms = 0 # ms
while sound[trim_ms:trim_ms + chunk_size].dBFS < silence_threshold:
trim_ms += chunk_size
return trim_ms
def plot(f1, f2):
plt.plot(f1, color='red', alpha=0.5)
plt.plot(f2, color='blue', alpha=0.5)
plt.show()
def mfcc(f):
y1, sr1 = librosa.load(f)
return librosa.feature.mfcc(y1, sr1).T
def build_mfccs(fs):
mfccs = {}
print "No MFCCs found, building..."
for path in fs:
if os.path.exists(path + "/oracle"):
digits = [f for f in os.listdir(path) if "_0" in f]
if len(digits) == 10:
for i in range(0, 10):
m = mfcc(path + "/" + digits[i])
mfccs[path + "/" + digits[i]] = m
with open(MFCC_PATH, 'wb') as handle:
pickle.dump(mfccs, handle)
return mfccs
def build_controls(fs):
control = {}
control_files = {}
for path in fs:
if os.path.exists(path + "/oracle"):
oracle = open(path + "/oracle").read()
oracle_nums = list(oracle)
digits = [f for f in os.listdir(path) if "_0" in f]
if len(digits) == 10:
for i in range(0, 10):
num = int(oracle_nums[i])
if num not in control:
control[num] = []
if len(control[num]) < CONTROL_SIZE:
control[num].append(all_mfccs[path + "/" + digits[i]])
control_files[path + "/" + digits[i]] = True
return control, control_files
control_files = {}
control = []
fs = []
bites = {}
all_mfccs = {}
for root, dirs, files in os.walk("data/"):
fs.extend([root + f for f in dirs])
random.shuffle(fs)
if os.path.exists(MFCC_PATH):
print("Found saved MFCCs, reading")
with open(MFCC_PATH, 'rb') as handle:
all_mfccs = pickle.load(handle)
else:
all_mfccs = build_mfccs(fs)
print("Processing...")
control, control_files = build_controls(fs)
correct = 0
wrong = 0
total = 0
accuracy = {}
totals = {}
to_break = False
time1 = time.time()
for path in fs:
print total
if to_break:
break
if os.path.exists(path + "/oracle"):
oracle = open(path + "/oracle").read()
oracle_nums = list(oracle)
digits = [f for f in os.listdir(path) if "_0" in f]
if len(digits) == 10:
for i in range(0, 10):
if total >= RUN_LIMIT:
to_break = True
break
if path + "/" + digits[i] in control_files:
continue
m = all_mfccs[path + "/" + digits[i]]
if oracle_nums[i] not in accuracy:
accuracy[oracle_nums[i]] = 0
totals[oracle_nums[i]] = 0
opts = []
for digit in control:
ds = []
for cm in control[digit]:
ds.append(min(fastdtw(cm, m)[0], fastdtw(m, cm)[0]))
opts.append(min(ds))
best = min(enumerate(opts), key=itemgetter(1))[0]
if int(oracle_nums[i]) == best:
correct += 1
accuracy[oracle_nums[i]] += 1
else:
wrong += 1
totals[oracle_nums[i]] += 1
total += 1
time2 = time.time()
print 'took %0.3f ms' % (time2 - time1)
print accuracy
print totals
print "CORRECT: %d (%f)" % (correct, correct / float(total))
print "INCORRECT: %d (%f)" % (wrong, wrong / float(total))
print "TOTAL: %d" % (total)
print "%d" % CONTROL_SIZE