Skip to content

Commit

Permalink
fix code style
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Apr 20, 2015
1 parent 1676e15 commit 472d757
Show file tree
Hide file tree
Showing 10 changed files with 273 additions and 405 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ import java.io.File
import java.io.OutputStream
import java.io.StringWriter
import javax.xml.transform.stream.StreamResult

import org.jpmml.model.JAXBUtil

import org.apache.spark.SparkContext
import org.apache.spark.mllib.pmml.export.PMMLModelExport
import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory

/**
Expand All @@ -34,42 +35,42 @@ import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory
trait PMMLExportable {

/**
* Export the model to the stream result in PMML format
*/
* Export the model to the stream result in PMML format
*/
private def toPMML(streamResult: StreamResult): Unit = {
val pmmlModelExport = PMMLModelExportFactory.createPMMLModelExport(this)
JAXBUtil.marshalPMML(pmmlModelExport.getPmml(), streamResult)
JAXBUtil.marshalPMML(pmmlModelExport.getPmml, streamResult)
}

/**
* Export the model to a local File in PMML format
*/
* Export the model to a local file in PMML format
*/
def toPMML(localPath: String): Unit = {
toPMML(new StreamResult(new File(localPath)))
}

/**
* Export the model to a distributed file in PMML format
*/
* Export the model to a directory on a distributed file system in PMML format
*/
def toPMML(sc: SparkContext, path: String): Unit = {
val pmml = toPMML()
sc.parallelize(Array(pmml),1).saveAsTextFile(path)
sc.parallelize(Array(pmml), 1).saveAsTextFile(path)
}

/**
* Export the model to the Outputtream in PMML format
*/
* Export the model to the OutputStream in PMML format
*/
def toPMML(outputStream: OutputStream): Unit = {
toPMML(new StreamResult(outputStream))
}

/**
* Export the model to a String in PMML format
*/
* Export the model to a String in PMML format
*/
def toPMML(): String = {
var writer = new StringWriter();
val writer = new StringWriter
toPMML(new StreamResult(writer))
return writer.toString();
writer.toString
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,10 @@

package org.apache.spark.mllib.pmml.export

import org.dmg.pmml.DataDictionary
import org.dmg.pmml.DataField
import org.dmg.pmml.DataType
import org.dmg.pmml.FieldName
import org.dmg.pmml.FieldUsageType
import org.dmg.pmml.MiningField
import org.dmg.pmml.MiningFunctionType
import org.dmg.pmml.MiningSchema
import org.dmg.pmml.NumericPredictor
import org.dmg.pmml.OpType
import org.dmg.pmml.RegressionModel
import org.dmg.pmml.RegressionTable
import scala.{Array => SArray}

import org.dmg.pmml._

import org.apache.spark.mllib.regression.GeneralizedLinearModel

/**
Expand All @@ -39,55 +31,43 @@ private[mllib] class GeneralizedLinearPMMLModelExport(
description : String)
extends PMMLModelExport{

populateGeneralizedLinearPMML(model)

/**
* Export the input GeneralizedLinearModel model to PMML format
* Export the input GeneralizedLinearModel model to PMML format.
*/
populateGeneralizedLinearPMML(model)

private def populateGeneralizedLinearPMML(model : GeneralizedLinearModel): Unit = {
private def populateGeneralizedLinearPMML(model: GeneralizedLinearModel): Unit = {
pmml.getHeader.setDescription(description)

pmml.getHeader().setDescription(description)

if(model.weights.size > 0){

val fields = new Array[FieldName](model.weights.size)

val dataDictionary = new DataDictionary()

val miningSchema = new MiningSchema()

val fields = new SArray[FieldName](model.weights.size)
val dataDictionary = new DataDictionary
val miningSchema = new MiningSchema
val regressionTable = new RegressionTable(model.intercept)

val regressionModel = new RegressionModel(miningSchema,MiningFunctionType.REGRESSION)
.withModelName(description).withRegressionTables(regressionTable)

for ( i <- 0 until model.weights.size) {
.withModelName(description)
.withRegressionTables(regressionTable)

for (i <- 0 until model.weights.size) {
fields(i) = FieldName.create("field_" + i)
dataDictionary
.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
miningSchema
.withMiningFields(new MiningField(fields(i))
.withUsageType(FieldUsageType.ACTIVE))
.withMiningFields(new MiningField(fields(i))
.withUsageType(FieldUsageType.ACTIVE))
regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
}

// for completeness add target field
val targetField = FieldName.create("target");
dataDictionary
.withDataFields(
new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE)
)
miningSchema
val targetField = FieldName.create("target")
dataDictionary.withDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE))
miningSchema
.withMiningFields(new MiningField(targetField)
.withUsageType(FieldUsageType.TARGET))
dataDictionary.withNumberOfFields((dataDictionary.getDataFields()).size())

dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size)

pmml.setDataDictionary(dataDictionary)
pmml.withModels(regressionModel)

}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,90 +17,64 @@

package org.apache.spark.mllib.pmml.export

import org.dmg.pmml.Array.Type
import org.dmg.pmml.Cluster
import org.dmg.pmml.ClusteringField
import org.dmg.pmml.ClusteringModel
import org.dmg.pmml.ClusteringModel.ModelClass
import org.dmg.pmml.CompareFunctionType
import org.dmg.pmml.ComparisonMeasure
import org.dmg.pmml.ComparisonMeasure.Kind
import org.dmg.pmml.DataDictionary
import org.dmg.pmml.DataField
import org.dmg.pmml.DataType
import org.dmg.pmml.FieldName
import org.dmg.pmml.FieldUsageType
import org.dmg.pmml.MiningField
import org.dmg.pmml.MiningFunctionType
import org.dmg.pmml.MiningSchema
import org.dmg.pmml.OpType
import org.dmg.pmml.SquaredEuclidean
import scala.{Array => SArray}

import org.dmg.pmml._

import org.apache.spark.mllib.clustering.KMeansModel

/**
* PMML Model Export for KMeansModel class
*/
private[mllib] class KMeansPMMLModelExport(model : KMeansModel) extends PMMLModelExport{

populateKMeansPMML(model)

/**
* Export the input KMeansModel model to PMML format
* Export the input KMeansModel model to PMML format.
*/
populateKMeansPMML(model)

private def populateKMeansPMML(model : KMeansModel): Unit = {

pmml.getHeader().setDescription("k-means clustering")

if(model.clusterCenters.length > 0){

val clusterCenter = model.clusterCenters(0)

val fields = new Array[FieldName](clusterCenter.size)

val dataDictionary = new DataDictionary()

val miningSchema = new MiningSchema()

val comparisonMeasure = new ComparisonMeasure()
.withKind(Kind.DISTANCE)
.withMeasure(new SquaredEuclidean()
)

val clusteringModel = new ClusteringModel(miningSchema, comparisonMeasure,
MiningFunctionType.CLUSTERING, ModelClass.CENTER_BASED, model.clusterCenters.length)
pmml.getHeader.setDescription("k-means clustering")

if (model.clusterCenters.length > 0) {
val clusterCenter = model.clusterCenters(0)
val fields = new SArray[FieldName](clusterCenter.size)
val dataDictionary = new DataDictionary
val miningSchema = new MiningSchema
val comparisonMeasure = new ComparisonMeasure()
.withKind(ComparisonMeasure.Kind.DISTANCE)
.withMeasure(new SquaredEuclidean())
val clusteringModel = new ClusteringModel(miningSchema, comparisonMeasure,
MiningFunctionType.CLUSTERING, ClusteringModel.ModelClass.CENTER_BASED,
model.clusterCenters.length)
.withModelName("k-means")

for ( i <- 0 until clusterCenter.size) {
fields(i) = FieldName.create("field_" + i)
dataDictionary
.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
miningSchema
.withMiningFields(new MiningField(fields(i))
.withUsageType(FieldUsageType.ACTIVE))
clusteringModel.withClusteringFields(
new ClusteringField(fields(i)).withCompareFunction(CompareFunctionType.ABS_DIFF)
)
}

dataDictionary.withNumberOfFields((dataDictionary.getDataFields()).size())

for ( i <- 0 until model.clusterCenters.size ) {
val cluster = new Cluster()
.withName("cluster_" + i)
.withArray(new org.dmg.pmml.Array()
.withType(Type.REAL)
.withN(clusterCenter.size)
.withValue(model.clusterCenters(i).toArray.mkString(" ")))
// we don't have the size of the single cluster but only the centroids (withValue)
// .withSize(value)
clusteringModel.withClusters(cluster)
}

pmml.setDataDictionary(dataDictionary)
pmml.withModels(clusteringModel)

}


for (i <- 0 until clusterCenter.size) {
fields(i) = FieldName.create("field_" + i)
dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
miningSchema
.withMiningFields(new MiningField(fields(i))
.withUsageType(FieldUsageType.ACTIVE))
clusteringModel.withClusteringFields(
new ClusteringField(fields(i)).withCompareFunction(CompareFunctionType.ABS_DIFF))
}

dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size)

for (i <- 0 until model.clusterCenters.length) {
val cluster = new Cluster()
.withName("cluster_" + i)
.withArray(new org.dmg.pmml.Array()
.withType(Array.Type.REAL)
.withN(clusterCenter.size)
.withValue(model.clusterCenters(i).toArray.mkString(" ")))
// we don't have the size of the single cluster but only the centroids (withValue)
// .withSize(value)
clusteringModel.withClusters(cluster)
}

pmml.setDataDictionary(dataDictionary)
pmml.withModels(clusteringModel)
}
}

}
Loading

0 comments on commit 472d757

Please sign in to comment.