-
Notifications
You must be signed in to change notification settings - Fork 13
/
cppmatcher.py
141 lines (122 loc) · 4.66 KB
/
cppmatcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# need conda to run this program
import csv
import math
import os
import sys
import warnings
import faiss
import julius
import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.multiprocessing as mp
import tqdm
import subprocess
if os.name == 'nt':
print(os.name)
import msvcrt
# torchaudio currently (0.7) will throw warning that cannot be disabled
with warnings.catch_warnings():
warnings.simplefilter("ignore")
import torchaudio
import simpleutils
from model import FpNetwork
from datautil.melspec import build_mel_spec_layer
from datautil.musicdata import MusicDataset
if __name__ == "__main__":
mp.set_start_method('spawn')
if len(sys.argv) < 4:
print('Usage: python %s <query list> <database dir> <result file>' % sys.argv[0])
sys.exit()
file_list_for_query = sys.argv[1]
dir_for_db = sys.argv[2]
result_file = sys.argv[3]
result_file2 = os.path.splitext(result_file) # for more detailed output
result_file2 = result_file2[0] + '_detail.csv'
result_file_score = result_file + '.bin'
configs = os.path.join(dir_for_db, 'configs.json')
params = simpleutils.read_config(configs)
visualize = False
d = params['model']['d']
h = params['model']['h']
u = params['model']['u']
F_bin = params['n_mels']
segn = int(params['segment_size'] * params['sample_rate'])
T = (segn + params['stft_hop'] - 1) // params['stft_hop']
top_k = params['indexer']['top_k']
frame_shift_mul = params['indexer'].get('frame_shift_mul', 1)
print('loading model...')
device = torch.device('cuda')
model = FpNetwork(d, h, u, F_bin, T, params['model']).to(device)
model.load_state_dict(torch.load(os.path.join(dir_for_db, 'model.pt')))
print('model loaded')
print('loading database...')
with open(os.path.join(dir_for_db, 'songList.txt'), 'r', encoding='utf8') as fin:
songList = []
for line in fin:
if line.endswith('\n'): line = line[:-1]
songList.append(line)
# doing inference, turn off gradient
model.eval()
for param in model.parameters():
param.requires_grad = False
dataset = MusicDataset(file_list_for_query, params)
# no task parallelism
loader = DataLoader(dataset, num_workers=0)
# open my c++ program
env = {**os.environ}
env['LD_LIBRARY_PATH'] = os.environ['CONDA_PREFIX'] + '/lib'
query_proc = subprocess.Popen(['cpp/faisscputest', dir_for_db]
, stdin=subprocess.PIPE, stdout=subprocess.PIPE
, universal_newlines=False, env=env)
if os.name == 'nt':
# only Windows needs this!
print('nt')
msvcrt.setmode(query_proc.stdin.fileno(), os.O_BINARY)
msvcrt.setmode(query_proc.stdout.fileno(), os.O_BINARY)
mel = build_mel_spec_layer(params).to(device)
fout = open(result_file, 'w', encoding='utf8', newline='\n')
fout2 = open(result_file2, 'w', encoding='utf8', newline='\n')
fout_score = open(result_file_score, 'wb')
detail_writer = csv.writer(fout2)
detail_writer.writerow(['query', 'answer', 'score', 'time', 'part_scores'])
torch.set_num_threads(1)
for dat in tqdm.tqdm(loader):
embeddings = []
grads = []
specs = []
i, name, wav = dat
i = int(i) # i is leaking file handles!
# batch size should be less than 20 because query contains at most 19 segments
for batch in torch.split(wav.squeeze(0), 16):
g = batch.to(device)
# Mel spectrogram
with warnings.catch_warnings():
# torchaudio is still using deprecated function torch.rfft
warnings.simplefilter("ignore")
g = mel(g)
z = model.forward(g, norm=False).cpu()
z = torch.nn.functional.normalize(z, p=2)
embeddings.append(z)
embeddings = torch.cat(embeddings)
song_score = np.zeros(len(songList), dtype=np.float32)
query_proc.stdin.write(embeddings[0::frame_shift_mul].numpy().size.to_bytes(4, 'little'))
query_proc.stdin.write(embeddings[0::frame_shift_mul].numpy().tobytes())
query_proc.stdin.flush()
ans = int.from_bytes(query_proc.stdout.read(4), 'little')
ans = songList[ans]
sco = 0
tim = 0
upsco = []
tim /= frame_shift_mul
tim *= params['hop_size']
fout.write('%s\t%s\n' % (name[0], ans))
fout.flush()
detail_writer.writerow([name[0], ans, sco, tim] + upsco)
fout2.flush()
#fout_score.write(song_score.tobytes())
fout.close()
fout2.close()
else:
torch.set_num_threads(1)