From c5d211ad73e6b37379d10dd45d6362e2eaafb2bb Mon Sep 17 00:00:00 2001 From: Qing Lan Date: Wed, 24 Mar 2021 16:57:02 -0700 Subject: [PATCH] add boolean set method to DJL --- api/src/main/java/ai/djl/ndarray/NDArray.java | 20 ++++++++++++++++ .../java/ai/djl/ndarray/NDArrayAdapter.java | 6 +++++ .../ai/djl/ndarray/index/NDArrayIndexer.java | 17 ++++++++++++-- .../tests/ndarray/NDArrayOtherOpTest.java | 23 +++++++++++++++++++ .../java/ai/djl/mxnet/engine/MxNDArray.java | 14 +++++++++-- .../java/ai/djl/mxnet/engine/MxNDArrayEx.java | 7 ++++++ .../ai/djl/mxnet/engine/MxNDArrayIndexer.java | 7 ------ .../java/ai/djl/pytorch/engine/PtNDArray.java | 15 +++++++++--- .../ai/djl/tensorflow/engine/TfNDArray.java | 6 +++++ .../tensorflow/engine/TfNDArrayIndexer.java | 7 ------ 10 files changed, 101 insertions(+), 21 deletions(-) diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index 20e64a79e0f..fe4059792f2 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -469,6 +469,16 @@ default void set(NDIndex index, Function function) { set(index, function.apply(array)); } + /** + * Sets the {@code NDArray} by boolean mask. + * + * @param index the boolean {@code NDArray} that indicates what to get + * @param value the value to replace with + */ + default void set(NDArray index, Number value) { + set(new NDIndex().addBooleanIndex(index), value); + } + /** * Sets the specified scalar in this {@code NDArray} with the given value. * @@ -3546,6 +3556,16 @@ default NDArray argSort(int axis) { */ NDArray cumSum(int axis); + /** + * Replace the handle of the NDArray with the other. The NDArray used for replacement will be + * killed. + * + *

Please use with caution, this method will make the input argument unusable. + * + * @param replaced the handle provider that will be killed + */ + void intern(NDArray replaced); + /** * Returns the boolean {@code NDArray} with value {@code true} where this {@code NDArray}'s * entries are infinite, or {@code false} where they are not infinite. diff --git a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java index 19324823983..7eb3d934ce7 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java +++ b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java @@ -783,6 +783,12 @@ default NDArray cumSum(int axis) { throw new UnsupportedOperationException(UNSUPPORTED_MSG); } + /** {@inheritDoc} */ + @Override + default void intern(NDArray replaced) { + throw new UnsupportedOperationException(UNSUPPORTED_MSG); + } + /** {@inheritDoc} */ @Override default NDArray isInfinite() { diff --git a/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java b/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java index b7f6e704ced..31f4008120f 100644 --- a/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java +++ b/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java @@ -13,6 +13,7 @@ package ai.djl.ndarray.index; import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDArrays; import ai.djl.ndarray.index.dim.NDIndexBooleans; import ai.djl.ndarray.index.dim.NDIndexElement; import ai.djl.ndarray.index.full.NDIndexFullPick; @@ -90,9 +91,11 @@ public NDArray get(NDArray array, NDIndex index) { * * @param array the array to set * @param indices a boolean array where true indicates values to update - * @param value the value to set with + * @param value the value to set with when condition is true */ - public abstract void set(NDArray array, NDIndexBooleans indices, NDArray value); + public void set(NDArray array, NDIndexBooleans indices, NDArray value) { + array.intern(NDArrays.where(indices.getIndex(), value, array)); + } /** * Sets the values of the array at the index locations with an array. @@ -144,6 +147,16 @@ public void set(NDArray array, NDIndex index, Number value) { set(array, fullSlice, value); return; } + // use booleanMask for NDIndexBooleans case + List indices = index.getIndices(); + if (!indices.isEmpty() && indices.get(0) instanceof NDIndexBooleans) { + if (indices.size() != 1) { + throw new IllegalArgumentException( + "set() currently didn't support more that one boolean NDArray"); + } + set(array, (NDIndexBooleans) indices.get(0), array.getManager().create(value)); + return; + } throw new UnsupportedOperationException( "set() currently supports all, fixed, and slices indices"); } diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java index 047c6493737..185b8b3b3a7 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java @@ -218,6 +218,17 @@ public void testIsNaN() { } } + @Test + public void testIntern() { + try (NDManager manager = NDManager.newBaseManager()) { + NDArray original = manager.ones(new Shape(3, 2)); + NDArray replaced = manager.zeros(new Shape(5, 5)); + original.intern(replaced); + Assert.assertEquals(original, manager.zeros(new Shape(5, 5))); + Assert.assertThrows(IllegalStateException.class, replaced::toFloatArray); + } + } + @Test(expectedExceptions = IllegalArgumentException.class) public void testBooleanMask() { try (NDManager manager = NDManager.newBaseManager()) { @@ -277,6 +288,18 @@ public void testSet() { } } + @Test + public void testSetBoolean() { + try (NDManager manager = NDManager.newBaseManager()) { + NDArray array = manager.arange(0f, 6f).reshape(2, 3); + array.set(array.gte(2), 10f); + NDArray expected = manager.create(new float[] {0, 1, 10, 10, 10, 10}, new Shape(2, 3)); + Assert.assertEquals(array, expected); + array.set(array.lt(-1), -1f); + Assert.assertEquals(array, expected); + } + } + @Test(expectedExceptions = IllegalArgumentException.class) public void testSequenceMask() { try (NDManager manager = NDManager.newBaseManager()) { diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java index 0da2badf00d..b89a27e0a90 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java @@ -1230,6 +1230,17 @@ public NDArray cumSum(int axis) { return manager.invoke("_np_cumsum", this, params); } + /** {@inheritDoc} */ + @Override + public void intern(NDArray replaced) { + MxNDArray arr = (MxNDArray) replaced; + Pointer oldHandle = handle.getAndSet(arr.handle.getAndSet(null)); + JnaUtils.waitToRead(oldHandle); + JnaUtils.freeNdArray(oldHandle); + // dereference old ndarray + arr.close(); + } + /** {@inheritDoc} */ @Override public NDArray isInfinite() { @@ -1613,8 +1624,7 @@ public void close() { if (pointer != null) { JnaUtils.waitToRead(pointer); JnaUtils.freeNdArray(pointer); - manager.detachInternal(getUid()); - manager = null; } + manager.detachInternal(getUid()); } } diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java index 626838a441a..158a451b74f 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java @@ -971,6 +971,13 @@ public NDArray where(NDArray condition, NDArray other) { (condition.getDataType() == DataType.BOOLEAN) ? condition.toType(DataType.INT32, false) : condition; + if (array.getDataType() != other.getDataType()) { + throw new IllegalArgumentException( + "DataType mismatch, required " + + array.getDataType() + + " actual " + + other.getDataType()); + } if (!array.shapeEquals(other)) { Shape res = deriveBroadcastedShape(array.getShape(), other.getShape()); array1 = (!res.equals(array.getShape())) ? array.broadcast(res) : array; diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayIndexer.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayIndexer.java index f9adddf40cc..128b2296f65 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayIndexer.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayIndexer.java @@ -15,7 +15,6 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.index.NDArrayIndexer; -import ai.djl.ndarray.index.dim.NDIndexBooleans; import ai.djl.ndarray.index.full.NDIndexFullPick; import ai.djl.ndarray.index.full.NDIndexFullSlice; import ai.djl.ndarray.types.Shape; @@ -88,12 +87,6 @@ public void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value) { } } - /** {@inheritDoc} */ - @Override - public void set(NDArray array, NDIndexBooleans indices, NDArray value) { - throw new UnsupportedOperationException("Setting with a boolean mask is not yet supported"); - } - /** {@inheritDoc} */ @Override public void set(NDArray array, NDIndexFullSlice fullSlice, Number value) { diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index 5b7b330d8df..479fcca7b4a 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -1108,6 +1108,16 @@ public PtNDArray cumSum(int axis) { return JniUtils.cumSum(this, axis); } + /** {@inheritDoc} */ + @Override + public void intern(NDArray replaced) { + PtNDArray arr = (PtNDArray) replaced; + Long oldHandle = handle.getAndSet(arr.handle.getAndSet(null)); + JniUtils.deleteNDArray(oldHandle); + // dereference old ndarray + arr.close(); + } + /** {@inheritDoc} */ @Override public PtNDArray isInfinite() { @@ -1439,9 +1449,8 @@ public void close() { Long pointer = handle.getAndSet(null); if (pointer != null) { JniUtils.deleteNDArray(pointer); - manager.detachInternal(getUid()); - manager = null; - dataRef = null; } + manager.detachInternal(getUid()); + dataRef = null; } } diff --git a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java index abee383da74..69c8b39c2bb 100644 --- a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java +++ b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java @@ -1530,6 +1530,12 @@ public NDArray cumSum(int axis) { } } + /** {@inheritDoc} */ + @Override + public void intern(NDArray replaced) { + throw new UnsupportedOperationException("Not implemented"); + } + /** {@inheritDoc} */ @Override public NDArray cumSum() { diff --git a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayIndexer.java b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayIndexer.java index 10bbe684161..d15c435f032 100644 --- a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayIndexer.java +++ b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayIndexer.java @@ -14,7 +14,6 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.index.NDArrayIndexer; -import ai.djl.ndarray.index.dim.NDIndexBooleans; import ai.djl.ndarray.index.full.NDIndexFullPick; import ai.djl.ndarray.index.full.NDIndexFullSlice; import java.util.Arrays; @@ -62,12 +61,6 @@ public void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value) { throw new UnsupportedOperationException("Tensor cannot be modified after creation"); } - /** {@inheritDoc} */ - @Override - public void set(NDArray array, NDIndexBooleans indices, NDArray value) { - throw new UnsupportedOperationException("Tensor cannot be modified after creation"); - } - /** {@inheritDoc} */ @Override public void set(NDArray array, NDIndexFullSlice fullSlice, Number value) {