Skip to content

Commit

Permalink
#2063: Clean up casts
Browse files Browse the repository at this point in the history
  • Loading branch information
thearusable committed Mar 27, 2024
1 parent ec83490 commit e47b78d
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 28 deletions.
8 changes: 4 additions & 4 deletions src/vt/collective/scatter/scatter.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ void Scatter::scatter(
auto scatter_msg =
makeMessageSz<ScatterMsg>(combined_size, combined_size, elm_size);
vtAssert(total_size == combined_size, "Sizes must be consistent");
auto ptr = reinterpret_cast<char*>(scatter_msg.get()) + sizeof(ScatterMsg);
auto ptr = reinterpret_cast<std::byte*>(scatter_msg.get()) + sizeof(ScatterMsg);
#if vt_check_enabled(memory_pool)
auto remaining_size =
thePool()->remainingSize(reinterpret_cast<std::byte*>(scatter_msg.get()));
Expand All @@ -84,11 +84,11 @@ void Scatter::scatter(
print_ptr(ptr), remaining_size
);
auto const& root_node = 0;
auto nptr = applyScatterRecur(root_node, reinterpret_cast<std::byte*>(ptr), elm_size, size_fn, data_fn);
auto nptr = applyScatterRecur(root_node, ptr, elm_size, size_fn, data_fn);
vt_debug_print(
verbose, scatter, "Scatter::scatter: incremented size={}\n", nptr - reinterpret_cast<std::byte*>(ptr)
verbose, scatter, "Scatter::scatter: incremented size={}\n", nptr - ptr
);
vtAssert(nptr == reinterpret_cast<std::byte*>(ptr + combined_size), "nptr must match size");
vtAssert(nptr == ptr + combined_size, "nptr must match size");
auto const& handler = auto_registry::makeScatterHandler<MessageT, f>();
auto const& this_node = theContext()->getNode();
scatter_msg->user_han = handler;
Expand Down
20 changes: 10 additions & 10 deletions src/vt/messaging/active.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ struct MultiMsg : vt::Message {
}

void ActiveMessenger::handleChunkedMultiMsg(MultiMsg* msg) {
auto buf = reinterpret_cast<std::byte*>(thePool()->alloc(msg->getSize()));
auto buf = thePool()->alloc(msg->getSize());

auto const size = msg->getSize();
auto const info = msg->getInfo();
Expand Down Expand Up @@ -703,7 +703,7 @@ bool ActiveMessenger::recvDataMsgBuffer(
MPI_Get_count(&stat, MPI_BYTE, &num_probe_bytes);

std::byte* buf = user_buf == nullptr ?
reinterpret_cast<std::byte*>(thePool()->alloc(num_probe_bytes)) :
thePool()->alloc(num_probe_bytes) :
user_buf;

NodeType const sender = stat.MPI_SOURCE;
Expand Down Expand Up @@ -760,7 +760,7 @@ void ActiveMessenger::recvDataDirect(
std::vector<MPI_Request> reqs;
reqs.resize(nchunks);

char* cbuf = reinterpret_cast<char*>(buf);
std::byte* cbuf = buf;
MsgSizeType remainder = len;
auto const max_per_send = theConfig()->vt_max_mpi_send_size;
for (int i = 0; i < nchunks; i++) {
Expand Down Expand Up @@ -801,7 +801,7 @@ void ActiveMessenger::recvDataDirect(
}

InProgressDataIRecv recv{
reinterpret_cast<std::byte*>(cbuf), len, from, std::move(reqs), is_user_buf ? buf : nullptr, dealloc,
cbuf, len, from, std::move(reqs), is_user_buf ? buf : nullptr, dealloc,
next, prio
};

Expand Down Expand Up @@ -998,7 +998,7 @@ bool ActiveMessenger::tryProcessIncomingActiveMsg() {
if (flag == 1) {
MPI_Get_count(&stat, MPI_BYTE, &num_probe_bytes);

char* buf = reinterpret_cast<char*>(thePool()->alloc(num_probe_bytes));
std::byte* buf = thePool()->alloc(num_probe_bytes);

NodeType const sender = stat.MPI_SOURCE;

Expand Down Expand Up @@ -1032,7 +1032,7 @@ bool ActiveMessenger::tryProcessIncomingActiveMsg() {
#endif
}

InProgressIRecv recv_holder{reinterpret_cast<std::byte*>(buf), num_probe_bytes, sender, req};
InProgressIRecv recv_holder{buf, num_probe_bytes, sender, req};

int num_mpi_tests = 0;
auto done = recv_holder.test(num_mpi_tests);
Expand All @@ -1051,7 +1051,7 @@ bool ActiveMessenger::tryProcessIncomingActiveMsg() {
}

void ActiveMessenger::finishPendingActiveMsgAsyncRecv(InProgressIRecv* irecv) {
char* buf = reinterpret_cast<char*>(irecv->buf);
std::byte* buf = irecv->buf;
auto num_probe_bytes = irecv->probe_bytes;
auto sender = irecv->sender;

Expand Down Expand Up @@ -1089,18 +1089,18 @@ void ActiveMessenger::finishPendingActiveMsgAsyncRecv(InProgressIRecv* irecv) {
if (put_tag == PutPackedTag) {
auto const put_size = envelopeGetPutSize(msg->env);
auto const msg_size = num_probe_bytes - put_size;
char* put_ptr = buf + msg_size;
std::byte* put_ptr = buf + msg_size;

if (!is_term || vt_check_enabled(print_term_msgs)) {
vt_debug_print(
verbose, active,
"finishPendingActiveMsgAsyncRecv: packed put: ptr={}, msg_size={}, "
"put_size={}\n",
put_ptr, msg_size, put_size
print_ptr(put_ptr), msg_size, put_size
);
}

envelopeSetPutPtrOnly(msg->env, reinterpret_cast<std::byte*>(put_ptr));
envelopeSetPutPtrOnly(msg->env, put_ptr);
put_finished = true;
} else {
/*bool const put_delivered = */recvDataMsg(
Expand Down
2 changes: 1 addition & 1 deletion src/vt/pool/pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ std::byte* Pool::defaultAlloc(size_t const& num_bytes, size_t const& oversize) {
}

void Pool::defaultDealloc(std::byte* const ptr) {
std::free(reinterpret_cast<void*>(ptr));
std::free(ptr);
}

std::byte* Pool::alloc(size_t const& num_bytes, size_t oversize) {
Expand Down
18 changes: 9 additions & 9 deletions src/vt/rdma/rdma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ RDMAManager::RDMAManager()
);
} else {
theMsg()->recvDataMsgBuffer(
msg->nchunks, reinterpret_cast<std::byte*>(get_ptr), msg->mpi_tag_to_recv, msg->send_back, true,
msg->nchunks, get_ptr, msg->mpi_tag_to_recv, msg->send_back, true,
[get_ptr_action]{
vt_debug_print(
normal, rdma,
Expand Down Expand Up @@ -178,9 +178,9 @@ RDMAManager::RDMAManager()
);

if (direct) {
auto data_ptr = reinterpret_cast<char*>(msg) + sizeof(PutMessage);
auto data_ptr = reinterpret_cast<std::byte*>(msg) + sizeof(PutMessage);
theRDMA()->triggerPutRecvData(
msg->rdma_handle, msg_tag, reinterpret_cast<std::byte*>(data_ptr), msg->num_bytes, msg->offset, [=]{
msg->rdma_handle, msg_tag, data_ptr, msg->num_bytes, msg->offset, [=]{
vt_debug_print(
normal, rdma,
"put_data: after put trigger: send_back={}\n", send_back
Expand Down Expand Up @@ -238,10 +238,10 @@ RDMAManager::RDMAManager()
});
} else {
auto const& put_ptr_offset =
msg->offset != no_byte ? reinterpret_cast<char*>(put_ptr) + msg->offset : reinterpret_cast<char*>(put_ptr);
msg->offset != no_byte ? put_ptr + msg->offset : put_ptr;
// do a direct recv into the user buffer
theMsg()->recvDataMsgBuffer(
msg->nchunks, reinterpret_cast<std::byte*>(put_ptr_offset), recv_tag, recv_node, true, []{},
msg->nchunks, put_ptr_offset, recv_tag, recv_node, true, []{},
[=](RDMA_GetType ptr, ActionType deleter){
vt_debug_print(
normal, rdma,
Expand Down Expand Up @@ -780,7 +780,7 @@ void RDMAManager::putRegionTypeless(
auto const& elm_size = region.elm_size;
auto const& rlo = region.lo;
auto const& roffset = lo - rlo;
auto const& ptr_offset = reinterpret_cast<char*>(ptr) + (roffset * elm_size);
auto const& ptr_offset = ptr + (roffset * elm_size);
auto const& block_offset = (lo - blk_lo) * elm_size;

vt_debug_print(
Expand All @@ -794,7 +794,7 @@ void RDMAManager::putRegionTypeless(
remote_action->addDep();

putData(
han, reinterpret_cast<std::byte*>(ptr_offset), (hi-lo)*elm_size, block_offset, no_tag, elm_size,
han, ptr_offset, (hi-lo)*elm_size, block_offset, no_tag, elm_size,
[=]{ remote_action->release(); }, node
);
});
Expand Down Expand Up @@ -842,7 +842,7 @@ void RDMAManager::getRegionTypeless(
auto const& elm_size = region.elm_size;
auto const& rlo = region.lo;
auto const& roffset = lo - rlo;
auto const& ptr_offset = reinterpret_cast<char*>(ptr) + (roffset * elm_size);
auto const& ptr_offset = ptr + (roffset * elm_size);
auto const& block_offset = (lo - blk_lo) * elm_size;

vt_debug_print(
Expand All @@ -855,7 +855,7 @@ void RDMAManager::getRegionTypeless(
action->addDep();

getDataIntoBuf(
han, reinterpret_cast<std::byte*>(ptr_offset), (hi-lo)*elm_size, block_offset, no_tag, [=]{
han, ptr_offset, (hi-lo)*elm_size, block_offset, no_tag, [=]{
auto const& my_node = theContext()->getNode();
vt_debug_print(
normal, rdma,
Expand Down
2 changes: 1 addition & 1 deletion src/vt/rdma/state/rdma_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ bool State::testReadyPutData(TagType const& tag) {
"To use default handler ptr, bytes must be set"
);

std::memcpy(reinterpret_cast<char*>(state.ptr) + req_offset, in_ptr, req_num_bytes);
std::memcpy(state.ptr + req_offset, in_ptr, req_num_bytes);
}

void State::getData(
Expand Down
4 changes: 2 additions & 2 deletions src/vt/serialization/messaging/serialized_messenger.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,8 @@ template <typename MsgT, typename BaseT>
if (node != dest) {
auto sys_msg = makeMessage<SerialWrapperMsgType<MsgT>>();
auto send_serialized = [=](Active::SendFnType send){
auto void_ptr = reinterpret_cast<std::byte*>(ptr);
auto ret = send(RDMA_GetType{void_ptr, ptr_size}, dest, no_tag);
auto byte_ptr = reinterpret_cast<std::byte*>(ptr);
auto ret = send(RDMA_GetType{byte_ptr, ptr_size}, dest, no_tag);
EventType event = ret.getEvent();
theEvent()->attachAction(event, [=]{ std::free(ptr); });
sys_msg->data_recv_tag = ret.getTag();
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/pool/test_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ TEST_F(TestPool, pool_alloc) {

for (size_t cur_bytes = 1; cur_bytes < max_bytes; cur_bytes *= 2) {
std::byte* ptr = testPool->alloc(cur_bytes);
std::memset(reinterpret_cast<void*>(ptr), init_val, cur_bytes);
std::memset(ptr, init_val, cur_bytes);
//fmt::print("alloc {} bytes, ptr={}\n", cur_bytes, ptr);
EXPECT_NE(ptr, nullptr);
for (size_t i = 0; i < cur_bytes; i++) {
Expand Down

0 comments on commit e47b78d

Please sign in to comment.