diff --git a/src/ucp/rma/amo_send.c b/src/ucp/rma/amo_send.c index 5ec3ea729cb..b44dc1bf10e 100644 --- a/src/ucp/rma/amo_send.c +++ b/src/ucp/rma/amo_send.c @@ -50,6 +50,21 @@ } +#define UCP_AMO_CHECK_PARAM_NBX(_context, _remote_addr, _size, _count, \ + _opcode, _last_opcode, _action) \ + { \ + if (ENABLE_PARAMS_CHECK) { \ + if ((_count) != 1) { \ + ucs_error("unsupported number of elements: %zu", (_count)); \ + _action; \ + } \ + } \ + \ + UCP_AMO_CHECK_PARAM(_context, _remote_addr, _size, _opcode, \ + _last_opcode, _action); \ + } + + static uct_atomic_op_t ucp_uct_op_table[] = { [UCP_ATOMIC_POST_OP_ADD] = UCT_ATOMIC_OP_ADD, [UCP_ATOMIC_POST_OP_AND] = UCT_ATOMIC_OP_AND, @@ -118,20 +133,57 @@ ucs_status_ptr_t ucp_atomic_fetch_nb(ucp_ep_h ep, ucp_atomic_fetch_op_t opcode, uint64_t value, void *result, size_t op_size, uint64_t remote_addr, ucp_rkey_h rkey, ucp_send_callback_t cb) +{ + ucp_request_param_t param = { + .op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_DATATYPE, + .datatype = ucp_dt_make_contig(op_size), + .cb.send = (ucp_send_nbx_callback_t)cb + }; + + return ucp_atomic_fetch_nbx(ep, opcode, &value, result, 1, + remote_addr, rkey, ¶m); +} + +ucs_status_ptr_t +ucp_atomic_fetch_nbx(ucp_ep_h ep, ucp_atomic_fetch_op_t opcode, + const void *buffer, void *result, size_t count, + uint64_t remote_addr, ucp_rkey_h rkey, + const ucp_request_param_t *param) { ucs_status_ptr_t status_p; ucs_status_t status; ucp_request_t *req; + uint64_t value; + size_t op_size; - UCP_AMO_CHECK_PARAM(ep->worker->context, remote_addr, op_size, opcode, - UCP_ATOMIC_FETCH_OP_LAST, - return UCS_STATUS_PTR(UCS_ERR_INVALID_PARAM)); + if (ucs_unlikely(!(param->op_attr_mask & UCP_OP_ATTR_FIELD_DATATYPE))) { + ucs_error("missing atomic operation datatype"); + return UCS_STATUS_PTR(UCS_ERR_INVALID_PARAM); + } + + if (param->datatype == ucp_dt_make_contig(8)) { + value = *(uint64_t*)buffer; + op_size = sizeof(uint64_t); + } else if (param->datatype == ucp_dt_make_contig(4)) { + value = *(uint32_t*)buffer; + op_size = sizeof(uint32_t); + } else { + ucs_error("invalid atomic operation datatype: %zu", param->datatype); + return UCS_STATUS_PTR(UCS_ERR_INVALID_PARAM); + } + + UCP_AMO_CHECK_PARAM_NBX(ep->worker->context, remote_addr, op_size, + count, opcode, UCP_ATOMIC_FETCH_OP_LAST, + return UCS_STATUS_PTR(UCS_ERR_INVALID_PARAM)); UCP_WORKER_THREAD_CS_ENTER_CONDITIONAL(ep->worker); - ucs_trace_req("atomic_fetch_nb opcode %d value %"PRIu64" buffer %p size %zu" - " remote_addr %"PRIx64" rkey %p to %s cb %p", - opcode, value, result, op_size, remote_addr, rkey, - ucp_ep_peer_name(ep), cb); + ucs_trace_req("atomic_fetch_nb opcode %d buffer %p result %p " + "datatype %zu remote_addr %"PRIx64" rkey %p to %s cb %p", + opcode, buffer, result, param->datatype, remote_addr, rkey, + ucp_ep_peer_name(ep), + (param->op_attr_mask & UCP_OP_ATTR_FIELD_CALLBACK) ? + param->cb.send : NULL); status = UCP_RKEY_RESOLVE(rkey, ep, amo); if (status != UCS_OK) { @@ -139,31 +191,20 @@ ucs_status_ptr_t ucp_atomic_fetch_nb(ucp_ep_h ep, ucp_atomic_fetch_op_t opcode, goto out; } - req = ucp_request_get(ep->worker); - if (ucs_unlikely(NULL == req)) { - status_p = UCS_STATUS_PTR(UCS_ERR_NO_MEMORY); - goto out; - } + req = ucp_request_get_param(ep->worker, param, + {status_p = UCS_STATUS_PTR(UCS_ERR_NO_MEMORY); + goto out;}); ucp_amo_init_fetch(req, ep, result, ucp_uct_fop_table[opcode], op_size, remote_addr, rkey, value, rkey->cache.amo_proto); - status_p = ucp_rma_send_request_cb(req, cb); + status_p = ucp_rma_send_request(req, param); out: UCP_WORKER_THREAD_CS_EXIT_CONDITIONAL(ep->worker); return status_p; } -ucs_status_ptr_t -ucp_atomic_fetch_nbx(ucp_ep_h ep, ucp_atomic_fetch_op_t opcode, - const void *buffer, void *result, size_t count, - uint64_t remote_addr, ucp_rkey_h rkey, - const ucp_request_param_t *param) -{ - return UCS_STATUS_PTR(UCS_ERR_NOT_IMPLEMENTED); -} - ucs_status_t ucp_atomic_post(ucp_ep_h ep, ucp_atomic_post_op_t opcode, uint64_t value, size_t op_size, uint64_t remote_addr, ucp_rkey_h rkey) {