Skip to content

Commit

Permalink
fix: memory leak in model.predictImageSet. (intel-analytics#2557)
Browse files Browse the repository at this point in the history
* fix: memory leak in `model.predictImageSet`.

There're three reasons of memory leak.

1. repeat allocations in bigquant, which will be fixed in BigDL-core.
2. repeat clone module but no release. `model.predictImageSet` will new
   Predictor again and again.
2. share weights.

This patch add a `StorageManager` which contains a concurrent hash map
to maintain all allocations of native memory/resources and prevent
duplicate release. It's also helpful for debug.

* fix: delete .

* refator:  as the API for AbstractModule

* fix: distribute predictor memory leak

* fix: move delete operation to ModelBroadcast

* refinement per review

* fix ut

* fix scala version issue
  • Loading branch information
i8run authored and wzhongyuan committed Jun 28, 2018
1 parent 24db1da commit 2263e40
Show file tree
Hide file tree
Showing 15 changed files with 263 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,19 @@

package com.intel.analytics.bigdl.models.utils

import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
import java.util.UUID

import com.intel.analytics.bigdl.Module
import com.intel.analytics.bigdl.nn.{Container, Graph}
import com.intel.analytics.bigdl.tensor.{QuantizedTensor, QuantizedType, Storage, Tensor}
import com.intel.analytics.bigdl.nn.Container
import com.intel.analytics.bigdl.nn.quantized.StorageManager
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.tensor._
import com.intel.analytics.bigdl.utils.Util._
import org.apache.spark.SparkContext
import org.apache.spark.broadcast.Broadcast
import com.intel.analytics.bigdl.utils.Util._

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

/**
Expand All @@ -38,10 +43,11 @@ import scala.reflect.ClassTag
class ModelBroadcast[T: ClassTag](applyProtoBuffer: Boolean = false)
(implicit ev: TensorNumeric[T]) extends Serializable {

private var broadcastModel: Broadcast[Module[T]] = _
private var broadcastModel: Broadcast[ModelInfo[T]] = _
private var broadcastConsts: Broadcast[Map[String, Tensor[_]]] = _
private var broadcastParameters: Broadcast[Array[Tensor[T]]] = _

private[bigdl] val uuid: String = UUID.randomUUID().toString

/**
* broadcast the model
Expand All @@ -53,26 +59,45 @@ class ModelBroadcast[T: ClassTag](applyProtoBuffer: Boolean = false)
* @return this
*/
def broadcast(sc: SparkContext, model: Module[T]): this.type = {
CachedModels.deleteAll(uuid) // delete the models on driver

if (applyProtoBuffer) {
broadcastModel = sc.broadcast(model)
broadcastModel = sc.broadcast(ModelInfo(uuid, model))
} else {
// We should clone a new model which will maintain the origin model.
// Otherwise, the origin model's resources will be cleaned.
val newModel = model.cloneModule()
CachedModels.add(uuid, newModel)

// broadcast Consts
if (model.isInstanceOf[Container[_, _, T]]) {
val moduleConsts = getAndClearConsts(model.asInstanceOf[Container[_, _, T]])
if (newModel.isInstanceOf[Container[_, _, T]]) {
val moduleConsts = getAndClearConsts(newModel.asInstanceOf[Container[_, _, T]])
// TODO: broadcast Const, model structure and weight in the same broadcast.
broadcastConsts = sc.broadcast(moduleConsts)
}

// broadcast weight and model
val weightsBias = getAndClearWeightBias(model.parameters())
broadcastModel = sc.broadcast(model.cloneModule())
val weightsBias = getAndClearWeightBias(newModel.parameters())

// We broadcast weight and model separately because of the memory limit of serialization.
// And we should clone the model structure (without weight) first because of lazy evaluation
// of broadcast. As you see, we have to put weights back to the model after broadcast call.
// As a quantized model, it will create relevant memory after clone because of
// `QuantizedTensor`. So we should release it first.
val cloned = newModel.cloneModule()
cloned.release()
CachedModels.add(uuid, cloned)

broadcastModel = sc.broadcast(ModelInfo[T](uuid, cloned))
broadcastParameters = sc.broadcast(weightsBias)

putWeightBias(weightsBias, model)
initGradWeightBias(weightsBias, model)
putWeightBias(weightsBias, newModel)
initGradWeightBias(weightsBias, newModel)
}
this
}


/**
* get the broadcast model
* put the weight and bias back to the model
Expand All @@ -81,14 +106,21 @@ class ModelBroadcast[T: ClassTag](applyProtoBuffer: Boolean = false)
* @return model
*/
def value(initGradient: Boolean = false): Module[T] = {
CachedModels.deleteAll(uuid)
if (applyProtoBuffer) {
val localModel = broadcastModel.value.clone(false)
val localModel = broadcastModel.value.model.clone(false)
val uuid = broadcastModel.value.uuid
CachedModels.add(uuid, localModel)

if (initGradient) {
initGradWeightBias(getWeightBias(localModel.parameters()), localModel)
}
localModel
} else {
val localModel = broadcastModel.value.cloneModule()
val localModel = broadcastModel.value.model.cloneModule()
val uuid = broadcastModel.value.uuid
CachedModels.add(uuid, localModel)

// share weight
putWeightBias(broadcastParameters.value, localModel)
// share Consts
Expand Down Expand Up @@ -141,13 +173,64 @@ class ModelBroadcast[T: ClassTag](applyProtoBuffer: Boolean = false)
Array()
}
}

}


object ModelBroadcast {
def apply[@specialized(Float, Double) T: ClassTag](applyProtoBuffer: Boolean = false)
(implicit ev: TensorNumeric[T]) : ModelBroadcast[T] = {
new ModelBroadcast(applyProtoBuffer)
}
}

private[bigdl] class ModelInfo[T: ClassTag](val uuid: String, @transient var model: Module[T])(
implicit ev: TensorNumeric[T]) extends Serializable {
@throws(classOf[IOException])
private def writeObject(out: ObjectOutputStream): Unit = {
out.defaultWriteObject()
val cloned = model.cloneModule()
out.writeObject(cloned)
CachedModels.add(uuid, cloned)
}
}

private[bigdl] object ModelInfo {
def apply[T: ClassTag](uuid: String, model: Module[T])(
implicit ev: TensorNumeric[T]): ModelInfo[T] = new ModelInfo[T](uuid, model)
}

object CachedModels {
import java.util.concurrent.ConcurrentHashMap

import scala.collection._
import scala.collection.convert.decorateAsScala._
import scala.language.existentials

type Modles = ArrayBuffer[Module[_]]

private val cachedModels: concurrent.Map[String, Modles] =
new ConcurrentHashMap[String, Modles]().asScala

def add[T: ClassTag](uuid: String, model: Module[T])( implicit ev: TensorNumeric[T]): Unit =
CachedModels.synchronized {
val models = cachedModels.get(uuid) match {
case Some(values) => values += model.asInstanceOf[Module[_]]
case _ => ArrayBuffer(model.asInstanceOf[Module[_]])
}
cachedModels.put(uuid, models.asInstanceOf[Modles])
}

def deleteAll[T: ClassTag](currentKey: String)(implicit ev: TensorNumeric[T]): Unit =
CachedModels.synchronized {
val keys = cachedModels.keys
for (key <- keys) {
if (key != currentKey) {
val models = cachedModels(key)
println(s"delete key = $key ${models.length}")
for (model <- models) {
model.release()
}
cachedModels.remove(key)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -227,4 +227,8 @@ abstract class Container[A <: Activity : ClassTag,
super.checkDuplicate(record)
if (!skipDuplicateCheck()) modules.foreach(_.checkDuplicate(record))
}

override def release(): Unit = {
modules.foreach(_.release())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -665,8 +665,11 @@ abstract class AbstractModule[A <: Activity: ClassTag, B <: Activity: ClassTag,
Predictor(this, featurePaddingParam, batchPerPartition)
.predictImage(distributedImageFrame, outputLayer, shareBuffer, predictKey)
case localImageFrame: LocalImageFrame =>
LocalPredictor(this, featurePaddingParam, batchPerPartition)
.predictImage(localImageFrame, outputLayer, shareBuffer, predictKey)
val predictor = LocalPredictor(this, featurePaddingParam, batchPerPartition)
val imageFrame = predictor.predictImage(localImageFrame, outputLayer, shareBuffer,
predictKey)
predictor.shutdown()
imageFrame
}
}

Expand Down Expand Up @@ -1106,5 +1109,11 @@ abstract class AbstractModule[A <: Activity: ClassTag, B <: Activity: ClassTag,
* @return
*/
private[nn] def skipDuplicateCheck(): Boolean = false

/**
* if the model contains native resources such as aligned memory, we should release it by manual.
* JVM GC can't release them reliably.
*/
def release(): Unit = {}
}

Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ object Desc {
desc
}

// add every native memory allocation.
StorageManager.add(desc, params.getType)

desc
}

Expand Down Expand Up @@ -178,3 +181,28 @@ object QuantParams {
val THRESHOLD = 127.0f
}

private[bigdl] case class StorageInfo(descType: DescType, isFreed: Boolean)

private[bigdl] object StorageManager {
import java.util.concurrent.ConcurrentHashMap
private val nativeStorages: ConcurrentHashMap[Long, StorageInfo] = new ConcurrentHashMap()

def isFreed(nativeStorage: Long): Boolean = {
nativeStorages.get(nativeStorage).isFreed
}

// atomically set the value
def checkAndSet(nativeStorage: Long): Boolean = {
val descType = nativeStorages.get(nativeStorage).descType
nativeStorages.replace(nativeStorage, StorageInfo(descType, false), StorageInfo(descType, true))
}

def get(): Map[Long, StorageInfo] = {
import scala.collection.JavaConverters._
nativeStorages.asScala.toMap
}

def add(key: Long, descType: DescType): Unit = {
nativeStorages.put(key, StorageInfo(descType, false))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ private[bigdl] class Linear[T: ClassTag](
s"quantized.${getPrintName()}($inputSize -> $outputSize)"
}

def release(): Unit = {
override def release(): Unit = {
weight.release()
data.release()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ private[bigdl] class SpatialConvolution[T: ClassTag](
s" $kernelH, $strideW, $strideH, $padW, $padH, $nGroup)"
}

def release(): Unit = {
override def release(): Unit = {
weight.foreach(_.asInstanceOf[QuantizedTensor[T]].release())
data.release()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ package com.intel.analytics.bigdl.optim

import com.intel.analytics.bigdl._
import com.intel.analytics.bigdl.dataset.{SampleToMiniBatch, _}
import com.intel.analytics.bigdl.nn.Container
import com.intel.analytics.bigdl.nn.abstractnn.Activity
import com.intel.analytics.bigdl.nn.quantized.QuantizedModule
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.{Engine, MklBlas, Util}
import com.intel.analytics.bigdl.utils.Util._
import com.intel.analytics.bigdl.transform.vision.image.{ImageFeature, ImageFrame, LocalImageFrame}
import com.intel.analytics.bigdl.utils.Util._
import com.intel.analytics.bigdl.utils.{Engine, MklBlas, Util}
import org.apache.log4j.Logger

import scala.reflect.ClassTag
Expand Down Expand Up @@ -58,15 +60,19 @@ class LocalPredictor[T: ClassTag] private[optim](model: Module[T],
case _ => throw new IllegalArgumentException
}

// we should clone a new model which has no impact to origin model
private val clonedModel = model.cloneModule()

private val workingModels = {
val weightsBias = Util.getAndClearWeightBias(model.parameters())

val weightsBias = Util.getAndClearWeightBias(clonedModel.parameters())
val models = (1 to subModelNumber).map(_ => {
val submodel = model.cloneModule().evaluate()
val submodel = clonedModel.cloneModule().evaluate()
putWeightBias(weightsBias, submodel)
submodel
}).toArray
Util.putWeightBias(weightsBias, model)
Util.initGradWeightBias(weightsBias, model)
Util.putWeightBias(weightsBias, clonedModel)
Util.initGradWeightBias(weightsBias, clonedModel)
models
}

Expand Down Expand Up @@ -176,6 +182,14 @@ class LocalPredictor[T: ClassTag] private[optim](model: Module[T],

ImageFrame.array(result.toArray)
}

/**
* `shutdown` will release all native resources.
*/
def shutdown(): Unit = {
workingModels.foreach(_.release())
clonedModel.release()
}
}


Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

package com.intel.analytics.bigdl.optim

import java.util.UUID

import com.intel.analytics.bigdl._
import com.intel.analytics.bigdl.dataset.{MiniBatch, PaddingParam, Sample, SampleToMiniBatch, Transformer, Utils, DataSet => _}
import com.intel.analytics.bigdl.models.utils.ModelBroadcast
import com.intel.analytics.bigdl.models.utils.{CachedModels, ModelBroadcast, ModelInfo}
import com.intel.analytics.bigdl.nn.abstractnn.Activity
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
Expand Down Expand Up @@ -161,6 +163,8 @@ class Predictor[T: ClassTag] private[optim](
partitionNum = Some(partitionNum),
featurePaddingParam = featurePaddingParam))
dataSet.mapPartitions { partition =>
CachedModels.add(modelBroad.uuid, model)

val localModel = modelBroad.value()
val localTransformer = otherBroad.value.cloneTransformer()
val miniBatch = localTransformer(partition)
Expand Down Expand Up @@ -192,6 +196,9 @@ class Predictor[T: ClassTag] private[optim](
partitionNum = Some(partitionNum),
featurePaddingParam = featurePaddingParam), shareBuffer)
val result = rdd.mapPartitions(partition => {
// By default, the `model` will be deserialized on worker, which will create new resources.
CachedModels.add(modelBroad.uuid, model)

val localModel = modelBroad.value()
val localToBatch = toBatchBroad.value._1.cloneTransformer()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2204,6 +2204,8 @@ private[tensor] class DenseTensor[@specialized T: ClassTag](
this.apply1(a => ev.digamma(a))
}

override private[bigdl] def toQuantizedTensor: QuantizedTensor[T] =
throw new IllegalArgumentException("DenseTensor cannot be cast to QuantizedTensor")
}

object DenseTensor {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ private[bigdl] class QuantizedTensor[T: ClassTag](
}

def release(): this.type = {
if (desc != 0) {
if (desc != 0 && StorageManager.checkAndSet(desc)) {
BigQuant.FreeMemory(desc)
}
desc = 0L
Expand Down Expand Up @@ -270,6 +270,8 @@ private[bigdl] class QuantizedTensor[T: ClassTag](

override def getTensorNumeric(): TensorNumeric[T] = ev

override def toQuantizedTensor: QuantizedTensor[T] = this.asInstanceOf[QuantizedTensor[T]]

@throws(classOf[IOException])
private def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,9 @@ private[tensor] class SparseTensor[@specialized(Float, Double) T: ClassTag](

override def sumSquare(): T =
throw new UnsupportedOperationException(s"SparseTensor: Unimplemented method")

override private[bigdl] def toQuantizedTensor: QuantizedTensor[T] =
throw new IllegalArgumentException("SparseTensor cannot be cast to QuantizedTensor")
}

object SparseTensor{
Expand Down
Loading

0 comments on commit 2263e40

Please sign in to comment.