Skip to content

Commit

Permalink
Python MKLDNN examples for CNN(LeNet) and RNN(LSTM) (#2932)
Browse files Browse the repository at this point in the history
  • Loading branch information
mengceng15 authored Oct 28, 2019
1 parent ad39cf2 commit 5abb1e2
Showing 1 changed file with 9 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import java.util.{ArrayList => JArrayList, HashMap => JHashMap, List => JList, M

import com.intel.analytics.bigdl._
import com.intel.analytics.bigdl.dataset.{Identity => DIdentity, Sample => JSample, _}
import com.intel.analytics.bigdl.nn.{PGCriterion, Zeros, _}
import com.intel.analytics.bigdl.nn.{PGCriterion, Sequential, Zeros, _}
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, _}
import com.intel.analytics.bigdl.numeric._
import com.intel.analytics.bigdl.optim.{Optimizer, _}
Expand Down Expand Up @@ -275,6 +275,10 @@ class PythonBigDL[T: ClassTag](implicit ev: TensorNumeric[T]) extends Serializab
Sequential[T]()
}

def toGraph(sequential: Sequential[T]): StaticGraph[T] = {
sequential.toGraph().asInstanceOf[StaticGraph[T]]
}

def createAttention(hiddenSize: Int, numHeads: Int, attentionDropout: Float): Attention[T] = {
Attention(hiddenSize, numHeads, attentionDropout)
}
Expand Down Expand Up @@ -2504,6 +2508,10 @@ class PythonBigDL[T: ClassTag](implicit ev: TensorNumeric[T]) extends Serializab
Engine.init
}

def getEngineType(): String = {
Engine.getEngineType().toString
}

def getNodeAndCoreNumber(): Array[Int] = {
Array(Engine.nodeNumber(), Engine.coreNumber())
}
Expand Down

0 comments on commit 5abb1e2

Please sign in to comment.