Skip to content

Commit

Permalink
Addressing reviewers comments: Scala style
Browse files Browse the repository at this point in the history
  • Loading branch information
avulanov committed Sep 15, 2014
1 parent cf4222b commit 517a594
Showing 1 changed file with 34 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.SparkContext._
*/
class MultilabelMetrics(predictionAndLabels: RDD[(Set[Double], Set[Double])]) {

private lazy val numDocs: Long = predictionAndLabels.count
private lazy val numDocs: Long = predictionAndLabels.count()

private lazy val numLabels: Long = predictionAndLabels.flatMap { case (_, labels) =>
labels}.distinct.count
Expand All @@ -36,59 +36,68 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Set[Double], Set[Double])]) {
* (for equal sets of labels)
*/
lazy val subsetAccuracy: Double = predictionAndLabels.filter { case (predictions, labels) =>
predictions == labels}.count.toDouble / numDocs
predictions == labels
}.count().toDouble / numDocs

/**
* Returns accuracy
*/
lazy val accuracy: Double = predictionAndLabels.map { case (predictions, labels) =>
labels.intersect(predictions).size.toDouble / labels.union(predictions).size}.sum / numDocs
labels.intersect(predictions).size.toDouble / labels.union(predictions).size
}.sum / numDocs

/**
* Returns Hamming-loss
*/
lazy val hammingLoss: Double = (predictionAndLabels.map { case (predictions, labels) =>
labels.diff(predictions).size + predictions.diff(labels).size}.
sum).toDouble / (numDocs * numLabels)
lazy val hammingLoss: Double = predictionAndLabels.map { case (predictions, labels) =>
labels.size + predictions.size - 2 * labels.intersect(predictions).size
}.sum / (numDocs * numLabels)

/**
* Returns document-based precision averaged by the number of documents
*/
lazy val precision: Double = (predictionAndLabels.map { case (predictions, labels) =>
lazy val precision: Double = predictionAndLabels.map { case (predictions, labels) =>
if (predictions.size > 0) {
predictions.intersect(labels).size.toDouble / predictions.size
} else 0
}.sum) / numDocs
} else {
0
}
}.sum / numDocs

/**
* Returns document-based recall averaged by the number of documents
*/
lazy val recall: Double = (predictionAndLabels.map { case (predictions, labels) =>
labels.intersect(predictions).size.toDouble / labels.size}.sum) / numDocs
lazy val recall: Double = predictionAndLabels.map { case (predictions, labels) =>
labels.intersect(predictions).size.toDouble / labels.size
}.sum / numDocs

/**
* Returns document-based f1-measure averaged by the number of documents
*/
lazy val f1Measure: Double = (predictionAndLabels.map { case (predictions, labels) =>
2.0 * predictions.intersect(labels).size / (predictions.size + labels.size)}.sum) / numDocs
lazy val f1Measure: Double = predictionAndLabels.map { case (predictions, labels) =>
2.0 * predictions.intersect(labels).size / (predictions.size + labels.size)
}.sum / numDocs


private lazy val tpPerClass = predictionAndLabels.flatMap { case (predictions, labels) =>
predictions.intersect(labels).map(category => (category, 1))}.reduceByKey(_ + _).collectAsMap()
predictions.intersect(labels)
}.countByValue()

private lazy val fpPerClass = predictionAndLabels.flatMap { case(predictions, labels) =>
predictions.diff(labels).map(category => (category, 1))}.reduceByKey(_ + _).collectAsMap()
private lazy val fpPerClass = predictionAndLabels.flatMap { case (predictions, labels) =>
predictions.diff(labels)
}.countByValue()

private lazy val fnPerClass = predictionAndLabels.flatMap{ case(predictions, labels) =>
labels.diff(predictions).map(category => (category, 1))}.reduceByKey(_ + _).collectAsMap()
private lazy val fnPerClass = predictionAndLabels.flatMap { case(predictions, labels) =>
labels.diff(predictions)
}.countByValue()

/**
* Returns precision for a given label (category)
* @param label the label.
*/
def precision(label: Double) = {
val tp = tpPerClass(label)
val fp = fpPerClass.getOrElse(label, 0)
val fp = fpPerClass.getOrElse(label, 0L)
if (tp + fp == 0) 0 else tp.toDouble / (tp + fp)
}

Expand All @@ -98,7 +107,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Set[Double], Set[Double])]) {
*/
def recall(label: Double) = {
val tp = tpPerClass(label)
val fn = fnPerClass.getOrElse(label, 0)
val fn = fnPerClass.getOrElse(label, 0L)
if (tp + fn == 0) 0 else tp.toDouble / (tp + fn)
}

Expand All @@ -112,16 +121,16 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Set[Double], Set[Double])]) {
if((p + r) == 0) 0 else 2 * p * r / (p + r)
}

private lazy val sumTp = tpPerClass.foldLeft(0L){ case (sum, (_, tp)) => sum + tp}
private lazy val sumFpClass = fpPerClass.foldLeft(0L){ case (sum, (_, fp)) => sum + fp}
private lazy val sumFnClass = fnPerClass.foldLeft(0L){ case (sum, (_, fn)) => sum + fn}
private lazy val sumTp = tpPerClass.foldLeft(0L) { case (sum, (_, tp)) => sum + tp }
private lazy val sumFpClass = fpPerClass.foldLeft(0L) { case (sum, (_, fp)) => sum + fp }
private lazy val sumFnClass = fnPerClass.foldLeft(0L) { case (sum, (_, fn)) => sum + fn }

/**
* Returns micro-averaged label-based precision
* (equals to micro-averaged document-based precision)
*/
lazy val microPrecision = {
val sumFp = fpPerClass.foldLeft(0L){ case(sumFp, (_, fp)) => sumFp + fp}
val sumFp = fpPerClass.foldLeft(0L){ case(cum, (_, fp)) => cum + fp}
sumTp.toDouble / (sumTp + sumFp)
}

Expand All @@ -130,7 +139,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Set[Double], Set[Double])]) {
* (equals to micro-averaged document-based recall)
*/
lazy val microRecall = {
val sumFn = fnPerClass.foldLeft(0.0){ case(sumFn, (_, fn)) => sumFn + fn}
val sumFn = fnPerClass.foldLeft(0.0){ case(cum, (_, fn)) => cum + fn}
sumTp.toDouble / (sumTp + sumFn)
}

Expand Down

0 comments on commit 517a594

Please sign in to comment.