Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tensor setter feature with advanced indexing on PyTorch engine #1755

Merged
merged 4 commits into from
Jul 6, 2022

Conversation

KexinFeng
Copy link
Contributor

@KexinFeng KexinFeng commented Jun 28, 2022

Description

Full support of setting with pytorch indexing

Add tensor setter feature with advanced indexing on PyTorch engine. This is the counterpart of the getter feature in one previous PR. Advanced indexing that supports all indexing features on PyTorch #1719

This PR succeeds Add put feature with linear indexing on PyTorch engine #1749.

Demo example

See the following code for demo

// get from integer array (higher rank included) or float array
original = manager.arange(1, 7f).reshape(-1, 2);
NDArray index = manager.create(new long[] {0, 0, 1, 2}, new Shape(2, 2));
NDArray indexFloat = manager.create(new float[] {0, 0, 1, 2}, new Shape(2, 2));
NDArray actual = original.get(index);
NDArray actual2 = original.get(indexFloat);
expected = manager.create(new float[] {1, 2, 1, 2, 3, 4, 5, 6}, new Shape(2, 2, 2));
Assert.assertEquals(actual, expected);
Assert.assertEquals(actual2, expected);
// indexing with boolean, slice, and integer array (higher rank included) or float array
original = manager.arange(3 * 3 * 3 * 3).reshape(3, 3, 3, 3);
NDArray bool1 = manager.create(new boolean[] {true, false, true});
NDArray index1 = manager.create(new long[] {2, 2}, new Shape(1, 2));
NDArray index2 = manager.create(new float[] {0, 1}, new Shape(1, 2));
actual = original.get(":{}, {}, {}, {}", 2, index1, bool1, index2);
expected = manager.create(new int[] {18, 25, 45, 52}, new Shape(2, 1, 2));
Assert.assertEquals(actual, expected);
// indexing with null, slice and integer array (higher rank included) or float array
original = manager.arange(3 * 3 * 3).reshape(3, 3, 3);
index1 = manager.create(new float[] {0, 1}, new Shape(2));
index2 = manager.create(new long[] {0, 0, 2, 1}, new Shape(2, 2));
actual = original.get(":{}, {}, {}, {}", 2, index1, index2, null);
expected = manager.create(new int[] {0, 3, 2, 4, 9, 12, 11, 13}, new Shape(2, 2, 2, 1));
Assert.assertEquals(actual, expected);

public void testSetArray() {
try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) {
NDArray original = manager.create(new float[] {1, 2, 3, 4}, new Shape(2, 2));
NDArray expected = manager.create(new float[] {9, 10, 3, 4}, new Shape(2, 2));
NDArray value = manager.create(new float[] {9, 10});
original.set(new NDIndex(0), value);
Assert.assertEquals(original, expected);
original = manager.arange(0, 8).reshape(2, 4);
expected = manager.create(new int[] {0, 1, 9, 10, 4, 5, 11, 12}, new Shape(2, 4));
original.set(new NDIndex(":, 2:"), manager.arange(9, 13).reshape(2, 2));
Assert.assertEquals(original, expected);
// set by index array
original = manager.arange(1, 10).reshape(3, 3);
NDArray index = manager.create(new float[] {0, 1}, new Shape(2));
value = manager.create(new int[] {666, 777, 888, 999}, new Shape(2, 2));
original.set(new NDIndex("{}, :{}", index, 2), value);
expected =
manager.create(new int[] {666, 777, 3, 888, 999, 6, 7, 8, 9}, new Shape(3, 3));
Assert.assertEquals(original, expected);
}
}

// set by boolean index array
original = manager.arange(1, 10).reshape(3, 3);
NDArray index = manager.create(new boolean[] {true, false, true}, new Shape(3));
original.set(new NDIndex("{}", index), 666);
expected =
manager.create(
new int[] {666, 666, 666, 4, 5, 6, 666, 666, 666}, new Shape(3, 3));
Assert.assertEquals(original, expected);
// set by integer index array
original = manager.arange(1, 10).reshape(3, 3);
index = manager.create(new long[] {0, 1}, new Shape(2));
original.set(new NDIndex("{}, :{}", index, 2), 666);
expected =
manager.create(new int[] {666, 666, 3, 666, 666, 6, 7, 8, 9}, new Shape(3, 3));
Assert.assertEquals(original, expected);
original = manager.arange(1, 10).reshape(3, 3);
original.set(index, 666);
expected =
manager.create(
new int[] {666, 666, 666, 666, 666, 666, 7, 8, 9}, new Shape(3, 3));
Assert.assertEquals(original, expected);

The relavant preceding PRs

Advanced indexing that supports all indexing features on PyTorch #1719
Add put feature with linear indexing on PyTorch engine #1749

@KexinFeng KexinFeng changed the title Add tensor setter feature with advanced indexing on PyTorch engine [WIP] Add tensor setter feature with advanced indexing on PyTorch engine Jun 28, 2022
@KexinFeng KexinFeng marked this pull request as ready for review June 28, 2022 18:58
@codecov-commenter
Copy link

codecov-commenter commented Jun 29, 2022

Codecov Report

Merging #1755 (4686f4b) into master (bb5073f) will decrease coverage by 1.52%.
The diff coverage is 70.68%.

@@             Coverage Diff              @@
##             master    #1755      +/-   ##
============================================
- Coverage     72.08%   70.55%   -1.53%     
- Complexity     5126     5583     +457     
============================================
  Files           473      527      +54     
  Lines         21970    24789    +2819     
  Branches       2351     2699     +348     
============================================
+ Hits          15838    17491    +1653     
- Misses         4925     5962    +1037     
- Partials       1207     1336     +129     
Impacted Files Coverage Δ
api/src/main/java/ai/djl/modality/cv/Image.java 69.23% <ø> (-4.11%) ⬇️
...rc/main/java/ai/djl/modality/cv/MultiBoxPrior.java 76.00% <ø> (ø)
...rc/main/java/ai/djl/modality/cv/output/Joints.java 71.42% <ø> (ø)
.../main/java/ai/djl/modality/cv/output/Landmark.java 100.00% <ø> (ø)
...main/java/ai/djl/modality/cv/output/Rectangle.java 72.41% <0.00%> (ø)
...i/djl/modality/cv/translator/BigGANTranslator.java 21.42% <ø> (-5.24%) ⬇️
...odality/cv/translator/BigGANTranslatorFactory.java 33.33% <0.00%> (+8.33%) ⬆️
...nslator/InstanceSegmentationTranslatorFactory.java 14.28% <0.00%> (-3.90%) ⬇️
.../cv/translator/StyleTransferTranslatorFactory.java 40.00% <ø> (ø)
.../ai/djl/modality/cv/translator/YoloTranslator.java 8.33% <0.00%> (-0.50%) ⬇️
... and 411 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 22cb398...4686f4b. Read the comment docs.

@KexinFeng KexinFeng changed the title [WIP] Add tensor setter feature with advanced indexing on PyTorch engine Add tensor setter feature with advanced indexing on PyTorch engine Jun 29, 2022
Fix NDIndexTest

javadoc fix

Remove shape check

Add `put` Pt engine support

Add NDIndeTest fix NDArray.get(index)

testIndexationUsesSpecificManager add manager checking into PtNDArrayIndexer get(NDArray, NDIndex)

Add :engines:pytorch:pytorch-jni

change at::indexing to torch::indexing; testRuntimeOnly project(":engines:pytorch:pytorch-jni

Restore the get(NDArray, NDIndex)

Torch index type check: long, byte or boolean; restore testPick behaviour; The previous commit: PtNDManager.from() bug fixed.

bug fixed

bug fixed

Update api/src/main/java/ai/djl/ndarray/index/dim/NDIndexNone.java

Co-authored-by: Frank Liu <[email protected]>

bug fixed

feed std::vector<> to tensor.index(ArrayRef<>)

code cleaning 2

code cleaning

mixed index getter on pytorch draft
@KexinFeng KexinFeng merged commit a7ee401 into deepjavalibrary:master Jul 6, 2022
@KexinFeng KexinFeng deleted the pt_torch_setter2 branch August 25, 2022 00:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants