Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Ram Sriharsha committed May 22, 2015
1 parent 2f76295 commit 46c41b1
Showing 1 changed file with 24 additions and 23 deletions.
47 changes: 24 additions & 23 deletions docs/ml-ensembles.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,11 @@ println(metrics.confusionMatrix)
// the Iris DataSet has three classes
val numClasses = 3

val fprs = (0 until numClasses).map(label => label + "\t" + metrics.falsePositiveRate(label.toDouble)).mkString("\n")
println("label\tfpr\n" + fprs)
println("label\tfpr\n")
(0 until numClasses).foreach { index =>
val label = index.toDouble
println(label + "\t" + metrics.falsePositiveRate(label))
}
{% endhighlight %}
</div>
<div data-lang="java" markdown="1">
Expand All @@ -67,38 +70,37 @@ import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.OneVsRest;
import org.apache.spark.ml.classification.OneVsRestModel;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;

SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample");
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext jsql = new SQLContext(jsc);

RDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(),
"data/mllib/sample_multiclass_classification_data.txt");

RDD<LabeledPoint>[] split = data.randomSplit(new double[]{0.7, 0.3}, 12345);
RDD<LabeledPoint> train = split[0];
RDD<LabeledPoint> test = split[1];
DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class);
DataFrame[] splits = dataFrame.randomSplit(new double[]{0.7, 0.3}, 12345);
DataFrame train = splits[0];
DataFrame test = splits[1];

// instantiate the One Vs Rest Classifier
OneVsRest ovr = new OneVsRest().setClassifier(new LogisticRegression());

// train the multiclass model
DataFrame trainingDataFrame = jsql.createDataFrame(train, LabeledPoint.class);
OneVsRestModel ovrModel = ovr.fit(trainingDataFrame.cache());
OneVsRestModel ovrModel = ovr.fit(train.cache());

// score the model on test data
DataFrame testDataFrame = jsql.createDataFrame(test, LabeledPoint.class);
DataFrame predictions = ovrModel
.transform(testDataFrame.cache())
.select("prediction", "label");
.transform(test)
.select("prediction", "label");

// obtain metrics
MulticlassMetrics metrics = new MulticlassMetrics(predictions);
Expand All @@ -109,20 +111,19 @@ System.out.println("Confusion Matrix");
System.out.println(confusionMatrix);

// compute the false positive rate per label
StringBuilder results = new StringBuilder();
results.append("label\tfpr\n");
System.out.println();
System.out.println("label\tfpr\n");

// the Iris DataSet has three classes
int numClasses = 3;

for (int label = 0; label < numClasses; label++) {
results.append(label);
results.append("\t");
results.append(metrics.falsePositiveRate((double) label));
results.append("\n");
for (int index = 0; index < numClasses; index++) {
double label = (double) index;
System.out.print(label);
System.out.print("\t");
System.out.print(metrics.falsePositiveRate(label));
System.out.println();
}
System.out.println();
System.out.println(results);

{% endhighlight %}
</div>
</div>

0 comments on commit 46c41b1

Please sign in to comment.