Skip to content

Commit

Permalink
Cleanup model copying (#503)
Browse files Browse the repository at this point in the history
* Remove extra parameters from the InferenceModel#copy

copyWeight and saveOptimizerState parameters are unused in the OnnxInferenceModel, and InferenceModel is not required to have a name, so copiedModelName does not always make sense as well. It's better to have a more generic copy method in the interface, and move these parameters to the appropriate implementing classes.

* Remove extra parameters from the TensorFlowInferenceModel#copy

TensorFlowInferenceModel does not have any layers, so it only makes sense to copy weights. And since the model class is not trainable, optimizer state does not need to be copied.

* Support copying optimizer variables in Sequential and Functional models
  • Loading branch information
juliabeliaeva authored Jan 2, 2023
1 parent 02f4e6f commit 2ab09f0
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,9 @@ public interface InferenceModel : AutoCloseable {
public fun reshape(vararg dims: Long)

/**
* Creates a copy.
* Creates a copy of this model.
*
* @param [copiedModelName] Set up this name to make a copy with a new name.
* @return A copied inference model.
*/
public fun copy(
copiedModelName: String? = null,
saveOptimizerState: Boolean = false,
copyWeights: Boolean = true
): InferenceModel
public fun copy(): InferenceModel
}
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,8 @@ public open class OnnxInferenceModel private constructor(
}
}

override fun copy(
copiedModelName: String?,
saveOptimizerState: Boolean,
copyWeights: Boolean
): OnnxInferenceModel {
override fun copy(): OnnxInferenceModel {
val model = OnnxInferenceModel(modelSource)
model.name = copiedModelName
if (inputShape != null) {
model.reshape(*inputDimensions)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import org.jetbrains.kotlinx.dl.api.core.layer.Layer
import org.jetbrains.kotlinx.dl.api.core.layer.core.Input
import org.jetbrains.kotlinx.dl.api.core.layer.freeze
import org.jetbrains.kotlinx.dl.api.core.layer.setOutputShape
import org.jetbrains.kotlinx.dl.api.core.layer.weights
import org.jetbrains.kotlinx.dl.api.core.util.sortTopologically
import org.jetbrains.kotlinx.dl.api.inference.keras.*
import org.tensorflow.Operand
Expand Down Expand Up @@ -43,28 +42,26 @@ public class Functional(vararg layers: Layer) : GraphTrainableModel(*layers) {
return input to output[layers.last()]!!
}

/** Returns a copy of this model. */
// TODO: support saveOptimizerState=true with assignment of intermediate optimizer state
public fun copy(saveOptimizerState: Boolean = false, copyWeights: Boolean = true): Functional {
val serializedModel = serializeModel(true)
val deserializedModel = deserializeFunctionalModel(serializedModel)
if (!copyWeights) {
return deserializedModel
} else {
// TODO: make deep copies, not just links
deserializedModel.compile(
optimizer = this.optimizer,
loss = this.loss,
metrics = this.metrics
)

deserializedModel.layers.forEach {
it.weights = this.getLayer(it.name).weights
}

deserializedModel.isModelInitialized = true
override fun copy(): Functional {
return copy(copiedModelName = null, copyOptimizerState = false, copyWeights = true)
}

return deserializedModel
/**
* Creates a copy of this model.
*
* @param [copiedModelName] a name for the copy
* @param [copyOptimizerState] whether optimizer state needs to be copied
* @param [copyWeights] whether model weights need to be copied
* @return A copied inference model.
*/
public fun copy(copiedModelName: String? = null,
copyOptimizerState: Boolean = false,
copyWeights: Boolean = true
): Functional {
val serializedModel = serializeModel(true)
return deserializeFunctionalModel(serializedModel).also { modelCopy ->
if (copiedModelName != null) modelCopy.name = copiedModelName
if (copyWeights) copyWeightsTo(modelCopy, copyOptimizerState)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,27 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
variable.initializerOperation.run(session)
}

protected fun copyWeightsTo(model: GraphTrainableModel, copyOptimizerState: Boolean) {
// TODO: make deep copies, not just links
model.compile(
optimizer = this.optimizer,
loss = this.loss,
metrics = this.metrics
)

model.layers.forEach {
it.weights = this.getLayer(it.name).weights
}

if (copyOptimizerState) {
val optimizerVariables = kGraph.variableNames().filter(::isOptimizerVariable)
copyVariablesToModel(model, optimizerVariables)
model.isOptimizerVariableInitialized = true
}

model.isModelInitialized = true
}

/**
* Return layer by [layerName].
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ package org.jetbrains.kotlinx.dl.api.core
import org.jetbrains.kotlinx.dl.api.core.layer.Layer
import org.jetbrains.kotlinx.dl.api.core.layer.core.Input
import org.jetbrains.kotlinx.dl.api.core.layer.setOutputShape
import org.jetbrains.kotlinx.dl.api.core.layer.weights
import org.jetbrains.kotlinx.dl.api.inference.keras.*
import org.tensorflow.Operand
import org.tensorflow.op.core.Placeholder
Expand Down Expand Up @@ -41,27 +40,26 @@ public class Sequential(vararg layers: Layer) : GraphTrainableModel(*layers) {
return input to output
}

/** Returns a copy of this model. */
// TODO: implement the saving of optimizer state
public fun copy(saveOptimizerState: Boolean = false, copyWeights: Boolean = true): Sequential {
val serializedModel = serializeModel(true)
val deserializedModel = deserializeSequentialModel(serializedModel)
if (!copyWeights) {
return deserializedModel
} else {
deserializedModel.compile(
optimizer = this.optimizer,
loss = this.loss,
metrics = this.metrics
)

deserializedModel.layers.forEach {
it.weights = this.getLayer(it.name).weights
}

deserializedModel.isModelInitialized = true
override fun copy(): Sequential {
return copy(copiedModelName = null, copyOptimizerState = false, copyWeights = true)
}

return deserializedModel
/**
* Creates a copy of this model.
*
* @param [copiedModelName] a name for the copy
* @param [copyOptimizerState] whether optimizer state needs to be copied
* @param [copyWeights] whether model weights need to be copied
* @return A copied inference model.
*/
public fun copy(copiedModelName: String? = null,
copyOptimizerState: Boolean = false,
copyWeights: Boolean = true
): Sequential {
val serializedModel = serializeModel(true)
return deserializeSequentialModel(serializedModel).also { modelCopy ->
if (copiedModelName != null) modelCopy.name = copiedModelName
if (copyWeights) copyWeightsTo(modelCopy, copyOptimizerState)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,12 @@ public open class TensorFlowInferenceModel : InferenceModel {
}
}

override fun copy(
copiedModelName: String?,
saveOptimizerState: Boolean, // TODO, check this case
copyWeights: Boolean
): TensorFlowInferenceModel {
override fun copy(): TensorFlowInferenceModel {
return copy(copiedModelName = null)
}

/** Returns a copy of this model. */
public fun copy(copiedModelName: String? = null): TensorFlowInferenceModel {
val model = TensorFlowInferenceModel()
model.kGraph = this.kGraph.copy()
model.tf = Ops.create(model.kGraph.tfGraph)
Expand All @@ -158,27 +159,21 @@ public open class TensorFlowInferenceModel : InferenceModel {
model.input = input
model.output = output
if (copiedModelName != null) model.name = name
// TODO: check that tensors are closed after usage
if (copyWeights) {
val modelWeightsExtractorRunner = session.runner()
val variableNames = kGraph.variableNames()
check(variableNames.isNotEmpty()) {
"Found 0 variable names in TensorFlow graph $kGraph. " +
"If copied model has no weights, set flag `copyWeights` to `false`."
}
copyVariablesToModel(model, kGraph.variableNames())
model.isModelInitialized = true
return model
}

val variableNamesToCopy = variableNames.filter { variableName ->
saveOptimizerState || !isOptimizerVariable(variableName)
}
variableNamesToCopy.forEach(modelWeightsExtractorRunner::fetch)
val modelWeights = variableNamesToCopy.zip(modelWeightsExtractorRunner.run()).toMap()
protected fun copyVariablesToModel(model: TensorFlowInferenceModel, variableNames: List<String>) {
if (variableNames.isEmpty()) return

model.loadVariables(modelWeights.keys) { variableName, _ ->
modelWeights[variableName]!!.use { it.convertTensorToMultiDimArray() }
}
val modelWeightsExtractorRunner = session.runner()
variableNames.forEach(modelWeightsExtractorRunner::fetch)
val modelWeights = variableNames.zip(modelWeightsExtractorRunner.run()).toMap()

model.loadVariables(modelWeights.keys) { variableName, _ ->
modelWeights[variableName]!!.use { it.convertTensorToMultiDimArray() }
}
model.isModelInitialized = true
return model
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ class TransferLearningTest : IntegrationTest() {

it.loadWeights(hdfFile)

val copy = it.copy()
val copy = it.copy(copyOptimizerState = false, copyWeights = true)
assertTrue(copy.layers.size == 11)
copy.close()

Expand Down

0 comments on commit 2ab09f0

Please sign in to comment.