Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-20348] [ML] Support squared hinge loss (L2 loss) for LinearSVC #17645

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.ml.classification

import java.util.Locale

import scala.collection.mutable

import breeze.linalg.{DenseVector => BDV}
Expand All @@ -42,15 +44,35 @@ import org.apache.spark.sql.functions.{col, lit}
/** Params for linear SVM Classifier. */
private[classification] trait LinearSVCParams extends ClassifierParams with HasRegParam
with HasMaxIter with HasFitIntercept with HasTol with HasStandardization with HasWeightCol
with HasThreshold with HasAggregationDepth
with HasThreshold with HasAggregationDepth {

/**
* Specifies the loss function. Currently "hinge" and "squared_hinge" are supported.
* "hinge" is the standard SVM loss (a.k.a. L1 loss) while "squared_hinge" is the square of
* the hinge loss (a.k.a. L2 loss).
*
* @see <a href="https://en.wikipedia.org/wiki/Hinge_loss">Hinge loss (Wikipedia)</a>
*
* @group param
*/
@Since("2.3.0")
final val lossFunction: Param[String] = new Param(this, "lossFunction", "Specifies the loss " +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to move this out to shared params, since it can be used by other algorithms as well. Thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure we can do it.
But I'm thinking maybe we should conduct an integrated refactor about the common optimization parameters some time in the future, either through shared params or other trait or abstract class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, let leave as is and refactor in the future. One minor issue: What about renaming it to loss? I found the name of corresponding params in sklearn.svm.linearSVC is loss. Thanks.

"function. hinge is the standard SVM loss while squared_hinge is the square of the hinge loss.",
(s: String) => LinearSVC.supportedLoss.contains(s.toLowerCase(Locale.ROOT)))

/** @group getParam */
@Since("2.3.0")
def getLossFunction: String = $(lossFunction)
}

/**
* :: Experimental ::
*
* <a href = "https://en.wikipedia.org/wiki/Support_vector_machine#Linear_SVM">
* Linear SVM Classifier</a>
*
* This binary classifier optimizes the Hinge Loss using the OWLQN optimizer.
* This binary classifier optimizes the Hinge Loss (or Squared Hinge Loss) using the
* OWLQN optimizer.
*
*/
@Since("2.2.0")
Expand All @@ -63,6 +85,15 @@ class LinearSVC @Since("2.2.0") (
@Since("2.2.0")
def this() = this(Identifiable.randomUID("linearsvc"))

/**
* Set the loss function. Default is "hinge".
*
* @group setParam
*/
@Since("2.3.0")
def setLossFunction(value: String): this.type = set(lossFunction, value)
setDefault(lossFunction -> "hinge")

/**
* Set the regularization parameter.
* Default is 0.0.
Expand Down Expand Up @@ -202,8 +233,8 @@ class LinearSVC @Since("2.2.0") (
val featuresStd = summarizer.variance.toArray.map(math.sqrt)
val regParamL2 = $(regParam)
val bcFeaturesStd = instances.context.broadcast(featuresStd)
val costFun = new LinearSVCCostFun(instances, $(fitIntercept),
$(standardization), bcFeaturesStd, regParamL2, $(aggregationDepth))
val costFun = new LinearSVCCostFun(instances, $(fitIntercept), $(standardization),
bcFeaturesStd, regParamL2, $(aggregationDepth), $(lossFunction)toLowerCase(Locale.ROOT))

def regParamL1Fun = (index: Int) => 0D
val optimizer = new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol))
Expand Down Expand Up @@ -260,6 +291,8 @@ object LinearSVC extends DefaultParamsReadable[LinearSVC] {

@Since("2.2.0")
override def load(path: String): LinearSVC = super.load(path)

private[classification] val supportedLoss = Array("hinge", "squared_hinge")
}

/**
Expand Down Expand Up @@ -355,15 +388,17 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] {
}

/**
* LinearSVCCostFun implements Breeze's DiffFunction[T] for hinge loss function
* LinearSVCCostFun implements Breeze's DiffFunction[T] for loss function ("hinge" or
* "squared_hinge").
*/
private class LinearSVCCostFun(
instances: RDD[Instance],
fitIntercept: Boolean,
standardization: Boolean,
bcFeaturesStd: Broadcast[Array[Double]],
regParamL2: Double,
aggregationDepth: Int) extends DiffFunction[BDV[Double]] {
aggregationDepth: Int,
lossFunction: String) extends DiffFunction[BDV[Double]] {

override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
val coeffs = Vectors.fromBreeze(coefficients)
Expand All @@ -376,7 +411,7 @@ private class LinearSVCCostFun(
val combOp = (c1: LinearSVCAggregator, c2: LinearSVCAggregator) => c1.merge(c2)

instances.treeAggregate(
new LinearSVCAggregator(bcCoeffs, bcFeaturesStd, fitIntercept)
new LinearSVCAggregator(bcCoeffs, bcFeaturesStd, fitIntercept, lossFunction)
)(seqOp, combOp, aggregationDepth)
}

Expand Down Expand Up @@ -421,8 +456,9 @@ private class LinearSVCCostFun(
}

/**
* LinearSVCAggregator computes the gradient and loss for hinge loss function, as used
* in binary classification for instances in sparse or dense vector in an online fashion.
* LinearSVCAggregator computes the gradient and loss for loss function ("hinge" or
* "squared_hinge"), as used in binary classification for instances in sparse or dense vector
* in an online fashion.
*
* Two LinearSVCAggregator can be merged together to have a summary of loss and gradient of
* the corresponding joint dataset.
Expand All @@ -436,7 +472,8 @@ private class LinearSVCCostFun(
private class LinearSVCAggregator(
bcCoefficients: Broadcast[Vector],
bcFeaturesStd: Broadcast[Array[Double]],
fitIntercept: Boolean) extends Serializable {
fitIntercept: Boolean,
lossFunction: String) extends Serializable {

private val numFeatures: Int = bcFeaturesStd.value.length
private val numFeaturesPlusIntercept: Int = if (fitIntercept) numFeatures + 1 else numFeatures
Expand Down Expand Up @@ -477,16 +514,26 @@ private class LinearSVCAggregator(
sum
}
// Our loss function with {0, 1} labels is max(0, 1 - (2y - 1) (f_w(x)))
// Therefore the gradient is -(2y - 1)*x
val labelScaled = 2 * label - 1.0
val loss = if (1.0 > labelScaled * dotProduct) {
weight * (1.0 - labelScaled * dotProduct)
val hingeLoss = 1.0 - labelScaled * dotProduct
lossFunction match {
case "hinge" => hingeLoss * weight
case "squared_hinge" => hingeLoss * hingeLoss * weight
case unexpected => throw new SparkException(
s"unexpected lossFunction in LinearSVCAggregator: $unexpected")
}
} else {
0.0
}

if (1.0 > labelScaled * dotProduct) {
val gradientScale = -labelScaled * weight
val gradientScale = lossFunction match {
case "hinge" => -labelScaled * weight
case "squared_hinge" => (labelScaled * dotProduct - 1) * labelScaled * 2
case unexpected => throw new SparkException(
s"unexpected lossFunction in LinearSVCAggregator: $unexpected")
}
features.foreachActive { (index, value) =>
if (localFeaturesStd(index) != 0.0 && value != 0.0) {
localGradientSumArray(index) += value * gradientScale / localFeaturesStd(index)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,14 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
}

test("Linear SVC binary classification") {
val svm = new LinearSVC()
val model = svm.fit(smallBinaryDataset)
assert(model.transform(smallValidationDataset)
.where("prediction=label").count() > nPoints * 0.8)
val sparseModel = svm.fit(smallSparseBinaryDataset)
checkModels(model, sparseModel)
Array("hinge", "squared_hinge").foreach { loss =>
val svm = new LinearSVC().setLossFunction(loss)
val model = svm.fit(smallBinaryDataset)
assert(model.transform(smallValidationDataset)
.where("prediction=label").count() > nPoints * 0.8)
val sparseModel = svm.fit(smallSparseBinaryDataset)
checkModels(model, sparseModel)
}
}

test("Linear SVC binary classification with regularization") {
Expand All @@ -100,6 +102,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau

test("linear svc: default params") {
val lsvc = new LinearSVC()
assert(lsvc.getLossFunction === "hinge")
assert(lsvc.getRegParam === 0.0)
assert(lsvc.getMaxIter === 100)
assert(lsvc.getFitIntercept)
Expand All @@ -116,6 +119,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
model.transform(smallBinaryDataset)
.select("label", "prediction", "rawPrediction")
.collect()
assert(model.getLossFunction === "hinge")
assert(model.getThreshold === 0.0)
assert(model.getFeaturesCol === "features")
assert(model.getPredictionCol === "prediction")
Expand All @@ -125,6 +129,14 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
assert(model.numFeatures === 2)

MLTestingUtils.checkCopyAndUids(lsvc, model)

withClue("lossFunction should be case-insensitive") {
lsvc.setLossFunction("HINGE")
lsvc.setLossFunction("Squared_hinge")
intercept[IllegalArgumentException] {
val model = lsvc.setLossFunction("hing")
}
}
}

test("linear svc doesn't fit intercept when fitIntercept is off") {
Expand All @@ -140,7 +152,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
test("sparse coefficients in SVCAggregator") {
val bcCoefficients = spark.sparkContext.broadcast(Vectors.sparse(2, Array(0), Array(1.0)))
val bcFeaturesStd = spark.sparkContext.broadcast(Array(1.0))
val agg = new LinearSVCAggregator(bcCoefficients, bcFeaturesStd, true)
val agg = new LinearSVCAggregator(bcCoefficients, bcFeaturesStd, true, "hinge")
val thrown = withClue("LinearSVCAggregator cannot handle sparse coefficients") {
intercept[IllegalArgumentException] {
agg.add(Instance(1.0, 1.0, Vectors.dense(1.0)))
Expand Down Expand Up @@ -168,7 +180,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
dataset.as[LabeledPoint], estimator, modelEquals, 42L)
}

test("linearSVC comparison with R e1071 and scikit-learn") {
test("linearSVC with hinge loss comparison with R e1071 and scikit-learn (liblinear)") {
val trainer1 = new LinearSVC()
.setRegParam(0.00002) // set regParam = 2.0 / datasize / c
.setMaxIter(200)
Expand Down Expand Up @@ -223,6 +235,38 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
assert(model1.coefficients ~== coefficientsSK relTol 4E-3)
}

test("linearSVC with squared_hinge loss comparison with scikit-learn (liblinear)") {
val linearSVC = new LinearSVC()
.setLossFunction("squared_hinge")
.setRegParam(2.0 / 10 / 1000) // set regParam = 2.0 / datasize / c
.setMaxIter(80)
.setTol(1e-4)
val model = linearSVC.fit(binaryDataset.limit(1000))

/*
Use the following python code to load the data and train the model using scikit-learn package.

import numpy as np
from sklearn import svm
f = open("path/spark/assembly/target/tmp/LinearSVC/binaryDataset/part-00000")
data = np.loadtxt(f, delimiter=",")[:1000]
X = data[:, 1:] # select columns 1 through end
y = data[:, 0] # select column 0 as label
clf = svm.LinearSVC(fit_intercept=True, C=10, loss='squared_hinge', tol=1e-4, random_state=42)
m = clf.fit(X, y)
print m.coef_
print m.intercept_

[[ 2.85136074 6.25310456 9.00668415 12.17750981]]
[ 2.93419973]
*/

val coefficientsSK = Vectors.dense(2.85136074, 6.25310456, 9.00668415, 12.17750981)
val interceptSK = 2.93419973
assert(model.intercept ~== interceptSK relTol 2E-2)
assert(model.coefficients ~== coefficientsSK relTol 2E-2)
}

test("read/write: SVM") {
def checkModelData(model: LinearSVCModel, model2: LinearSVCModel): Unit = {
assert(model.intercept === model2.intercept)
Expand All @@ -238,6 +282,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
object LinearSVCSuite {

val allParamSettings: Map[String, Any] = Map(
"lossFunction" -> "squared_hinge",
"regParam" -> 0.01,
"maxIter" -> 2, // intentionally small
"fitIntercept" -> true,
Expand Down