Skip to content

Commit

Permalink
Merge pull request openucx#37 from Mellanox/support_io_offload/alloca…
Browse files Browse the repository at this point in the history
…tor_ucp_uct_impl

Support io_offload, memory allocator ucp & uct impl
  • Loading branch information
yosefe authored Sep 22, 2022
2 parents a66e616 + f48e69d commit 10e77ae
Show file tree
Hide file tree
Showing 33 changed files with 872 additions and 290 deletions.
37 changes: 20 additions & 17 deletions src/ucp/api/ucp.h
Original file line number Diff line number Diff line change
Expand Up @@ -1239,6 +1239,25 @@ typedef size_t (*ucp_mem_allocator_cb_t)(void *arg, size_t num_of_buffers,
void **buffers, ucp_mem_h *memh);


typedef struct {
/**
* User memory allocator get buf function used by UCX in post receive.
*/
ucp_mem_allocator_cb_t cb;

/**
* User-defined argument for the allocator callback.
*/
void *arg;

/**
* User memory allocator payload's buffer size.
* This will be the size of the active message fragment.
*/
size_t buffer_size;
} ucp_user_mem_allocator_t;


/**
* @ingroup UCP_WORKER
* @brief Tuning parameters for the UCP worker.
Expand Down Expand Up @@ -1347,23 +1366,7 @@ typedef struct ucp_worker_params {
/**
* User defined memory allocator
*/
struct {
/**
* User memory allocator get buf function used by UCX in post receive.
*/
ucp_mem_allocator_cb_t cb;

/**
* User-defined argument for the allocator callback.
*/
void *arg;

/**
* User memory allocator payload's buffer size.
* This will be the size of the active message fragment.
*/
size_t buffer_size;
} user_allocator;
ucp_user_mem_allocator_t user_allocator;
} ucp_worker_params_t;


Expand Down
85 changes: 38 additions & 47 deletions src/ucp/core/ucp_am.c
Original file line number Diff line number Diff line change
Expand Up @@ -877,8 +877,6 @@ ucp_am_send_req(ucp_request_t *req, size_t count,
return UCS_STATUS_PTR(status);
}

ucs_assert(ucp_am_send_req_total_size(req) >= rndv_thresh);

status = ucp_am_send_start_rndv(req, param);
if (status != UCS_OK) {
return UCS_STATUS_PTR(status);
Expand Down Expand Up @@ -1198,7 +1196,8 @@ UCS_PROFILE_FUNC(ucs_status_ptr_t, ucp_am_recv_data_nbx,
static UCS_F_ALWAYS_INLINE ucs_status_t
ucp_am_invoke_cb(ucp_worker_h worker, uint16_t am_id, void *user_hdr,
uint32_t user_hdr_length, void *data, size_t data_length,
ucp_ep_h reply_ep, uint64_t recv_flags)
ucp_ep_h reply_ep, uint64_t recv_flags,
const uct_am_callback_params_t *uct_cb_params)
{
ucp_am_entry_t *am_cb = &ucs_array_elem(&worker->am.cbs, am_id);
ucp_am_recv_param_t param;
Expand All @@ -1211,6 +1210,7 @@ ucp_am_invoke_cb(ucp_worker_h worker, uint16_t am_id, void *user_hdr,
if (ucs_likely(am_cb->flags & UCP_AM_CB_PRIV_FLAG_NBX)) {
param.recv_attr = recv_flags;
param.reply_ep = reply_ep;
param.payload = uct_cb_params->payload;

return am_cb->cb(am_cb->context, user_hdr, user_hdr_length, data,
data_length, &param);
Expand All @@ -1223,16 +1223,16 @@ ucp_am_invoke_cb(ucp_worker_h worker, uint16_t am_id, void *user_hdr,
return UCS_OK;
}

flags = (recv_flags & UCP_AM_RECV_ATTR_FLAG_DATA) ?
UCP_CB_PARAM_FLAG_DATA : 0;
flags = (recv_flags & UCP_AM_RECV_ATTR_FLAG_DATA) ? UCP_CB_PARAM_FLAG_DATA :
0;

return am_cb->cb_old(am_cb->context, data, data_length, reply_ep, flags);
}

static UCS_F_ALWAYS_INLINE ucs_status_t ucp_am_handler_common(
ucp_worker_h worker, ucp_am_hdr_t *am_hdr, size_t total_length,
ucp_ep_h reply_ep, unsigned am_flags, uint64_t recv_flags,
const char *name)
const char *name, const uct_am_callback_params_t *params)
{
ucp_recv_desc_t *desc = NULL;
uint16_t am_id = am_hdr->am_id;
Expand All @@ -1241,7 +1241,7 @@ static UCS_F_ALWAYS_INLINE ucs_status_t ucp_am_handler_common(
void *data = am_hdr + 1;
size_t data_length = total_length -
(sizeof(*am_hdr) + am_hdr->header_length);
void *user_hdr = UCS_PTR_BYTE_OFFSET(data, data_length);
void *user_hdr = UCS_PTR_BYTE_OFFSET(params->payload, data_length);
ucs_status_t desc_status = UCS_OK;
ucs_status_t status;

Expand Down Expand Up @@ -1277,12 +1277,13 @@ static UCS_F_ALWAYS_INLINE ucs_status_t ucp_am_handler_common(
worker, am_id);
return UCS_OK;
}
data = desc + 1;
desc->length = data_length;
recv_flags |= UCP_AM_RECV_ATTR_FLAG_DATA;
}

status = ucp_am_invoke_cb(worker, am_id, user_hdr, user_hdr_size, data,
data_length, reply_ep, recv_flags);
status = ucp_am_invoke_cb(worker, am_id, user_hdr, user_hdr_size,
desc + 1, data_length, reply_ep,
recv_flags, params);
if (desc == NULL) {
if (ucs_unlikely(status == UCS_INPROGRESS)) {
ucs_error("can't hold data, FLAG_DATA flag is not set");
Expand All @@ -1306,34 +1307,36 @@ static UCS_F_ALWAYS_INLINE ucs_status_t ucp_am_handler_common(
}

UCS_PROFILE_FUNC(ucs_status_t, ucp_am_handler_reply,
(am_arg, am_data, am_length, am_flags, params), void *am_arg,
void *am_data, size_t am_length, unsigned am_flags,
uct_am_callback_params_t *params)
(am_arg, am_data, am_length, am_flags, params),
void *am_arg, void *am_data, size_t am_length,
unsigned am_flags, uct_am_callback_params_t *params)
{
ucp_am_hdr_t *hdr = (ucp_am_hdr_t*)am_data;
ucp_worker_h worker = (ucp_worker_h)am_arg;
ucp_am_reply_ftr_t *ftr = UCS_PTR_BYTE_OFFSET(am_data,
am_length - sizeof(*ftr));
size_t reply_ftr_offset = am_length - sizeof(ucp_am_reply_ftr_t) -
sizeof(ucp_am_hdr_t);
ucp_am_reply_ftr_t *ftr = UCS_PTR_BYTE_OFFSET(params->payload,
reply_ftr_offset);
ucp_ep_h reply_ep;

UCP_WORKER_GET_VALID_EP_BY_ID(&reply_ep, worker, ftr->ep_id, return UCS_OK,
"AM (reply proto)");

return ucp_am_handler_common(worker, hdr, am_length - sizeof(ftr), reply_ep,
am_flags, UCP_AM_RECV_ATTR_FIELD_REPLY_EP,
"am_handler_reply");
"am_handler_reply", params);
}

UCS_PROFILE_FUNC(ucs_status_t, ucp_am_handler,
(am_arg, am_data, am_length, am_flags, params), void *am_arg,
void *am_data, size_t am_length, unsigned am_flags,
uct_am_callback_params_t *params)
(am_arg, am_data, am_length, am_flags, params),
void *am_arg, void *am_data, size_t am_length,
unsigned am_flags, uct_am_callback_params_t *params)
{
ucp_worker_h worker = am_arg;
ucp_am_hdr_t *hdr = am_data;

return ucp_am_handler_common(worker, hdr, am_length, NULL, am_flags, 0ul,
"am_handler");
"am_handler", params);
}

static UCS_F_ALWAYS_INLINE ucp_recv_desc_t *
Expand Down Expand Up @@ -1432,7 +1435,7 @@ ucp_am_handle_unfinished(ucp_worker_h worker, ucp_recv_desc_t *first_rdesc,
status = ucp_am_invoke_cb(worker, am_id, user_hdr,
user_hdr_length,
payload, total_size,
reply_ep, recv_flags);
reply_ep, recv_flags, NULL);
if (!ucp_am_rdesc_in_progress(first_rdesc, status)) {
/* user does not need to hold this data */
ucp_am_release_long_desc(first_rdesc);
Expand All @@ -1444,9 +1447,9 @@ ucp_am_handle_unfinished(ucp_worker_h worker, ucp_recv_desc_t *first_rdesc,
}

UCS_PROFILE_FUNC(ucs_status_t, ucp_am_long_first_handler,
(am_arg, am_data, am_length, am_flags, params), void *am_arg,
void *am_data, size_t am_length, unsigned am_flags,
uct_am_callback_params_t *params)
(am_arg, am_data, am_length, am_flags, params),
void *am_arg, void *am_data, size_t am_length,
unsigned am_flags, uct_am_callback_params_t *params)
{
ucp_worker_h worker = am_arg;
ucp_am_hdr_t *hdr = am_data;
Expand Down Expand Up @@ -1478,7 +1481,7 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_am_long_first_handler,
return ucp_am_handler_common(worker, hdr,
am_length - sizeof(*first_ftr), ep,
am_flags, recv_flags,
"am_long_first_handler");
"am_long_first_handler", params);
}

/* This is the first fragment, other fragments (if arrived) should be on
Expand Down Expand Up @@ -1561,9 +1564,9 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_am_long_first_handler,
}

UCS_PROFILE_FUNC(ucs_status_t, ucp_am_long_middle_handler,
(am_arg, am_data, am_length, am_flags, params), void *am_arg,
void *am_data, size_t am_length, unsigned am_flags,
uct_am_callback_params_t *params)
(am_arg, am_data, am_length, am_flags, params),
void *am_arg, void *am_data, size_t am_length,
unsigned am_flags, uct_am_callback_params_t *params)
{
ucp_worker_h worker = am_arg;
ucp_am_mid_hdr_t *mid_hdr = am_data;
Expand Down Expand Up @@ -1610,18 +1613,18 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_am_long_middle_handler,
return status;
}

ucs_status_t ucp_am_rndv_process_rts(void *arg, void *data, size_t length,
unsigned tl_flags)
void ucp_am_rndv_process_rts(void *arg, void *data, size_t length,
unsigned tl_flags)
{
ucp_rndv_rts_hdr_t *rts = data;
ucp_recv_desc_t *desc = data;
ucp_rndv_rts_hdr_t *rts = (ucp_rndv_rts_hdr_t*)(desc + 1);
ucp_worker_h worker = arg;
ucp_am_hdr_t *am = ucp_am_hdr_from_rts(rts);
uint16_t am_id = am->am_id;
ucp_recv_desc_t *desc = NULL;
ucp_am_entry_t *am_cb = &ucs_array_elem(&worker->am.cbs, am_id);
ucp_ep_h ep;
ucp_am_recv_param_t param;
ucs_status_t status, desc_status;
ucs_status_t status;
void *hdr;

if (ENABLE_PARAMS_CHECK && !(am_cb->flags & UCP_AM_CB_PRIV_FLAG_NBX)) {
Expand Down Expand Up @@ -1649,17 +1652,6 @@ ucs_status_t ucp_am_rndv_process_rts(void *arg, void *data, size_t length,
hdr = NULL;
}

desc_status = ucp_recv_desc_init(worker, data, length, 0, tl_flags, 0,
UCP_RECV_DESC_FLAG_RNDV |
UCP_RECV_DESC_FLAG_AM_CB_INPROGRESS, 0, 1,
"am_rndv_process_rts", &desc);
if (ucs_unlikely(UCS_STATUS_IS_ERR(desc_status))) {
ucs_error("worker %p could not allocate descriptor for active"
" message RTS on callback %u", worker, am_id);
status = UCS_ERR_NO_MEMORY;
goto out_send_ats;
}

param.recv_attr = UCP_AM_RECV_ATTR_FLAG_RNDV |
ucp_am_hdr_reply_ep(worker, am->flags, ep,
&param.reply_ep);
Expand All @@ -1672,7 +1664,7 @@ ucs_status_t ucp_am_rndv_process_rts(void *arg, void *data, size_t length,
ucs_status_string(status));

desc->flags &= ~UCP_RECV_DESC_FLAG_AM_CB_INPROGRESS;
return desc_status;
return;
} else if (desc->flags & UCP_RECV_DESC_FLAG_RECV_STARTED) {
/* User initiated rendezvous receive in the callback and it is
* already completed. No need to save the descriptor for further use
Expand All @@ -1689,6 +1681,7 @@ ucs_status_t ucp_am_rndv_process_rts(void *arg, void *data, size_t length,
/* Some error occurred or user does not need this data. Send ATS back to the
* sender to complete its send request. */
ucp_am_rndv_send_ats(worker, rts, status);
ucp_recv_desc_release(desc);

out:
if (desc != NULL) {
Expand All @@ -1705,8 +1698,6 @@ ucs_status_t ucp_am_rndv_process_rts(void *arg, void *data, size_t length,
ucp_recv_desc_release(desc);
}
}

return UCS_OK;
}

UCP_DEFINE_AM(UCP_FEATURE_AM, UCP_AM_ID_AM_SINGLE,
Expand Down
4 changes: 2 additions & 2 deletions src/ucp/core/ucp_am.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ size_t ucp_am_max_header_size(ucp_worker_h worker);

ucs_status_t ucp_proto_progress_am_rndv_rts(uct_pending_req_t *self);

ucs_status_t ucp_am_rndv_process_rts(void *arg, void *data, size_t length,
unsigned tl_flags);
void ucp_am_rndv_process_rts(void *arg, void *data, size_t length,
unsigned tl_flags);

#endif
2 changes: 1 addition & 1 deletion src/ucp/core/ucp_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ static ucs_config_field_t ucp_context_config_table[] = {
"multiple rails. Must be greater than 0.",
ucs_offsetof(ucp_context_config_t, min_rndv_chunk_size), UCS_CONFIG_TYPE_MEMUNITS},

{"RNDV_SCHEME", "auto",
{"RNDV_SCHEME", "get_zcopy",
"Communication scheme in RNDV protocol.\n"
" get_zcopy - use get_zcopy scheme in RNDV protocol.\n"
" put_zcopy - use put_zcopy scheme in RNDV protocol.\n"
Expand Down
41 changes: 11 additions & 30 deletions src/ucp/core/ucp_request.c
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,6 @@ ucs_status_t ucp_request_send_start(ucp_request_t *req, ssize_t max_short,
const ucp_request_param_t *param)
{
ucs_status_t status;
int multi;

req->status = UCS_INPROGRESS;

Expand All @@ -602,20 +601,20 @@ ucs_status_t ucp_request_send_start(ucp_request_t *req, ssize_t max_short,
req->send.uct.func = proto->contig_short;
UCS_PROFILE_REQUEST_EVENT(req, "start_contig_short", req->send.length);
return UCS_OK;
} else if (length < zcopy_thresh) {
// disable bcopy multi fragment proto
} else if ((length < zcopy_thresh) &&
(length <= (msg_config->max_bcopy - proto->only_hdr_size))) {
/* bcopy */
ucp_request_send_state_reset(req, NULL, UCP_REQUEST_SEND_PROTO_BCOPY_AM);
ucs_assert(msg_config->max_bcopy >= proto->only_hdr_size);
if (length <= (msg_config->max_bcopy - proto->only_hdr_size)) {
req->send.uct.func = proto->bcopy_single;
UCS_PROFILE_REQUEST_EVENT(req, "start_bcopy_single", req->send.length);
} else {
ucp_request_init_multi_proto(req, proto->bcopy_multi,
"start_bcopy_multi");
}
req->send.uct.func = proto->bcopy_single;
UCS_PROFILE_REQUEST_EVENT(req, "start_bcopy_single", req->send.length);

return UCS_OK;
} else if (length < zcopy_max) {
// disable zcopy multi fragment proto
} else if ((length < zcopy_max) &&
ucs_likely(length <
msg_config->max_zcopy - proto->only_hdr_size)) {
/* zcopy */
ucp_request_send_state_reset(req, proto->zcopy_completion,
UCP_REQUEST_SEND_PROTO_ZCOPY_AM);
Expand All @@ -630,26 +629,8 @@ ucs_status_t ucp_request_send_start(ucp_request_t *req, ssize_t max_short,
return status;
}

if (ucs_unlikely(length > msg_config->max_zcopy - proto->only_hdr_size)) {
multi = 1;
} else if (ucs_unlikely(UCP_DT_IS_IOV(req->send.datatype))) {
if (dt_count <= (msg_config->max_iov - priv_iov_count)) {
multi = 0;
} else {
multi = ucp_dt_iov_count_nonempty(req->send.buffer, dt_count) >
(msg_config->max_iov - priv_iov_count);
}
} else {
multi = 0;
}

if (multi) {
ucp_request_init_multi_proto(req, proto->zcopy_multi,
"start_zcopy_multi");
} else {
req->send.uct.func = proto->zcopy_single;
UCS_PROFILE_REQUEST_EVENT(req, "start_zcopy_single", req->send.length);
}
req->send.uct.func = proto->zcopy_single;
UCS_PROFILE_REQUEST_EVENT(req, "start_zcopy_single", req->send.length);

return UCS_OK;
}
Expand Down
Loading

0 comments on commit 10e77ae

Please sign in to comment.