Skip to content

Commit

Permalink
Merge pull request #2253 from DARMA-tasking/2063-convert-buffer-point…
Browse files Browse the repository at this point in the history
…ers-to-std-byte

#2063: Convert buffer pointers to `std::byte*`
  • Loading branch information
lifflander authored Apr 16, 2024
2 parents 119a0b2 + 463661b commit 084c6c3
Show file tree
Hide file tree
Showing 55 changed files with 213 additions and 217 deletions.
2 changes: 1 addition & 1 deletion examples/rdma/rdma_simple_get.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ static vt::RDMA_GetType test_get_fn(
this_node, num_bytes, tag
);
return vt::RDMA_GetType{
my_data.get() + tag, num_bytes == vt::no_byte ? sizeof(double)*10 : num_bytes
reinterpret_cast<std::byte*>(my_data.get() + tag), num_bytes == vt::no_byte ? sizeof(double)*10 : num_bytes
};
}

Expand Down
4 changes: 2 additions & 2 deletions examples/rdma/rdma_simple_put.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ static void put_data_fn(HandleMsg* msg) {
}

vt::theRDMA()->putData(
handle, local_data, sizeof(double)*local_data_len,
handle, reinterpret_cast<std::byte*>(local_data), sizeof(double)*local_data_len,
(this_node-1)*local_data_len, vt::no_tag, vt::rdma::rdma_default_byte_size,
[=]{
delete [] local_data;
Expand Down Expand Up @@ -111,7 +111,7 @@ static void put_handler_fn(
for (decltype(count) i = 0; i < count; i++) {
::fmt::print(
"{}: put_handler_fn: data[{}] = {}\n",
this_node, i, static_cast<double*>(in_ptr)[i]
this_node, i, reinterpret_cast<double*>(in_ptr)[i]
);
}

Expand Down
12 changes: 6 additions & 6 deletions src/vt/collective/scatter/scatter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,19 @@ Scatter::Scatter()
: tree::Tree(tree::tree_cons_tag_t)
{ }

char* Scatter::applyScatterRecur(
NodeType node, char* ptr, std::size_t elm_size, FuncSizeType size_fn,
std::byte* Scatter::applyScatterRecur(
NodeType node, std::byte* ptr, std::size_t elm_size, FuncSizeType size_fn,
FuncDataType data_fn
) {
// pre-order k-ary tree traversal for data layout
auto children = Tree::getChildren(node);
char* cur_ptr = ptr;
auto cur_ptr = ptr;
vt_debug_print(
normal, scatter,
"Scatter::applyScatterRecur: elm_size={}, ptr={}, node={}\n",
elm_size, print_ptr(ptr), node
);
data_fn(node, reinterpret_cast<void*>(cur_ptr));
data_fn(node, cur_ptr);
cur_ptr += elm_size;
for (auto&& child : children) {
vt_debug_print(
Expand Down Expand Up @@ -100,7 +100,7 @@ void Scatter::scatterIn(ScatterMsg* msg) {
child, num_children, child_bytes_size
);
auto const child_remaining_size =
thePool()->remainingSize(reinterpret_cast<void*>(child_msg.get()));
thePool()->remainingSize(reinterpret_cast<std::byte*>(child_msg.get()));
child_msg->user_han = user_handler;
auto ptr = reinterpret_cast<char*>(child_msg.get()) + sizeof(ScatterMsg);
vt_debug_print(
Expand All @@ -118,7 +118,7 @@ void Scatter::scatterIn(ScatterMsg* msg) {
});

auto const& active_fn = auto_registry::getScatterAutoHandler(user_handler);
active_fn->dispatch(in_base_ptr, nullptr);
active_fn->dispatch(reinterpret_cast<std::byte*>(in_base_ptr), nullptr);
}

/*static*/ void Scatter::scatterHandler(ScatterMsg* msg) {
Expand Down
6 changes: 3 additions & 3 deletions src/vt/collective/scatter/scatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ namespace vt { namespace collective { namespace scatter {
*/
struct Scatter : virtual collective::tree::Tree {
using FuncSizeType = std::function<std::size_t(NodeType)>;
using FuncDataType = std::function<void(NodeType, void*)>;
using FuncDataType = std::function<void(NodeType, std::byte*)>;

/**
* \internal \brief Construct a scatter manager
Expand Down Expand Up @@ -130,8 +130,8 @@ struct Scatter : virtual collective::tree::Tree {
*
* \return incremented point after scatter is complete
*/
char* applyScatterRecur(
NodeType node, char* ptr, std::size_t elm_size, FuncSizeType size_fn,
std::byte* applyScatterRecur(
NodeType node, std::byte* ptr, std::size_t elm_size, FuncSizeType size_fn,
FuncDataType data_fn
);

Expand Down
4 changes: 2 additions & 2 deletions src/vt/collective/scatter/scatter.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ void Scatter::scatter(
auto scatter_msg =
makeMessageSz<ScatterMsg>(combined_size, combined_size, elm_size);
vtAssert(total_size == combined_size, "Sizes must be consistent");
auto ptr = reinterpret_cast<char*>(scatter_msg.get()) + sizeof(ScatterMsg);
auto ptr = reinterpret_cast<std::byte*>(scatter_msg.get()) + sizeof(ScatterMsg);
#if vt_check_enabled(memory_pool)
auto remaining_size =
thePool()->remainingSize(reinterpret_cast<void*>(scatter_msg.get()));
thePool()->remainingSize(reinterpret_cast<std::byte*>(scatter_msg.get()));
vtAssertInfo(
remaining_size >= combined_size, "Remaining size must be sufficient",
total_size, combined_size, remaining_size, elm_size
Expand Down
2 changes: 1 addition & 1 deletion src/vt/configs/types/types_rdma.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@

namespace vt {

using RDMA_PtrType = void *;
using RDMA_PtrType = std::byte *;
using RDMA_ElmType = uint64_t;
using RDMA_BlockType = int64_t;
using RDMA_HandleType = int64_t;
Expand Down
4 changes: 2 additions & 2 deletions src/vt/context/runnable_context/collection.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ struct Collection {
void resume();

private:
std::function<void(void* in_elm)> set_; /**< Set context function */
std::function<void(std::byte* in_elm)> set_; /**< Set context function */
std::function<void()> clear_; /**< Clear context function */
void* elm_ = nullptr; /**< The element (untyped) */
std::byte* elm_ = nullptr; /**< The element (untyped) */
};

}} /* end namespace vt::ctx */
Expand Down
6 changes: 3 additions & 3 deletions src/vt/context/runnable_context/collection.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ template <typename IndexT>
/*explicit*/ Collection::Collection(
vrt::collection::Indexable<IndexT>* elm
) {
elm_ = elm;
set_ = [](void* in_elm){
auto e = static_cast<vrt::collection::Indexable<IndexT>*>(in_elm);
elm_ = reinterpret_cast<std::byte*>(elm);
set_ = [](std::byte* in_elm){
auto e = reinterpret_cast<vrt::collection::Indexable<IndexT>*>(in_elm);
auto& idx_ = e->getIndex();
auto proxy_ = e->getProxy();
vrt::collection::CollectionContextHolder<IndexT>::set(&idx_, proxy_);
Expand Down
4 changes: 2 additions & 2 deletions src/vt/group/msg/group_msg.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,13 @@ struct GroupListMsg : GroupInfoMsg<GroupMsg<::vt::PayloadMessage>> {
)
{
setPut(
in_range->getBound(),
reinterpret_cast<const std::byte*>(in_range->getBound()),
in_range->getSize() * sizeof(RangeType::BoundType)
);
}

RangeType getRange() {
auto const& ptr = static_cast<RangeType::BoundType*>(getPut());
auto const& ptr = reinterpret_cast<RangeType::BoundType*>(getPut());
return region::ShallowList(ptr, getCount());
}
};
Expand Down
42 changes: 21 additions & 21 deletions src/vt/messaging/active.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,15 @@ trace::TraceEventIDType ActiveMessenger::makeTraceCreationSend(
}

MsgSizeType ActiveMessenger::packMsg(
MessageType* msg, MsgSizeType size, void* ptr, MsgSizeType ptr_bytes
MessageType* msg, MsgSizeType size, std::byte* ptr, MsgSizeType ptr_bytes
) {
vt_debug_print(
verbose, active,
"packMsg: msg_size={}, put_size={}, ptr={}\n",
size, ptr_bytes, print_ptr(ptr)
);

auto const can_grow = thePool()->tryGrowAllocation(msg, ptr_bytes);
auto const can_grow = thePool()->tryGrowAllocation(reinterpret_cast<std::byte*>(msg), ptr_bytes);
// Typically this should be checked by the caller in advance
vtAssert(can_grow, "not enough space to pack message" );

Expand Down Expand Up @@ -236,7 +236,7 @@ EventType ActiveMessenger::sendMsgBytesWithPut(
auto const& put_ptr = envelopeGetPutPtr(msg->env);
auto const& put_size = envelopeGetPutSize(msg->env);
bool const& memory_pool_active = thePool()->active_env();
auto const& rem_size = thePool()->remainingSize(msg);
auto const& rem_size = thePool()->remainingSize(reinterpret_cast<std::byte*>(msg));
/*
* Directly pack if the pool is active (which means it may have
* overallocated and the remaining size of the (envelope) buffer is
Expand All @@ -255,7 +255,7 @@ EventType ActiveMessenger::sendMsgBytesWithPut(
verbose, active,
"sendMsgBytesWithPut: (put) put_ptr={}, size:[msg={},put={},rem={}],"
"dest={}, max_pack_size={}, direct_buf_pack={}\n",
put_ptr, base.size(), put_size, rem_size, dest, max_pack_direct_size,
print_ptr(put_ptr), base.size(), put_size, rem_size, dest, max_pack_direct_size,
print_bool(direct_buf_pack)
);
}
Expand Down Expand Up @@ -300,7 +300,7 @@ struct MultiMsg : vt::Message {
}

void ActiveMessenger::handleChunkedMultiMsg(MultiMsg* msg) {
auto buf = static_cast<char*>(thePool()->alloc(msg->getSize()));
auto buf = thePool()->alloc(msg->getSize());

auto const size = msg->getSize();
auto const info = msg->getInfo();
Expand Down Expand Up @@ -376,7 +376,7 @@ EventType ActiveMessenger::sendMsgMPI(
auto tag = allocateNewTag();

// Send the actual data in multiple chunks
PtrLenPairType tup = std::make_tuple(untyped_msg, msg_size);
PtrLenPairType tup = std::make_tuple(reinterpret_cast<std::byte*>(untyped_msg), msg_size);
SendInfo info = sendData(tup, dest, tag);

auto event_id = info.getEvent();
Expand Down Expand Up @@ -539,7 +539,7 @@ SendInfo ActiveMessenger::sendData(
vt_debug_print(
terse, active,
"sendData: ptr={}, num_bytes={} dest={}, tag={}, send_tag={}\n",
data_ptr, num_bytes, dest, tag, send_tag
print_ptr(data_ptr), num_bytes, dest, tag, send_tag
);

vtAbortIf(
Expand All @@ -564,7 +564,7 @@ SendInfo ActiveMessenger::sendData(
std::tuple<EventType, int> ActiveMessenger::sendDataMPI(
PtrLenPairType const& payload, NodeType const& dest, TagType const& tag
) {
auto ptr = static_cast<char*>(std::get<0>(payload));
auto ptr = reinterpret_cast<char*>(std::get<0>(payload));
auto remainder = std::get<1>(payload);
int num_sends = 0;
std::vector<EventType> events;
Expand Down Expand Up @@ -670,7 +670,7 @@ bool ActiveMessenger::tryProcessDataMsgRecv() {
}

bool ActiveMessenger::recvDataMsgBuffer(
int nchunks, void* const user_buf, TagType const& tag,
int nchunks, std::byte* const user_buf, TagType const& tag,
NodeType const& node, bool const& enqueue, ActionType dealloc,
ContinuationDeleterType next, bool is_user_buf
) {
Expand All @@ -681,7 +681,7 @@ bool ActiveMessenger::recvDataMsgBuffer(
}

bool ActiveMessenger::recvDataMsgBuffer(
int nchunks, void* const user_buf, PriorityType priority, TagType const& tag,
int nchunks, std::byte* const user_buf, PriorityType priority, TagType const& tag,
NodeType const& node, bool const& enqueue, ActionType dealloc_user_buf,
ContinuationDeleterType next, bool is_user_buf
) {
Expand All @@ -702,9 +702,9 @@ bool ActiveMessenger::recvDataMsgBuffer(
if (flag == 1) {
MPI_Get_count(&stat, MPI_BYTE, &num_probe_bytes);

char* buf = user_buf == nullptr ?
static_cast<char*>(thePool()->alloc(num_probe_bytes)) :
static_cast<char*>(user_buf);
std::byte* buf = user_buf == nullptr ?
thePool()->alloc(num_probe_bytes) :
user_buf;

NodeType const sender = stat.MPI_SOURCE;

Expand Down Expand Up @@ -743,15 +743,15 @@ void ActiveMessenger::recvDataDirect(
int nchunks, TagType const tag, NodeType const from, MsgSizeType len,
ContinuationDeleterType next
) {
char* buf = static_cast<char*>(thePool()->alloc(len));
std::byte* buf = thePool()->alloc(len);

recvDataDirect(
nchunks, buf, tag, from, len, default_priority, nullptr, next, false
);
}

void ActiveMessenger::recvDataDirect(
int nchunks, void* const buf, TagType const tag, NodeType const from,
int nchunks, std::byte* const buf, TagType const tag, NodeType const from,
MsgSizeType len, PriorityType prio, ActionType dealloc,
ContinuationDeleterType next, bool is_user_buf
) {
Expand All @@ -760,7 +760,7 @@ void ActiveMessenger::recvDataDirect(
std::vector<MPI_Request> reqs;
reqs.resize(nchunks);

char* cbuf = static_cast<char*>(buf);
std::byte* cbuf = buf;
MsgSizeType remainder = len;
auto const max_per_send = theConfig()->vt_max_mpi_send_size;
for (int i = 0; i < nchunks; i++) {
Expand Down Expand Up @@ -838,7 +838,7 @@ void ActiveMessenger::finishPendingDataMsgAsyncRecv(InProgressDataIRecv* irecv)
vt_debug_print(
normal, active,
"finishPendingDataMsgAsyncRecv: continuation user_buf={}, buf={}\n",
user_buf, buf
print_ptr(user_buf), print_ptr(buf)
);

if (user_buf == nullptr) {
Expand Down Expand Up @@ -998,7 +998,7 @@ bool ActiveMessenger::tryProcessIncomingActiveMsg() {
if (flag == 1) {
MPI_Get_count(&stat, MPI_BYTE, &num_probe_bytes);

char* buf = static_cast<char*>(thePool()->alloc(num_probe_bytes));
std::byte* buf = thePool()->alloc(num_probe_bytes);

NodeType const sender = stat.MPI_SOURCE;

Expand Down Expand Up @@ -1051,7 +1051,7 @@ bool ActiveMessenger::tryProcessIncomingActiveMsg() {
}

void ActiveMessenger::finishPendingActiveMsgAsyncRecv(InProgressIRecv* irecv) {
char* buf = irecv->buf;
std::byte* buf = irecv->buf;
auto num_probe_bytes = irecv->probe_bytes;
auto sender = irecv->sender;

Expand Down Expand Up @@ -1089,14 +1089,14 @@ void ActiveMessenger::finishPendingActiveMsgAsyncRecv(InProgressIRecv* irecv) {
if (put_tag == PutPackedTag) {
auto const put_size = envelopeGetPutSize(msg->env);
auto const msg_size = num_probe_bytes - put_size;
char* put_ptr = buf + msg_size;
std::byte* put_ptr = buf + msg_size;

if (!is_term || vt_check_enabled(print_term_msgs)) {
vt_debug_print(
verbose, active,
"finishPendingActiveMsgAsyncRecv: packed put: ptr={}, msg_size={}, "
"put_size={}\n",
put_ptr, msg_size, put_size
print_ptr(put_ptr), msg_size, put_size
);
}

Expand Down
Loading

0 comments on commit 084c6c3

Please sign in to comment.