Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tensor setter feature with advanced indexing on PyTorch engine #1755

Merged
merged 4 commits into from
Jul 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ default NDArray get(NDManager manager, long... indices) {
*
* @param index select the entries of an {@code NDArray}
* @param data numbers to assign to the indexed entries
* @return The NDArray with updated values
* @return the NDArray with updated values
*/
NDArray put(NDArray index, NDArray data);

Expand Down
99 changes: 45 additions & 54 deletions api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,89 +79,80 @@ public NDArray get(NDArray array, NDIndex index) {
}

/**
* Sets the values of the array at the fullSlice with an array.
*
* @param array the array to set
* @param fullSlice the fullSlice of the index to set in the array
* @param value the value to set with
*/
public abstract void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value);

/**
* Sets the values of the array at the boolean locations with an array.
*
* @param array the array to set
* @param indices a boolean array where true indicates values to update
* @param value the value to set with when condition is true
*/
public void set(NDArray array, NDIndexBooleans indices, NDArray value) {
array.intern(NDArrays.where(indices.getIndex(), value, array));
}

/**
* Sets the values of the array at the index locations with an array.
* Sets the entries of array at the indexed locations with the parameter value. The value can be
* only Number or NDArray.
*
* @param array the array to set
* @param index the index to set at in the array
* @param value the value to set with
*/
public void set(NDArray array, NDIndex index, NDArray value) {
public void set(NDArray array, NDIndex index, Object value) {
NDIndexFullSlice fullSlice =
NDIndexFullSlice.fromIndex(index, array.getShape()).orElse(null);
if (fullSlice != null) {
if (value instanceof Number) {
set(array, fullSlice, (Number) value);
} else if (value instanceof NDArray) {
set(array, fullSlice, (NDArray) value);
} else {
throw new IllegalArgumentException(
"The type of value to assign cannot be other than NDArray and Number.");
}
return;
}

List<NDIndexElement> indices = index.getIndices();
if (!indices.isEmpty() && indices.get(0) instanceof NDIndexBooleans) {
if (indices.size() != 1) {
throw new IllegalArgumentException(
"get() currently didn't support more that one boolean NDArray");
"set() currently doesn't support more than one boolean NDArray");
}
if (value instanceof Number) {
set(
array,
(NDIndexBooleans) indices.get(0),
array.getManager().create((Number) value));
} else if (value instanceof NDArray) {
set(array, (NDIndexBooleans) indices.get(0), (NDArray) value);
} else {
throw new IllegalArgumentException(
"The type of value to assign cannot be other than NDArray and Number.");
}
set(array, (NDIndexBooleans) indices.get(0), value);
}

NDIndexFullSlice fullSlice =
NDIndexFullSlice.fromIndex(index, array.getShape()).orElse(null);
if (fullSlice != null) {
set(array, fullSlice, value);
return;
}
throw new UnsupportedOperationException(
"set() currently supports all, fixed, and slices indices");
}

/**
* Sets the values of the array at the fullSlice with a number.
* Sets the values of the array at the fullSlice with an array.
*
* @param array the array to set
* @param fullSlice the fullSlice of the index to set in the array
* @param value the value to set with
*/
public abstract void set(NDArray array, NDIndexFullSlice fullSlice, Number value);
public abstract void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value);

/**
* Sets the values of the array at the index locations with a number.
* Sets the values of the array at the boolean locations with an array.
*
* @param array the array to set
* @param index the index to set at in the array
* @param value the value to set with
* @param indices a boolean array where true indicates values to update
* @param value the value to set with when condition is true
*/
public void set(NDArray array, NDIndex index, Number value) {
NDIndexFullSlice fullSlice =
NDIndexFullSlice.fromIndex(index, array.getShape()).orElse(null);
if (fullSlice != null) {
set(array, fullSlice, value);
return;
}
// use booleanMask for NDIndexBooleans case
List<NDIndexElement> indices = index.getIndices();
if (!indices.isEmpty() && indices.get(0) instanceof NDIndexBooleans) {
if (indices.size() != 1) {
throw new IllegalArgumentException(
"set() currently didn't support more that one boolean NDArray");
}
set(array, (NDIndexBooleans) indices.get(0), array.getManager().create(value));
return;
}
throw new UnsupportedOperationException(
"set() currently supports all, fixed, and slices indices");
public void set(NDArray array, NDIndexBooleans indices, NDArray value) {
array.intern(NDArrays.where(indices.getIndex(), value, array));
}

/**
* Sets the values of the array at the fullSlice with a number.
*
* @param array the array to set
* @param fullSlice the fullSlice of the index to set in the array
* @param value the value to set with
*/
public abstract void set(NDArray array, NDIndexFullSlice fullSlice, Number value);

/**
* Sets a scalar value in the array at the indexed location.
*
Expand Down
5 changes: 5 additions & 0 deletions api/src/main/java/ai/djl/ndarray/index/NDIndex.java
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ public NDIndex() {
* // Uses null to add an extra axis to the output array
* assertEquals(a.get(new NDIndex(":2, null, 0, :2")).getShape(), new Shape(2, 1, 2));
*
* // Gets entries of an NDArray with mixed index
* index1 = manager.create(new long[] {0, 1, 1}, new Shape(2));
* bool1 = manager.create(new boolean[] {true, false, true});
* assertEquals(a.get(new NDIndex(":{}, {}, {}, {}" 2, index1, bool1, null).getShape(), new Shape(2, 2, 1));
*
* </pre>
*
* @param indices a comma separated list of indices corresponding to either subsections,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,25 @@ public NDArray get(NDArray array, NDIndex index) {
}
}

/** {@inheritDoc} */
@Override
public void set(NDArray array, NDIndex index, Object data) {
PtNDArray ptArray =
array instanceof PtNDArray
? (PtNDArray) array
: manager.create(
array.toByteBuffer(), array.getShape(), array.getDataType());

if (data instanceof Number) {
JniUtils.indexAdvPut(ptArray, index, (PtNDArray) manager.create((Number) data));
} else if (data instanceof NDArray) {
JniUtils.indexAdvPut(ptArray, index, (PtNDArray) data);
} else {
throw new IllegalArgumentException(
"The type of value to assign cannot be other than NDArray and Number.");
}
}

/** {@inheritDoc} */
@Override
public void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,69 @@ public static PtNDArray indexAdv(PtNDArray ndArray, NDIndex index) {

return new PtNDArray(
ndArray.getManager(),
PyTorchLibrary.LIB.torchIndexReturn(ndArray.getHandle(), torchIndexHandle));
PyTorchLibrary.LIB.torchIndexAdvGet(ndArray.getHandle(), torchIndexHandle));
}

@SuppressWarnings("OptionalGetWithoutIsPresent")
public static void indexAdvPut(PtNDArray ndArray, NDIndex index, PtNDArray data) {
if (ndArray == null) {
return;
}

// Index aggregation
List<NDIndexElement> indices = index.getIndices();
long torchIndexHandle = PyTorchLibrary.LIB.torchIndexInit(indices.size());
ListIterator<NDIndexElement> it = indices.listIterator();
while (it.hasNext()) {
if (it.nextIndex() == index.getEllipsisIndex()) {
PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true);
}

NDIndexElement elem = it.next();
if (elem instanceof NDIndexNull) {
PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, false);
} else if (elem instanceof NDIndexSlice) {
Long min = ((NDIndexSlice) elem).getMin();
Long max = ((NDIndexSlice) elem).getMax();
Long step = ((NDIndexSlice) elem).getStep();
int nullSliceBin = (min == null ? 1 : 0) * 2 + (max == null ? 1 : 0);
// nullSliceBin encodes whether the slice (min, max) is null:
// is_null == 1, ! is_null == 0;
// 0b11 == 3, 0b10 = 2, ...
PyTorchLibrary.LIB.torchIndexAppendSlice(
torchIndexHandle,
min == null ? 0 : min,
max == null ? 0 : max,
step == null ? 1 : step,
nullSliceBin);
} else if (elem instanceof NDIndexAll) {
PyTorchLibrary.LIB.torchIndexAppendSlice(torchIndexHandle, 0, 0, 1, 3);
} else if (elem instanceof NDIndexFixed) {
PyTorchLibrary.LIB.torchIndexAppendFixed(
torchIndexHandle, ((NDIndexFixed) elem).getIndex());
} else if (elem instanceof NDIndexBooleans) {
PtNDArray indexArr = (PtNDArray) ((NDIndexBooleans) elem).getIndex();
PyTorchLibrary.LIB.torchIndexAppendArray(torchIndexHandle, indexArr.getHandle());
} else if (elem instanceof NDIndexTake) {
PtNDArray indexArr = (PtNDArray) ((NDIndexTake) elem).getIndex();
if (indexArr.getDataType() != DataType.INT64) {
indexArr = indexArr.toType(DataType.INT64, true);
}
PyTorchLibrary.LIB.torchIndexAppendArray(torchIndexHandle, indexArr.getHandle());
} else if (elem instanceof NDIndexPick) {
// Backward compatible
NDIndexFullPick fullPick =
NDIndexFullPick.fromIndex(index, ndArray.getShape()).get();
pick(ndArray, ndArray.getManager().from(fullPick.getIndices()), fullPick.getAxis());
return;
}
}
if (indices.size() == index.getEllipsisIndex()) {
PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true);
}

PyTorchLibrary.LIB.torchIndexAdvPut(
ndArray.getHandle(), torchIndexHandle, data.getHandle());
}

public static void indexSet(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ native void torchIndexPut(
long[] maxIndices,
long[] stepIndices);

native void torchIndexAdvPut(long handle, long torchIndexHandle, long data);

native void torchSet(long handle, ByteBuffer data);

native long torchSlice(long handle, long dim, long start, long end, long step);
Expand Down Expand Up @@ -605,7 +607,7 @@ native void sgdUpdate(

native long torchIndexInit(int size);

native long torchIndexReturn(long handle, long torchIndexHandle);
native long torchIndexAdvGet(long handle, long torchIndexHandle);

native void torchIndexAppendNoneEllipsis(long torchIndexHandle, boolean isEllipsis);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexInit(JN
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexReturn(
JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAdvGet(
JNIEnv* env, jobject jthis, jlong jhandle, jlong jtorch_index_handle) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
Expand Down Expand Up @@ -194,6 +194,16 @@ JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexPut(JNIE
API_END()
}

JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAdvPut(
JNIEnv* env, jobject jthis, jlong jhandle, jlong jtorch_index_handle, jlong jdata_handle) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
const auto* data_ptr = reinterpret_cast<torch::Tensor*>(jdata_handle);
auto* index_ptr = reinterpret_cast<std::vector<torch::indexing::TensorIndex>*>(jtorch_index_handle);
((torch::Tensor) *tensor_ptr).index_put_(*index_ptr, *data_ptr);
API_END()
}

JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchSet(
JNIEnv* env, jobject jthis, jlong jhandle, jobject jbuffer) {
API_BEGIN()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,15 @@ public void testSetArray() {
expected = manager.create(new int[] {0, 1, 9, 10, 4, 5, 11, 12}, new Shape(2, 4));
original.set(new NDIndex(":, 2:"), manager.arange(9, 13).reshape(2, 2));
Assert.assertEquals(original, expected);

// set by index array
original = manager.arange(1, 10).reshape(3, 3);
NDArray index = manager.create(new long[] {0, 1}, new Shape(2));
value = manager.create(new int[] {666, 777, 888, 999}, new Shape(2, 2));
original.set(new NDIndex("{}, :{}", index, 2), value);
expected =
manager.create(new int[] {666, 777, 3, 888, 999, 6, 7, 8, 9}, new Shape(3, 3));
Assert.assertEquals(original, expected);
}
}

Expand Down Expand Up @@ -188,6 +197,14 @@ public void testSetNumber() {
expected = manager.create(new float[] {1, 1, 1, 3}).reshape(2, 2);
original.set(new NDIndex("..., 0"), 1);
Assert.assertEquals(original, expected);

// set by index array
original = manager.arange(1, 10).reshape(3, 3);
NDArray index = manager.create(new long[] {0, 1}, new Shape(2));
original.set(new NDIndex("{}, :{}", index, 2), 666);
expected =
manager.create(new int[] {666, 666, 3, 666, 666, 6, 7, 8, 9}, new Shape(3, 3));
Assert.assertEquals(original, expected);
}
}

Expand Down