-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbbt_solver.py
63 lines (58 loc) · 2.28 KB
/
bbt_solver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from tree_utils import *
# import likelihood_optimizer_LP import *
def LLH_method(method):
global Likelihood_optimizer
if method == "cvxopt":
import likelihood_optimizer
Likelihood_optimizer = likelihood_optimizer.Likelihood_optimizer
elif method == "pLP":
import likelihood_optimizer_LP
Likelihood_optimizer = likelihood_optimizer_LP.Likelihood_optimizer
class BBT_solver:
def __init__(self,V,R,lapprox,EPS,llh_EPS,neg):
self.V = V
self.R = R
self.m,self.n = self.V.shape
self.T = self.V+self.R
self.lapprox = lapprox
self.EPS = EPS
self.llh_EPS = llh_EPS
self.ave_f = self.V/self.T
self.sof = self.ave_f.sum(axis=0)
self.rank = list(range(self.n))
self.rank.sort(key = lambda x:-self.sof[x])
self.solver = Likelihood_optimizer(V,R,EPS)
self.neg = neg
self.filters = None
def init(self):
init_tree = Tree([(-1,self.rank[0])])
llh = self.solver.llh(init_tree)
self.BBTs = {1:[(init_tree,llh)]}
def greedy_expand(self,tree):
partial_trees = tree.expand(self.rank[tree.n],self.neg)
ress = [self.solver.llh(t) for t in partial_trees]
max_llh = max(ress)
next_trees = [(_,t) for t,_ in zip(partial_trees,ress)]
next_trees.sort()
return next_trees[-1]
def iteration(self):
iter_index = len(self.BBTs)
partial_trees = []
for t,_ in self.BBTs[iter_index]:
partial_trees.extend(t.expand(self.rank[iter_index],self.neg))
print("iteration %d: exploring %d trees"%(iter_index,len(partial_trees)))
ress = [self.solver.llh(t) for t in partial_trees]
max_llh = max(ress)
self.BBTs[iter_index+1] = [(t,_) for t,_ in zip(partial_trees,ress) if _ >= max_llh + self.lapprox - self.llh_EPS]
print("iteration %d: %d trees selected"%(iter_index,len(self.BBTs[iter_index+1])))
def main(self,ell,tau=-1):
self.init()
for i in range(2,self.n+1):
if ell > 0:
if i-1 >= ell:
return self.BBTs[i-1]
self.iteration()
if tau > 0:
if len(self.BBTs[i]) > tau:
return self.BBTs[i-1]
return self.BBTs[self.n]