Skip to content

Commit

Permalink
add boolean set method to DJL
Browse files Browse the repository at this point in the history
  • Loading branch information
Qing Lan committed Mar 25, 2021
1 parent d075028 commit c5d211a
Show file tree
Hide file tree
Showing 10 changed files with 101 additions and 21 deletions.
20 changes: 20 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,16 @@ default void set(NDIndex index, Function<NDArray, NDArray> function) {
set(index, function.apply(array));
}

/**
* Sets the {@code NDArray} by boolean mask.
*
* @param index the boolean {@code NDArray} that indicates what to get
* @param value the value to replace with
*/
default void set(NDArray index, Number value) {
set(new NDIndex().addBooleanIndex(index), value);
}

/**
* Sets the specified scalar in this {@code NDArray} with the given value.
*
Expand Down Expand Up @@ -3546,6 +3556,16 @@ default NDArray argSort(int axis) {
*/
NDArray cumSum(int axis);

/**
* Replace the handle of the NDArray with the other. The NDArray used for replacement will be
* killed.
*
* <p>Please use with caution, this method will make the input argument unusable.
*
* @param replaced the handle provider that will be killed
*/
void intern(NDArray replaced);

/**
* Returns the boolean {@code NDArray} with value {@code true} where this {@code NDArray}'s
* entries are infinite, or {@code false} where they are not infinite.
Expand Down
6 changes: 6 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,12 @@ default NDArray cumSum(int axis) {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

/** {@inheritDoc} */
@Override
default void intern(NDArray replaced) {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

/** {@inheritDoc} */
@Override
default NDArray isInfinite() {
Expand Down
17 changes: 15 additions & 2 deletions api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package ai.djl.ndarray.index;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.index.dim.NDIndexBooleans;
import ai.djl.ndarray.index.dim.NDIndexElement;
import ai.djl.ndarray.index.full.NDIndexFullPick;
Expand Down Expand Up @@ -90,9 +91,11 @@ public NDArray get(NDArray array, NDIndex index) {
*
* @param array the array to set
* @param indices a boolean array where true indicates values to update
* @param value the value to set with
* @param value the value to set with when condition is true
*/
public abstract void set(NDArray array, NDIndexBooleans indices, NDArray value);
public void set(NDArray array, NDIndexBooleans indices, NDArray value) {
array.intern(NDArrays.where(indices.getIndex(), value, array));
}

/**
* Sets the values of the array at the index locations with an array.
Expand Down Expand Up @@ -144,6 +147,16 @@ public void set(NDArray array, NDIndex index, Number value) {
set(array, fullSlice, value);
return;
}
// use booleanMask for NDIndexBooleans case
List<NDIndexElement> indices = index.getIndices();
if (!indices.isEmpty() && indices.get(0) instanceof NDIndexBooleans) {
if (indices.size() != 1) {
throw new IllegalArgumentException(
"set() currently didn't support more that one boolean NDArray");
}
set(array, (NDIndexBooleans) indices.get(0), array.getManager().create(value));
return;
}
throw new UnsupportedOperationException(
"set() currently supports all, fixed, and slices indices");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,17 @@ public void testIsNaN() {
}
}

@Test
public void testIntern() {
try (NDManager manager = NDManager.newBaseManager()) {
NDArray original = manager.ones(new Shape(3, 2));
NDArray replaced = manager.zeros(new Shape(5, 5));
original.intern(replaced);
Assert.assertEquals(original, manager.zeros(new Shape(5, 5)));
Assert.assertThrows(IllegalStateException.class, replaced::toFloatArray);
}
}

@Test(expectedExceptions = IllegalArgumentException.class)
public void testBooleanMask() {
try (NDManager manager = NDManager.newBaseManager()) {
Expand Down Expand Up @@ -277,6 +288,18 @@ public void testSet() {
}
}

@Test
public void testSetBoolean() {
try (NDManager manager = NDManager.newBaseManager()) {
NDArray array = manager.arange(0f, 6f).reshape(2, 3);
array.set(array.gte(2), 10f);
NDArray expected = manager.create(new float[] {0, 1, 10, 10, 10, 10}, new Shape(2, 3));
Assert.assertEquals(array, expected);
array.set(array.lt(-1), -1f);
Assert.assertEquals(array, expected);
}
}

@Test(expectedExceptions = IllegalArgumentException.class)
public void testSequenceMask() {
try (NDManager manager = NDManager.newBaseManager()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1230,6 +1230,17 @@ public NDArray cumSum(int axis) {
return manager.invoke("_np_cumsum", this, params);
}

/** {@inheritDoc} */
@Override
public void intern(NDArray replaced) {
MxNDArray arr = (MxNDArray) replaced;
Pointer oldHandle = handle.getAndSet(arr.handle.getAndSet(null));
JnaUtils.waitToRead(oldHandle);
JnaUtils.freeNdArray(oldHandle);
// dereference old ndarray
arr.close();
}

/** {@inheritDoc} */
@Override
public NDArray isInfinite() {
Expand Down Expand Up @@ -1613,8 +1624,7 @@ public void close() {
if (pointer != null) {
JnaUtils.waitToRead(pointer);
JnaUtils.freeNdArray(pointer);
manager.detachInternal(getUid());
manager = null;
}
manager.detachInternal(getUid());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,13 @@ public NDArray where(NDArray condition, NDArray other) {
(condition.getDataType() == DataType.BOOLEAN)
? condition.toType(DataType.INT32, false)
: condition;
if (array.getDataType() != other.getDataType()) {
throw new IllegalArgumentException(
"DataType mismatch, required "
+ array.getDataType()
+ " actual "
+ other.getDataType());
}
if (!array.shapeEquals(other)) {
Shape res = deriveBroadcastedShape(array.getShape(), other.getShape());
array1 = (!res.equals(array.getShape())) ? array.broadcast(res) : array;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.index.NDArrayIndexer;
import ai.djl.ndarray.index.dim.NDIndexBooleans;
import ai.djl.ndarray.index.full.NDIndexFullPick;
import ai.djl.ndarray.index.full.NDIndexFullSlice;
import ai.djl.ndarray.types.Shape;
Expand Down Expand Up @@ -88,12 +87,6 @@ public void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value) {
}
}

/** {@inheritDoc} */
@Override
public void set(NDArray array, NDIndexBooleans indices, NDArray value) {
throw new UnsupportedOperationException("Setting with a boolean mask is not yet supported");
}

/** {@inheritDoc} */
@Override
public void set(NDArray array, NDIndexFullSlice fullSlice, Number value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,16 @@ public PtNDArray cumSum(int axis) {
return JniUtils.cumSum(this, axis);
}

/** {@inheritDoc} */
@Override
public void intern(NDArray replaced) {
PtNDArray arr = (PtNDArray) replaced;
Long oldHandle = handle.getAndSet(arr.handle.getAndSet(null));
JniUtils.deleteNDArray(oldHandle);
// dereference old ndarray
arr.close();
}

/** {@inheritDoc} */
@Override
public PtNDArray isInfinite() {
Expand Down Expand Up @@ -1439,9 +1449,8 @@ public void close() {
Long pointer = handle.getAndSet(null);
if (pointer != null) {
JniUtils.deleteNDArray(pointer);
manager.detachInternal(getUid());
manager = null;
dataRef = null;
}
manager.detachInternal(getUid());
dataRef = null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1530,6 +1530,12 @@ public NDArray cumSum(int axis) {
}
}

/** {@inheritDoc} */
@Override
public void intern(NDArray replaced) {
throw new UnsupportedOperationException("Not implemented");
}

/** {@inheritDoc} */
@Override
public NDArray cumSum() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.index.NDArrayIndexer;
import ai.djl.ndarray.index.dim.NDIndexBooleans;
import ai.djl.ndarray.index.full.NDIndexFullPick;
import ai.djl.ndarray.index.full.NDIndexFullSlice;
import java.util.Arrays;
Expand Down Expand Up @@ -62,12 +61,6 @@ public void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value) {
throw new UnsupportedOperationException("Tensor cannot be modified after creation");
}

/** {@inheritDoc} */
@Override
public void set(NDArray array, NDIndexBooleans indices, NDArray value) {
throw new UnsupportedOperationException("Tensor cannot be modified after creation");
}

/** {@inheritDoc} */
@Override
public void set(NDArray array, NDIndexFullSlice fullSlice, Number value) {
Expand Down

0 comments on commit c5d211a

Please sign in to comment.