-
Notifications
You must be signed in to change notification settings - Fork 1
/
cluster.py
83 lines (59 loc) · 1.98 KB
/
cluster.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import sys
import math
import cairo
import numpy as np
import sklearn.cluster
import sklearn.manifold
import matplotlib.pyplot as plt
sys.stderr.write("Loading data from '%s'...\n" % sys.argv[1])
A = np.genfromtxt(sys.argv[1]+'.vi', delimiter=',')
assert np.size(A,0) == np.size(A,1), "Expected square distance matrix"
n_seqs = np.size(A,0)
ent = np.loadtxt(sys.argv[1]+'.ent')
assert np.size(ent) == n_seqs
sys.stderr.write("Number of datasets: %d\n" % n_seqs)
# fix A
A = A + np.transpose(A) - np.diag(A)
# construct affinity matrix
aff = np.exp(-A)
sys.stderr.write("Clustering...\n")
# cluster
edges = []
mapping = list(range(2048))
prev_clusters = 2048
n = 0
while len(mapping) > 2:
preference = np.median(aff) - 0.1*n*np.std(aff)
ap = sklearn.cluster.AffinityPropagation(affinity='precomputed', preference=preference)
ap.fit(aff)
labels = ap.labels_
indices = ap.cluster_centers_indices_
clusters = len(indices)
sys.stderr.write("clustered to %d\n" % clusters)
if clusters == prev_clusters:
n += 1
sys.stderr.write("\tRetrying with %f standard deviations lower preference\n" % (0.1*n))
prev_clusters = clusters
else:
if n > 0:
sys.stderr.write("\tSituation normalized at %f stdevs\n" % (0.1*n))
n = 0
prev_clusters = clusters
# add edges to centers
for node,label in enumerate(labels):
if mapping[node] != mapping[indices[label]]:
i,j = mapping[indices[label]], mapping[node]
edges.append((i,j,A[i,j]))
mapping = [mapping[k] for k in indices]
aff = aff[indices,:][:,indices]
if len(mapping) == 2:
if ent[mapping[0]] > ent[mapping[1]]:
j,i = mapping[0], mapping[1]
else:
i,j = mapping[0], mapping[1]
edges.append((i,j,A[i,j]))
sys.stderr.write("Produced %d edges.\n" % len(edges))
# output
sys.stderr.write("Writing output...\n")
np.savetxt(sys.argv[2], np.array(edges), fmt=["%d", "%d", "%f"])
sys.stderr.write("done.\n")