diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index a32302bf5dfc8..116f0f6507852 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -57,7 +57,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam final val family: Param[String] = new Param(this, "family", "The name of family which is a description of the error distribution to be used in the " + s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.", - ParamValidators.inArray[String](supportedFamilyNames.toArray)) + (value: String) => supportedFamilyNames.contains(value.toLowerCase)) /** @group getParam */ @Since("2.0.0") @@ -74,7 +74,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam final val link: Param[String] = new Param(this, "link", "The name of link function " + "which provides the relationship between the linear predictor and the mean of the " + s"distribution function. Supported options: ${supportedLinkNames.mkString(", ")}", - ParamValidators.inArray[String](supportedLinkNames.toArray)) + (value: String) => supportedLinkNames.contains(value.toLowerCase)) /** @group getParam */ @Since("2.0.0") @@ -414,7 +414,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * @param name family name: "gaussian", "binomial", "poisson" or "gamma". */ def fromName(name: String): Family = { - name match { + name.toLowerCase match { case Gaussian.name => Gaussian case Binomial.name => Binomial case Poisson.name => Poisson @@ -626,7 +626,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * "inverse", "probit", "cloglog" or "sqrt". */ def fromName(name: String): Link = { - name match { + name.toLowerCase match { case Identity.name => Identity case Logit.name => Logit case Log.name => Log diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index ed24c1e16a130..9f3d643c2bb0c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -553,7 +553,7 @@ class GeneralizedLinearRegressionSuite for ((link, dataset) <- Seq(("inverse", datasetGammaInverse), ("identity", datasetGammaIdentity), ("log", datasetGammaLog))) { for (fitIntercept <- Seq(false, true)) { - val trainer = new GeneralizedLinearRegression().setFamily("gamma").setLink(link) + val trainer = new GeneralizedLinearRegression().setFamily("Gamma").setLink(link) .setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction") val model = trainer.fit(dataset) val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) @@ -990,7 +990,7 @@ class GeneralizedLinearRegressionSuite -0.6344390 0.3172195 0.2114797 -0.1586097 */ val trainer = new GeneralizedLinearRegression() - .setFamily("gamma") + .setFamily("Gamma") .setWeightCol("weight") val model = trainer.fit(datasetWithWeight)