Skip to content

Commit

Permalink
Added missing saving functions for ReLU and ELU activation layers (Je…
Browse files Browse the repository at this point in the history
…tBrains#78)
  • Loading branch information
dosier committed Jun 14, 2021
1 parent 218fb57 commit 4592920
Showing 1 changed file with 25 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@ import org.jetbrains.kotlinx.dl.api.core.Sequential
import org.jetbrains.kotlinx.dl.api.core.activation.Activations
import org.jetbrains.kotlinx.dl.api.core.initializer.*
import org.jetbrains.kotlinx.dl.api.core.layer.Layer
import org.jetbrains.kotlinx.dl.api.core.layer.activation.PReLU
import org.jetbrains.kotlinx.dl.api.core.layer.activation.LeakyReLU
import org.jetbrains.kotlinx.dl.api.core.layer.activation.Softmax
import org.jetbrains.kotlinx.dl.api.core.layer.activation.ThresholdedReLU
import org.jetbrains.kotlinx.dl.api.core.layer.activation.*
import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.*
import org.jetbrains.kotlinx.dl.api.core.layer.core.ActivationLayer
import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense
Expand All @@ -28,6 +25,7 @@ import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.ZeroPadding2D
import org.jetbrains.kotlinx.dl.api.core.regularizer.L2L1
import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer
import org.jetbrains.kotlinx.dl.api.inference.keras.config.*
import org.tensorflow.op.nn.Elu
import java.io.File

/**
Expand Down Expand Up @@ -86,6 +84,8 @@ private fun convertToKerasLayer(layer: Layer, isKerasFullyCompatible: Boolean, i
is Input -> createKerasInput(layer)
is BatchNorm -> createKerasBatchNorm(layer, isKerasFullyCompatible)
is ActivationLayer -> createKerasActivationLayer(layer)
is ELU -> createKerasELU(layer)
is ReLU -> createKerasReLU(layer)
is PReLU -> createKerasPReLULayer(layer, isKerasFullyCompatible)
is LeakyReLU -> createKerasLeakyReLU(layer)
is ThresholdedReLU -> createKerasThresholdedReLULayer(layer)
Expand Down Expand Up @@ -221,6 +221,26 @@ private fun createKerasActivationLayer(layer: ActivationLayer): KerasLayer {
return KerasLayer(class_name = LAYER_ACTIVATION, config = configX)
}

private fun createKerasReLU(layer: ReLU): KerasLayer {
val configX = LayerConfig(
dtype = DATATYPE_FLOAT32,
max_value = layer.maxValue?.toDouble(),
negative_slope = layer.negativeSlope.toDouble(),
threshold = layer.threshold.toDouble(),
name = layer.name
)
return KerasLayer(class_name = LAYER_RELU, config = configX)
}

private fun createKerasELU(layer: ELU): KerasLayer {
val configX = LayerConfig(
dtype = DATATYPE_FLOAT32,
alpha = layer.alpha.toDouble(),
name = layer.name
)
return KerasLayer(class_name = LAYER_ELU, config = configX)
}

private fun createKerasPReLULayer(layer: PReLU, isKerasFullyCompatible: Boolean): KerasLayer {
val configX = LayerConfig(
dtype = DATATYPE_FLOAT32,
Expand Down Expand Up @@ -604,4 +624,4 @@ private fun createKerasZeroPadding2D(layer: ZeroPadding2D): KerasLayer {
padding = KerasPadding.ZeroPadding2D(layer.padding)
)
return KerasLayer(class_name = LAYER_ZERO_PADDING_2D, config = configX)
}
}

0 comments on commit 4592920

Please sign in to comment.