Skip to content

Commit

Permalink
made case class to deal with model selector metadata (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
leahmcguire authored Aug 9, 2018
1 parent 308fc1a commit 45164ff
Show file tree
Hide file tree
Showing 33 changed files with 1,032 additions and 556 deletions.
307 changes: 84 additions & 223 deletions core/src/main/scala/com/salesforce/op/ModelInsights.scala

Large diffs are not rendered by default.

43 changes: 22 additions & 21 deletions core/src/main/scala/com/salesforce/op/evaluators/Evaluators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ import com.salesforce.op.utils.json.JsonUtils
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.sql.Dataset

import scala.util.Try

/**
* Just a handy factory for evaluators
*/
Expand All @@ -57,7 +59,7 @@ object Evaluators {
*/
def auROC(): OpBinaryClassificationEvaluator =
new OpBinaryClassificationEvaluator(
name = BinaryClassEvalMetrics.AuROC.humanFriendlyName, isLargerBetter = true) {
name = BinaryClassEvalMetrics.AuROC, isLargerBetter = true) {
override def evaluate(dataset: Dataset[_]): Double =
getBinaryEvaluatorMetric(BinaryClassEvalMetrics.AuROC, dataset)
}
Expand All @@ -66,7 +68,7 @@ object Evaluators {
* Area under Precision/Recall curve
*/
def auPR(): OpBinaryClassificationEvaluator =
new OpBinaryClassificationEvaluator(name = BinaryClassEvalMetrics.AuPR.humanFriendlyName, isLargerBetter = true) {
new OpBinaryClassificationEvaluator(name = BinaryClassEvalMetrics.AuPR, isLargerBetter = true) {
override def evaluate(dataset: Dataset[_]): Double =
getBinaryEvaluatorMetric(BinaryClassEvalMetrics.AuPR, dataset)
}
Expand All @@ -76,7 +78,7 @@ object Evaluators {
*/
def precision(): OpBinaryClassificationEvaluator =
new OpBinaryClassificationEvaluator(
name = MultiClassEvalMetrics.Precision.humanFriendlyName, isLargerBetter = true) {
name = MultiClassEvalMetrics.Precision, isLargerBetter = true) {
override def evaluate(dataset: Dataset[_]): Double = {
import dataset.sparkSession.implicits._
new MulticlassMetrics(dataset.select(getPredictionCol, getLabelCol).as[(Double, Double)].rdd).precision(1.0)
Expand All @@ -88,7 +90,7 @@ object Evaluators {
*/
def recall(): OpBinaryClassificationEvaluator =
new OpBinaryClassificationEvaluator(
name = MultiClassEvalMetrics.Recall.humanFriendlyName, isLargerBetter = true) {
name = MultiClassEvalMetrics.Recall, isLargerBetter = true) {
override def evaluate(dataset: Dataset[_]): Double = {
import dataset.sparkSession.implicits._
new MulticlassMetrics(dataset.select(getPredictionCol, getLabelCol).as[(Double, Double)].rdd).recall(1.0)
Expand All @@ -99,7 +101,7 @@ object Evaluators {
* F1 score
*/
def f1(): OpBinaryClassificationEvaluator =
new OpBinaryClassificationEvaluator(name = MultiClassEvalMetrics.F1.humanFriendlyName, isLargerBetter = true) {
new OpBinaryClassificationEvaluator(name = MultiClassEvalMetrics.F1, isLargerBetter = true) {
override def evaluate(dataset: Dataset[_]): Double = {
import dataset.sparkSession.implicits._
new MulticlassMetrics(
Expand All @@ -112,7 +114,7 @@ object Evaluators {
*/
def error(): OpBinaryClassificationEvaluator =
new OpBinaryClassificationEvaluator(
name = MultiClassEvalMetrics.Error.humanFriendlyName, isLargerBetter = false) {
name = MultiClassEvalMetrics.Error, isLargerBetter = false) {
override def evaluate(dataset: Dataset[_]): Double =
1.0 - getMultiEvaluatorMetric(MultiClassEvalMetrics.Error, dataset)
}
Expand All @@ -135,15 +137,15 @@ object Evaluators {
new OpBinaryClassificationEvaluatorBase[SingleMetric](
uid = UID[OpBinaryClassificationEvaluatorBase[SingleMetric]]
) {
override val name: String = metricName
override val name: EvalMetric = OpEvaluatorNames.Custom(metricName, metricName)
override val isLargerBetter: Boolean = islbt
override def getDefaultMetric: SingleMetric => Double = _.value
override def evaluateAll(dataset: Dataset[_]): SingleMetric = {
import dataset.sparkSession.implicits._
val ds = dataset.select(getLabelCol, getRawPredictionCol, getProbabilityCol, getPredictionCol)
.as[(Double, OPVector#Value, OPVector#Value, Double)]
val metric = evaluateFn(ds)
SingleMetric(name, metric)
SingleMetric(name.humanFriendlyName, metric)
}
}
}
Expand All @@ -166,7 +168,7 @@ object Evaluators {
*/
def precision(): OpMultiClassificationEvaluator =
new OpMultiClassificationEvaluator(
name = MultiClassEvalMetrics.Precision.humanFriendlyName, isLargerBetter = true) {
name = MultiClassEvalMetrics.Precision, isLargerBetter = true) {
override def evaluate(dataset: Dataset[_]): Double =
getMultiEvaluatorMetric(MultiClassEvalMetrics.Precision, dataset)
}
Expand All @@ -175,7 +177,7 @@ object Evaluators {
* Weighted Recall
*/
def recall(): OpMultiClassificationEvaluator =
new OpMultiClassificationEvaluator(name = MultiClassEvalMetrics.Recall.humanFriendlyName, isLargerBetter = true) {
new OpMultiClassificationEvaluator(name = MultiClassEvalMetrics.Recall, isLargerBetter = true) {
override def evaluate(dataset: Dataset[_]): Double =
getMultiEvaluatorMetric(MultiClassEvalMetrics.Recall, dataset)
}
Expand All @@ -184,7 +186,7 @@ object Evaluators {
* F1 Score
*/
def f1(): OpMultiClassificationEvaluator =
new OpMultiClassificationEvaluator(name = MultiClassEvalMetrics.F1.humanFriendlyName, isLargerBetter = true) {
new OpMultiClassificationEvaluator(name = MultiClassEvalMetrics.F1, isLargerBetter = true) {
override def evaluate(dataset: Dataset[_]): Double =
getMultiEvaluatorMetric(MultiClassEvalMetrics.F1, dataset)
}
Expand All @@ -193,7 +195,7 @@ object Evaluators {
* Prediction Error
*/
def error(): OpMultiClassificationEvaluator =
new OpMultiClassificationEvaluator(name = MultiClassEvalMetrics.Error.humanFriendlyName, isLargerBetter = false) {
new OpMultiClassificationEvaluator(name = MultiClassEvalMetrics.Error, isLargerBetter = false) {
override def evaluate(dataset: Dataset[_]): Double =
1.0 - getMultiEvaluatorMetric(MultiClassEvalMetrics.Error, dataset)
}
Expand All @@ -216,7 +218,7 @@ object Evaluators {
new OpMultiClassificationEvaluatorBase[SingleMetric](
uid = UID[OpMultiClassificationEvaluatorBase[SingleMetric]]
) {
override val name: String = metricName
override val name: EvalMetric = OpEvaluatorNames.Custom(metricName, metricName)
override val isLargerBetter: Boolean = islbt

override def getDefaultMetric: SingleMetric => Double = _.value
Expand All @@ -227,7 +229,7 @@ object Evaluators {
.as[(Double, OPVector#Value, OPVector#Value, Double)]
try {
val metric = evaluateFn(ds)
SingleMetric(name, metric)
SingleMetric(name.humanFriendlyName, metric)
} catch {
case iae: IllegalArgumentException =>
val size = dataset.count
Expand Down Expand Up @@ -257,7 +259,7 @@ object Evaluators {
*/
def mse(): OpRegressionEvaluator =
new OpRegressionEvaluator(
name = RegressionEvalMetrics.MeanSquaredError.humanFriendlyName, isLargerBetter = false) {
name = RegressionEvalMetrics.MeanSquaredError, isLargerBetter = false) {
override def evaluate(dataset: Dataset[_]): Double =
getRegEvaluatorMetric(RegressionEvalMetrics.MeanSquaredError, dataset)
}
Expand All @@ -267,7 +269,7 @@ object Evaluators {
*/
def mae(): OpRegressionEvaluator =
new OpRegressionEvaluator(
name = RegressionEvalMetrics.MeanAbsoluteError.humanFriendlyName, isLargerBetter = false) {
name = RegressionEvalMetrics.MeanAbsoluteError, isLargerBetter = false) {
override def evaluate(dataset: Dataset[_]): Double =
getRegEvaluatorMetric(RegressionEvalMetrics.MeanAbsoluteError, dataset)
}
Expand All @@ -276,7 +278,7 @@ object Evaluators {
* R2
*/
def r2(): OpRegressionEvaluator =
new OpRegressionEvaluator(name = RegressionEvalMetrics.R2.humanFriendlyName, isLargerBetter = true) {
new OpRegressionEvaluator(name = RegressionEvalMetrics.R2, isLargerBetter = true) {
override def evaluate(dataset: Dataset[_]): Double =
getRegEvaluatorMetric(RegressionEvalMetrics.R2, dataset)
}
Expand All @@ -286,7 +288,7 @@ object Evaluators {
*/
def rmse(): OpRegressionEvaluator =
new OpRegressionEvaluator(
name = RegressionEvalMetrics.RootMeanSquaredError.humanFriendlyName, isLargerBetter = false) {
name = RegressionEvalMetrics.RootMeanSquaredError, isLargerBetter = false) {
override def evaluate(dataset: Dataset[_]): Double =
getRegEvaluatorMetric(RegressionEvalMetrics.RootMeanSquaredError, dataset)
}
Expand All @@ -309,7 +311,7 @@ object Evaluators {
new OpRegressionEvaluatorBase[SingleMetric](
uid = UID[OpRegressionEvaluatorBase[SingleMetric]]
) {
override val name: String = metricName
override val name: EvalMetric = OpEvaluatorNames.Custom(metricName, metricName)
override val isLargerBetter: Boolean = islbt

override def getDefaultMetric: SingleMetric => Double = _.value
Expand All @@ -318,7 +320,7 @@ object Evaluators {
import dataset.sparkSession.implicits._
val ds = dataset.select(getLabelCol, getPredictionCol).as[(Double, Double)]
val metric = evaluateFn(ds)
SingleMetric(name, metric)
SingleMetric(name.humanFriendlyName, metric)
}
}
}
Expand Down Expand Up @@ -349,4 +351,3 @@ case class MultiMetrics(metrics: Map[String, EvaluationMetrics]) extends Evaluat
}
override def toString: String = JsonUtils.toJsonString(this.toMap, pretty = true)
}

Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ import org.slf4j.LoggerFactory

private[op] class OpBinaryClassificationEvaluator
(
override val name: String = OpEvaluatorNames.binary,
override val name: EvalMetric = OpEvaluatorNames.Binary,
override val isLargerBetter: Boolean = true,
override val uid: String = UID[OpBinaryClassificationEvaluator],
val numBins: Int = 100
Expand Down Expand Up @@ -179,7 +179,6 @@ case class BinaryClassificationMetrics
@JsonDeserialize(contentAs = classOf[java.lang.Double])
falsePositiveRateByThreshold: Seq[Double]
) extends EvaluationMetrics {

def rocCurve: Seq[(Double, Double)] = recallByThreshold.zip(falsePositiveRateByThreshold)
def prCurve: Seq[(Double, Double)] = precisionByThreshold.zip(recallByThreshold)
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ import org.apache.spark.ml.param._
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.Metadata

import scala.util.Try


/**
* Trait for labelCol param
Expand Down Expand Up @@ -123,15 +125,13 @@ trait EvaluationMetrics extends JsonLike {
* @return a map from metric name to metric value
*/
def toMap: Map[String, Any] = JsonUtils.toMap(JsonUtils.toJsonTree(this))

/**
* Convert metrics into metadata for saving
* @return metadata
*/
def toMetadata: Metadata = this.toMap.toMetadata
}


/**
* Base Interface for OpEvaluator to be used in Evaluator creation. Can be used for both OP and spark
* eval (so with workflows and cross validation).
Expand All @@ -143,7 +143,7 @@ abstract class OpEvaluatorBase[T <: EvaluationMetrics] extends Evaluator
/**
* Name of evaluator
*/
val name: String = "Eval"
val name: EvalMetric

/**
* Evaluate function that returns a class or value with the calculated metric value(s).
Expand Down Expand Up @@ -271,7 +271,7 @@ abstract class OpRegressionEvaluatorBase[T <: EvaluationMetrics]
/**
* Eval metric
*/
trait EvalMetric extends Serializable {
trait EvalMetric extends EnumEntry with Serializable {
/**
* Spark metric name
*/
Expand All @@ -281,6 +281,21 @@ trait EvalMetric extends Serializable {
* Human friendly metric name
*/
def humanFriendlyName: String

}

/**
* Eval metric companion object
*/
object EvalMetric {

def withNameInsensitive(name: String): EvalMetric = {
BinaryClassEvalMetrics.withNameInsensitiveOption(name)
.orElse(MultiClassEvalMetrics.withNameInsensitiveOption(name))
.orElse(RegressionEvalMetrics.withNameInsensitiveOption(name))
.orElse(OpEvaluatorNames.withNameInsensitiveOption(name))
.getOrElse(OpEvaluatorNames.Custom(name, name))
}
}

/**
Expand All @@ -290,7 +305,7 @@ sealed abstract class ClassificationEvalMetric
(
val sparkEntryName: String,
val humanFriendlyName: String
) extends EnumEntry with EvalMetric
) extends EvalMetric

/**
* Binary Classification Metrics
Expand All @@ -302,7 +317,11 @@ object BinaryClassEvalMetrics extends Enum[ClassificationEvalMetric] {
case object F1 extends ClassificationEvalMetric("f1", "f1")
case object Error extends ClassificationEvalMetric("accuracy", "error")
case object AuROC extends ClassificationEvalMetric("areaUnderROC", "area under ROC")
case object AuPR extends ClassificationEvalMetric("areaUnderPR", "area under PR")
case object AuPR extends ClassificationEvalMetric("areaUnderPR", "area under precision-recall")
case object TP extends ClassificationEvalMetric("TP", "true positive")
case object TN extends ClassificationEvalMetric("TN", "true negative")
case object FP extends ClassificationEvalMetric("FP", "false positive")
case object FN extends ClassificationEvalMetric("FN", "false negative")
}

/**
Expand All @@ -325,8 +344,11 @@ sealed abstract class RegressionEvalMetric
(
val sparkEntryName: String,
val humanFriendlyName: String
) extends EnumEntry with EvalMetric
) extends EvalMetric

/**
* Regression Metrics
*/
object RegressionEvalMetrics extends Enum[RegressionEvalMetric] {
val values: Seq[RegressionEvalMetric] = findValues
case object RootMeanSquaredError extends RegressionEvalMetric("rmse", "root mean square error")
Expand All @@ -335,11 +357,34 @@ object RegressionEvalMetrics extends Enum[RegressionEvalMetric] {
case object MeanAbsoluteError extends RegressionEvalMetric("mae", "mean absolute error")
}


/**
* GeneralMetrics
*/
sealed abstract class OpEvaluatorNames
(
val sparkEntryName: String,
val humanFriendlyName: String
) extends EvalMetric

/**
* Contains evaluator names used in logging
*/
case object OpEvaluatorNames {
val binary = "binEval"
val multi = "multiEval"
val regression = "regEval"
object OpEvaluatorNames extends Enum[OpEvaluatorNames] {
val values: Seq[OpEvaluatorNames] = findValues

case object Binary extends OpEvaluatorNames("binEval", "binary evaluation metics")

case object Multi extends OpEvaluatorNames("multiEval", "multiclass evaluation metics")

case object Regression extends OpEvaluatorNames("regEval", "regression evaluation metics")

case class Custom(name: String, humanName: String) extends OpEvaluatorNames(name, humanName) {
override def entryName: String = name.toLowerCase
}

override def withName(name: String): OpEvaluatorNames = Try(super.withName(name)).getOrElse(Custom(name, name))

override def withNameInsensitive(name: String): OpEvaluatorNames = super.withNameInsensitiveOption(name)
.getOrElse(Custom(name, name))
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ import scala.collection.mutable
*/
private[op] class OpMultiClassificationEvaluator
(
override val name: String = OpEvaluatorNames.multi,
override val name: EvalMetric = OpEvaluatorNames.Multi,
override val isLargerBetter: Boolean = true,
override val uid: String = UID[OpMultiClassificationEvaluator]
) extends OpMultiClassificationEvaluatorBase[MultiClassificationMetrics](uid) {
Expand Down Expand Up @@ -307,3 +307,4 @@ case class ThresholdMetrics
@JsonDeserialize(keyAs = classOf[java.lang.Integer])
noPredictionCounts: Map[Int, Seq[Long]]
) extends EvaluationMetrics

Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ import org.slf4j.LoggerFactory

private[op] class OpRegressionEvaluator
(
override val name: String = OpEvaluatorNames.regression,
override val name: EvalMetric = OpEvaluatorNames.Regression,
override val isLargerBetter: Boolean = false,
override val uid: String = UID[OpRegressionEvaluator]
) extends OpRegressionEvaluatorBase[RegressionMetrics](uid) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@

package com.salesforce.op.stages

import com.salesforce.op.utils.json.JsonLike
import enumeratum.EnumEntry
import org.apache.spark.sql.types.Metadata

package object impl {

Expand All @@ -39,4 +41,7 @@ package object impl {
*/
trait ModelsToTry extends EnumEntry with Serializable

trait MetadataLike {
def toMetadata(): Metadata
}
}
Loading

0 comments on commit 45164ff

Please sign in to comment.