-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathWeight.py
257 lines (231 loc) · 7.83 KB
/
Weight.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
"""
File name: Weight.py
Date created: 01/12/2016
Date last modified: 05/07/2016
Python version: 3.5.1
Description: Weighting system used to
adjust SVM classifcation of genes
during training.
"""
###############
### IMPORTS ###
###############
from Profile import Profile
from Loader import Loader
import numpy as np
import matplotlib.pyplot as plt
import argparse
############
### CODE ###
############
class Weight:
"""Used to create weights for training set"""
def __init__(self, profiles, pairwiseDict):
""" Init for Weight. Takes Loader of the
profiles.
Input:
profiles: Loader of profiles
pairwiseDict: a 2D dictionary of
gene names
Returns:
None.
"""
self.profiles = profiles
self.pairwiseDict = pairwiseDict
@staticmethod
def load(filename):
"""Loads .dist file from RAxML output
and creates a dictionary of
pairwise distances, which it
returns.
"""
pairwiseDict = {}
for line in open(filename, 'r'):
parts = line.split()
if parts[0] in pairwiseDict:
pairwiseDict[parts[0]][parts[1]] = float(parts[2])
else:
pairwiseDict[parts[0]] = {parts[1]: float(parts[2])}
return pairwiseDict
def cluster(self, clusterType, cutoff, profiles=None):
"""Clusters profiles after distance has
been calculated. Hierarchical clustering
can be done using 'farthest' or 'nearest'
neighbors (type).
"""
# Set profiles
if profiles == None:
profiles = self.profiles
# Set cluster type
if clusterType == "nearest":
multiplier = 1
elif clusterType == "farthest":
multiplier = -1
else:
print("Cluster type",clusterType,"not recongnized.")
return None
# Build list of clusters
clusters = set([frozenset([profile]) for profile in self.profiles])
# Loop until all clusters are merged
while len(clusters) > 1:
# Init
minClustDist = float("inf")
minClust1 = None
minClust2 = None
visited = set()
# Compare cluster pairs
for cluster1 in clusters:
# Keep track of sets we have already looked at
visited.add(cluster1)
for cluster2 in clusters:
# Compute distance if different
if cluster2 not in visited:
# Loop over profiles in clusters
bestDist = multiplier * float("inf")
bestProf1 = None
bestProf2 = None
for profile1 in cluster1:
for profile2 in cluster2:
# Lookup profile distance
if profile1.name in self.pairwiseDict and \
profile2.name in self.pairwiseDict[profile1.name]:
currentDist = self.pairwiseDict[profile1.name][profile2.name]
else: # Other order
currentDist = self.pairwiseDict[profile2.name][profile1.name]
# Update best match (farthest or nearest profiles between clusters)
if multiplier > 0 and currentDist < bestDist: # min/nearest
bestDist = currentDist
bestProf1 = profile1
bestProf2 = profile2
elif multiplier < 0 and currentDist > bestDist: # max/farthest
bestDist = currentDist
bestProf1 = profile1
bestProf2 = profile2
# Once distance for profiles found, get closest cluster
if bestDist < minClustDist:
minClustDist = bestDist
minClust1 = cluster1
minClust2 = cluster2
#Check distance cutoff
if minClustDist > cutoff:
#Exceeded cutoff, end clustering
return clusters
else:
# Merge clusters
clusters.add(minClust1.union(minClust2))
# Remove old clusters
clusters.remove(minClust1)
clusters.remove(minClust2)
return clusters
def weight(self, clusters):
"""Weights each profile based on
how they are clustered together
"""
for cluster in clusters:
profiles = cluster
cluster_size = len(profiles)
for profile in profiles:
profile.weight = 1.0 / cluster_size
""" Visualization and Analysis Functions """
def visualize_clusters(self, clusters):
for cluster in clusters:
self.visualize_helper(cluster, True, ' ')
print("-*")
def visualize_helper(self, cluster, isFirst, indent):
angle = '/' if isFirst else '\\'
if len(cluster) ==1:
print(indent+' '+angle+'-'+str(cluster[0]))
else:
self.visualize_helper(cluster[0], True, indent + (' ' if isFirst else ' | '))
print(indent+' '+angle+'-*')
self.visualize_helper(cluster[1], False, indent + (' | ' if isFirst else ' '))
def plot_hist(self):
vals = []
for key1 in pairwiseDict:
for key2 in pairwiseDict[key1]:
vals.append(pairwiseDict[key1][key2])
plt.hist(vals)
plt.title("Pairwise Distance Histogram")
plt.show()
def print_quartiles(self):
vals = []
for key1 in pairwiseDict:
for key2 in pairwiseDict[key1]:
vals.append(pairwiseDict[key1][key2])
vals.sort()
l = len(vals)
print("Q1: %f - %f" % (vals[0], vals[int(l/4)]))
print("Q2: %f - %f" % (vals[int(l/4)], vals[int(l/2)]))
print("Q3: %f - %f" % (vals[int(l/2)], vals[int(3*l/4)]))
print("Q4: %f - %f" % (vals[int(3*l/4)], vals[l-1]))
def print_under_threshold(self, cutoff):
vals = []
for key1 in pairwiseDict:
for key2 in pairwiseDict[key1]:
vals.append(pairwiseDict[key1][key2])
vals2 = [i for i in vals if i < cutoff]
print("%d/%d under threshold of %f" % (len(vals2), len(vals), cutoff))
def count_clusters(self, clusters):
clustered = 0
cl = 0
for cluster in clusters:
if len(cluster) > 1:
clustered += len(cluster)
cl += 1
print("%d genes have been clustered, out of %d total, over %d clusters." % (clustered, len(self.profiles), cl))
# Command-line driver
if __name__ == '__main__':
# Define arg parser
parser = argparse.ArgumentParser(description="Gene weighting.")
parser.add_argument("-p", "--profile_path", type=str, nargs=1,
dest="profile_path", required=True,
help="The .faa or .fna file used in calculating pairwise distances")
parser.add_argument("-w", "--pairwise_path", type=str, nargs=1,
dest="pairwise_path", required=True,
help="The .dist file of pairwise distances created by RAxML")
parser.add_argument("-t", "--cluter_type", type=str, nargs=1,
dest="cluster_type", required=False, default=["farthest"],
help="Specify 'farthest' or 'nearest' neighbors clustering.")
parser.add_argument("-d", "--cutoff_distance", type=float, nargs=1,
dest="cutoff", required=False, default=[0.05],
help="Specify the cutoff distance for clustering.")
parser.add_argument("-v", "--visualize", action="store_true",
dest="v", required=False,
help="Shows visualization of the clustered profiles.")
parser.add_argument("-i", "--histogram", action="store_true",
dest="i", required=False,
help="Shows histogram of pairwise distances.")
parser.add_argument("-q", "--quartiles", action="store_true",
dest="q", required=False,
help="Shows quartiles of pairwise distances.")
parser.add_argument("-r", "--threshold", type=float, nargs=1,
dest="r", required=False,
help="Show number of pairwise distances that fall below the given threshold.")
parser.add_argument("-c", "--count", action="store_true",
dest="c", required=False,
help="Shows how many profiles clustered into how many clusters.")
args = parser.parse_args()
# Load in profiles and pairwise dictionary
profiles = Loader.load(args.profile_path[0])
pairwiseDict = Weight.load(args.pairwise_path[0])
# Init Weight object
weight = Weight(profiles, pairwiseDict)
# Create clusters
clusters = weight.cluster(args.cluster_type[0], args.cutoff[0])
# Assign weights to profiles
weight.weight(clusters)
# Visualize
if args.v:
weight.visualize_clusters(clusters)
# Histogram
if args.i:
weight.plot_hist()
# Quartiles
if args.q:
weight.print_quartiles()
# Threshold
if args.r:
weight.print_under_threshold(args.r[0])
# Count
if args.c:
weight.count_clusters(clusters)