From f8a96f169369344ef7370576e4961c74146eeb44 Mon Sep 17 00:00:00 2001 From: markbookk Date: Wed, 7 Apr 2021 09:36:04 -0400 Subject: [PATCH] Fixing sequenceMask error when n dimension is 2 --- .../src/main/java/ai/djl/mxnet/engine/MxNDArray.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java index b89a27e0a90..c1ab38dea4b 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java @@ -386,9 +386,9 @@ public NDArray booleanMask(NDArray index, int axis) { /** {@inheritDoc} */ @Override public NDArray sequenceMask(NDArray sequenceLength, float value) { - if (getShape().dimension() < 3 || getShape().isScalar() || getShape().hasZeroDimension()) { + if (getShape().dimension() < 2 || getShape().isScalar() || getShape().hasZeroDimension()) { throw new IllegalArgumentException( - "sequenceMask is not supported for NDArray with less than 3 dimensions"); + "sequenceMask is not supported for NDArray with less than 2 dimensions"); } Shape expectedSequenceLengthShape = new Shape(getShape().get(0)); if (!sequenceLength.getShape().equals(expectedSequenceLengthShape)) {