-
Notifications
You must be signed in to change notification settings - Fork 3
/
utils.py
68 lines (54 loc) · 2.24 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import json
import numpy as np
from datetime import datetime
def init_exp(exp_id):
import os
if not os.path.exists('saved_models/{}'.format(exp_id)):
os.makedirs('saved_models/{}'.format(exp_id))
if not os.path.exists('log/{}'.format(exp_id)):
os.makedirs('log/{}'.format(exp_id))
def generate_probability_partition(num_entries):
frac_vec = [0., 1.]
for _ in range(num_entries-1):
frac_vec.append(np.random.rand())
frac_vec.sort()
frac_vec = [frac_vec[j] - frac_vec[j - 1] for j in range(1, num_entries+1)]
return frac_vec
class _Logger:
@staticmethod
def log_exp_info(exp_id, description):
info_dict = {
'exp_id': exp_id,
'start_datetime': datetime.now().strftime("%m/%d/%Y %H:%M:%S"),
'description': description,
}
with open('./log/{}/exp_info.json'.format(exp_id), 'w') as f:
json.dump(info_dict, f, indent=4)
@staticmethod
def log_model_description(exp_id, model_fn):
model = model_fn()
with open('./log/{}/model_description.txt'.format(exp_id), 'w') as f:
print(model, file=f)
@staticmethod
def log_server_solver(exp_id, server_solver):
with open('./log/{}/server_solver.json'.format(exp_id), 'w') as f:
json.dump(server_solver, f, indent=4)
@staticmethod
def log_client_solver(exp_id, client_solver):
with open('./log/{}/client_solver.json'.format(exp_id), 'w') as f:
json.dump(client_solver, f, indent=4)
@staticmethod
def log_client_preparation(exp_id, client_preparation_dict):
with open('./log/{}/client_preparation.json'.format(exp_id), 'w') as f:
json.dump(client_preparation_dict, f, indent=4)
@staticmethod
def log_validation_data(exp_id, validation_dict):
with open('./log/{}/validation.json'.format(exp_id), 'w') as f:
json.dump(validation_dict, f, indent=4)
_client_selection_cache = {}
@staticmethod
def log_client_selection(exp_id, t, selection):
_Logger._client_selection_cache[t] = selection
with open('./log/{}/client_selection.json'.format(exp_id), 'w') as f:
json.dump(_Logger._client_selection_cache, f, indent=4)
logger = _Logger()