From 25dce336065e39da8492c6d8379a21ab3733cd45 Mon Sep 17 00:00:00 2001 From: Vincenzo Selvaggio Date: Wed, 22 Apr 2015 00:48:23 +0100 Subject: [PATCH] [SPARK-1406] Update code to latest pmml model --- .../pmml/export/GeneralizedLinearPMMLModelExport.scala | 4 +++- .../spark/mllib/pmml/export/KMeansPMMLModelExport.scala | 9 ++++++--- .../pmml/export/LogisticRegressionPMMLModelExport.scala | 4 +++- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala index 8c079d5aec42c..1874786af0002 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala @@ -44,7 +44,9 @@ private[mllib] class GeneralizedLinearPMMLModelExport( val dataDictionary = new DataDictionary val miningSchema = new MiningSchema val regressionTable = new RegressionTable(model.intercept) - val regressionModel = new RegressionModel(miningSchema, MiningFunctionType.REGRESSION) + val regressionModel = new RegressionModel() + .withFunctionName(MiningFunctionType.REGRESSION) + .withMiningSchema(miningSchema) .withModelName(description) .withRegressionTables(regressionTable) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala index c12b275b2185c..069e7afc9fca0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala @@ -44,10 +44,13 @@ private[mllib] class KMeansPMMLModelExport(model : KMeansModel) extends PMMLMode val comparisonMeasure = new ComparisonMeasure() .withKind(ComparisonMeasure.Kind.DISTANCE) .withMeasure(new SquaredEuclidean()) - val clusteringModel = new ClusteringModel(miningSchema, comparisonMeasure, - MiningFunctionType.CLUSTERING, ClusteringModel.ModelClass.CENTER_BASED, - model.clusterCenters.length) + val clusteringModel = new ClusteringModel() .withModelName("k-means") + .withMiningSchema(miningSchema) + .withComparisonMeasure(comparisonMeasure) + .withFunctionName(MiningFunctionType.CLUSTERING) + .withModelClass(ClusteringModel.ModelClass.CENTER_BASED) + .withNumberOfClusters(model.clusterCenters.length) for (i <- 0 until clusterCenter.size) { fields(i) = FieldName.create("field_" + i) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExport.scala index 6e818c7709bda..2bf4fa858b09b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExport.scala @@ -45,7 +45,9 @@ private[mllib] class LogisticRegressionPMMLModelExport( val miningSchema = new MiningSchema val regressionTableYES = new RegressionTable(model.intercept).withTargetCategory("1") val regressionTableNO = new RegressionTable(0.0).withTargetCategory("0") - val regressionModel = new RegressionModel(miningSchema, MiningFunctionType.CLASSIFICATION) + val regressionModel = new RegressionModel() + .withFunctionName(MiningFunctionType.CLASSIFICATION) + .withMiningSchema(miningSchema) .withModelName(description) .withNormalizationMethod(RegressionNormalizationMethodType.LOGIT) .withRegressionTables(regressionTableYES, regressionTableNO)