-
Notifications
You must be signed in to change notification settings - Fork 4
/
completeness.py
169 lines (142 loc) · 8.88 KB
/
completeness.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
"""
Perform cross-validation with a given training algorithm and classification algorithm
@author Norman MacDonald
@date 2010-02-16
"""
import os,sys,threading,time
from optparse import OptionParser
from pica.io.FileIO import FileIO
from pica.Completeness import Completeness
from pica.TestConfiguration import TestConfiguration
from pica.io.FileIO import error
from pprint import pprint # add by RVF
if __name__ == "__main__":
parser = OptionParser(version="PICA %prog 1.0.1")
parser.add_option("-a","--training_algorithm",help="Training algorithm [default = %default]",metavar="ALG",default="libsvm.libSVMTrainer")
parser.add_option("-k","--svm_cost",action="store",dest="C",metavar="FLOAT",help="Set the SVM misclassification penalty parameter C to FLOAT")
parser.add_option("-b","--classification_algorithm",help="Testing algorithm [default = %default]",metavar="ALG",default="libsvm.libSVMClassifier")
parser.add_option("-m","--accuracy_measure",help="Accuracy measure [default = %default]",metavar="ALG",default="laplace")
parser.add_option("-r","--replicates",type="int",help="Number of replicates [default = %default]",default=10)
parser.add_option("-v","--folds",type="int",help="v-fold cross-validation [default = %default]",default=5)
parser.add_option("-s","--samples",action="store",dest="input_samples_filename",help="Read samples from FILE",metavar="FILE")
parser.add_option("-c","--classes",action="store",dest="input_classes_filename",help="Read class labels from FILE",metavar="FILE")
parser.add_option("-t","--targetclass",action="store",dest="target_class",help="Set the target CLASS for testing",metavar="CLASS")
parser.add_option("-o","--output_filename",help="Write results to FILE",metavar="FILE",default=None)
parser.add_option("-p","--parameters",action="store",dest="parameters",help="FILE with additional, classifier-specific parameters. (confounders for CWMI)",metavar="FILE",default=None)
parser.add_option("-x","--profile",action="store_true",dest="profile",help="Profile the code",default=False)
# RVF add option save crossval files
parser.add_option("-y","--save_crossval_files",action="store_true",dest="crossval_files",help="Save the training and test sets for crossvalidation to files under /crossvalidation",default=False)
parser.add_option("-d","--metadata",help="Load metadata from FILE and add to misclassification report [default: %default]",metavar="FILE",default=None)
parser.add_option("-f","--feature_select",help="Model file (currently only association rule files) with features to select from [default: %default]",metavar="FILE",default=None)
parser.add_option("-1","--feature_select_score",help="Order features by (feature selection option)", default="order_cwmi")
parser.add_option("-n","--feature_select_top_n",help="Take the top n features(feature selection option)", type="int", default=20)
# PH add option completeness, contamination
parser.add_option("-w","--completeness_steps",help="Completeness steps between (default = %default)",type="int",metavar="INT",default=10)
parser.add_option("-z","--contamination_steps",help="Contamination steps between (default = %default)",type="int",metavar="INT",default=0)
parser.add_option("--completeness",help="If completeness_steps=0, use specified completeness (default = %default)",type="float",metavar="FLOAT",default=1.0)
parser.add_option("--contamination",help="If contamination_steps=0, use specified contamination (default = %default)",type="float",metavar="FLOAT",default=0.0)
parser.add_option("--threads",help="Allow multiple threads",type="int",metavar="INT",default=1)
(options, args) = parser.parse_args()
# Check arguments for crucial errors
errorCount = 0
if not options.input_samples_filename:
error("Please provide a genotype sample file with -s /path/to/genotype.file")
errorCount += 1
if not options.input_classes_filename:
error("Please provide a phenotype class file with -c /path/to/phenotype.file")
errorCount += 1
if not options.target_class:
error("Please provide the phenotype target to be predicted with -t \"TRAITNAME\"")
errorCount += 1
if not options.output_filename:
error("Please specify a file for the output with -o /path/to/result.file")
errorCount += 1
if errorCount > 0:
error("For help on usage, try calling:\n\tpython %s -h" % os.path.basename(sys.argv[0]))
exit(1)
fileio = FileIO()
unmodified_samples=fileio.load_samples(options.input_samples_filename)
samples = fileio.load_samples(options.input_samples_filename)
if options.feature_select:
print "Selecting top %d features from %s, ordered by %s"%(options.feature_select_top_n,options.feature_select,options.feature_select_score)
from pica.AssociationRule import load_rules,AssociationRuleSet
selected_rules = AssociationRuleSet()
rules = load_rules(options.feature_select)
rules.set_target_accuracy(options.feature_select_score)
selected_rules.extend(rules[:options.feature_select_top_n])
samples = samples.feature_select(selected_rules)
classes = fileio.load_classes(options.input_classes_filename)
unmodified_samples.load_class_labels(classes)
samples.load_class_labels(classes)
print "Sample set has %d features."%(samples.get_number_of_features())
unmodified_samples.set_current_class(options.target_class)
samples.set_current_class(options.target_class)
print "Parameters from %s"%(options.parameters)
print "Compressing features...",
#for the moment: don't compress features. potential bug with testing! potentially interferes with completeness check
samples = samples.compress_features()
print "compressed to %d distinct features."%(samples.get_number_of_features())
unmodified_samples.hide_nulls(options.target_class)
samples.set_current_class(options.target_class)
samples.hide_nulls(options.target_class)
modulepath = "pica.trainers.%s"%(options.training_algorithm)
classname = options.training_algorithm.split(".")[-1]
TrainerClass = __import__(modulepath, fromlist=(classname,))
if options.C:
trainer = TrainerClass.__dict__[classname](options.parameters, C=options.C)
else:
trainer = TrainerClass.__dict__[classname](options.parameters)
trainer.set_null_flag("NULL")
modulepath = "pica.classifiers.%s"%(options.classification_algorithm)
classname = options.classification_algorithm.split(".")[-1]
ClassifierClass = __import__(modulepath, fromlist=(classname,))
classifier = ClassifierClass.__dict__[classname](options.parameters)
classifier.set_null_flag("NULL")
test_configurations = [TestConfiguration("A",None,trainer,classifier)]
#HP added contamination/completeness
if options.contamination_steps == 0:
contamination=[options.contamination]
else:
contamination=[]
for i in xrange(0,options.contamination_steps+1,1):
contamination.append(float(i/float(1.0*options.contamination_steps)))
if options.completeness_steps == 0:
completeness=[options.completeness]
else:
completeness=[]
for i in xrange(0,options.completeness_steps+1,1):
completeness.append(float(i/float(1.0*options.completeness_steps)))
print(completeness)
print(contamination)
threadLock=threading.Lock()
#RVF changed (added the last 3 parameters)
if ( options.crossval_files ):
crossvalidator = Completeness(samples,options.parameters,options.folds,options.replicates,completeness,contamination,test_configurations,unmodified_samples,options.threads,False,None,options.target_class,options.output_filename)
else:
crossvalidator = Completeness(samples,options.parameters,options.folds,options.replicates,completeness,contamination,test_configurations,unmodified_samples,options.threads)
crossvalidator.crossvalidate()
c=0
print("Waiting for threads to finish")
while threading.activeCount() > 1 and c < 1800:
time.sleep(1)
c += 1
#contamination makes this kind of output quite difficult. so leaving out at the moment
fout = open(options.output_filename,"w")
stats = crossvalidator.get_summary_statistics(0)
resorted={}
for index in stats[0][0].keys():
resorted[index]=[]
for w in range(len(stats)):
resorted[index].append([])
for z in range(len(stats[w])):
resorted[index][w].append(stats[w][z][index])
for index in resorted.keys():
fout.write("[%s]\n"%index)
for w in range(len(resorted[index])):
printline=[]
for z in range(len(resorted[index][w])):
printline.append(str(resorted[index][w][z]))
fout.write("%s\n"%"\t".join(printline))
fout.write("\n")
#pprint(stats)
fout.close()