Skip to content

Commit

Permalink
Merge pull request #22 from yiheng/master
Browse files Browse the repository at this point in the history
Some small changes
  • Loading branch information
yiheng-wang-intel authored Sep 25, 2016
2 parents a41815c + 47f2416 commit 7f9790b
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 177 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.awt.color.ColorSpace
import java.util

import com.intel.analytics.sparkdl.nn.ClassNLLCriterion
import com.intel.analytics.sparkdl.optim.SGD
import com.intel.analytics.sparkdl.optim.{EvaluateMethods, SGD}
import com.intel.analytics.sparkdl.tensor.Tensor
import com.intel.analytics.sparkdl.utils.{File, T}

Expand Down Expand Up @@ -49,160 +49,9 @@ object ImageNetLocal {
println(s"[${(System.nanoTime() - startTime) / 1e9}s] $msg")
}

def runDouble(donkey: Donkey, dataSet: DataSets, netType: String, classNum: Int,
def run(donkey: Donkey, dataSet: DataSets, netType: String, classNum: Int,
labelsMap: Map[String, Double], testInterval: Int, donkeyVal: Donkey,
dataSetVal: DataSets, batchSize: Int): Unit = {
// Compute Mean on amount of samples
val samples = 10000
log(s"Start to calculate Mean on $samples samples")
var (meanR, meanG, meanB) = Array.tabulate(samples)(n => {
print(".")
val data = donkey.pull
dataSet.post(data._2)
ImageNetUtils.computeMean(data._1, data._2.dataOffset)
}).reduce((a, b) => (a._1 + b._1, a._2 + b._2, a._3 + b._3))
meanR /= samples
meanG /= samples
meanB /= samples
println()

// Compute std on amount of samples
log(s"Start to calculate std on $samples samples")
var (varR, varG, varB) = Array.tabulate(samples)(n => {
print(".")
val data = donkey.pull
dataSet.post(data._2)
ImageNetUtils.computeVar(data._1, meanR, meanG, meanB, data._2.dataOffset)
}).reduce((a, b) => (a._1 + b._1, a._2 + b._2, a._3 + b._3))
varR /= samples
varG /= samples
varB /= samples

val model = netType match {
case "alexnet" => AlexNet.getModel[Double](classNum)
case "googlenet" => GoogleNet.getModel[Double](classNum)
case "googlenet-bn" => GoogleNet.getModel[Double](classNum, "googlenet-bn")
case "googlenet-cf" => GoogleNet.getModelCaffe[Double](classNum)
case _ => throw new IllegalArgumentException
}
val (weights, grad) = model.getParameters()
println(s"modelsize ${weights.nElement()}")
println(model)
val criterion = new ClassNLLCriterion[Double]()
val epochNum = 90
val featureShape = Array(3, 224, 224)
val targetShape = Array(1)
val sgd = new SGD[Double]
val state = T("momentum" -> 0.9, "dampening" -> 0.0)
val stageImgs = new util.ArrayDeque[Image](batchSize)
val input = Tensor[Double](batchSize, 3, 224, 224)
val target = Tensor[Double](batchSize)
val iter = ImageNetUtils.toTensorDouble(
donkey.map(d => {
stageImgs.push(d._2)
(labelsMap(d._2.label), d._1)
}),
featureShape,
targetShape,
batchSize,
(meanR, meanG, meanB),
(varR, varG, varB),
input,
target
)

val stageImgsVal = new util.ArrayDeque[Image](batchSize)
val iterVal = ImageNetUtils.toTensorDouble(
donkeyVal.map(d => {
stageImgsVal.push(d._2)
(labelsMap(d._2.label), d._1)
}),
featureShape,
targetShape,
batchSize,
(meanR, meanG, meanB),
(varR, varG, varB),
input,
target
)

log(s"meanR is $meanR meanG is $meanG meanB is $meanB")
log(s"varR is $varR varG is $varG varB is $varB")
log("Start to train...")

var wallClockTime = 0L
for (i <- 1 to epochNum) {
println(s"Epoch[$i] Train")

for (regime <- regimes(netType)) {
if (i >= regime._1 && i <= regime._2) {
state("learningRate") = regime._3
state("weightDecay") = regime._4
}
}

var j = 0
var c = 0
model.training()
while (j < dataSet.getTotal) {
val start = System.nanoTime()
val (input, target) = iter.next()
val readImgTime = System.nanoTime()
model.zeroGradParameters()
val output = model.forward(input)
val loss = criterion.forward(output, target)
val gradOutput = criterion.backward(output, target)
model.backward(input, gradOutput)
sgd.optimize(_ => (loss, grad), weights, state, state)
val end = System.nanoTime()
wallClockTime += end - start
log(s"Epoch[$i][Iteration $c $j/${dataSet.getTotal}][Wall Clock ${wallClockTime / 1e9}s]" +
s" loss is $loss time ${(end - start) / 1e9}s read " +
s"time ${(readImgTime - start) / 1e9}s train time ${(end - readImgTime) / 1e9}s." +
s" Throughput is ${input.size(1).toDouble / (end - start) * 1e9} img / second")
while (!stageImgs.isEmpty) {
dataSet.post(stageImgs.poll())
}
j += input.size(1)
c += 1
}

if (i % testInterval == 0) {
model.evaluate()
var correct = 0
var k = 0
while (k < dataSetVal.getTotal) {
val (input, target) = iterVal.next()
val output = model.forward(input)
output.max(2)._2.squeeze().map(target, (a, b) => {
if (a == b) {
correct += 1
}
a
})
while (!stageImgsVal.isEmpty) {
dataSetVal.post(stageImgsVal.poll())
}
k += input.size(1)
}

val accuracy = correct.toDouble / dataSetVal.getTotal
println(s"[Wall Clock ${wallClockTime / 1e9}s] Accuracy is $accuracy")

// Save model to a file each epoch
File.save(model, s"${netType}${accuracy}.model${i}", true)
File.save(state, s"${netType}${accuracy}.state${i}", true)
}

log("shuffle")
dataSet.shuffle
log("shuffle end")
}
}

def runFloat(donkey: Donkey, dataSet: DataSets, netType: String, classNum: Int,
labelsMap: Map[String, Double], testInterval: Int, donkeyVal: Donkey,
dataSetVal: DataSets, batchSize: Int): Unit = {
dataSetVal: DataSets, batchSize: Int, modelPath : String): Unit = {
// Compute Mean on amount of samples
val samples = 10000
log(s"Start to calculate Mean on $samples samples")
Expand Down Expand Up @@ -327,25 +176,27 @@ object ImageNetLocal {

if (i % testInterval == 0) {
model.evaluate()
var correct = 0
var top1Correct = 0
var top5Correct = 0
var k = 0
while (k < dataSetVal.getTotal) {
val (input, target) = iterVal.next()
val output = model.forward(input)
output.max(2)._2.squeeze().map(target, (a, b) => {
if (a == b) {
correct += 1
}
a
})
top1Correct += EvaluateMethods.calcAccuracy(output, target)._1
top5Correct += EvaluateMethods.calcTop5Accuracy(output, target)._1
while (!stageImgsVal.isEmpty) {
dataSetVal.post(stageImgsVal.poll())
}
k += input.size(1)
}

val accuracy = correct.toDouble / dataSetVal.getTotal
println(s"[Wall Clock ${wallClockTime / 1e9}s] Accuracy is $accuracy")
val top1Accuracy = top1Correct.toDouble / dataSetVal.getTotal
val top5Accuracy = top5Correct.toDouble / dataSetVal.getTotal
println(s"[Wall Clock ${wallClockTime / 1e9}s] Top-1 Accuracy is $top1Accuracy")
println(s"[Wall Clock ${wallClockTime / 1e9}s] Top-5 Accuracy is $top5Accuracy")
println(s"Save model and state to $modelPath-$i")
File.save(model, modelPath + s"-$i.model")
File.save(state, modelPath + s"-$i.state")
}

log("shuffle")
Expand All @@ -371,8 +222,8 @@ object ImageNetLocal {
val testInterval = args(4).toInt
val netType = args(5)
val classNum = args(6).toInt
val dataType = args(7)
val batchSize = args(8).toInt
val batchSize = args(7).toInt
val modelPath = args(8)

val dataSet = new DataSets(path, classNum, labelsMap)
val donkey = new Donkey(parallelism, dataSet)
Expand All @@ -383,12 +234,7 @@ object ImageNetLocal {
dataSet.shuffle
log("shuffle end")

dataType match {
case "double" => runDouble(donkey, dataSet, netType, classNum, labelsMap, testInterval,
donkeyVal, dataSetVal, batchSize)
case "float" => runFloat(donkey, dataSet, netType, classNum, labelsMap, testInterval,
donkeyVal, dataSetVal, batchSize)
case _ => throw new IllegalArgumentException
}
run(donkey, dataSet, netType, classNum, labelsMap, testInterval,
donkeyVal, dataSetVal, batchSize, modelPath)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ object GoogleNet_v2 {

val conv3 = new Sequential[D]
conv3.add(new SpatialConvolution[D](inputSize, config[Table](2)(1), 1, 1, 1, 1)
.setName(namePrefix + "3x3_s2"))
.setName(namePrefix + "3x3_reduce"))
conv3.add(new SpatialBatchNormalization(config[Table](2)(1), 1e-3)
.setName(namePrefix + "3x3_reduce/bn"))
conv3.add(new ReLU[D](true). setName(namePrefix + "3x3_reduce/bn/sc/relu"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ object Perf {

def performance[T: ClassTag](param: Params)(implicit tn: TensorNumeric[T]): Unit = {
val (model, input) = param.module match {
case "alexnet" => (AlexNet(1000), Tensor[T](param.batchSize, 3, 224, 224))
case "alexnet" => (AlexNet(1000), Tensor[T](param.batchSize, 3, 227, 227))
case "alexnetowt" => (AlexNet_OWT(1000), Tensor[T](param.batchSize, 3, 224, 224))
case "googlenet_v1" => (GoogleNet_v1(1000), Tensor[T](param.batchSize, 3, 224, 224))
case "googlenet_v2" => (GoogleNet_v2(1000), Tensor[T](param.batchSize, 3, 224, 224))
Expand Down Expand Up @@ -139,8 +139,6 @@ object Perf {
}
}

case class TestCase[T](input: Tensor[T], target: Tensor[T], model: Module[T])

case class Params(
batchSize: Int = 128,
iteration: Int = 10,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package com.intel.analytics.sparkdl.optim
import com.intel.analytics.sparkdl.nn._
import com.intel.analytics.sparkdl.ps.{AllReduceParameterManager, OneReduceParameterManager}
import com.intel.analytics.sparkdl.tensor.{Storage, Tensor}
import com.intel.analytics.sparkdl.utils.{Engine, T}
import com.intel.analytics.sparkdl.utils.{RandomGenerator, Engine, T}
import org.apache.log4j.{Level, Logger}
import org.apache.spark.SparkContext
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}
Expand All @@ -38,6 +38,7 @@ class EpochOptimizerSpec extends FlatSpec with Matchers with BeforeAndAfter {
"An Artificial Neural Network with MSE and LBFGS" should "be trained with good result" in {
Logger.getLogger("org").setLevel(Level.WARN)
Logger.getLogger("akka").setLevel(Level.WARN)
RandomGenerator.RNG.setSeed(1000)

sc = new SparkContext("local[1]", "SerialOptimizerSpec")

Expand Down Expand Up @@ -98,6 +99,7 @@ class EpochOptimizerSpec extends FlatSpec with Matchers with BeforeAndAfter {
Logger.getLogger("org").setLevel(Level.WARN)
Logger.getLogger("akka").setLevel(Level.WARN)

RandomGenerator.RNG.setSeed(1000)
sc = new SparkContext("local[1]", "SerialOptimizerSpec")

// Prepare two kinds of input and their corresponding label
Expand Down

0 comments on commit 7f9790b

Please sign in to comment.