From e006a7d6c1a03333dace209ec696c126e0968b86 Mon Sep 17 00:00:00 2001 From: Devendar Bureddy Date: Tue, 19 Sep 2017 13:35:40 -0700 Subject: [PATCH] refactor mem address domain detection --- src/ucp/api/ucp_def.h | 11 ++- src/ucp/core/ucp_mm.c | 38 +++++++--- src/ucp/core/ucp_mm.h | 18 ++++- src/ucp/core/ucp_request.h | 2 +- src/ucp/dt/dt.c | 105 ++++++++++++++------------ src/ucp/tag/tag_recv.c | 2 +- src/ucp/tag/tag_send.c | 28 +++---- src/uct/api/uct.h | 17 +++-- src/uct/base/uct_md.c | 4 +- src/uct/base/uct_md.h | 2 +- src/uct/cuda/cuda_copy/cuda_copy_md.c | 18 ++--- src/uct/cuda/gdr_copy/gdr_copy_md.c | 14 ++-- src/uct/ib/base/ib_md.c | 2 +- src/uct/rocm/rocm_cma_md.c | 2 +- src/uct/sm/cma/cma_md.c | 2 +- src/uct/sm/knem/knem_md.c | 2 +- src/uct/sm/mm/mm_md.c | 2 +- src/uct/sm/self/self_md.c | 2 +- src/uct/ugni/base/ugni_md.c | 2 +- 19 files changed, 150 insertions(+), 123 deletions(-) diff --git a/src/ucp/api/ucp_def.h b/src/ucp/api/ucp_def.h index ab26cff81e7..64274fa51e6 100644 --- a/src/ucp/api/ucp_def.h +++ b/src/ucp/api/ucp_def.h @@ -136,10 +136,13 @@ typedef struct ucp_rkey *ucp_rkey_h; */ typedef struct ucp_mem *ucp_mem_h; -typedef struct ucp_addr_dn { - uint64_t mask; -} ucp_addr_dn_h; - +/* + * @ingroup UCP_ADDR_DN + * @brief UCP Address Domain + * + * Address Domain handle is an opaque object representing a memory adreess domain +*/ +typedef struct ucp_addr_dn *ucp_addr_dn_h; /** * @ingroup UCP_MEM diff --git a/src/ucp/core/ucp_mm.c b/src/ucp/core/ucp_mm.c index 9f66bf642a6..669edd49a75 100644 --- a/src/ucp/core/ucp_mm.c +++ b/src/ucp/core/ucp_mm.c @@ -24,6 +24,12 @@ static ucp_mem_t ucp_mem_dummy_handle = { .md_map = 0 }; +ucp_addr_dn_t ucp_addr_dn_dummy_handle = { + .md_map = 0, + .id = UCT_MD_ADDR_DOMAIN_LAST +}; + + /** * Unregister memory from all memory domains. * Save in *alloc_md_memh_p the memory handle of the allocating MD, if such exists. @@ -107,28 +113,38 @@ static ucs_status_t ucp_memh_reg_mds(ucp_context_h context, ucp_mem_h memh, } ucs_status_t ucp_addr_domain_detect_mds(ucp_context_h context, void *addr, - ucp_addr_dn_h *addr_dn) + ucp_addr_dn_h *addr_dn_h) { ucs_status_t status; unsigned md_index; - uint64_t dn_mask; + uct_addr_domain_t domain_id = UCT_MD_ADDR_DOMAIN_DEFAULT; + + *addr_dn_h = &ucp_addr_dn_dummy_handle; - addr_dn->mask = 0; + /*TODO: return if no MDs with address domain detect */ for (md_index = 0; md_index < context->num_mds; ++md_index) { if (context->tl_mds[md_index].attr.cap.flags & UCT_MD_FLAG_ADDR_DN) { - if(!(addr_dn->mask & context->tl_mds[md_index].attr.cap.addr_dn_mask)) { - dn_mask = 0; - - status = uct_md_mem_detect(context->tl_mds[md_index].md, addr, &dn_mask); - if (status != UCS_OK) { - return status; + if (domain_id == UCT_MD_ADDR_DOMAIN_DEFAULT) { + status = uct_md_mem_detect(context->tl_mds[md_index].md, addr); + if (status == UCS_OK) { + domain_id = context->tl_mds[md_index].attr.cap.addr_dn; + + *addr_dn_h = ucs_malloc(sizeof(ucp_addr_dn_t), "ucp_addr_dn_h"); + if (*addr_dn_h == NULL) { + return UCS_ERR_NO_MEMORY; + } + + (*addr_dn_h)->id = domain_id; + (*addr_dn_h)->md_map = UCS_BIT(md_index); + } + } else { + if (domain_id == context->tl_mds[md_index].attr.cap.addr_dn) { + (*addr_dn_h)->md_map |= UCS_BIT(md_index); } - addr_dn->mask |= dn_mask; } } } - return UCS_OK; } /** diff --git a/src/ucp/core/ucp_mm.h b/src/ucp/core/ucp_mm.h index 0223fdcabc0..2516a07f073 100644 --- a/src/ucp/core/ucp_mm.h +++ b/src/ucp/core/ucp_mm.h @@ -64,6 +64,15 @@ typedef struct ucp_mem_desc { } ucp_mem_desc_t; +/** + * Memory Address Domain descriptor. + * Contains domain information of the memory address it belongs to. + */ +typedef struct ucp_addr_dn { + ucp_md_map_t md_map; /* Which MDs have own ths addr Domain */ + uct_addr_domain_t id; /* Address domain index */ +} ucp_addr_dn_t; + void ucp_rkey_resolve_inner(ucp_rkey_h rkey, ucp_ep_h ep); ucs_status_t ucp_mpool_malloc(ucs_mpool_t *mp, size_t *size_p, void **chunk_p); @@ -72,8 +81,15 @@ void ucp_mpool_free(ucs_mpool_t *mp, void *chunk); void ucp_mpool_obj_init(ucs_mpool_t *mp, void *obj, void *chunk); +/** + * Detects the address domain on all MDs. skips on detect on sub-sequence MDs + * if it sucessfully detected by MD. +**/ ucs_status_t ucp_addr_domain_detect_mds(ucp_context_h context, void *addr, - ucp_addr_dn_h *addr_dn); + ucp_addr_dn_h *addr_dn_h); + + +extern ucp_addr_dn_t ucp_addr_dn_dummy_handle; static UCS_F_ALWAYS_INLINE uct_mem_h ucp_memh2uct(ucp_mem_h memh, ucp_md_index_t md_idx) diff --git a/src/ucp/core/ucp_request.h b/src/ucp/core/ucp_request.h index 2e206e73ff9..45209d775b7 100644 --- a/src/ucp/core/ucp_request.h +++ b/src/ucp/core/ucp_request.h @@ -78,7 +78,7 @@ typedef void (*ucp_request_callback_t)(ucp_request_t *req); struct ucp_request { ucs_status_t status; /* Operation status */ uint16_t flags; /* Request flags */ - ucp_addr_dn_h dn_mask; /* Memory domain mask */ + ucp_addr_dn_h addr_dn_h; /* Memory domain handle */ union { struct { diff --git a/src/ucp/dt/dt.c b/src/ucp/dt/dt.c index 1938ebea7e0..677fb7799cf 100644 --- a/src/ucp/dt/dt.c +++ b/src/ucp/dt/dt.c @@ -48,62 +48,69 @@ size_t ucp_dt_pack(ucp_datatype_t datatype, void *dest, const void *src, static UCS_F_ALWAYS_INLINE ucs_status_t ucp_dn_dt_unpack(ucp_request_t *req, void *buffer, size_t buffer_size, - const void *recv_data, size_t recv_length) + const void *recv_data, size_t recv_length) { ucp_worker_h worker = req->recv.worker; ucp_context_h context = worker->context; unsigned md_index; ucs_status_t status; + ucp_ep_h ep = ucp_worker_ep_find(worker, worker->uuid); for (md_index = 0; md_index < context->num_mds; md_index++) { - if (context->tl_mds[md_index].attr.cap.addr_dn_mask & req->dn_mask.mask) { - ucp_ep_h ep = ucp_worker_ep_find(worker, worker->uuid); - uct_mem_h memh; - uct_iov_t iov; - void *rkey_buffer; - size_t rkey_buffer_size; - ucp_rkey_h rkey; - ucp_lane_index_t lane; - - status = uct_md_mem_reg(context->tl_mds[md_index].md, buffer, buffer_size, - 0, &memh); - if (status != UCS_OK) { - uct_md_mem_dereg(context->tl_mds[md_index].md, memh); - ucs_error("Failed to reg address %p with md %s", buffer, - context->tl_mds[md_index].rsc.md_name); - return status; - } - - ucp_rkey_pack(context, memh, &rkey_buffer, &rkey_buffer_size); - ucp_ep_rkey_unpack(ep, rkey_buffer, &rkey); - ucp_rkey_buffer_release(rkey_buffer); - - iov.buffer = buffer; - iov.length = buffer_size; - iov.count = 1; - iov.memh = memh; - - lane = rkey->cache.rma_lane; - - status = uct_ep_put_zcopy(ep->uct_eps[lane], &iov, 1, (uint64_t)recv_data, - rkey->cache.rma_rkey, NULL); - if (status != UCS_OK) { - ucp_rkey_destroy(rkey); - uct_md_mem_dereg(context->tl_mds[md_index].md, memh); - ucs_error("Failed to perform uct_ep_put_zcopy to address %p", recv_data); - return status; - } - - ucp_rkey_destroy(rkey); - - status = uct_md_mem_dereg(context->tl_mds[md_index].md, memh); - if (status != UCS_OK) { - ucs_error("Failed to dereg address %p with md %s", buffer, - context->tl_mds[md_index].rsc.md_name); - return status; - } - break; + + if (!(UCS_BIT(md_index) & req->addr_dn_h->md_map)) { + continue; + } + + /*TODO check if put-zcopy there on iface */ + + uct_mem_h memh; + uct_iov_t iov; + + // void *rkey_buffer; + // size_t rkey_buffer_size; + // ucp_rkey_h rkey; + // ucp_lane_index_t lane; + + status = uct_md_mem_reg(context->tl_mds[md_index].md, buffer, buffer_size, + 0, &memh); + if (status != UCS_OK) { + uct_md_mem_dereg(context->tl_mds[md_index].md, memh); + ucs_error("Failed to reg address %p with md %s", buffer, + context->tl_mds[md_index].rsc.md_name); + return status; } + + // ucp_rkey_pack(context, memh, &rkey_buffer, &rkey_buffer_size); + // ucp_ep_rkey_unpack(ep, rkey_buffer, &rkey); + // ucp_rkey_buffer_release(rkey_buffer); + + ucs_assert(buffer_size >= recv_length); + iov.buffer = (void *)recv_data; + iov.length = recv_length; + iov.count = 1; + iov.memh = UCT_MEM_HANDLE_NULL; + + //lane = rkey->cache.rma_lane; + + status = uct_ep_put_zcopy(ep->uct_eps[0], &iov, 1, (uint64_t)buffer, + (uct_rkey_t )memh, NULL); + if (status != UCS_OK) { + // ucp_rkey_destroy(rkey); + uct_md_mem_dereg(context->tl_mds[md_index].md, memh); + ucs_error("Failed to perform uct_ep_put_zcopy to address %p", recv_data); + return status; + } + + //ucp_rkey_destroy(rkey); + + status = uct_md_mem_dereg(context->tl_mds[md_index].md, memh); + if (status != UCS_OK) { + ucs_error("Failed to dereg address %p with md %s", buffer, + context->tl_mds[md_index].rsc.md_name); + return status; + } + break; } return UCS_OK; @@ -128,7 +135,7 @@ ucs_status_t ucp_dt_unpack(ucp_request_t *req, ucp_datatype_t datatype, void *bu switch (datatype & UCP_DATATYPE_CLASS_MASK) { case UCP_DATATYPE_CONTIG: - if (!(req->dn_mask.mask)) { + if (req->addr_dn_h == &ucp_addr_dn_dummy_handle) { UCS_PROFILE_NAMED_CALL("memcpy_recv", memcpy, buffer + offset, recv_data, recv_length); return UCS_OK; diff --git a/src/ucp/tag/tag_recv.c b/src/ucp/tag/tag_recv.c index a5b1a491e8e..896841b6ecd 100644 --- a/src/ucp/tag/tag_recv.c +++ b/src/ucp/tag/tag_recv.c @@ -127,7 +127,7 @@ ucp_tag_recv_request_init(ucp_request_t *req, ucp_worker_h worker, void* buffer, req->recv.state.offset = 0; req->recv.worker = worker; - ucp_addr_domain_detect_mds(worker->context, buffer, &(req->dn_mask)); + ucp_addr_domain_detect_mds(worker->context, buffer, &(req->addr_dn_h)); switch (datatype & UCP_DATATYPE_CLASS_MASK) { case UCP_DATATYPE_IOV: diff --git a/src/ucp/tag/tag_send.c b/src/ucp/tag/tag_send.c index 135be708717..c2288d5a68f 100644 --- a/src/ucp/tag/tag_send.c +++ b/src/ucp/tag/tag_send.c @@ -214,7 +214,7 @@ ucp_tag_send_req(ucp_request_t *req, size_t count, ssize_t max_short, static void ucp_tag_send_req_init(ucp_request_t* req, ucp_ep_h ep, const void* buffer, uintptr_t datatype, ucp_tag_t tag, uint16_t flags, - ucp_addr_dn_h addr_dn) + ucp_addr_dn_h addr_dn_h) { req->flags = flags; req->send.ep = ep; @@ -223,7 +223,7 @@ static void ucp_tag_send_req_init(ucp_request_t* req, ucp_ep_h ep, req->send.tag = tag; req->send.reg_rsc = UCP_NULL_RESOURCE; req->send.state.offset = 0; - req->dn_mask = addr_dn; + req->addr_dn_h = addr_dn_h; VALGRIND_MAKE_MEM_UNDEFINED(&req->send.uct_comp.count, sizeof(req->send.uct_comp.count)); @@ -241,18 +241,16 @@ UCS_PROFILE_FUNC(ucs_status_ptr_t, ucp_tag_send_nb, ucp_request_t *req; size_t length; ucs_status_ptr_t ret; - ucp_addr_dn_h addr_dn; - uint64_t mask; - unsigned md_index; + ucp_addr_dn_h addr_dn_h; UCP_THREAD_CS_ENTER_CONDITIONAL(&ep->worker->mt_lock); ucs_trace_req("send_nb buffer %p count %zu tag %"PRIx64" to %s cb %p", buffer, count, tag, ucp_ep_peer_name(ep), cb); - ucp_addr_domain_detect_mds(ep->worker->context, (void *)buffer, &addr_dn); + ucp_addr_domain_detect_mds(ep->worker->context, (void *)buffer, &addr_dn_h); - if (!addr_dn.mask && ucs_likely(UCP_DT_IS_CONTIG(datatype))) { + if (addr_dn_h->id == UCT_MD_ADDR_DOMAIN_DEFAULT && ucs_likely(UCP_DT_IS_CONTIG(datatype))) { length = ucp_contig_dt_length(datatype, count); if (ucs_likely((ssize_t)length <= ucp_ep_config(ep)->tag.eager.max_short)) { status = UCS_PROFILE_CALL(ucp_tag_send_eager_short, ep, tag, buffer, @@ -271,17 +269,11 @@ UCS_PROFILE_FUNC(ucs_status_ptr_t, ucp_tag_send_nb, goto out; } - ucp_tag_send_req_init(req, ep, buffer, datatype, tag, 0, addr_dn); - - mask = addr_dn.mask; - md_index = 0; - if (addr_dn.mask) { - CONVERT_BITMASK_TO_INDEX(mask, md_index); - } + ucp_tag_send_req_init(req, ep, buffer, datatype, tag, 0, addr_dn_h); ret = ucp_tag_send_req(req, count, - addr_dn.mask ? ucp_ep_config(ep)->dn[md_index].tag.eager.max_short : ucp_ep_config(ep)->tag.eager.max_short, - addr_dn.mask ? ucp_ep_config(ep)->dn[md_index].tag.eager.zcopy_thresh : ucp_ep_config(ep)->tag.eager.zcopy_thresh, + addr_dn_h->id != UCT_MD_ADDR_DOMAIN_DEFAULT ? ucp_ep_config(ep)->dn[addr_dn_h->id].tag.eager.max_short : ucp_ep_config(ep)->tag.eager.max_short, + addr_dn_h->id != UCT_MD_ADDR_DOMAIN_DEFAULT ? ucp_ep_config(ep)->dn[addr_dn_h->id].tag.eager.zcopy_thresh : ucp_ep_config(ep)->tag.eager.zcopy_thresh, ucp_ep_config(ep)->tag.rndv.rma_thresh, ucp_ep_config(ep)->tag.rndv.am_thresh, cb, ucp_ep_config(ep)->tag.proto); @@ -297,7 +289,6 @@ UCS_PROFILE_FUNC(ucs_status_ptr_t, ucp_tag_send_sync_nb, { ucp_request_t *req; ucs_status_ptr_t ret; - ucp_addr_dn_h addr_dn; UCP_THREAD_CS_ENTER_CONDITIONAL(&ep->worker->mt_lock); @@ -318,10 +309,9 @@ UCS_PROFILE_FUNC(ucs_status_ptr_t, ucp_tag_send_sync_nb, /* Remote side needs to send reply, so have it connect to us */ ucp_ep_connect_remote(ep); - ucp_addr_domain_detect_mds(ep->worker->context, (void *)buffer, &addr_dn); ucp_tag_send_req_init(req, ep, buffer, datatype, tag, UCP_REQUEST_FLAG_SYNC, - addr_dn); + &ucp_addr_dn_dummy_handle); ret = ucp_tag_send_req(req, count, -1, /* disable short method */ diff --git a/src/uct/api/uct.h b/src/uct/api/uct.h index 3d6d5d11699..097ba9aa987 100644 --- a/src/uct/api/uct.h +++ b/src/uct/api/uct.h @@ -336,9 +336,12 @@ enum { * @ingroup UCT_MD * @brief Memory addr domoins. */ -enum { - UCT_MD_ADDR_DOMAIN_CUDA = UCS_BIT(0) /**< NVIDIA CUDA domain */ -}; +typedef enum { + UCT_MD_ADDR_DOMAIN_CUDA = 0, /**< NVIDIA CUDA domain */ + UCT_MD_ADDR_DOMAIN_DEFAULT, /**< Default system domain */ + UCT_MD_ADDR_DOMAIN_LAST = UCT_MD_ADDR_DOMAIN_DEFAULT + +} uct_addr_domain_t; /** @@ -533,7 +536,7 @@ struct uct_md_attr { size_t max_alloc; /**< Maximal allocation size */ size_t max_reg; /**< Maximal registration size */ uint64_t flags; /**< UCT_MD_FLAG_xx */ - uint64_t addr_dn_mask; /**< Supported addr domains */ + uct_addr_domain_t addr_dn; /**< Supported addr domain */ struct { size_t max_short; } eager; @@ -1278,13 +1281,13 @@ ucs_status_t uct_md_mem_dereg(uct_md_h md, uct_mem_h memh); * @ingroup UCT_MD * @brief Detect memory on the memory domain. * - * Detect memory on the memory domain. Return memory domain in domain mask. + * Detect memory on the memory domain. + * Return UCS_OK if address belongs to MDs address domain * * @param [in] md Memory domain to register memory on. * @param [in] address Memory address to detect. - * @param [out] dn_mask Filled with memory domain mask. */ -ucs_status_t uct_md_mem_detect(uct_md_h md, void *addr, uint64_t *dn_mask); +ucs_status_t uct_md_mem_detect(uct_md_h md, void *addr); /** * @ingroup UCT_MD diff --git a/src/uct/base/uct_md.c b/src/uct/base/uct_md.c index bb72307e3d8..c8ce1768678 100644 --- a/src/uct/base/uct_md.c +++ b/src/uct/base/uct_md.c @@ -507,7 +507,7 @@ ucs_status_t uct_md_mem_dereg(uct_md_h md, uct_mem_h memh) return md->ops->mem_dereg(md, memh); } -ucs_status_t uct_md_mem_detect(uct_md_h md, void *addr, uint64_t *dn_mask) +ucs_status_t uct_md_mem_detect(uct_md_h md, void *addr) { - return md->ops->mem_detect(md, addr, dn_mask); + return md->ops->mem_detect(md, addr); } diff --git a/src/uct/base/uct_md.h b/src/uct/base/uct_md.h index 9a61b488bdb..266f1be3700 100644 --- a/src/uct/base/uct_md.h +++ b/src/uct/base/uct_md.h @@ -131,7 +131,7 @@ struct uct_md_ops { ucs_status_t (*mkey_pack)(uct_md_h md, uct_mem_h memh, void *rkey_buffer); - ucs_status_t (*mem_detect)(uct_md_h md, void *addr, uint64_t *dn_mask); + ucs_status_t (*mem_detect)(uct_md_h md, void *addr); }; diff --git a/src/uct/cuda/cuda_copy/cuda_copy_md.c b/src/uct/cuda/cuda_copy/cuda_copy_md.c index acd1f7aac93..5e7d959046d 100644 --- a/src/uct/cuda/cuda_copy/cuda_copy_md.c +++ b/src/uct/cuda/cuda_copy/cuda_copy_md.c @@ -19,7 +19,7 @@ static ucs_status_t uct_cuda_copy_md_query(uct_md_h md, uct_md_attr_t *md_attr) { md_attr->cap.flags = UCT_MD_FLAG_REG | UCT_MD_FLAG_ADDR_DN; - md_attr->cap.addr_dn_mask = UCT_MD_ADDR_DOMAIN_CUDA; + md_attr->cap.addr_dn = UCT_MD_ADDR_DOMAIN_CUDA; md_attr->cap.max_alloc = 0; md_attr->cap.max_reg = ULONG_MAX; md_attr->cap.eager.max_short = -1; @@ -85,18 +85,15 @@ static ucs_status_t uct_cuda_copy_mem_dereg(uct_md_h md, uct_mem_h memh) return UCS_OK; } -static ucs_status_t uct_cuda_copy_mem_detect(uct_md_h md, void *addr, uint64_t *dn_mask) +static ucs_status_t uct_cuda_copy_mem_detect(uct_md_h md, void *addr) { -#if HAVE_CUDA int memory_type; cudaError_t cuda_err = cudaSuccess; struct cudaPointerAttributes attributes; CUresult cu_err = CUDA_SUCCESS; - (*dn_mask) = 0; - if (addr == NULL) { - return UCS_OK; + return UCS_ERR_INVALID_ADDR; } cu_err = cuPointerGetAttribute(&memory_type, @@ -106,16 +103,13 @@ static ucs_status_t uct_cuda_copy_mem_detect(uct_md_h md, void *addr, uint64_t * cuda_err = cudaPointerGetAttributes (&attributes, addr); if (cuda_err == cudaSuccess) { if (attributes.memoryType == cudaMemoryTypeDevice) { - (*dn_mask) = UCT_MD_ADDR_DOMAIN_CUDA; + return UCS_OK; } } } else if (memory_type == CU_MEMORYTYPE_DEVICE) { - (*dn_mask) = UCT_MD_ADDR_DOMAIN_CUDA; + return UCS_OK; } -#else - (*dn_mask) = 0; -#endif - return UCS_OK; + return UCS_ERR_INVALID_ADDR; } static ucs_status_t uct_cuda_copy_query_md_resources(uct_md_resource_desc_t **resources_p, diff --git a/src/uct/cuda/gdr_copy/gdr_copy_md.c b/src/uct/cuda/gdr_copy/gdr_copy_md.c index e0499003a65..b5c7f9ab6ca 100644 --- a/src/uct/cuda/gdr_copy/gdr_copy_md.c +++ b/src/uct/cuda/gdr_copy/gdr_copy_md.c @@ -47,7 +47,7 @@ static ucs_config_field_t uct_gdr_copy_md_config_table[] = { static ucs_status_t uct_gdr_copy_md_query(uct_md_h md, uct_md_attr_t *md_attr) { md_attr->cap.flags = UCT_MD_FLAG_REG | UCT_MD_FLAG_ADDR_DN; - md_attr->cap.addr_dn_mask = UCT_MD_ADDR_DOMAIN_CUDA; + md_attr->cap.addr_dn = UCT_MD_ADDR_DOMAIN_CUDA; md_attr->cap.max_alloc = 0; md_attr->cap.max_reg = ULONG_MAX; md_attr->cap.eager.max_short = -1; @@ -164,17 +164,15 @@ static ucs_status_t uct_gdr_copy_mem_dereg(uct_md_h uct_md, uct_mem_h memh) return status; } -static ucs_status_t uct_gdr_copy_mem_detect(uct_md_h md, void *addr, uint64_t *dn_mask) +static ucs_status_t uct_gdr_copy_mem_detect(uct_md_h md, void *addr) { int memory_type; cudaError_t cuda_err = cudaSuccess; struct cudaPointerAttributes attributes; CUresult cu_err = CUDA_SUCCESS; - (*dn_mask) = 0; - if (addr == NULL) { - return UCS_OK; + return UCS_ERR_INVALID_ADDR; } cu_err = cuPointerGetAttribute(&memory_type, @@ -184,14 +182,14 @@ static ucs_status_t uct_gdr_copy_mem_detect(uct_md_h md, void *addr, uint64_t *d cuda_err = cudaPointerGetAttributes (&attributes, addr); if (cuda_err == cudaSuccess) { if (attributes.memoryType == cudaMemoryTypeDevice) { - (*dn_mask) = UCT_MD_ADDR_DOMAIN_CUDA; + return UCS_OK; } } } else if (memory_type == CU_MEMORYTYPE_DEVICE) { - (*dn_mask) = UCT_MD_ADDR_DOMAIN_CUDA; + return UCS_OK; } - return UCS_OK; + return UCS_ERR_INVALID_ADDR; } static ucs_status_t uct_gdr_copy_query_md_resources(uct_md_resource_desc_t **resources_p, diff --git a/src/uct/ib/base/ib_md.c b/src/uct/ib/base/ib_md.c index e4b32404fc3..5e22982a89e 100644 --- a/src/uct/ib/base/ib_md.c +++ b/src/uct/ib/base/ib_md.c @@ -158,7 +158,7 @@ static ucs_status_t uct_ib_md_query(uct_md_h uct_md, uct_md_attr_t *md_attr) UCT_MD_FLAG_NEED_MEMH | UCT_MD_FLAG_NEED_RKEY | UCT_MD_FLAG_ADVISE; - md_attr->cap.addr_dn_mask = 0; + md_attr->cap.addr_dn = UCT_MD_ADDR_DOMAIN_DEFAULT; md_attr->rkey_packed_size = sizeof(uint64_t); if (md->config.enable_contig_pages && diff --git a/src/uct/rocm/rocm_cma_md.c b/src/uct/rocm/rocm_cma_md.c index 90b576ad379..7d70152b41a 100644 --- a/src/uct/rocm/rocm_cma_md.c +++ b/src/uct/rocm/rocm_cma_md.c @@ -30,7 +30,7 @@ static ucs_status_t uct_rocm_cma_md_query(uct_md_h md, uct_md_attr_t *md_attr) md_attr->rkey_packed_size = sizeof(uct_rocm_cma_key_t); md_attr->cap.flags = UCT_MD_FLAG_REG | UCT_MD_FLAG_NEED_RKEY; - md_attr->cap.addr_dn_mask = 0; + md_attr->cap.addr_dn = UCT_MD_ADDR_DOMAIN_DEFAULT; md_attr->cap.max_alloc = 0; md_attr->cap.max_reg = ULONG_MAX; diff --git a/src/uct/sm/cma/cma_md.c b/src/uct/sm/cma/cma_md.c index 43e4af39242..b5b4503f0f5 100644 --- a/src/uct/sm/cma/cma_md.c +++ b/src/uct/sm/cma/cma_md.c @@ -81,7 +81,7 @@ ucs_status_t uct_cma_md_query(uct_md_h md, uct_md_attr_t *md_attr) { md_attr->rkey_packed_size = 0; md_attr->cap.flags = UCT_MD_FLAG_REG; - md_attr->cap.addr_dn_mask = 0; + md_attr->cap.addr_dn = UCT_MD_ADDR_DOMAIN_DEFAULT; md_attr->cap.max_alloc = 0; md_attr->cap.max_reg = ULONG_MAX; md_attr->reg_cost.overhead = 9e-9; diff --git a/src/uct/sm/knem/knem_md.c b/src/uct/sm/knem/knem_md.c index 57b888e0e0e..8592a9fa55a 100644 --- a/src/uct/sm/knem/knem_md.c +++ b/src/uct/sm/knem/knem_md.c @@ -13,7 +13,7 @@ ucs_status_t uct_knem_md_query(uct_md_h md, uct_md_attr_t *md_attr) md_attr->rkey_packed_size = sizeof(uct_knem_key_t); md_attr->cap.flags = UCT_MD_FLAG_REG | UCT_MD_FLAG_NEED_RKEY; - md_attr->cap.addr_dn_mask = 0; + md_attr->cap.addr_dn = UCT_MD_ADDR_DOMAIN_DEFAULT; md_attr->cap.max_alloc = 0; md_attr->cap.max_reg = ULONG_MAX; md_attr->reg_cost.overhead = 1200.0e-9; diff --git a/src/uct/sm/mm/mm_md.c b/src/uct/sm/mm/mm_md.c index b48d4dcdf3d..1cb173d3d59 100644 --- a/src/uct/sm/mm/mm_md.c +++ b/src/uct/sm/mm/mm_md.c @@ -124,7 +124,7 @@ ucs_status_t uct_mm_md_query(uct_md_h md, uct_md_attr_t *md_attr) md_attr->reg_cost.growth = 0.007e-9; } md_attr->cap.flags |= UCT_MD_FLAG_NEED_RKEY; - md_attr->cap.addr_dn_mask = 0; + md_attr->cap.addr_dn = UCT_MD_ADDR_DOMAIN_DEFAULT; /* all mm md(s) support fixed memory alloc */ md_attr->cap.flags |= UCT_MD_FLAG_FIXED; md_attr->cap.max_alloc = ULONG_MAX; diff --git a/src/uct/sm/self/self_md.c b/src/uct/sm/self/self_md.c index 4bef485d12e..9e5a0386b14 100644 --- a/src/uct/sm/self/self_md.c +++ b/src/uct/sm/self/self_md.c @@ -10,7 +10,7 @@ static ucs_status_t uct_self_md_query(uct_md_h md, uct_md_attr_t *attr) { /* Dummy memory registration provided. No real memory handling exists */ attr->cap.flags = UCT_MD_FLAG_REG; - attr->cap.addr_dn_mask = 0; + attr->cap.addr_dn = UCT_MD_ADDR_DOMAIN_DEFAULT; attr->cap.max_alloc = 0; attr->cap.max_reg = ULONG_MAX; attr->rkey_packed_size = 0; /* uct_md_query adds UCT_MD_COMPONENT_NAME_MAX to this */ diff --git a/src/uct/ugni/base/ugni_md.c b/src/uct/ugni/base/ugni_md.c index 9c8e5d5bd70..065c339680c 100644 --- a/src/uct/ugni/base/ugni_md.c +++ b/src/uct/ugni/base/ugni_md.c @@ -34,7 +34,7 @@ static ucs_status_t uct_ugni_md_query(uct_md_h md, uct_md_attr_t *md_attr) md_attr->cap.flags = UCT_MD_FLAG_REG | UCT_MD_FLAG_NEED_MEMH | UCT_MD_FLAG_NEED_RKEY; - md_attr->cap.addr_dn_mask = 0; + md_attr->cap.addr_dn = UCT_MD_ADDR_DOMAIN_DEFAULT; md_attr->cap.max_alloc = 0; md_attr->cap.max_reg = ULONG_MAX; md_attr->reg_cost.overhead = 1000.0e-9;