-
Notifications
You must be signed in to change notification settings - Fork 0
/
ensemble_perplexity.py
61 lines (48 loc) · 1.78 KB
/
ensemble_perplexity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import subprocess
import signal
from perplexity_model import PerplexityModel
import time
import socket
import atexit
import zmq
context = zmq.Context()
def find_free_port():
with socket.socket() as s:
s.bind(('', 0)) # Bind to a free port provided by the host.
return s.getsockname()[1]
def start_single_perplexity_model(model_name, gpus, port):
process = subprocess.Popen([
"python", "start_single_perplexity_model.py",
"--gpus", str(gpus),
'--model_name', str(model_name),
'--port', str(port)
])
s = context.socket(zmq.REQ)
s.connect(f"tcp://localhost:{port}")
return process, s
def cleanup_processes(processes):
for process in processes:
process.terminate()
class EnsemblePerplexity():
def __init__(self, model_names=[]):
self.model_names = model_names
self.processes = []
self.sockets = {}
for i, model_name in enumerate(model_names):
port = find_free_port()
process, s = start_single_perplexity_model(model_name, [i], port)
self.processes.append(process)
self.sockets[model_name] = s
atexit.register(lambda: cleanup_processes(self.processes))
time.sleep(15)
def get_ensemble_perplexity(self, prompts, model_names=None):
perplexities = {}
for model_name, s in self.sockets.items():
if model_names is None or model_name in model_names:
s.send_json(prompts)
for model_name, s in self.sockets.items():
if model_names is None or model_name in model_names:
perplexity = s.recv_json()
perplexities[model_name] = perplexity
print(f"Received reply {perplexities}")
return perplexities