Skip to content

Commit

Permalink
Refactoring. TopologyModel trait as a common model interface
Browse files Browse the repository at this point in the history
  • Loading branch information
avulanov committed May 15, 2015
1 parent 7f15956 commit cb03fe0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
12 changes: 6 additions & 6 deletions mllib/src/main/scala/org/apache/spark/mllib/ann/Layer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ object FeedForwardModel {

/* Neural network gradient. Does nothing but calling Model's gradient
* */
class ANNGradient(topology: FeedForwardTopology, dataStacker: DataStacker) extends Gradient {
class ANNGradient(topology: Topology, dataStacker: DataStacker) extends Gradient {

override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = {
val gradient = Vectors.zeros(weights.size)
Expand All @@ -464,7 +464,6 @@ class ANNGradient(topology: FeedForwardTopology, dataStacker: DataStacker) exten
override def compute(data: Vector, label: Double, weights: Vector,
cumGradient: Vector): Double = {
val (input, target, realBatchSize) = dataStacker.unstack(data)
//val model = FeedForwardModel(topology, weights)
val model = topology.getInstance(weights)
model.computeGradient(input, target, cumGradient, realBatchSize)
}
Expand Down Expand Up @@ -528,11 +527,11 @@ private class ANNUpdater extends Updater {
}
/* MLlib-style trainer class that trains a network given the data and topology
* */
class FeedForwardTrainer (topology: FeedForwardTopology, val inputSize: Int,
class FeedForwardTrainer (topology: Topology, val inputSize: Int,
val outputSize: Int) extends Serializable {

// TODO: what if we need to pass random seed?
private var _weights = FeedForwardModel(topology).weights()
private var _weights = topology.getInstance(11L).weights()//FeedForwardModel(topology).weights()
private var _batchSize = 1
private var dataStacker = new DataStacker(_batchSize, inputSize, outputSize)
private var _gradient: Gradient = new ANNGradient(topology, dataStacker)
Expand Down Expand Up @@ -595,9 +594,10 @@ class FeedForwardTrainer (topology: FeedForwardTopology, val inputSize: Int,
}
}

def train(data: RDD[(Vector, Vector)]): FeedForwardModel = {
def train(data: RDD[(Vector, Vector)]): TopologyModel = {
val newWeights = optimizer.optimize(dataStacker.stack(data), getWeights)
FeedForwardModel(topology, newWeights)
//FeedForwardModel(topology, newWeights)
topology.getInstance(newWeights)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.mllib.classification

import org.apache.spark.mllib.ann.{FeedForwardTrainer, FeedForwardModel}
import org.apache.spark.mllib.ann.{TopologyModel, Topology, FeedForwardTrainer, FeedForwardModel}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
Expand All @@ -37,7 +37,7 @@ object LabelConverter {
}
}

class ANNClassifierModel (val annModel: FeedForwardModel)
class ANNClassifierModel (val annModel: TopologyModel)
extends ClassificationModel with Serializable {
/**
* Predict values for the given data set using the model trained.
Expand Down

0 comments on commit cb03fe0

Please sign in to comment.