Skip to content

Commit

Permalink
move validateDataType into BaseNDManager
Browse files Browse the repository at this point in the history
Change-Id: Iac7155e469cc5c2918c4452eb95b4c9a2ef9cb43
  • Loading branch information
frankfliu committed Sep 1, 2022
1 parent 1d3fdea commit 65c2548
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 87 deletions.
13 changes: 11 additions & 2 deletions api/src/main/java/ai/djl/ndarray/BaseNDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -404,9 +404,18 @@ NDManager getAlternativeManager() {
* @throws IllegalArgumentException if buffer size is invalid
*/
public static void validateBufferSize(Buffer buffer, DataType dataType, int expected) {
boolean isByteBuffer = buffer instanceof ByteBuffer;
DataType type = DataType.fromBuffer(buffer);
if (!isByteBuffer && type != dataType) {
throw new IllegalArgumentException(
"The input data type: "
+ type
+ " does not match target array data type: "
+ target);
}

int remaining = buffer.remaining();
int expectedSize =
buffer instanceof ByteBuffer ? dataType.getNumOfBytes() * expected : expected;
int expectedSize = isByteBuffer ? dataType.getNumOfBytes() * expected : expected;
if (remaining < expectedSize) {
throw new IllegalArgumentException(
"The NDArray size is: " + expected + ", but buffer size is: " + remaining);
Expand Down
4 changes: 2 additions & 2 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -400,9 +400,9 @@ default Number[] toArray() {
/**
* Sets this {@code NDArray} value from {@link Buffer}.
*
* @param data the input buffered data
* @param buffer the input buffered data
*/
void set(Buffer data);
void set(Buffer buffer);

/**
* Sets this {@code NDArray} value from an array of floats.
Expand Down
4 changes: 2 additions & 2 deletions api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ public NDArray get(NDIndex index) {

/** {@inheritDoc} */
@Override
public void set(Buffer data) {
NDArray array = manager.create(data, getShape(), getDataType());
public void set(Buffer buffer) {
NDArray array = manager.create(buffer, getShape(), getDataType());
intern(array);
array.detach();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,31 +302,18 @@ public ByteBuffer toByteBuffer() {

/** {@inheritDoc} */
@Override
public void set(Buffer data) {
DataType arrayType = getDataType();
DataType inputType = DataType.fromBuffer(data);
if (arrayType != inputType) {
DataType[] types = {DataType.UINT8, DataType.INT8, DataType.BOOLEAN};
if (Arrays.stream(types).noneMatch(x -> x == arrayType)
|| Arrays.stream(types).noneMatch(x -> x == inputType)) {
throw new IllegalArgumentException(
"The input data type: "
+ inputType
+ " does not match target array data type: "
+ arrayType);
}
}

public void set(Buffer buffer) {
int size = Math.toIntExact(size());
BaseNDManager.validateBufferSize(data, getDataType(), size);
if (data.isDirect()) {
JnaUtils.syncCopyFromCPU(getHandle(), data, size);
DataType type = getDataType();
BaseNDManager.validateBufferSize(buffer, type, size);
if (buffer.isDirect()) {
JnaUtils.syncCopyFromCPU(getHandle(), buffer, size);
return;
}

ByteBuffer buf = manager.allocateDirect(size * getDataType().getNumOfBytes());
BaseNDManager.copyBuffer(data, buf);
JnaUtils.syncCopyFromCPU(getHandle(), buf, size);
ByteBuffer bb = manager.allocateDirect(size * type.getNumOfBytes());
BaseNDManager.copyBuffer(buffer, bb);
JnaUtils.syncCopyFromCPU(getHandle(), bb, size);
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,37 +225,24 @@ public String[] toStringArray(Charset charset) {

/** {@inheritDoc} */
@Override
public void set(Buffer data) {
DataType arrayType = getDataType();
DataType inputType = DataType.fromBuffer(data);
if (arrayType != inputType) {
DataType[] types = {DataType.UINT8, DataType.INT8, DataType.BOOLEAN};
if (Arrays.stream(types).noneMatch(x -> x == arrayType)
|| Arrays.stream(types).noneMatch(x -> x == inputType)) {
throw new IllegalArgumentException(
"The input data type: "
+ inputType
+ " does not match target array data type: "
+ arrayType);
}
}

public void set(Buffer buffer) {
int size = Math.toIntExact(size());
BaseNDManager.validateBufferSize(data, getDataType(), size);
DataType type = getDataType();
BaseNDManager.validateBufferSize(buffer, type, size);
// TODO how do we handle the exception happened in the middle
dataRef = null;
if (data.isDirect() && data instanceof ByteBuffer) {
if (buffer.isDirect() && buffer instanceof ByteBuffer) {
// If NDArray is on the GPU, it is native code responsibility to control the data life
// cycle
if (!getDevice().isGpu()) {
dataRef = (ByteBuffer) data;
dataRef = (ByteBuffer) buffer;
}
JniUtils.set(this, (ByteBuffer) data);
JniUtils.set(this, (ByteBuffer) buffer);
return;
}
// int8, uint8, boolean use ByteBuffer, so need to explicitly input DataType
ByteBuffer buf = manager.allocateDirect(size * inputType.getNumOfBytes());
BaseNDManager.copyBuffer(data, buf);
ByteBuffer buf = manager.allocateDirect(size * type.getNumOfBytes());
BaseNDManager.copyBuffer(buffer, buf);

// If NDArray is on the GPU, it is native code responsibility to control the data life cycle
if (!getDevice().isGpu()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,34 +186,21 @@ public ByteBuffer toByteBuffer() {

/** {@inheritDoc} */
@Override
public void set(Buffer data) {
DataType arrayType = getDataType();
DataType inputType = DataType.fromBuffer(data);
if (arrayType != inputType) {
DataType[] types = {DataType.UINT8, DataType.INT8, DataType.BOOLEAN};
if (Arrays.stream(types).noneMatch(x -> x == arrayType)
|| Arrays.stream(types).noneMatch(x -> x == inputType)) {
throw new IllegalArgumentException(
"The input data type: "
+ inputType
+ " does not match target array data type: "
+ arrayType);
}
}

public void set(Buffer buffer) {
if (getDevice().isGpu()) {
// TODO: Implement set for GPU
throw new UnsupportedOperationException("GPU Tensor cannot be modified after creation");
}
int size = Math.toIntExact(getShape().size());
BaseNDManager.validateBufferSize(data, arrayType, size);
if (data instanceof ByteBuffer) {
JavacppUtils.setByteBuffer(getHandle(), (ByteBuffer) data);
DataType type = getDataType();
BaseNDManager.validateBufferSize(buffer, type, size);
if (buffer instanceof ByteBuffer) {
JavacppUtils.setByteBuffer(getHandle(), (ByteBuffer) buffer);
return;
}
ByteBuffer buf = getManager().allocateDirect(size * arrayType.getNumOfBytes());
BaseNDManager.copyBuffer(data, buf);
JavacppUtils.setByteBuffer(getHandle(), buf);
ByteBuffer bb = getManager().allocateDirect(size * type.getNumOfBytes());
BaseNDManager.copyBuffer(bb, bb);
JavacppUtils.setByteBuffer(getHandle(), bb);
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.UUID;

/** {@code TrtNDArray} is the TensorRT implementation of {@link NDArray}. */
Expand Down Expand Up @@ -63,23 +62,9 @@ public ByteBuffer toByteBuffer() {

/** {@inheritDoc} */
@Override
public void set(Buffer data) {
DataType arrayType = getDataType();
DataType inputType = DataType.fromBuffer(data);
if (arrayType != inputType) {
DataType[] types = {DataType.UINT8, DataType.INT8, DataType.BOOLEAN};
if (Arrays.stream(types).noneMatch(x -> x == arrayType)
|| Arrays.stream(types).noneMatch(x -> x == inputType)) {
throw new IllegalArgumentException(
"The input data type: "
+ inputType
+ " does not match target array data type: "
+ arrayType);
}
}

public void set(Buffer buffer) {
int size = Math.toIntExact(shape.size());
BaseNDManager.validateBufferSize(data, dataType, size);
BaseNDManager.copyBuffer(data, this.data);
BaseNDManager.validateBufferSize(buffer, dataType, size);
BaseNDManager.copyBuffer(buffer, data);
}
}

0 comments on commit 65c2548

Please sign in to comment.