-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added CopyFromBytes and CopyToBytes convenience methods to NDArray. Fixed typos. #4970
Changes from 2 commits
0b6f69a
ab71424
46f2dec
176102b
b3f5f98
c0a4726
db247f1
dccb6d4
8abe01f
7d1d4fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,19 +68,37 @@ class NDArray : public ObjectRef { | |
/*! | ||
* \brief Copy data content from another array. | ||
* \param other The source array to be copied from. | ||
* \note The copy may happen asynchrously if it involves a GPU context. | ||
* \note The copy may happen asynchronously if it involves a GPU context. | ||
* TVMSynchronize is necessary. | ||
*/ | ||
inline void CopyFrom(const DLTensor* other); | ||
inline void CopyFrom(const NDArray& other); | ||
/*! | ||
* \brief Copy data content from a byte buffer. | ||
* \param data The source bytes to be copied from. | ||
* \param nbytes The size of the buffer in bytes | ||
* Must be equal to the size of the NDArray. | ||
* \note The copy may happen asynchronously if it involves a GPU context. | ||
* TVMSynchronize is necessary. | ||
*/ | ||
inline void CopyFromBytes(const void* data, size_t nbytes); | ||
/*! | ||
* \brief Copy data content into another array. | ||
* \param other The source array to be copied from. | ||
* \note The copy may happen asynchrously if it involves a GPU context. | ||
* \note The copy may happen asynchronously if it involves a GPU context. | ||
* TVMSynchronize is necessary. | ||
*/ | ||
inline void CopyTo(DLTensor* other) const; | ||
inline void CopyTo(const NDArray& other) const; | ||
/*! | ||
* \brief Copy data content into a byte buffer | ||
* \param data The source bytes to be copied from. | ||
* \param nbytes The size of the data buffer. | ||
* Must be equal to the size of the NDArray. | ||
* \note The copy may happen asynchronously if it involves a GPU context. | ||
* TVMSynchronize is necessary. | ||
*/ | ||
inline void CopyToBytes(void* data, size_t nbytes) const; | ||
/*! | ||
* \brief Copy the data to another context. | ||
* \param ctx The target context. | ||
|
@@ -182,7 +200,7 @@ class NDArray : public ObjectRef { | |
|
||
/*! | ||
* \brief Save a DLTensor to stream | ||
* \param strm The outpu stream | ||
* \param strm The output stream | ||
* \param tensor The tensor to be saved. | ||
*/ | ||
inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor); | ||
|
@@ -205,7 +223,7 @@ class NDArray::ContainerBase { | |
DLTensor dl_tensor; | ||
|
||
/*! | ||
* \brief addtional context, reserved for recycling | ||
* \brief additional context, reserved for recycling | ||
* \note We can attach additional content here | ||
* which the current container depend on | ||
* (e.g. reference to original memory when creating views). | ||
|
@@ -306,6 +324,26 @@ inline void NDArray::CopyFrom(const NDArray& other) { | |
CopyFromTo(&(other.get_mutable()->dl_tensor), &(get_mutable()->dl_tensor)); | ||
} | ||
|
||
inline void NDArray::CopyFromBytes(const void* data, size_t nbytes) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To avoid duplication, let us make it a non-line member function, and make uses of implementation in the C API here https://github.com/apache/incubator-tvm/blob/master/src/runtime/ndarray.cc#L289 After that, we can safely redirect the C API calls into these C++ APIs There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I have it calling the TVMArrayCopyFrom/ToBytes methods. Are you saying to put the implementation into ndarray.cc ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I misunderstood your original message. Let me fix |
||
CHECK(data != nullptr); | ||
CHECK(data_ != nullptr); | ||
|
||
// Make a temporary copy of the dltensor | ||
DLTensor input = get_mutable()->dl_tensor; | ||
|
||
CHECK_EQ(nbytes, GetDataSize(input)) | ||
<< "CopyFromBytes: The size must exactly match"; | ||
|
||
TVMContext cpu_ctx; | ||
cpu_ctx.device_type = kDLCPU; | ||
cpu_ctx.device_id = 0; | ||
// Overwrite our temporary dltensor with a CPU context | ||
input.ctx = cpu_ctx; | ||
input.data = const_cast<void*>(data); | ||
|
||
CopyFrom(&input); | ||
} | ||
|
||
inline void NDArray::CopyTo(DLTensor* other) const { | ||
CHECK(data_ != nullptr); | ||
CopyFromTo(&(get_mutable()->dl_tensor), other); | ||
|
@@ -317,6 +355,26 @@ inline void NDArray::CopyTo(const NDArray& other) const { | |
CopyFromTo(&(get_mutable()->dl_tensor), &(other.get_mutable()->dl_tensor)); | ||
} | ||
|
||
inline void NDArray::CopyToBytes(void* data, size_t nbytes) const { | ||
CHECK(data != nullptr); | ||
CHECK(data_ != nullptr); | ||
|
||
// Make a temporary copy of the dltensor | ||
DLTensor output = get_mutable()->dl_tensor; | ||
|
||
CHECK_EQ(nbytes, GetDataSize(output)) | ||
<< "CopyToBytes: The size must exactly match"; | ||
|
||
TVMContext cpu_ctx; | ||
cpu_ctx.device_type = kDLCPU; | ||
cpu_ctx.device_id = 0; | ||
// Overwrite our temporary dltensor with a CPU context | ||
output.ctx = cpu_ctx; | ||
output.data = data; | ||
|
||
CopyTo(&output); | ||
} | ||
|
||
inline NDArray NDArray::CopyTo(const DLContext& ctx) const { | ||
CHECK(data_ != nullptr); | ||
const DLTensor* dptr = operator->(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Declare as TVM_DLL as it is a member function that is not inlined
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed!