Skip to content

Commit

Permalink
Reimplement SoftShrinkActivation using tf.where operand. This also fi…
Browse files Browse the repository at this point in the history
…xes failing tests.

Related-To: Kotlin#170
  • Loading branch information
michalharakal committed Nov 30, 2021
1 parent c6d3ad5 commit d21d0c3
Showing 1 changed file with 18 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -516,22 +516,25 @@ public class HardShrinkActivation(public val lower: Float = -0.5f, public val up
*/
public class SoftShrinkActivation(public val lower: Float = -0.5f, public val upper: Float = 0.5f) : Activation {
override fun apply(tf: Ops, features: Operand<Float>): Operand<Float> {
require(upper < lower) {
"The value of lambda must be no less than zero"
}
val maskLower = tf.math.minimum(features, tf.constant(lower)) != tf.constant(lower)
val maskUpper = tf.math.maximum(features, tf.constant(upper)) != tf.constant(upper)
val mask = (maskLower || maskUpper)
return when (mask) {
false -> tf.constant(0) as Operand<Float>
true -> {
if (maskUpper)
tf.math.add(features, tf.constant(upper))
else tf.math.sub(
features, tf.constant(lower)
)
}
require((lower < upper) && (lower < 0) && (upper > 0)) {
"The boundary values have to be non zero and the lower bound has to be lower as the upper"
}
val zeros = tf.math.mul(features, tf.constant(0f))
val valuesBelowLower = tf.where3(
tf.math.less(features, tf.constant(lower)),
tf.math.sub(
features, tf.constant(lower)
),
zeros
)
val valuesAboveUpper = tf.where3(
tf.math.less(tf.constant(upper), features),
tf.math.sub(
features, tf.constant(upper)
),
zeros
)
return tf.math.add(valuesBelowLower, valuesAboveUpper)
}
}

Expand Down

0 comments on commit d21d0c3

Please sign in to comment.