forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-1406] Mllib pmml model export
See PDF attached to the JIRA issue 1406. The contribution is my original work and I license the work to the project under the project's open source license. Author: Vincenzo Selvaggio <[email protected]> Author: Xiangrui Meng <[email protected]> Author: selvinsource <[email protected]> Closes apache#3062 from selvinsource/mllib_pmml_model_export_SPARK-1406 and squashes the following commits: 852aac6 [Vincenzo Selvaggio] [SPARK-1406] Update JPMML version to 1.1.15 in LICENSE file 085cf42 [Vincenzo Selvaggio] [SPARK-1406] Added Double Min and Max Fixed scala style 30165c4 [Vincenzo Selvaggio] [SPARK-1406] Fixed extreme cases for logit 7a5e0ec [Vincenzo Selvaggio] [SPARK-1406] Binary classification for SVM and Logistic Regression cfcb596 [Vincenzo Selvaggio] [SPARK-1406] Throw IllegalArgumentException when exporting a multinomial logistic regression 25dce33 [Vincenzo Selvaggio] [SPARK-1406] Update code to latest pmml model dea98ca [Vincenzo Selvaggio] [SPARK-1406] Exclude transitive dependency for pmml model 66b7c12 [Vincenzo Selvaggio] [SPARK-1406] Updated pmml model lib to 1.1.15, latest Java 6 compatible a0a55f7 [Vincenzo Selvaggio] Merge pull request #2 from mengxr/SPARK-1406 3c22f79 [Xiangrui Meng] more code style e2313df [Vincenzo Selvaggio] Merge pull request #1 from mengxr/SPARK-1406 472d757 [Xiangrui Meng] fix code style 1676e15 [Vincenzo Selvaggio] fixed scala issue e2ffae8 [Vincenzo Selvaggio] fixed scala style b8823b0 [Vincenzo Selvaggio] Merge remote-tracking branch 'upstream/master' into mllib_pmml_model_export_SPARK-1406 b25bbf7 [Vincenzo Selvaggio] [SPARK-1406] Added export of pmml to distributed file system using the spark context 7a949d0 [Vincenzo Selvaggio] [SPARK-1406] Fixed scala style f46c75c [Vincenzo Selvaggio] [SPARK-1406] Added PMMLExportable to supported models 7b33b4e [Vincenzo Selvaggio] [SPARK-1406] Added a PMMLExportable interface Restructured code in a new package mllib.pmml Supported models implements the new PMMLExportable interface: LogisticRegression, SVM, KMeansModel, LinearRegression, RidgeRegression, Lasso d559ec5 [Vincenzo Selvaggio] Merge remote-tracking branch 'upstream/master' into mllib_pmml_model_export_SPARK-1406 8fe12bb [Vincenzo Selvaggio] [SPARK-1406] Adjusted logistic regression export description and target categories 03bc3a5 [Vincenzo Selvaggio] added logistic regression da2ec11 [Vincenzo Selvaggio] [SPARK-1406] added linear SVM PMML export 82f2131 [Vincenzo Selvaggio] Merge remote-tracking branch 'upstream/master' into mllib_pmml_model_export_SPARK-1406 19adf29 [Vincenzo Selvaggio] [SPARK-1406] Fixed scala style 1faf985 [Vincenzo Selvaggio] [SPARK-1406] Added target field to the regression model for completeness Adjusted unit test to deal with this change 3ae8ae5 [Vincenzo Selvaggio] [SPARK-1406] Adjusted imported order according to the guidelines c67ce81 [Vincenzo Selvaggio] Merge remote-tracking branch 'upstream/master' into mllib_pmml_model_export_SPARK-1406 78515ec [Vincenzo Selvaggio] [SPARK-1406] added pmml export for LinearRegressionModel, RidgeRegressionModel and LassoModel e29dfb9 [Vincenzo Selvaggio] removed version, by default is set to 4.2 (latest from jpmml) removed copyright ae8b993 [Vincenzo Selvaggio] updated some commented tests to use the new ModelExporter object reordered the imports df8a89e [Vincenzo Selvaggio] added pmml version to pmml model changed the copyright to spark a1b4dc3 [Vincenzo Selvaggio] updated imports 834ca44 [Vincenzo Selvaggio] reordered the import accordingly to the guidelines 349a76b [Vincenzo Selvaggio] new helper object to serialize the models to pmml format c3ef9b8 [Vincenzo Selvaggio] set it to private 6357b98 [Vincenzo Selvaggio] set it to private e1eb251 [Vincenzo Selvaggio] removed serialization part, this will be part of the ModelExporter helper object aba5ee1 [Vincenzo Selvaggio] fixed cluster export cd6c07c [Vincenzo Selvaggio] fixed scala style to run tests f75b988 [Vincenzo Selvaggio] Merge remote-tracking branch 'origin/master' into mllib_pmml_model_export_SPARK-1406 07a29bf [selvinsource] Update LICENSE 8841439 [Vincenzo Selvaggio] adjust scala style in order to compile 1433b11 [Vincenzo Selvaggio] complete suite tests 8e71b8d [Vincenzo Selvaggio] kmeans pmml export implementation 9bc494f [Vincenzo Selvaggio] added scala suite tests added saveLocalFile to ModelExport trait 226e184 [Vincenzo Selvaggio] added javadoc and export model type in case there is a need to support other types of export (not just PMML) a0e3679 [Vincenzo Selvaggio] export and pmml export traits kmeans test implementation
- Loading branch information
1 parent
4459514
commit 254e050
Showing
18 changed files
with
774 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
74 changes: 74 additions & 0 deletions
74
mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.pmml | ||
|
||
import java.io.{File, OutputStream, StringWriter} | ||
import javax.xml.transform.stream.StreamResult | ||
|
||
import org.jpmml.model.JAXBUtil | ||
|
||
import org.apache.spark.SparkContext | ||
import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory | ||
|
||
/** | ||
* Export model to the PMML format | ||
* Predictive Model Markup Language (PMML) is an XML-based file format | ||
* developed by the Data Mining Group (www.dmg.org). | ||
*/ | ||
trait PMMLExportable { | ||
|
||
/** | ||
* Export the model to the stream result in PMML format | ||
*/ | ||
private def toPMML(streamResult: StreamResult): Unit = { | ||
val pmmlModelExport = PMMLModelExportFactory.createPMMLModelExport(this) | ||
JAXBUtil.marshalPMML(pmmlModelExport.getPmml, streamResult) | ||
} | ||
|
||
/** | ||
* Export the model to a local file in PMML format | ||
*/ | ||
def toPMML(localPath: String): Unit = { | ||
toPMML(new StreamResult(new File(localPath))) | ||
} | ||
|
||
/** | ||
* Export the model to a directory on a distributed file system in PMML format | ||
*/ | ||
def toPMML(sc: SparkContext, path: String): Unit = { | ||
val pmml = toPMML() | ||
sc.parallelize(Array(pmml), 1).saveAsTextFile(path) | ||
} | ||
|
||
/** | ||
* Export the model to the OutputStream in PMML format | ||
*/ | ||
def toPMML(outputStream: OutputStream): Unit = { | ||
toPMML(new StreamResult(outputStream)) | ||
} | ||
|
||
/** | ||
* Export the model to a String in PMML format | ||
*/ | ||
def toPMML(): String = { | ||
val writer = new StringWriter | ||
toPMML(new StreamResult(writer)) | ||
writer.toString | ||
} | ||
|
||
} |
90 changes: 90 additions & 0 deletions
90
...c/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.pmml.export | ||
|
||
import scala.{Array => SArray} | ||
|
||
import org.dmg.pmml._ | ||
|
||
import org.apache.spark.mllib.regression.GeneralizedLinearModel | ||
|
||
/** | ||
* PMML Model Export for GeneralizedLinearModel class with binary ClassificationModel | ||
*/ | ||
private[mllib] class BinaryClassificationPMMLModelExport( | ||
model : GeneralizedLinearModel, | ||
description : String, | ||
normalizationMethod : RegressionNormalizationMethodType, | ||
threshold: Double) | ||
extends PMMLModelExport { | ||
|
||
populateBinaryClassificationPMML() | ||
|
||
/** | ||
* Export the input LogisticRegressionModel or SVMModel to PMML format. | ||
*/ | ||
private def populateBinaryClassificationPMML(): Unit = { | ||
pmml.getHeader.setDescription(description) | ||
|
||
if (model.weights.size > 0) { | ||
val fields = new SArray[FieldName](model.weights.size) | ||
val dataDictionary = new DataDictionary | ||
val miningSchema = new MiningSchema | ||
val regressionTableYES = new RegressionTable(model.intercept).withTargetCategory("1") | ||
var interceptNO = threshold | ||
if (RegressionNormalizationMethodType.LOGIT == normalizationMethod) { | ||
if (threshold <= 0) { | ||
interceptNO = Double.MinValue | ||
} else if (threshold >= 1) { | ||
interceptNO = Double.MaxValue | ||
} else { | ||
interceptNO = -math.log(1 / threshold - 1) | ||
} | ||
} | ||
val regressionTableNO = new RegressionTable(interceptNO).withTargetCategory("0") | ||
val regressionModel = new RegressionModel() | ||
.withFunctionName(MiningFunctionType.CLASSIFICATION) | ||
.withMiningSchema(miningSchema) | ||
.withModelName(description) | ||
.withNormalizationMethod(normalizationMethod) | ||
.withRegressionTables(regressionTableYES, regressionTableNO) | ||
|
||
for (i <- 0 until model.weights.size) { | ||
fields(i) = FieldName.create("field_" + i) | ||
dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) | ||
miningSchema | ||
.withMiningFields(new MiningField(fields(i)) | ||
.withUsageType(FieldUsageType.ACTIVE)) | ||
regressionTableYES.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) | ||
} | ||
|
||
// add target field | ||
val targetField = FieldName.create("target") | ||
dataDictionary | ||
.withDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.STRING)) | ||
miningSchema | ||
.withMiningFields(new MiningField(targetField) | ||
.withUsageType(FieldUsageType.TARGET)) | ||
|
||
dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) | ||
|
||
pmml.setDataDictionary(dataDictionary) | ||
pmml.withModels(regressionModel) | ||
} | ||
} | ||
} |
75 changes: 75 additions & 0 deletions
75
.../src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.pmml.export | ||
|
||
import scala.{Array => SArray} | ||
|
||
import org.dmg.pmml._ | ||
|
||
import org.apache.spark.mllib.regression.GeneralizedLinearModel | ||
|
||
/** | ||
* PMML Model Export for GeneralizedLinearModel abstract class | ||
*/ | ||
private[mllib] class GeneralizedLinearPMMLModelExport( | ||
model: GeneralizedLinearModel, | ||
description: String) | ||
extends PMMLModelExport { | ||
|
||
populateGeneralizedLinearPMML(model) | ||
|
||
/** | ||
* Export the input GeneralizedLinearModel model to PMML format. | ||
*/ | ||
private def populateGeneralizedLinearPMML(model: GeneralizedLinearModel): Unit = { | ||
pmml.getHeader.setDescription(description) | ||
|
||
if (model.weights.size > 0) { | ||
val fields = new SArray[FieldName](model.weights.size) | ||
val dataDictionary = new DataDictionary | ||
val miningSchema = new MiningSchema | ||
val regressionTable = new RegressionTable(model.intercept) | ||
val regressionModel = new RegressionModel() | ||
.withFunctionName(MiningFunctionType.REGRESSION) | ||
.withMiningSchema(miningSchema) | ||
.withModelName(description) | ||
.withRegressionTables(regressionTable) | ||
|
||
for (i <- 0 until model.weights.size) { | ||
fields(i) = FieldName.create("field_" + i) | ||
dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) | ||
miningSchema | ||
.withMiningFields(new MiningField(fields(i)) | ||
.withUsageType(FieldUsageType.ACTIVE)) | ||
regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) | ||
} | ||
|
||
// for completeness add target field | ||
val targetField = FieldName.create("target") | ||
dataDictionary.withDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE)) | ||
miningSchema | ||
.withMiningFields(new MiningField(targetField) | ||
.withUsageType(FieldUsageType.TARGET)) | ||
|
||
dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) | ||
|
||
pmml.setDataDictionary(dataDictionary) | ||
pmml.withModels(regressionModel) | ||
} | ||
} | ||
} |
Oops, something went wrong.