diff --git a/scala/orca/src/main/scala/com/intel/analytics/bigdl/orca/net/TFNet.scala b/scala/orca/src/main/scala/com/intel/analytics/bigdl/orca/net/TFNet.scala index 46f711d04b9..3de9248871c 100644 --- a/scala/orca/src/main/scala/com/intel/analytics/bigdl/orca/net/TFNet.scala +++ b/scala/orca/src/main/scala/com/intel/analytics/bigdl/orca/net/TFNet.scala @@ -62,6 +62,9 @@ class TFNet(private val graphDef: TFGraphHolder, implicit val ev = TensorNumeric.NumericFloat implicit val tag: ClassTag[Float] = ClassTag.Float + System.setProperty("bigdl.ModelBroadcastFactory", + "com.intel.analytics.zoo.tfpark.TFModelBroadcastFactory") + @transient private lazy val tensorManager = new TFResourceManager() diff --git a/scala/orca/src/main/scala/com/intel/analytics/bigdl/orca/net/TFNetForInference.scala b/scala/orca/src/main/scala/com/intel/analytics/bigdl/orca/net/TFNetForInference.scala index e62ccd5b847..1919c5ef39a 100644 --- a/scala/orca/src/main/scala/com/intel/analytics/bigdl/orca/net/TFNetForInference.scala +++ b/scala/orca/src/main/scala/com/intel/analytics/bigdl/orca/net/TFNetForInference.scala @@ -49,6 +49,9 @@ private[zoo] class TFNetForInference(graphRunner: GraphRunner, implicit val ev = TensorNumeric.NumericFloat implicit val tag: ClassTag[Float] = ClassTag.Float + System.setProperty("bigdl.ModelBroadcastFactory", + "com.intel.analytics.zoo.tfpark.TFModelBroadcastFactory") + override def parameters(): (Array[Tensor[Float]], Array[Tensor[Float]]) = { (weights, gradWeights) }