-
Notifications
You must be signed in to change notification settings - Fork 191
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1426 from wkentaro/merge-sklearn-to-jsk-perception
Merge sklearn to jsk_perception
- Loading branch information
Showing
10 changed files
with
353 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
#!/usr/bin/env python | ||
|
||
try: | ||
from ml_classifiers.srv import * | ||
except: | ||
import roslib;roslib.load_manifest("ml_classifiers") | ||
from ml_classifiers.srv import * | ||
|
||
import rospy | ||
import numpy as np | ||
from sklearn.cross_validation import cross_val_score | ||
from sklearn.ensemble import RandomForestClassifier | ||
from sklearn.ensemble import ExtraTreesClassifier | ||
from sklearn.externals import joblib | ||
|
||
class RandomForestServer: | ||
def __init__(self, clf): | ||
self.clf = clf | ||
s = rospy.Service('predict', ClassifyData, self.classifyData) | ||
|
||
@classmethod | ||
def initWithData(cls, data_x, data_y): | ||
if len(data_x) != len(data_y): | ||
rospy.logerr("Lenght of datas are different") | ||
exit() | ||
rospy.loginfo("InitWithData please wait..") | ||
clf = RandomForestClassifier(n_estimators=250, max_features=2, max_depth=29, min_samples_split=1, random_state=0) | ||
clf.fit(data_x, data_y) | ||
return cls(clf) | ||
|
||
@classmethod | ||
def initWithFileModel(cls, filename): | ||
rospy.loginfo("InitWithFileModel with%s please wait.."%filename) | ||
clf = joblib.load(filename) | ||
return cls(clf) | ||
|
||
#Return predict result | ||
def classifyData(self, req): | ||
ret = [] | ||
for data in req.data: | ||
print data | ||
ret.append(" ".join([str(predict_data) for predict_data in self.clf.predict(data.point)])) | ||
rospy.loginfo("req : " + str(data.point) + "-> answer : " + str(ret)) | ||
return ClassifyDataResponse(ret) | ||
|
||
#Run random forest | ||
def run(self): | ||
rospy.loginfo("RandomForestServer is running!") | ||
rospy.spin() | ||
|
||
|
||
if __name__ == "__main__": | ||
rospy.init_node('random_forest_cloth_classifier') | ||
|
||
try: | ||
train_file = rospy.get_param('~random_forest_train_file') | ||
except KeyError: | ||
rospy.logerr("Train File is not Set. Set train_data file or tree model file as ~random_forest_train_file.") | ||
exit() | ||
|
||
if train_file.endswith("pkl"): | ||
node = RandomForestServer.initWithFileModel(train_file) | ||
else: | ||
try: | ||
class_file = rospy.get_param('~random_forest_train_class_file') | ||
|
||
data_x = [] | ||
data_y = [] | ||
for l in open(train_file).readlines(): | ||
float_strings = l.split(","); | ||
data_x.append(map(lambda x: float(x), float_strings)) | ||
|
||
for l in open(class_file).readlines(): | ||
data_y.append(float(l)) | ||
|
||
#build servece server | ||
node = RandomForestServer.initWithData(np.array(data_x), np.array(data_y)) | ||
|
||
except KeyError: | ||
rospy.logerr("Train Class File is not Set. Set train_data file or tree model file.") | ||
rospy.logerr("Or Did you expect Extension to be pkl?.") | ||
exit() | ||
|
||
|
||
#run | ||
node.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
#!/usr/bin/env python | ||
|
||
try: | ||
from ml_classifiers.srv import * | ||
from ml_classifiers.msg import * | ||
except: | ||
import roslib;roslib.load_manifest("ml_classifiers") | ||
from ml_classifiers.srv import * | ||
from ml_classifiers.msg import * | ||
|
||
import rospy | ||
import random | ||
|
||
HEADER = '\033[95m' | ||
OKBLUE = '\033[94m' | ||
OKGREEN = '\033[92m' | ||
WARNING = '\033[93m' | ||
FAIL = '\033[91m' | ||
ENDC = '\033[0m' | ||
|
||
if __name__ == "__main__": | ||
rospy.init_node("random_forest_client") | ||
|
||
rospy.wait_for_service('predict') | ||
|
||
rospy.loginfo("Start Request Service!!") | ||
|
||
predict_data = rospy.ServiceProxy('predict', ClassifyData) | ||
|
||
while not rospy.is_shutdown(): | ||
req = ClassifyDataRequest() | ||
req_point = ClassDataPoint() | ||
target = [random.random(), random.random()] | ||
answer = 1 | ||
#Check if it is in the circle radius = 1? | ||
if target[0] * target[0] + target[1] * target[1] > 1: | ||
answer = 0 | ||
req_point.point = target | ||
req.data.append(req_point) | ||
print OKGREEN,"Send Request ====================> Answer",ENDC | ||
print OKGREEN," ",req_point.point," : ",str(answer),ENDC | ||
response = predict_data(req) | ||
print WARNING,"Get the result : ",ENDC | ||
print WARNING,response.classifications,ENDC | ||
if response.classifications[0].find(str(answer)): | ||
print OKBLUE,"Succeed!!!",ENDC | ||
else: | ||
print FAIL,"FAIL...",FAIL | ||
print "--- --- --- ---" | ||
|
||
rospy.sleep(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
<launch> | ||
<!-- <arg name="" value=""/> --> | ||
<!-- <param name="depth_registered" value="true"/> --> | ||
<node pkg="jsk_perception" name="random_server" type="random_forest_server.py" args="" output="screen" clear_params="true"> | ||
<param name="random_forest_train_file" value="$(find jsk_perception)/sample/random_forest_sample_data_x.txt" /> | ||
<param name="random_forest_train_class_file" value="$(find jsk_perception)/sample/random_forest_sample_data_y.txt" /> | ||
</node> | ||
|
||
<node pkg="jsk_perception" name="random_client" type="random_forest_client_sample.py" launch-prefix="xterm -e"/> | ||
</launch> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
7.050784631809560166e-01,8.046683086007709873e-01 | ||
6.096653288013824668e-01,7.873620471643174579e-01 | ||
2.811033817416092040e-01,6.339361361558826236e-01 | ||
7.801995416294117414e-01,9.467792753305790399e-01 | ||
7.858225659011705000e-01,3.710612600887184254e-01 | ||
8.157686426855794704e-01,7.984707129129120506e-01 | ||
2.008397047641143907e-01,1.329179488631015982e-02 | ||
6.773168302846169775e-01,2.905122631407458522e-01 | ||
2.079836293439472072e-01,6.254115473743593334e-01 | ||
7.216258736132767915e-01,3.847332391767432913e-02 | ||
5.163072309573801810e-02,1.249715431727107529e-01 | ||
2.150153896831472622e-01,5.837452646033300940e-01 | ||
2.767705194540767133e-01,3.492657764725383140e-01 | ||
2.013643369030752028e-01,7.757491708705339661e-01 | ||
5.415505329133989409e-01,2.319589952524581111e-02 | ||
2.698666376378994203e-01,3.669089946957688753e-01 | ||
5.296933526868621289e-01,4.069964469669882234e-01 | ||
2.664463141134059132e-01,8.369360506284481138e-01 | ||
8.051822863223647708e-01,1.675662256027702357e-01 | ||
2.687941928729791208e-01,2.171535420430178442e-01 | ||
1.053074571508881840e-01,2.711817581978339664e-01 | ||
5.158533198836301459e-02,9.394090247086499534e-01 | ||
2.466315652519230905e-01,9.734071182253897225e-01 | ||
8.653412173224572790e-02,2.261526755386399357e-01 | ||
6.869330723463005217e-01,5.380188137167629669e-01 | ||
5.471386901022898819e-01,8.918904215021232762e-01 | ||
2.675423374270380350e-01,6.740002378173290953e-01 | ||
3.268184989877687130e-03,7.090074769354818285e-01 | ||
8.391603466549982793e-01,5.669705723955342780e-02 | ||
7.211545439187765361e-01,9.850668938532211039e-01 | ||
4.844099981878126071e-01,4.941561918934577191e-01 | ||
1.471206538456847346e-01,7.073126583052363747e-01 | ||
9.822274374102388794e-01,8.768447162117745108e-01 | ||
9.143497127765143340e-01,6.936789555828239973e-01 | ||
6.035165577771864909e-01,5.837708035406097284e-02 | ||
5.393450860458999241e-01,1.824217117855180259e-01 | ||
8.158059495447460563e-01,7.929565696279639031e-02 | ||
8.306436465272276637e-01,4.989808136165542196e-01 | ||
2.829505133146952289e-01,9.949029553139620008e-01 | ||
1.281045915884742037e-01,7.528956565994686656e-01 | ||
1.850423725893397542e-01,3.134362407450973498e-01 | ||
7.396952149344313554e-01,8.114037505563848063e-01 | ||
6.824773309134731791e-01,5.522168947517891446e-01 | ||
3.030528002909151919e-01,1.337830333163891883e-01 | ||
5.741048404123196836e-01,4.737696799137651738e-01 | ||
8.637830826280318286e-01,5.528775710088408291e-01 | ||
1.396915958837541272e-01,6.515037182002766381e-01 | ||
4.012962574290585005e-01,5.197016038589425957e-01 | ||
7.577012975291815833e-01,3.706450040906402732e-01 | ||
2.684103900855477898e-01,7.276711576533216874e-01 | ||
9.965897847761595596e-02,1.838793361641620772e-01 | ||
2.779462995697515870e-01,2.270655270495752776e-01 | ||
8.737111277152969091e-01,5.716882344470158861e-01 | ||
6.592180916131151758e-01,8.474997569849554990e-01 | ||
5.593060704853801690e-01,6.694066835862061415e-01 | ||
3.735499482308757280e-02,8.531284190093497699e-01 | ||
8.140702971958405643e-01,7.893409870953725926e-01 | ||
8.127160367563831533e-01,5.439395437305392100e-01 | ||
8.509148259986432095e-01,2.163504345102591486e-01 | ||
2.877437959593864836e-01,7.953302541305056206e-01 | ||
5.001102093375447977e-01,3.108868299410670888e-01 | ||
6.432199670156300009e-01,5.625250966059655022e-01 | ||
7.848035624885292272e-01,7.431135737213656611e-03 | ||
3.219624042074087367e-02,1.467644557075351575e-01 | ||
6.671419740483219840e-02,3.861318887671982836e-01 | ||
4.795189499530881916e-01,9.041206996074405700e-01 | ||
4.042078160444128043e-01,5.419698992833000828e-01 | ||
1.074755515266740957e-01,5.140925691505582318e-02 | ||
1.469059233865445124e-01,9.897740054999489834e-01 | ||
9.986750459267872415e-01,5.496220608900808102e-01 | ||
5.841178111248690463e-01,1.423295359011074179e-01 | ||
6.359040826409402269e-01,9.415321709674080441e-01 | ||
7.861235814806490918e-01,2.810114352113507463e-01 | ||
2.946780257911257861e-01,4.307006849751624511e-01 | ||
7.112529598823839061e-01,2.830527648390895878e-01 | ||
2.058680579224158036e-01,8.000722853219316422e-01 | ||
1.116652935663198232e-01,7.360326682701556766e-01 | ||
3.533948303481136977e-01,9.818561212186378562e-01 | ||
5.672246039618782376e-01,7.270181639629115233e-01 | ||
2.740087283094545523e-01,2.072742667750498979e-01 | ||
1.790022781007799546e-01,6.498815999461178272e-01 | ||
8.007329814483360453e-01,8.097217794144797587e-01 | ||
9.600935891014146240e-01,9.917877775501137139e-01 | ||
3.152721787368474304e-01,7.928939419973043412e-01 | ||
6.581246756992120694e-01,3.484833634922872569e-01 | ||
3.075544003935030135e-01,8.405545291722384960e-01 | ||
7.723844355687795593e-01,4.995003530699433369e-01 | ||
8.244572772248098813e-01,4.273997286068799140e-01 | ||
8.967955894960450980e-01,8.254071783341614399e-01 | ||
7.971367566536775584e-01,7.320521387105597411e-01 | ||
3.133880876766332868e-01,3.609505297154206316e-01 | ||
1.378965426670620831e-01,9.060382547642155115e-01 | ||
2.869324686970693428e-01,2.485228388448748049e-01 | ||
8.293043938044526442e-01,3.350924672150027428e-01 | ||
4.709058863016502006e-01,5.953563465745531635e-01 | ||
8.717637347859621411e-01,9.304323976436368326e-02 | ||
9.492556445724387171e-01,8.137608387150720990e-01 | ||
3.644126361279739212e-01,3.506210381753425143e-01 | ||
6.981186511356197721e-01,8.346210791790131811e-02 | ||
1.123579513420740472e-01,8.253611616172757959e-01 |
Oops, something went wrong.