Skip to content

Commit

Permalink
Fixed bug from last commit (sorting paramMap by parameter names in to…
Browse files Browse the repository at this point in the history
…String). Fixed bug in persisting logreg data. Added threshold_internal to logreg for faster test-time prediction (avoiding map lookup).
  • Loading branch information
jkbradley committed Feb 5, 2015
1 parent 601e792 commit 0617d61
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ private[classification] trait LogisticRegressionParams extends ClassifierParams

/**
* Logistic regression.
* Currently, this class only supports binary classification.
*/
class LogisticRegression extends Classifier[LogisticRegression, LogisticRegressionModel]
with LogisticRegressionParams {
Expand All @@ -71,7 +72,8 @@ class LogisticRegression extends Classifier[LogisticRegression, LogisticRegressi
val oldDataset = dataset.map { case LabeledPoint(label: Double, features: Vector, weight) =>
org.apache.spark.mllib.regression.LabeledPoint(label, features)
}
val handlePersistence = oldDataset.getStorageLevel == StorageLevel.NONE
// If dataset is persisted, do not persist oldDataset.
val handlePersistence = dataset.getStorageLevel == StorageLevel.NONE
if (handlePersistence) {
oldDataset.persist(StorageLevel.MEMORY_AND_DISK)
}
Expand All @@ -84,6 +86,7 @@ class LogisticRegression extends Classifier[LogisticRegression, LogisticRegressi
if (handlePersistence) {
oldDataset.unpersist()
}
lrm.setThreshold(paramMap(threshold))
lrm
}
}
Expand All @@ -103,9 +106,15 @@ class LogisticRegressionModel private[ml] (
with ProbabilisticClassificationModel
with LogisticRegressionParams {

def setThreshold(value: Double): this.type = set(threshold, value)
def setThreshold(value: Double): this.type = {
this.threshold_internal = value
set(threshold, value)
}
def setScoreCol(value: String): this.type = set(scoreCol, value)

/** Store for faster test-time prediction. */
private var threshold_internal: Double = this.getThreshold

private val margin: Vector => Double = (features) => {
BLAS.dot(features, weights) + intercept
}
Expand All @@ -121,11 +130,8 @@ class LogisticRegressionModel private[ml] (
val scoreFunction = udf { v: Vector =>
val margin = BLAS.dot(v, weights)
1.0 / (1.0 + math.exp(-margin))
}
val t = map(threshold)
val predictFunction = udf { score: Double =>
if (score > t) 1.0 else 0.0
}
val t = threshold_internal
val predictFunction: Double => Double = (score) => { if (score > t) 1.0 else 0.0 }
dataset
.select($"*", scoreFunction(col(map(featuresCol))).as(map(scoreCol)))
.select($"*", predictFunction(col(map(scoreCol))).as(map(predictionCol)))
Expand All @@ -138,7 +144,7 @@ class LogisticRegressionModel private[ml] (
* The behavior of this can be adjusted using [[threshold]].
*/
override def predict(features: Vector): Double = {
if (score(features) > paramMap(threshold)) 1 else 0
if (score(features) > threshold_internal) 1 else 0
}

override def predictProbabilities(features: Vector): Vector = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
def copy: ParamMap = new ParamMap(map.clone())

override def toString: String = {
map.toSeq.sorted.map { case (param, value) =>
map.toSeq.sortBy(_._1.name).map { case (param, value) =>
s"\t${param.parent.uid}-${param.name}: $value"
}.mkString("{\n", ",\n", "\n}")
}
Expand Down

0 comments on commit 0617d61

Please sign in to comment.