Skip to content

Commit

Permalink
Rename validateBufferSize to validateBuffer
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed Sep 3, 2022
1 parent dcb0c8e commit db9893e
Show file tree
Hide file tree
Showing 12 changed files with 14 additions and 13 deletions.
5 changes: 3 additions & 2 deletions api/src/main/java/ai/djl/ndarray/BaseNDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -403,10 +403,11 @@ NDManager getAlternativeManager() {
* @param expected the expected size
* @throws IllegalArgumentException if buffer size is invalid
*/
public static void validateBufferSize(Buffer buffer, DataType dataType, int expected) {
public static void validateBuffer(Buffer buffer, DataType dataType, int expected) {
boolean isByteBuffer = buffer instanceof ByteBuffer;
DataType type = DataType.fromBuffer(buffer);
if (!isByteBuffer && type != dataType) {
if (type != dataType && !isByteBuffer) {
// It's ok if type != datatype and buffer is ByteBuffer, since buffer will be copied into ByteBuffer
throw new IllegalArgumentException(
"The input data type: "
+ type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public NDArray create(Buffer data, Shape shape, DataType dataType) {
throw new UnsupportedOperationException("DlrNDArray only supports float32.");
}
int size = Math.toIntExact(shape.size());
BaseNDManager.validateBufferSize(data, dataType, size);
BaseNDManager.validateBuffer(data, dataType, size);
if (data instanceof ByteBuffer) {
return new DlrNDArray(this, alternativeManager, (ByteBuffer) data, shape, dataType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ public ByteBuffer toByteBuffer() {
public void set(Buffer buffer) {
int size = Math.toIntExact(size());
DataType type = getDataType();
BaseNDManager.validateBufferSize(buffer, type, size);
BaseNDManager.validateBuffer(buffer, type, size);
if (buffer.isDirect()) {
JnaUtils.syncCopyFromCPU(getHandle(), buffer, size);
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public OrtNDArray create(Buffer data, Shape shape, DataType dataType) {
"Use NDManager.create(String[], Shape) to create String NDArray.");
}
int size = Math.toIntExact(shape.size());
BaseNDManager.validateBufferSize(data, dataType, size);
BaseNDManager.validateBuffer(data, dataType, size);
OnnxTensor tensor = OrtUtils.toTensor(env, data, shape, dataType);
return new OrtNDArray(this, alternativeManager, tensor);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public PpNDArray createInternal(ByteBuffer data, long handle) {
@Override
public PpNDArray create(Buffer data, Shape shape, DataType dataType) {
int size = Math.toIntExact(shape.size());
BaseNDManager.validateBufferSize(data, dataType, size);
BaseNDManager.validateBuffer(data, dataType, size);
if (data.isDirect() && data instanceof ByteBuffer) {
return JniUtils.createNdArray(this, (ByteBuffer) data, shape, dataType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ public String[] toStringArray(Charset charset) {
public void set(Buffer buffer) {
int size = Math.toIntExact(size());
DataType type = getDataType();
BaseNDManager.validateBufferSize(buffer, type, size);
BaseNDManager.validateBuffer(buffer, type, size);
// TODO how do we handle the exception happened in the middle
dataRef = null;
if (buffer.isDirect() && buffer instanceof ByteBuffer) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public PtNDArray create(Shape shape, DataType dataType) {
@Override
public PtNDArray create(Buffer data, Shape shape, DataType dataType) {
int size = Math.toIntExact(shape.size());
BaseNDManager.validateBufferSize(data, dataType, size);
BaseNDManager.validateBuffer(data, dataType, size);
if (data.isDirect() && data instanceof ByteBuffer) {
return JniUtils.createNdFromByteBuffer(
this, (ByteBuffer) data, shape, dataType, SparseFormat.DENSE, device);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ public void set(Buffer buffer) {
}
int size = Math.toIntExact(getShape().size());
DataType type = getDataType();
BaseNDManager.validateBufferSize(buffer, type, size);
BaseNDManager.validateBuffer(buffer, type, size);
if (buffer instanceof ByteBuffer) {
JavacppUtils.setByteBuffer(getHandle(), (ByteBuffer) buffer);
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public TfNDArray create(Buffer data, Shape shape, DataType dataType) {
"Use NDManager.create(String[], Charset, Shape) to create String NDArray.");
}
int size = Math.toIntExact(shape.size());
BaseNDManager.validateBufferSize(data, dataType, size);
BaseNDManager.validateBuffer(data, dataType, size);
if (data.isDirect() && data instanceof ByteBuffer) {
TFE_TensorHandle handle =
JavacppUtils.createTFETensorFromByteBuffer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public ByteBuffer toByteBuffer() {
@Override
public void set(Buffer buffer) {
int size = Math.toIntExact(shape.size());
BaseNDManager.validateBufferSize(buffer, dataType, size);
BaseNDManager.validateBuffer(buffer, dataType, size);
BaseNDManager.copyBuffer(buffer, data);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public TrtNDManager newSubManager(Device dev) {
@Override
public TrtNDArray create(Buffer data, Shape shape, DataType dataType) {
int size = Math.toIntExact(shape.size());
BaseNDManager.validateBufferSize(data, dataType, size);
BaseNDManager.validateBuffer(data, dataType, size);
if (data.isDirect() && data instanceof ByteBuffer) {
return new TrtNDArray(this, alternativeManager, (ByteBuffer) data, shape, dataType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ TfLiteNDArray createInternal(Tensor tensor) {
@Override
public TfLiteNDArray create(Buffer data, Shape shape, DataType dataType) {
int size = Math.toIntExact(shape.size());
BaseNDManager.validateBufferSize(data, dataType, size);
BaseNDManager.validateBuffer(data, dataType, size);
if (data.isDirect() && data instanceof ByteBuffer) {
return new TfLiteNDArray(this, alternativeManager, (ByteBuffer) data, shape, dataType);
}
Expand Down

0 comments on commit db9893e

Please sign in to comment.