diff --git a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java index 761e3e0e2651..127f8cf416ad 100644 --- a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java +++ b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java @@ -403,10 +403,11 @@ NDManager getAlternativeManager() { * @param expected the expected size * @throws IllegalArgumentException if buffer size is invalid */ - public static void validateBufferSize(Buffer buffer, DataType dataType, int expected) { + public static void validateBuffer(Buffer buffer, DataType dataType, int expected) { boolean isByteBuffer = buffer instanceof ByteBuffer; DataType type = DataType.fromBuffer(buffer); - if (!isByteBuffer && type != dataType) { + if (type != dataType && !isByteBuffer) { + // It's ok if type != datatype and buffer is ByteBuffer, since buffer will be copied into ByteBuffer throw new IllegalArgumentException( "The input data type: " + type diff --git a/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDManager.java b/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDManager.java index c77aadb10b84..aa834b84fb9c 100644 --- a/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDManager.java +++ b/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDManager.java @@ -81,7 +81,7 @@ public NDArray create(Buffer data, Shape shape, DataType dataType) { throw new UnsupportedOperationException("DlrNDArray only supports float32."); } int size = Math.toIntExact(shape.size()); - BaseNDManager.validateBufferSize(data, dataType, size); + BaseNDManager.validateBuffer(data, dataType, size); if (data instanceof ByteBuffer) { return new DlrNDArray(this, alternativeManager, (ByteBuffer) data, shape, dataType); } diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java index 5c7b0621eaad..7453557f8529 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java @@ -305,7 +305,7 @@ public ByteBuffer toByteBuffer() { public void set(Buffer buffer) { int size = Math.toIntExact(size()); DataType type = getDataType(); - BaseNDManager.validateBufferSize(buffer, type, size); + BaseNDManager.validateBuffer(buffer, type, size); if (buffer.isDirect()) { JnaUtils.syncCopyFromCPU(getHandle(), buffer, size); return; diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java index e884d125ab68..0136e3e77b5b 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java @@ -73,7 +73,7 @@ public OrtNDArray create(Buffer data, Shape shape, DataType dataType) { "Use NDManager.create(String[], Shape) to create String NDArray."); } int size = Math.toIntExact(shape.size()); - BaseNDManager.validateBufferSize(data, dataType, size); + BaseNDManager.validateBuffer(data, dataType, size); OnnxTensor tensor = OrtUtils.toTensor(env, data, shape, dataType); return new OrtNDArray(this, alternativeManager, tensor); } diff --git a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDManager.java b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDManager.java index a4c38ebdb789..8bfe0e850c1e 100644 --- a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDManager.java +++ b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDManager.java @@ -91,7 +91,7 @@ public PpNDArray createInternal(ByteBuffer data, long handle) { @Override public PpNDArray create(Buffer data, Shape shape, DataType dataType) { int size = Math.toIntExact(shape.size()); - BaseNDManager.validateBufferSize(data, dataType, size); + BaseNDManager.validateBuffer(data, dataType, size); if (data.isDirect() && data instanceof ByteBuffer) { return JniUtils.createNdArray(this, (ByteBuffer) data, shape, dataType); } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index 1b613bf90b08..2078f2a8584d 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -228,7 +228,7 @@ public String[] toStringArray(Charset charset) { public void set(Buffer buffer) { int size = Math.toIntExact(size()); DataType type = getDataType(); - BaseNDManager.validateBufferSize(buffer, type, size); + BaseNDManager.validateBuffer(buffer, type, size); // TODO how do we handle the exception happened in the middle dataRef = null; if (buffer.isDirect() && buffer instanceof ByteBuffer) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java index 307edb414196..979a0493e316 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java @@ -66,7 +66,7 @@ public PtNDArray create(Shape shape, DataType dataType) { @Override public PtNDArray create(Buffer data, Shape shape, DataType dataType) { int size = Math.toIntExact(shape.size()); - BaseNDManager.validateBufferSize(data, dataType, size); + BaseNDManager.validateBuffer(data, dataType, size); if (data.isDirect() && data instanceof ByteBuffer) { return JniUtils.createNdFromByteBuffer( this, (ByteBuffer) data, shape, dataType, SparseFormat.DENSE, device); diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java index d21a4a657b24..2438643450ed 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java @@ -193,7 +193,7 @@ public void set(Buffer buffer) { } int size = Math.toIntExact(getShape().size()); DataType type = getDataType(); - BaseNDManager.validateBufferSize(buffer, type, size); + BaseNDManager.validateBuffer(buffer, type, size); if (buffer instanceof ByteBuffer) { JavacppUtils.setByteBuffer(getHandle(), (ByteBuffer) buffer); return; diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java index 407d67baacfc..622b3c44c272 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java @@ -85,7 +85,7 @@ public TfNDArray create(Buffer data, Shape shape, DataType dataType) { "Use NDManager.create(String[], Charset, Shape) to create String NDArray."); } int size = Math.toIntExact(shape.size()); - BaseNDManager.validateBufferSize(data, dataType, size); + BaseNDManager.validateBuffer(data, dataType, size); if (data.isDirect() && data instanceof ByteBuffer) { TFE_TensorHandle handle = JavacppUtils.createTFETensorFromByteBuffer( diff --git a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtNDArray.java b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtNDArray.java index 188174a230ea..3461e35ef387 100644 --- a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtNDArray.java +++ b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtNDArray.java @@ -64,7 +64,7 @@ public ByteBuffer toByteBuffer() { @Override public void set(Buffer buffer) { int size = Math.toIntExact(shape.size()); - BaseNDManager.validateBufferSize(buffer, dataType, size); + BaseNDManager.validateBuffer(buffer, dataType, size); BaseNDManager.copyBuffer(buffer, data); } } diff --git a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtNDManager.java b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtNDManager.java index 6fa8c1967c1d..2307e32d7f15 100644 --- a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtNDManager.java +++ b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtNDManager.java @@ -72,7 +72,7 @@ public TrtNDManager newSubManager(Device dev) { @Override public TrtNDArray create(Buffer data, Shape shape, DataType dataType) { int size = Math.toIntExact(shape.size()); - BaseNDManager.validateBufferSize(data, dataType, size); + BaseNDManager.validateBuffer(data, dataType, size); if (data.isDirect() && data instanceof ByteBuffer) { return new TrtNDArray(this, alternativeManager, (ByteBuffer) data, shape, dataType); } diff --git a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDManager.java b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDManager.java index 70a321c4afb7..d0b9f8d20e44 100644 --- a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDManager.java +++ b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDManager.java @@ -63,7 +63,7 @@ TfLiteNDArray createInternal(Tensor tensor) { @Override public TfLiteNDArray create(Buffer data, Shape shape, DataType dataType) { int size = Math.toIntExact(shape.size()); - BaseNDManager.validateBufferSize(data, dataType, size); + BaseNDManager.validateBuffer(data, dataType, size); if (data.isDirect() && data instanceof ByteBuffer) { return new TfLiteNDArray(this, alternativeManager, (ByteBuffer) data, shape, dataType); }