-
Notifications
You must be signed in to change notification settings - Fork 248
/
models.py
57 lines (40 loc) · 1.2 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
from torchvision import models
import numpy as np
def encrypt_vector(public_key, x):
return [public_key.encrypt(i) for i in x]
def encrypt_matrix(public_key, x):
ret = []
for r in x:
ret.append(encrypt_vector(public_key, r))
return ret
def decrypt_vector(private_key, x):
return [private_key.decrypt(i) for i in x]
def decrypt_matrix(private_key, x):
ret = []
for r in x:
ret.append(decrypt_vector(private_key, r))
return ret
class LR_Model(object):
def __init__ (self, public_key, w_size=None, w=None, encrypted=False):
"""
w_size: 权重参数数量
w: 是否直接传递已有权重,w和w_size只需要传递一个即可
encrypted: 是明文还是加密的形式
"""
self.public_key = public_key
if w is not None:
self.weights = w
else:
limit = -1.0/w_size
self.weights = np.random.uniform(-0.5, 0.5, (w_size,))
if encrypted==False:
self.encrypt_weights = encrypt_vector(public_key, self.weights)
else:
self.encrypt_weights = self.weights
def set_encrypt_weights(self, w):
for id, e in enumerate(w):
self.encrypt_weights[id] = e
def set_raw_weights(self, w):
for id, e in enumerate(w):
self.weights[id] = e