Skip to content

Commit

Permalink
Add tensorboard support (intel-analytics#2667)
Browse files Browse the repository at this point in the history
* add tensorboard support

* fix style

* fix style2

* fix test
  • Loading branch information
jenniew authored Aug 11, 2020
1 parent 4c5a496 commit 6647762
Showing 1 changed file with 28 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import com.intel.analytics.bigdl.{Criterion, Module}
import com.intel.analytics.bigdl.dataset.MiniBatch
import com.intel.analytics.bigdl.optim._
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.visualization.{TrainSummary, ValidationSummary}
import com.intel.analytics.zoo.feature.{DiskFeatureSet, DistributedFeatureSet, FeatureSet}
import com.intel.analytics.zoo.pipeline.api.keras.models.InternalDistriOptimizer
import org.apache.log4j.Logger
Expand Down Expand Up @@ -72,6 +73,10 @@ class Estimator[T: ClassTag] private[zoo](
protected val gradientClipping: ArrayBuffer[GradientClipping] =
new ArrayBuffer[GradientClipping]()

protected var logDir: String = null

protected var appName: String = null

/**
* Clear gradient clipping parameters. In this case, gradient clipping will not be applied.
* In order to take effect, it needs to be called before fit.
Expand Down Expand Up @@ -100,6 +105,22 @@ class Estimator[T: ClassTag] private[zoo](
def setGradientClippingByL2Norm(clipNorm: Double): Unit = {
this.gradientClipping.append(L2NormClipping(clipNorm))
}

def setTensorBoard(logDir: String, appName: String): Unit = {
this.logDir = logDir
this.appName = appName
}

def getTrainSummary(tag: String): Array[(Long, Float, Double)] = {
this.internalEstimator.asInstanceOf[InternalDistriOptimizer[T]].getTrainSummary(tag)
}

def getValidationSummary(tag: String): Array[(Long, Float, Double)] = {
this.internalEstimator.asInstanceOf[InternalDistriOptimizer[T]].getValidationSummary(tag)
}



/**
* Train model with provided trainSet and criterion.
* The training will end until the endTrigger is triggered.
Expand Down Expand Up @@ -128,6 +149,13 @@ class Estimator[T: ClassTag] private[zoo](
.setCheckpointDir(modelDir)
.setOptimMethods(optimMethods)
.setNumOfSlice(d.numOfSlice)
if ((logDir != null) && (appName != null)) {
val trainSummary = TrainSummary(logDir, appName)
val valSummary = ValidationSummary(logDir, appName)
internalEstimator.asInstanceOf[Optimizer[_, _]]
.setTrainSummary(trainSummary)
.setValidationSummary(valSummary)
}
}
case _ => throw new IllegalArgumentException("Unsupported FeatureSet type.")
}
Expand Down

0 comments on commit 6647762

Please sign in to comment.