From 472d75777c98439522a6e3226cfb83f7cc1dd00d Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 20 Apr 2015 13:20:58 -0700 Subject: [PATCH] fix code style --- .../spark/mllib/pmml/PMMLExportable.scala | 31 +++-- .../GeneralizedLinearPMMLModelExport.scala | 70 ++++------ .../pmml/export/KMeansPMMLModelExport.scala | 122 +++++++--------- .../LogisticRegressionPMMLModelExport.scala | 79 ++++------- .../mllib/pmml/export/PMMLModelExport.scala | 24 ++-- .../pmml/export/PMMLModelExportFactory.scala | 37 +++-- ...eneralizedLinearPMMLModelExportSuite.scala | 131 ++++++++---------- .../export/KMeansPMMLModelExportSuite.scala | 36 ++--- ...gisticRegressionPMMLModelExportSuite.scala | 51 +++---- .../export/PMMLModelExportFactorySuite.scala | 97 +++++-------- 10 files changed, 273 insertions(+), 405 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala index 988271fae292f..938a7998cdf5f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala @@ -21,9 +21,10 @@ import java.io.File import java.io.OutputStream import java.io.StringWriter import javax.xml.transform.stream.StreamResult + import org.jpmml.model.JAXBUtil + import org.apache.spark.SparkContext -import org.apache.spark.mllib.pmml.export.PMMLModelExport import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory /** @@ -34,42 +35,42 @@ import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory trait PMMLExportable { /** - * Export the model to the stream result in PMML format - */ + * Export the model to the stream result in PMML format + */ private def toPMML(streamResult: StreamResult): Unit = { val pmmlModelExport = PMMLModelExportFactory.createPMMLModelExport(this) - JAXBUtil.marshalPMML(pmmlModelExport.getPmml(), streamResult) + JAXBUtil.marshalPMML(pmmlModelExport.getPmml, streamResult) } /** - * Export the model to a local File in PMML format - */ + * Export the model to a local file in PMML format + */ def toPMML(localPath: String): Unit = { toPMML(new StreamResult(new File(localPath))) } /** - * Export the model to a distributed file in PMML format - */ + * Export the model to a directory on a distributed file system in PMML format + */ def toPMML(sc: SparkContext, path: String): Unit = { val pmml = toPMML() - sc.parallelize(Array(pmml),1).saveAsTextFile(path) + sc.parallelize(Array(pmml), 1).saveAsTextFile(path) } /** - * Export the model to the Outputtream in PMML format - */ + * Export the model to the OutputStream in PMML format + */ def toPMML(outputStream: OutputStream): Unit = { toPMML(new StreamResult(outputStream)) } /** - * Export the model to a String in PMML format - */ + * Export the model to a String in PMML format + */ def toPMML(): String = { - var writer = new StringWriter(); + val writer = new StringWriter toPMML(new StreamResult(writer)) - return writer.toString(); + writer.toString } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala index 94bbd705a9b69..baab1a2dbf963 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala @@ -17,18 +17,10 @@ package org.apache.spark.mllib.pmml.export -import org.dmg.pmml.DataDictionary -import org.dmg.pmml.DataField -import org.dmg.pmml.DataType -import org.dmg.pmml.FieldName -import org.dmg.pmml.FieldUsageType -import org.dmg.pmml.MiningField -import org.dmg.pmml.MiningFunctionType -import org.dmg.pmml.MiningSchema -import org.dmg.pmml.NumericPredictor -import org.dmg.pmml.OpType -import org.dmg.pmml.RegressionModel -import org.dmg.pmml.RegressionTable +import scala.{Array => SArray} + +import org.dmg.pmml._ + import org.apache.spark.mllib.regression.GeneralizedLinearModel /** @@ -39,55 +31,43 @@ private[mllib] class GeneralizedLinearPMMLModelExport( description : String) extends PMMLModelExport{ + populateGeneralizedLinearPMML(model) + /** - * Export the input GeneralizedLinearModel model to PMML format + * Export the input GeneralizedLinearModel model to PMML format. */ - populateGeneralizedLinearPMML(model) - - private def populateGeneralizedLinearPMML(model : GeneralizedLinearModel): Unit = { + private def populateGeneralizedLinearPMML(model: GeneralizedLinearModel): Unit = { + pmml.getHeader.setDescription(description) - pmml.getHeader().setDescription(description) - if(model.weights.size > 0){ - - val fields = new Array[FieldName](model.weights.size) - - val dataDictionary = new DataDictionary() - - val miningSchema = new MiningSchema() - + val fields = new SArray[FieldName](model.weights.size) + val dataDictionary = new DataDictionary + val miningSchema = new MiningSchema val regressionTable = new RegressionTable(model.intercept) - val regressionModel = new RegressionModel(miningSchema,MiningFunctionType.REGRESSION) - .withModelName(description).withRegressionTables(regressionTable) - - for ( i <- 0 until model.weights.size) { + .withModelName(description) + .withRegressionTables(regressionTable) + + for (i <- 0 until model.weights.size) { fields(i) = FieldName.create("field_" + i) - dataDictionary - .withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) miningSchema - .withMiningFields(new MiningField(fields(i)) - .withUsageType(FieldUsageType.ACTIVE)) + .withMiningFields(new MiningField(fields(i)) + .withUsageType(FieldUsageType.ACTIVE)) regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) } // for completeness add target field - val targetField = FieldName.create("target"); - dataDictionary - .withDataFields( - new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE) - ) - miningSchema + val targetField = FieldName.create("target") + dataDictionary.withDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE)) + miningSchema .withMiningFields(new MiningField(targetField) .withUsageType(FieldUsageType.TARGET)) - - dataDictionary.withNumberOfFields((dataDictionary.getDataFields()).size()) - + + dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + pmml.setDataDictionary(dataDictionary) pmml.withModels(regressionModel) - } - } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala index 901fbb6858a20..c12b275b2185c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala @@ -17,24 +17,10 @@ package org.apache.spark.mllib.pmml.export -import org.dmg.pmml.Array.Type -import org.dmg.pmml.Cluster -import org.dmg.pmml.ClusteringField -import org.dmg.pmml.ClusteringModel -import org.dmg.pmml.ClusteringModel.ModelClass -import org.dmg.pmml.CompareFunctionType -import org.dmg.pmml.ComparisonMeasure -import org.dmg.pmml.ComparisonMeasure.Kind -import org.dmg.pmml.DataDictionary -import org.dmg.pmml.DataField -import org.dmg.pmml.DataType -import org.dmg.pmml.FieldName -import org.dmg.pmml.FieldUsageType -import org.dmg.pmml.MiningField -import org.dmg.pmml.MiningFunctionType -import org.dmg.pmml.MiningSchema -import org.dmg.pmml.OpType -import org.dmg.pmml.SquaredEuclidean +import scala.{Array => SArray} + +import org.dmg.pmml._ + import org.apache.spark.mllib.clustering.KMeansModel /** @@ -42,65 +28,53 @@ import org.apache.spark.mllib.clustering.KMeansModel */ private[mllib] class KMeansPMMLModelExport(model : KMeansModel) extends PMMLModelExport{ + populateKMeansPMML(model) + /** - * Export the input KMeansModel model to PMML format + * Export the input KMeansModel model to PMML format. */ - populateKMeansPMML(model) - private def populateKMeansPMML(model : KMeansModel): Unit = { - - pmml.getHeader().setDescription("k-means clustering") - - if(model.clusterCenters.length > 0){ - - val clusterCenter = model.clusterCenters(0) - - val fields = new Array[FieldName](clusterCenter.size) - - val dataDictionary = new DataDictionary() - - val miningSchema = new MiningSchema() - - val comparisonMeasure = new ComparisonMeasure() - .withKind(Kind.DISTANCE) - .withMeasure(new SquaredEuclidean() - ) - - val clusteringModel = new ClusteringModel(miningSchema, comparisonMeasure, - MiningFunctionType.CLUSTERING, ModelClass.CENTER_BASED, model.clusterCenters.length) + pmml.getHeader.setDescription("k-means clustering") + + if (model.clusterCenters.length > 0) { + val clusterCenter = model.clusterCenters(0) + val fields = new SArray[FieldName](clusterCenter.size) + val dataDictionary = new DataDictionary + val miningSchema = new MiningSchema + val comparisonMeasure = new ComparisonMeasure() + .withKind(ComparisonMeasure.Kind.DISTANCE) + .withMeasure(new SquaredEuclidean()) + val clusteringModel = new ClusteringModel(miningSchema, comparisonMeasure, + MiningFunctionType.CLUSTERING, ClusteringModel.ModelClass.CENTER_BASED, + model.clusterCenters.length) .withModelName("k-means") - - for ( i <- 0 until clusterCenter.size) { - fields(i) = FieldName.create("field_" + i) - dataDictionary - .withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) - miningSchema - .withMiningFields(new MiningField(fields(i)) - .withUsageType(FieldUsageType.ACTIVE)) - clusteringModel.withClusteringFields( - new ClusteringField(fields(i)).withCompareFunction(CompareFunctionType.ABS_DIFF) - ) - } - - dataDictionary.withNumberOfFields((dataDictionary.getDataFields()).size()) - - for ( i <- 0 until model.clusterCenters.size ) { - val cluster = new Cluster() - .withName("cluster_" + i) - .withArray(new org.dmg.pmml.Array() - .withType(Type.REAL) - .withN(clusterCenter.size) - .withValue(model.clusterCenters(i).toArray.mkString(" "))) - // we don't have the size of the single cluster but only the centroids (withValue) - // .withSize(value) - clusteringModel.withClusters(cluster) - } - - pmml.setDataDictionary(dataDictionary) - pmml.withModels(clusteringModel) - - } - + + for (i <- 0 until clusterCenter.size) { + fields(i) = FieldName.create("field_" + i) + dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + miningSchema + .withMiningFields(new MiningField(fields(i)) + .withUsageType(FieldUsageType.ACTIVE)) + clusteringModel.withClusteringFields( + new ClusteringField(fields(i)).withCompareFunction(CompareFunctionType.ABS_DIFF)) + } + + dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + + for (i <- 0 until model.clusterCenters.length) { + val cluster = new Cluster() + .withName("cluster_" + i) + .withArray(new org.dmg.pmml.Array() + .withType(Array.Type.REAL) + .withN(clusterCenter.size) + .withValue(model.clusterCenters(i).toArray.mkString(" "))) + // we don't have the size of the single cluster but only the centroids (withValue) + // .withSize(value) + clusteringModel.withClusters(cluster) + } + + pmml.setDataDictionary(dataDictionary) + pmml.withModels(clusteringModel) + } } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExport.scala index 0b1d1d465b939..75c28e1c03514 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExport.scala @@ -17,19 +17,10 @@ package org.apache.spark.mllib.pmml.export -import org.dmg.pmml.DataDictionary -import org.dmg.pmml.DataField -import org.dmg.pmml.DataType -import org.dmg.pmml.FieldName -import org.dmg.pmml.FieldUsageType -import org.dmg.pmml.MiningField -import org.dmg.pmml.MiningFunctionType -import org.dmg.pmml.MiningSchema -import org.dmg.pmml.NumericPredictor -import org.dmg.pmml.OpType -import org.dmg.pmml.RegressionModel -import org.dmg.pmml.RegressionTable -import org.dmg.pmml.RegressionNormalizationMethodType +import scala.{Array => SArray} + +import org.dmg.pmml._ + import org.apache.spark.mllib.classification.LogisticRegressionModel /** @@ -40,62 +31,46 @@ private[mllib] class LogisticRegressionPMMLModelExport( description : String) extends PMMLModelExport{ + populateLogisticRegressionPMML(model) + /** * Export the input LogisticRegressionModel model to PMML format */ - populateLogisticRegressionPMML(model) - private def populateLogisticRegressionPMML(model : LogisticRegressionModel): Unit = { + pmml.getHeader.setDescription(description) - pmml.getHeader().setDescription(description) - - if(model.weights.size > 0){ - - val fields = new Array[FieldName](model.weights.size) - - val dataDictionary = new DataDictionary() - - val miningSchema = new MiningSchema() - - val regressionTableYES = new RegressionTable(model.intercept) - .withTargetCategory("1") - - val regressionTableNO = new RegressionTable(0.0) - .withTargetCategory("0") - - val regressionModel = new RegressionModel(miningSchema,MiningFunctionType.CLASSIFICATION) - .withModelName(description) - .withNormalizationMethod(RegressionNormalizationMethodType.LOGIT) - .withRegressionTables(regressionTableYES, regressionTableNO) - - for ( i <- 0 until model.weights.size) { + if (model.weights.size > 0) { + val fields = new SArray[FieldName](model.weights.size) + val dataDictionary = new DataDictionary + val miningSchema = new MiningSchema + val regressionTableYES = new RegressionTable(model.intercept).withTargetCategory("1") + val regressionTableNO = new RegressionTable(0.0).withTargetCategory("0") + val regressionModel = new RegressionModel(miningSchema, MiningFunctionType.CLASSIFICATION) + .withModelName(description) + .withNormalizationMethod(RegressionNormalizationMethodType.LOGIT) + .withRegressionTables(regressionTableYES, regressionTableNO) + + for (i <- 0 until model.weights.size) { fields(i) = FieldName.create("field_" + i) - dataDictionary - .withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) miningSchema - .withMiningFields(new MiningField(fields(i)) - .withUsageType(FieldUsageType.ACTIVE)) - regressionTableYES - .withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) + .withMiningFields(new MiningField(fields(i)) + .withUsageType(FieldUsageType.ACTIVE)) + regressionTableYES.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) } // add target field - val targetField = FieldName.create("target"); + val targetField = FieldName.create("target") dataDictionary - .withDataFields( - new DataField(targetField, OpType.CATEGORICAL, DataType.STRING) - ) - miningSchema + .withDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.STRING)) + miningSchema .withMiningFields(new MiningField(targetField) .withUsageType(FieldUsageType.TARGET)) - dataDictionary.withNumberOfFields((dataDictionary.getDataFields()).size()) + dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) pmml.setDataDictionary(dataDictionary) pmml.withModels(regressionModel) - } - } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala index 14ab5e0d2c7b6..ebdeae50bb32f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala @@ -19,11 +19,10 @@ package org.apache.spark.mllib.pmml.export import java.text.SimpleDateFormat import java.util.Date + import scala.beans.BeanProperty -import org.dmg.pmml.Application -import org.dmg.pmml.Header -import org.dmg.pmml.PMML -import org.dmg.pmml.Timestamp + +import org.dmg.pmml.{Application, Header, PMML, Timestamp} private[mllib] trait PMMLModelExport { @@ -31,19 +30,18 @@ private[mllib] trait PMMLModelExport { * Holder of the exported model in PMML format */ @BeanProperty - val pmml: PMML = new PMML(); + val pmml: PMML = new PMML - setHeader(pmml); + setHeader(pmml) - private def setHeader(pmml : PMML): Unit = { - val version = getClass().getPackage().getImplementationVersion() + private def setHeader(pmml: PMML): Unit = { + val version = getClass.getPackage.getImplementationVersion val app = new Application().withName("Apache Spark MLlib").withVersion(version) val timestamp = new Timestamp() - .withContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss").format(new Date())) + .withContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss").format(new Date())) val header = new Header() - .withApplication(app) - .withTimestamp(timestamp) + .withApplication(app) + .withTimestamp(timestamp) pmml.setHeader(header) - } - + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala index a33ac14bbc446..0c374a46fb562 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala @@ -31,25 +31,24 @@ private[mllib] object PMMLModelExportFactory { * taking as input the machine learning model (for example KMeansModel). */ def createPMMLModelExport(model: Any): PMMLModelExport = { - return model match{ - case kmeans: KMeansModel => - new KMeansPMMLModelExport(kmeans) - case linearRegression: LinearRegressionModel => - new GeneralizedLinearPMMLModelExport(linearRegression, "linear regression") - case ridgeRegression: RidgeRegressionModel => - new GeneralizedLinearPMMLModelExport(ridgeRegression, "ridge regression") - case lassoRegression: LassoModel => - new GeneralizedLinearPMMLModelExport(lassoRegression, "lasso regression") - case svm: SVMModel => - new GeneralizedLinearPMMLModelExport( - svm, - "linear SVM: if predicted value > 0, the outcome is positive, or negative otherwise") - case logisticRegression: LogisticRegressionModel => - new LogisticRegressionPMMLModelExport(logisticRegression, "logistic regression") - case _ => - throw new IllegalArgumentException("PMML Export not supported for model: " - + model.getClass) - } + model match { + case kmeans: KMeansModel => + new KMeansPMMLModelExport(kmeans) + case linear: LinearRegressionModel => + new GeneralizedLinearPMMLModelExport(linear, "linear regression") + case ridge: RidgeRegressionModel => + new GeneralizedLinearPMMLModelExport(ridge, "ridge regression") + case lasso: LassoModel => + new GeneralizedLinearPMMLModelExport(lasso, "lasso regression") + case svm: SVMModel => + new GeneralizedLinearPMMLModelExport(svm, + "linear SVM: if predicted value > 0, the outcome is positive, or negative otherwise") + case logistic: LogisticRegressionModel => + new LogisticRegressionPMMLModelExport(logistic, "logistic regression") + case _ => + throw new IllegalArgumentException( + "PMML Export not supported for model: " + model.getClass.getName) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala index dd50112eba7cb..f48d39f889cd3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala @@ -19,96 +19,87 @@ package org.apache.spark.mllib.pmml.export import org.dmg.pmml.RegressionModel import org.scalatest.FunSuite + import org.apache.spark.mllib.classification.SVMModel import org.apache.spark.mllib.regression.LassoModel import org.apache.spark.mllib.regression.LinearRegressionModel import org.apache.spark.mllib.regression.RidgeRegressionModel import org.apache.spark.mllib.util.LinearDataGenerator -class GeneralizedLinearPMMLModelExportSuite extends FunSuite{ - - test("GeneralizedLinearPMMLModelExport generate PMML format") { +class GeneralizedLinearPMMLModelExportSuite extends FunSuite { - // arrange models to test - val linearInput = LinearDataGenerator.generateLinearInput( - 3.0, Array(10.0, 10.0), 1, 17) - val linearRegressionModel = new LinearRegressionModel( - linearInput(0).features, linearInput(0).label); - val ridgeRegressionModel = new RidgeRegressionModel( - linearInput(0).features, linearInput(0).label); - val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label); - val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label); - - // act by exporting the model to the PMML format + test("linear regression pmml export") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + val linearRegressionModel = + new LinearRegressionModel(linearInput(0).features, linearInput(0).label) val linearModelExport = PMMLModelExportFactory.createPMMLModelExport(linearRegressionModel) // assert that the PMML format is as expected assert(linearModelExport.isInstanceOf[PMMLModelExport]) - var pmml = linearModelExport.asInstanceOf[PMMLModelExport].getPmml() - assert(pmml.getHeader().getDescription() === "linear regression") + val pmml = linearModelExport.getPmml + assert(pmml.getHeader.getDescription === "linear regression") // check that the number of fields match the weights size - assert(pmml.getDataDictionary().getNumberOfFields() - === linearRegressionModel.weights.size + 1) - // this verify that there is a model attached to the pmml object - // and the model is a regression one - // it also verifies that the pmml model has a regression table - // with the same number of predictors of the model weights - assert(pmml.getModels().get(0).asInstanceOf[RegressionModel] - .getRegressionTables().get(0).getNumericPredictors().size() - === linearRegressionModel.weights.size) - - // act + assert(pmml.getDataDictionary.getNumberOfFields === linearRegressionModel.weights.size + 1) + // This verifies that there is a model attached to the pmml object and the model is a regression + // one. It also verifies that the pmml model has a regression table with the same number of + // predictors of the model weights. + val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] + assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size + === linearRegressionModel.weights.size) + } + + test("ridge regression pmml export") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + val ridgeRegressionModel = + new RidgeRegressionModel(linearInput(0).features, linearInput(0).label) val ridgeModelExport = PMMLModelExportFactory.createPMMLModelExport(ridgeRegressionModel) // assert that the PMML format is as expected assert(ridgeModelExport.isInstanceOf[PMMLModelExport]) - pmml = ridgeModelExport.asInstanceOf[PMMLModelExport].getPmml() - assert(pmml.getHeader().getDescription() === "ridge regression") + val pmml = ridgeModelExport.getPmml + assert(pmml.getHeader.getDescription === "ridge regression") // check that the number of fields match the weights size - assert(pmml.getDataDictionary().getNumberOfFields() === ridgeRegressionModel.weights.size + 1) - // this verify that there is a model attached to the pmml object - // and the model is a regression one - // it also verifies that the pmml model has a regression table - // with the same number of predictors of the model weights - assert(pmml.getModels().get(0).asInstanceOf[RegressionModel] - .getRegressionTables().get(0).getNumericPredictors().size() - === ridgeRegressionModel.weights.size) - - // act - val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel) + assert(pmml.getDataDictionary.getNumberOfFields === ridgeRegressionModel.weights.size + 1) + // This verify that there is a model attached to the pmml object and the model is a regression + // one. It also verifies that the pmml model has a regression table with the same number of + // predictors of the model weights. + val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] + assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size + === ridgeRegressionModel.weights.size) + } + + test("lasso pmml export") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label) + val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel) // assert that the PMML format is as expected assert(lassoModelExport.isInstanceOf[PMMLModelExport]) - pmml = lassoModelExport.asInstanceOf[PMMLModelExport].getPmml() - assert(pmml.getHeader().getDescription() === "lasso regression") + val pmml = lassoModelExport.getPmml + assert(pmml.getHeader.getDescription === "lasso regression") // check that the number of fields match the weights size - assert(pmml.getDataDictionary().getNumberOfFields() === lassoModel.weights.size + 1) - // this verify that there is a model attached to the pmml object - // and the model is a regression one - // it also verifies that the pmml model has a regression table - // with the same number of predictors of the model weights - assert(pmml.getModels().get(0).asInstanceOf[RegressionModel] - .getRegressionTables().get(0).getNumericPredictors().size() === lassoModel.weights.size) - - // act + assert(pmml.getDataDictionary.getNumberOfFields === lassoModel.weights.size + 1) + // This verify that there is a model attached to the pmml object and the model is a regression + // one. It also verifies that the pmml model has a regression table with the same number of + // predictors of the model weights. + val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] + assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size + === lassoModel.weights.size) + } + + test("svm pmml export") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) // assert that the PMML format is as expected assert(svmModelExport.isInstanceOf[PMMLModelExport]) - pmml = svmModelExport.asInstanceOf[PMMLModelExport].getPmml() - assert(pmml.getHeader().getDescription() - === "linear SVM: if predicted value > 0, the outcome is positive, or negative otherwise") + val pmml = svmModelExport.getPmml + assert(pmml.getHeader.getDescription + === "linear SVM: if predicted value > 0, the outcome is positive, or negative otherwise") // check that the number of fields match the weights size - assert(pmml.getDataDictionary().getNumberOfFields() === svmModel.weights.size + 1) - // this verify that there is a model attached to the pmml object - // and the model is a regression one - // it also verifies that the pmml model has a regression table - // with the same number of predictors of the model weights - assert(pmml.getModels().get(0).asInstanceOf[RegressionModel] - .getRegressionTables().get(0).getNumericPredictors().size() === svmModel.weights.size) - - // manual checking - // linearRegressionModel.toPMML("/tmp/linearregression.xml") - // ridgeRegressionModel.toPMML("/tmp/ridgeregression.xml") - // lassoModel.toPMML("/tmp/lassoregression.xml") - // svmModel.toPMML("/tmp/linearsvm.xml") - - } - + assert(pmml.getDataDictionary.getNumberOfFields === svmModel.weights.size + 1) + // This verify that there is a model attached to the pmml object and the model is a regression + // one. It also verifies that the pmml model has a regression table with the same number of + // predictors of the model weights. + val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] + assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size + === svmModel.weights.size) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala index 83def5ace0cbc..f34e2a210a9fd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala @@ -19,41 +19,31 @@ package org.apache.spark.mllib.pmml.export import org.dmg.pmml.ClusteringModel import org.scalatest.FunSuite + import org.apache.spark.mllib.clustering.KMeansModel import org.apache.spark.mllib.linalg.Vectors -class KMeansPMMLModelExportSuite extends FunSuite{ +class KMeansPMMLModelExportSuite extends FunSuite { - test("KMeansPMMLModelExport generate PMML format") { - + test("KMeansPMMLModelExport generate PMML format") { // arrange model to test val clusterCenters = Array( Vectors.dense(1.0, 2.0, 6.0), Vectors.dense(1.0, 3.0, 0.0), - Vectors.dense(1.0, 4.0, 6.0) - ) - val kmeansModel = new KMeansModel(clusterCenters); + Vectors.dense(1.0, 4.0, 6.0)) + val kmeansModel = new KMeansModel(clusterCenters) - // act by exporting the model to the PMML format val modelExport = PMMLModelExportFactory.createPMMLModelExport(kmeansModel) // assert that the PMML format is as expected assert(modelExport.isInstanceOf[PMMLModelExport]) - val pmml = modelExport.asInstanceOf[PMMLModelExport].getPmml() - assert(pmml.getHeader().getDescription() === "k-means clustering") + val pmml = modelExport.asInstanceOf[PMMLModelExport].getPmml + assert(pmml.getHeader.getDescription === "k-means clustering") // check that the number of fields match the single vector size - assert(pmml.getDataDictionary().getNumberOfFields() === clusterCenters(0).size) - // this verify that there is a model attached to the pmml object - // and the model is a clustering one - // it also verifies that the pmml model has the same number of clusters of the spark model - assert(pmml.getModels().get(0).asInstanceOf[ClusteringModel].getNumberOfClusters() - === clusterCenters.size) - - // manual checking - // kmeansModel.toPMML("/tmp/kmeans.xml") - // kmeansModel.toPMML(System.out) - // System.out.println(kmeansModel.toPMML()) - - } - + assert(pmml.getDataDictionary.getNumberOfFields === clusterCenters(0).size) + // This verify that there is a model attached to the pmml object and the model is a clustering + // one. It also verifies that the pmml model has the same number of clusters of the spark model. + val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel] + assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExportSuite.scala index ca5d8ca8b2f5b..af642702ed942 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExportSuite.scala @@ -19,47 +19,34 @@ package org.apache.spark.mllib.pmml.export import org.dmg.pmml.RegressionModel import org.scalatest.FunSuite + import org.apache.spark.mllib.classification.LogisticRegressionModel import org.apache.spark.mllib.util.LinearDataGenerator class LogisticRegressionPMMLModelExportSuite extends FunSuite{ - test("LogisticRegressionPMMLModelExport generate PMML format") { - - // arrange models to test - val linearInput = LinearDataGenerator.generateLinearInput( - 3.0, Array(10.0, 10.0), 1, 17) - val logisticRegressionModel = new LogisticRegressionModel( - linearInput(0).features, linearInput(0).label); + test("LogisticRegressionPMMLModelExport generate PMML format") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + val logisticRegressionModel = + new LogisticRegressionModel(linearInput(0).features, linearInput(0).label) - // act by exporting the model to the PMML format - val logisticModelExport = PMMLModelExportFactory - .createPMMLModelExport(logisticRegressionModel) + val logisticModelExport = PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel) + // assert that the PMML format is as expected assert(logisticModelExport.isInstanceOf[PMMLModelExport]) - var pmml = logisticModelExport.asInstanceOf[PMMLModelExport].getPmml() - assert(pmml.getHeader().getDescription() === "logistic regression") + val pmml = logisticModelExport.asInstanceOf[PMMLModelExport].getPmml + assert(pmml.getHeader.getDescription === "logistic regression") // check that the number of fields match the weights size - assert( - pmml.getDataDictionary().getNumberOfFields() === logisticRegressionModel.weights.size + 1) - // this verify that there is a model attached to the pmml object - // and the model is a regression one - // it also verifies that the pmml model has a regression table (for target category 1) - // with the same number of predictors of the model weights - assert(pmml.getModels().get(0).asInstanceOf[RegressionModel] - .getRegressionTables().get(0).getTargetCategory() === "1") - assert(pmml.getModels().get(0).asInstanceOf[RegressionModel] - .getRegressionTables().get(0).getNumericPredictors().size() - === logisticRegressionModel.weights.size) + assert(pmml.getDataDictionary.getNumberOfFields === logisticRegressionModel.weights.size + 1) + // This verify that there is a model attached to the pmml object and the model is a regression + // one. It also verifies that the pmml model has a regression table (for target category 1) + // with the same number of predictors of the model weights. + val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] + assert(pmmlRegressionModel.getRegressionTables.get(0).getTargetCategory === "1") + assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size + === logisticRegressionModel.weights.size) // verify if there is a second table with target category 0 and no predictors - assert(pmml.getModels().get(0).asInstanceOf[RegressionModel] - .getRegressionTables().get(1).getTargetCategory() === "0") - assert(pmml.getModels().get(0).asInstanceOf[RegressionModel] - .getRegressionTables().get(1).getNumericPredictors().size() === 0) - - // manual checking - // logisticRegressionModel.toPMML("/tmp/logisticregression.xml") - + assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "0") + assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size === 0) } - } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala index a54cd247a120f..b466e08d09e6d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.pmml.export import org.scalatest.FunSuite + import org.apache.spark.mllib.classification.LogisticRegressionModel import org.apache.spark.mllib.classification.SVMModel import org.apache.spark.mllib.clustering.KMeansModel @@ -27,90 +28,62 @@ import org.apache.spark.mllib.regression.LinearRegressionModel import org.apache.spark.mllib.regression.RidgeRegressionModel import org.apache.spark.mllib.util.LinearDataGenerator -class PMMLModelExportFactorySuite extends FunSuite{ +class PMMLModelExportFactorySuite extends FunSuite { - test("PMMLModelExportFactory create KMeansPMMLModelExport when passing a KMeansModel") { - - // arrange + test("PMMLModelExportFactory create KMeansPMMLModelExport when passing a KMeansModel") { val clusterCenters = Array( Vectors.dense(1.0, 2.0, 6.0), Vectors.dense(1.0, 3.0, 0.0), - Vectors.dense(1.0, 4.0, 6.0) - ) - val kmeansModel = new KMeansModel(clusterCenters); - - // act + Vectors.dense(1.0, 4.0, 6.0)) + val kmeansModel = new KMeansModel(clusterCenters) + val modelExport = PMMLModelExportFactory.createPMMLModelExport(kmeansModel) - // assert assert(modelExport.isInstanceOf[KMeansPMMLModelExport]) - } test("PMMLModelExportFactory create GeneralizedLinearPMMLModelExport when passing a " + "LinearRegressionModel, RidgeRegressionModel, LassoModel or SVMModel") { - - // arrange - val linearInput = LinearDataGenerator.generateLinearInput( - 3.0, Array(10.0, 10.0), 1, 17) - val linearRegressionModel = new LinearRegressionModel( - linearInput(0).features, linearInput(0).label) - val ridgeRegressionModel = new RidgeRegressionModel( - linearInput(0).features, linearInput(0).label) - val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label) - val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) - - // act - val linearModelExport = PMMLModelExportFactory.createPMMLModelExport(linearRegressionModel) - // assert - assert(linearModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) - // act - val ridgeModelExport = PMMLModelExportFactory.createPMMLModelExport(ridgeRegressionModel) - // assert - assert(ridgeModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) - - // act - val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel) - // assert - assert(lassoModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) - - // act - val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) - // assert - assert(svmModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) - + val linearRegressionModel = + new LinearRegressionModel(linearInput(0).features, linearInput(0).label) + val linearModelExport = PMMLModelExportFactory.createPMMLModelExport(linearRegressionModel) + assert(linearModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) + + val ridgeRegressionModel = + new RidgeRegressionModel(linearInput(0).features, linearInput(0).label) + val ridgeModelExport = PMMLModelExportFactory.createPMMLModelExport(ridgeRegressionModel) + assert(ridgeModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) + + + val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label) + val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel) + assert(lassoModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) + + val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) + val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) + assert(svmModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) } - - test("PMMLModelExportFactory create LogisticRegressionPMMLModelExport " - + "when passing a LogisticRegressionModel") { - - // arrange - val linearInput = LinearDataGenerator.generateLinearInput( - 3.0, Array(10.0, 10.0), 1, 17) - val logisticRegressionModel = new LogisticRegressionModel( - linearInput(0).features, linearInput(0).label); - // act - val logisticRegressionModelExport = PMMLModelExportFactory - .createPMMLModelExport(logisticRegressionModel) - // assert + test("PMMLModelExportFactory create LogisticRegressionPMMLModelExport " + + "when passing a LogisticRegressionModel") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + val logisticRegressionModel = + new LogisticRegressionModel(linearInput(0).features, linearInput(0).label) + + val logisticRegressionModelExport = + PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel) + assert(logisticRegressionModelExport.isInstanceOf[LogisticRegressionPMMLModelExport]) - } test("PMMLModelExportFactory throw IllegalArgumentException " + "when passing an unsupported model") { + val invalidModel = new Object - // arrange - val invalidModel = new Object; - - // assert intercept[IllegalArgumentException] { - // act PMMLModelExportFactory.createPMMLModelExport(invalidModel) } - } - }