Skip to content

Commit

Permalink
[SPARK-20183][ML] Added outlierRatio arg to MLTestingUtils.testOutlie…
Browse files Browse the repository at this point in the history
…rsWithSmallWeights

## What changes were proposed in this pull request?

This is a small piece from #16722 which ultimately will add sample weights to decision trees.  This is to allow more flexibility in testing outliers since linear models and trees behave differently.

Note: The primary author when this is committed should be sethah since this is taken from his code.

## How was this patch tested?

Existing tests

Author: Joseph K. Bradley <[email protected]>

Closes #17501 from jkbradley/SPARK-20183.
  • Loading branch information
Seth Hendrickson authored and jkbradley committed Apr 5, 2017
1 parent 295747e commit a59759e
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
MLTestingUtils.testArbitrarilyScaledWeights[LinearSVCModel, LinearSVC](
dataset.as[LabeledPoint], estimator, modelEquals)
MLTestingUtils.testOutliersWithSmallWeights[LinearSVCModel, LinearSVC](
dataset.as[LabeledPoint], estimator, 2, modelEquals)
dataset.as[LabeledPoint], estimator, 2, modelEquals, outlierRatio = 3)
MLTestingUtils.testOversamplingVsWeighting[LinearSVCModel, LinearSVC](
dataset.as[LabeledPoint], estimator, modelEquals, 42L)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1874,7 +1874,7 @@ class LogisticRegressionSuite
MLTestingUtils.testArbitrarilyScaledWeights[LogisticRegressionModel, LogisticRegression](
dataset.as[LabeledPoint], estimator, modelEquals)
MLTestingUtils.testOutliersWithSmallWeights[LogisticRegressionModel, LogisticRegression](
dataset.as[LabeledPoint], estimator, numClasses, modelEquals)
dataset.as[LabeledPoint], estimator, numClasses, modelEquals, outlierRatio = 3)
MLTestingUtils.testOversamplingVsWeighting[LogisticRegressionModel, LogisticRegression](
dataset.as[LabeledPoint], estimator, modelEquals, seed)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
MLTestingUtils.testArbitrarilyScaledWeights[NaiveBayesModel, NaiveBayes](
dataset.as[LabeledPoint], estimatorNoSmoothing, modelEquals)
MLTestingUtils.testOutliersWithSmallWeights[NaiveBayesModel, NaiveBayes](
dataset.as[LabeledPoint], estimatorWithSmoothing, numClasses, modelEquals)
dataset.as[LabeledPoint], estimatorWithSmoothing, numClasses, modelEquals, outlierRatio = 3)
MLTestingUtils.testOversamplingVsWeighting[NaiveBayesModel, NaiveBayes](
dataset.as[LabeledPoint], estimatorWithSmoothing, modelEquals, seed)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,8 @@ class LinearRegressionSuite
MLTestingUtils.testArbitrarilyScaledWeights[LinearRegressionModel, LinearRegression](
datasetWithStrongNoise.as[LabeledPoint], estimator, modelEquals)
MLTestingUtils.testOutliersWithSmallWeights[LinearRegressionModel, LinearRegression](
datasetWithStrongNoise.as[LabeledPoint], estimator, numClasses, modelEquals)
datasetWithStrongNoise.as[LabeledPoint], estimator, numClasses, modelEquals,
outlierRatio = 3)
MLTestingUtils.testOversamplingVsWeighting[LinearRegressionModel, LinearRegression](
datasetWithStrongNoise.as[LabeledPoint], estimator, modelEquals, seed)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,13 @@ object MLTestingUtils extends SparkFunSuite {
data: Dataset[LabeledPoint],
estimator: E with HasWeightCol,
numClasses: Int,
modelEquals: (M, M) => Unit): Unit = {
modelEquals: (M, M) => Unit,
outlierRatio: Int): Unit = {
import data.sqlContext.implicits._
val outlierDS = data.withColumn("weight", lit(1.0)).as[Instance].flatMap {
case Instance(l, w, f) =>
val outlierLabel = if (numClasses == 0) -l else numClasses - l - 1
List.fill(3)(Instance(outlierLabel, 0.0001, f)) ++ List(Instance(l, w, f))
List.fill(outlierRatio)(Instance(outlierLabel, 0.0001, f)) ++ List(Instance(l, w, f))
}
val trueModel = estimator.set(estimator.weightCol, "").fit(data)
val outlierModel = estimator.set(estimator.weightCol, "weight").fit(outlierDS)
Expand Down

0 comments on commit a59759e

Please sign in to comment.