-
Notifications
You must be signed in to change notification settings - Fork 673
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix] add data type check to array setter #1975
Conversation
@@ -212,6 +212,16 @@ public NDArray get(NDIndex index) { | |||
/** {@inheritDoc} */ | |||
@Override | |||
public void set(Buffer data) { | |||
DataType arrayType = getDataType(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we check list this:
- We should always allow ByteBuffer to set for any DataType
- Multiple dataType maps to the same Buffer type, we have to handle them differently
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find that in the implementation of set(Buffer data) in different engines, like MXNet, PyTorch, TensorFlow, the buffer is already converted to ByteBuffer before being fed into the engines. Do you mean we should add this to NDArrayAdapter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By check listing those, is it out of the purpose of automatically converting the input datatype into the target array datatype?
Codecov Report
@@ Coverage Diff @@
## master #1975 +/- ##
============================================
- Coverage 72.08% 69.81% -2.28%
- Complexity 5126 5899 +773
============================================
Files 473 584 +111
Lines 21970 26159 +4189
Branches 2351 2824 +473
============================================
+ Hits 15838 18262 +2424
- Misses 4925 6521 +1596
- Partials 1207 1376 +169
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
65c2548
to
8bdf9a1
Compare
Change-Id: Iac7155e469cc5c2918c4452eb95b4c9a2ef9cb43
8bdf9a1
to
dcb0c8e
Compare
@@ -404,9 +404,18 @@ NDManager getAlternativeManager() { | |||
* @throws IllegalArgumentException if buffer size is invalid | |||
*/ | |||
public static void validateBufferSize(Buffer buffer, DataType dataType, int expected) { | |||
boolean isByteBuffer = buffer instanceof ByteBuffer; | |||
DataType type = DataType.fromBuffer(buffer); | |||
if (!isByteBuffer && type != dataType) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The following mismatch case will escape the check:
type != dataType && buffer is ByteBuffer
, but dataType
is not one of the byte types = {DataType.UINT8, DataType.INT8, DataType.BOOLEAN}
if (!isByteBuffer && type != dataType) { | |
if (arrayType != inputType ) { | |
DataType[] types = {DataType.UINT8, DataType.INT8, DataType.BOOLEAN}; | |
if (!isByteBuffer || Arrays.stream(types).noneMatch(x -> x == dataTypeType)) { | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ByteBuffer should always be allowed. It's the memory representation of all data types. All NDArray has toByteBuffer()
method. The JNI only accept ByteBuffer, if we only accept matching Buffer, we have to covert ByteBuffer to Buffer and copy the Buffer to ByteBuffer to pass to JNI. It's not efficient for Hybrid engine. We trying to achieve 0 copy between pytorch and onnx engines. It rely on ByteBuffer.
db9893e
to
b3beab7
Compare
Description
Fix #1970. The test code is therein.
details
A poential place to convert the datatype from the input data type to the target array data type is in
djl/api/src/main/java/ai/djl/ndarray/BaseNDManager.java
Line 428 in aeb9135
Here it does the job of:
But it is not ideal to conver the Buffer type inside java
https://stackoverflow.com/questions/38745123/bytebufferasshortbuffer-cannot-be-cast-to-java-nio-floatbuffer
Cannot cast 'java.nio.HeapIntBuffer' to 'java.nio.FloatBuffer'
So here we simply throw an exception when datatypes don't match.