diff --git a/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java b/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java index 51ee27ad019..3c9bb3f89d7 100644 --- a/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java +++ b/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java @@ -68,7 +68,10 @@ public RecurrentBlock(BaseBuilder builder) { bidirectional = builder.bidirectional; returnState = builder.returnState; - Parameter.Type[] parameterTypes = {Parameter.Type.WEIGHT, Parameter.Type.BIAS}; + Parameter.Type[] parameterTypes = + hasBiases + ? new Parameter.Type[] {Parameter.Type.WEIGHT, Parameter.Type.BIAS} + : new Parameter.Type[] {Parameter.Type.WEIGHT}; String[] directions = {"l"}; if (builder.bidirectional) { directions = new String[] {"l", "r"}; diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java index 9e490df7b1f..0cb0701a9d9 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java @@ -835,7 +835,11 @@ public NDList lstm( boolean training, boolean bidirectional, boolean batchFirst) { - int numParams = numLayers * ((hasBiases) ? 4 : 2) * ((bidirectional) ? 2 : 1); + if (!hasBiases) { + throw new UnsupportedOperationException( + "Setting hasBias to be false is not supported on MXNet engine."); + } + int numParams = numLayers * 4 * (bidirectional ? 2 : 1); Preconditions.checkArgument( params.size() == numParams, "The size of Params is incorrect expect "