Skip to content

Commit

Permalink
Merge pull request openucx#15 from bureddy/refactor_mem_detect
Browse files Browse the repository at this point in the history
refactor mem address domain detection
  • Loading branch information
bureddy authored Sep 19, 2017
2 parents 841e4d0 + e006a7d commit 5007e2c
Show file tree
Hide file tree
Showing 19 changed files with 150 additions and 123 deletions.
11 changes: 7 additions & 4 deletions src/ucp/api/ucp_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,13 @@ typedef struct ucp_rkey *ucp_rkey_h;
*/
typedef struct ucp_mem *ucp_mem_h;

typedef struct ucp_addr_dn {
uint64_t mask;
} ucp_addr_dn_h;

/*
* @ingroup UCP_ADDR_DN
* @brief UCP Address Domain
*
* Address Domain handle is an opaque object representing a memory adreess domain
*/
typedef struct ucp_addr_dn *ucp_addr_dn_h;

/**
* @ingroup UCP_MEM
Expand Down
38 changes: 27 additions & 11 deletions src/ucp/core/ucp_mm.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ static ucp_mem_t ucp_mem_dummy_handle = {
.md_map = 0
};

ucp_addr_dn_t ucp_addr_dn_dummy_handle = {
.md_map = 0,
.id = UCT_MD_ADDR_DOMAIN_LAST
};


/**
* Unregister memory from all memory domains.
* Save in *alloc_md_memh_p the memory handle of the allocating MD, if such exists.
Expand Down Expand Up @@ -107,28 +113,38 @@ static ucs_status_t ucp_memh_reg_mds(ucp_context_h context, ucp_mem_h memh,
}

ucs_status_t ucp_addr_domain_detect_mds(ucp_context_h context, void *addr,
ucp_addr_dn_h *addr_dn)
ucp_addr_dn_h *addr_dn_h)
{
ucs_status_t status;
unsigned md_index;
uint64_t dn_mask;
uct_addr_domain_t domain_id = UCT_MD_ADDR_DOMAIN_DEFAULT;

*addr_dn_h = &ucp_addr_dn_dummy_handle;

addr_dn->mask = 0;
/*TODO: return if no MDs with address domain detect */

for (md_index = 0; md_index < context->num_mds; ++md_index) {
if (context->tl_mds[md_index].attr.cap.flags & UCT_MD_FLAG_ADDR_DN) {
if(!(addr_dn->mask & context->tl_mds[md_index].attr.cap.addr_dn_mask)) {
dn_mask = 0;

status = uct_md_mem_detect(context->tl_mds[md_index].md, addr, &dn_mask);
if (status != UCS_OK) {
return status;
if (domain_id == UCT_MD_ADDR_DOMAIN_DEFAULT) {
status = uct_md_mem_detect(context->tl_mds[md_index].md, addr);
if (status == UCS_OK) {
domain_id = context->tl_mds[md_index].attr.cap.addr_dn;

*addr_dn_h = ucs_malloc(sizeof(ucp_addr_dn_t), "ucp_addr_dn_h");
if (*addr_dn_h == NULL) {
return UCS_ERR_NO_MEMORY;
}

(*addr_dn_h)->id = domain_id;
(*addr_dn_h)->md_map = UCS_BIT(md_index);
}
} else {
if (domain_id == context->tl_mds[md_index].attr.cap.addr_dn) {
(*addr_dn_h)->md_map |= UCS_BIT(md_index);
}
addr_dn->mask |= dn_mask;
}
}
}

return UCS_OK;
}
/**
Expand Down
18 changes: 17 additions & 1 deletion src/ucp/core/ucp_mm.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ typedef struct ucp_mem_desc {
} ucp_mem_desc_t;


/**
* Memory Address Domain descriptor.
* Contains domain information of the memory address it belongs to.
*/
typedef struct ucp_addr_dn {
ucp_md_map_t md_map; /* Which MDs have own ths addr Domain */
uct_addr_domain_t id; /* Address domain index */
} ucp_addr_dn_t;

void ucp_rkey_resolve_inner(ucp_rkey_h rkey, ucp_ep_h ep);

ucs_status_t ucp_mpool_malloc(ucs_mpool_t *mp, size_t *size_p, void **chunk_p);
Expand All @@ -72,8 +81,15 @@ void ucp_mpool_free(ucs_mpool_t *mp, void *chunk);

void ucp_mpool_obj_init(ucs_mpool_t *mp, void *obj, void *chunk);

/**
* Detects the address domain on all MDs. skips on detect on sub-sequence MDs
* if it sucessfully detected by MD.
**/
ucs_status_t ucp_addr_domain_detect_mds(ucp_context_h context, void *addr,
ucp_addr_dn_h *addr_dn);
ucp_addr_dn_h *addr_dn_h);


extern ucp_addr_dn_t ucp_addr_dn_dummy_handle;

static UCS_F_ALWAYS_INLINE uct_mem_h
ucp_memh2uct(ucp_mem_h memh, ucp_md_index_t md_idx)
Expand Down
2 changes: 1 addition & 1 deletion src/ucp/core/ucp_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ typedef void (*ucp_request_callback_t)(ucp_request_t *req);
struct ucp_request {
ucs_status_t status; /* Operation status */
uint16_t flags; /* Request flags */
ucp_addr_dn_h dn_mask; /* Memory domain mask */
ucp_addr_dn_h addr_dn_h; /* Memory domain handle */

union {
struct {
Expand Down
105 changes: 56 additions & 49 deletions src/ucp/dt/dt.c
Original file line number Diff line number Diff line change
Expand Up @@ -48,62 +48,69 @@ size_t ucp_dt_pack(ucp_datatype_t datatype, void *dest, const void *src,


static UCS_F_ALWAYS_INLINE ucs_status_t ucp_dn_dt_unpack(ucp_request_t *req, void *buffer, size_t buffer_size,
const void *recv_data, size_t recv_length)
const void *recv_data, size_t recv_length)
{
ucp_worker_h worker = req->recv.worker;
ucp_context_h context = worker->context;
unsigned md_index;
ucs_status_t status;
ucp_ep_h ep = ucp_worker_ep_find(worker, worker->uuid);

for (md_index = 0; md_index < context->num_mds; md_index++) {
if (context->tl_mds[md_index].attr.cap.addr_dn_mask & req->dn_mask.mask) {
ucp_ep_h ep = ucp_worker_ep_find(worker, worker->uuid);
uct_mem_h memh;
uct_iov_t iov;
void *rkey_buffer;
size_t rkey_buffer_size;
ucp_rkey_h rkey;
ucp_lane_index_t lane;

status = uct_md_mem_reg(context->tl_mds[md_index].md, buffer, buffer_size,
0, &memh);
if (status != UCS_OK) {
uct_md_mem_dereg(context->tl_mds[md_index].md, memh);
ucs_error("Failed to reg address %p with md %s", buffer,
context->tl_mds[md_index].rsc.md_name);
return status;
}

ucp_rkey_pack(context, memh, &rkey_buffer, &rkey_buffer_size);
ucp_ep_rkey_unpack(ep, rkey_buffer, &rkey);
ucp_rkey_buffer_release(rkey_buffer);

iov.buffer = buffer;
iov.length = buffer_size;
iov.count = 1;
iov.memh = memh;

lane = rkey->cache.rma_lane;

status = uct_ep_put_zcopy(ep->uct_eps[lane], &iov, 1, (uint64_t)recv_data,
rkey->cache.rma_rkey, NULL);
if (status != UCS_OK) {
ucp_rkey_destroy(rkey);
uct_md_mem_dereg(context->tl_mds[md_index].md, memh);
ucs_error("Failed to perform uct_ep_put_zcopy to address %p", recv_data);
return status;
}

ucp_rkey_destroy(rkey);

status = uct_md_mem_dereg(context->tl_mds[md_index].md, memh);
if (status != UCS_OK) {
ucs_error("Failed to dereg address %p with md %s", buffer,
context->tl_mds[md_index].rsc.md_name);
return status;
}
break;

if (!(UCS_BIT(md_index) & req->addr_dn_h->md_map)) {
continue;
}

/*TODO check if put-zcopy there on iface */

uct_mem_h memh;
uct_iov_t iov;

// void *rkey_buffer;
// size_t rkey_buffer_size;
// ucp_rkey_h rkey;
// ucp_lane_index_t lane;

status = uct_md_mem_reg(context->tl_mds[md_index].md, buffer, buffer_size,
0, &memh);
if (status != UCS_OK) {
uct_md_mem_dereg(context->tl_mds[md_index].md, memh);
ucs_error("Failed to reg address %p with md %s", buffer,
context->tl_mds[md_index].rsc.md_name);
return status;
}

// ucp_rkey_pack(context, memh, &rkey_buffer, &rkey_buffer_size);
// ucp_ep_rkey_unpack(ep, rkey_buffer, &rkey);
// ucp_rkey_buffer_release(rkey_buffer);

ucs_assert(buffer_size >= recv_length);
iov.buffer = (void *)recv_data;
iov.length = recv_length;
iov.count = 1;
iov.memh = UCT_MEM_HANDLE_NULL;

//lane = rkey->cache.rma_lane;

status = uct_ep_put_zcopy(ep->uct_eps[0], &iov, 1, (uint64_t)buffer,
(uct_rkey_t )memh, NULL);
if (status != UCS_OK) {
// ucp_rkey_destroy(rkey);
uct_md_mem_dereg(context->tl_mds[md_index].md, memh);
ucs_error("Failed to perform uct_ep_put_zcopy to address %p", recv_data);
return status;
}

//ucp_rkey_destroy(rkey);

status = uct_md_mem_dereg(context->tl_mds[md_index].md, memh);
if (status != UCS_OK) {
ucs_error("Failed to dereg address %p with md %s", buffer,
context->tl_mds[md_index].rsc.md_name);
return status;
}
break;
}

return UCS_OK;
Expand All @@ -128,7 +135,7 @@ ucs_status_t ucp_dt_unpack(ucp_request_t *req, ucp_datatype_t datatype, void *bu

switch (datatype & UCP_DATATYPE_CLASS_MASK) {
case UCP_DATATYPE_CONTIG:
if (!(req->dn_mask.mask)) {
if (req->addr_dn_h == &ucp_addr_dn_dummy_handle) {
UCS_PROFILE_NAMED_CALL("memcpy_recv", memcpy, buffer + offset,
recv_data, recv_length);
return UCS_OK;
Expand Down
2 changes: 1 addition & 1 deletion src/ucp/tag/tag_recv.c
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ ucp_tag_recv_request_init(ucp_request_t *req, ucp_worker_h worker, void* buffer,
req->recv.state.offset = 0;
req->recv.worker = worker;

ucp_addr_domain_detect_mds(worker->context, buffer, &(req->dn_mask));
ucp_addr_domain_detect_mds(worker->context, buffer, &(req->addr_dn_h));

switch (datatype & UCP_DATATYPE_CLASS_MASK) {
case UCP_DATATYPE_IOV:
Expand Down
28 changes: 9 additions & 19 deletions src/ucp/tag/tag_send.c
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ ucp_tag_send_req(ucp_request_t *req, size_t count, ssize_t max_short,
static void ucp_tag_send_req_init(ucp_request_t* req, ucp_ep_h ep,
const void* buffer, uintptr_t datatype,
ucp_tag_t tag, uint16_t flags,
ucp_addr_dn_h addr_dn)
ucp_addr_dn_h addr_dn_h)
{
req->flags = flags;
req->send.ep = ep;
Expand All @@ -223,7 +223,7 @@ static void ucp_tag_send_req_init(ucp_request_t* req, ucp_ep_h ep,
req->send.tag = tag;
req->send.reg_rsc = UCP_NULL_RESOURCE;
req->send.state.offset = 0;
req->dn_mask = addr_dn;
req->addr_dn_h = addr_dn_h;
VALGRIND_MAKE_MEM_UNDEFINED(&req->send.uct_comp.count,
sizeof(req->send.uct_comp.count));

Expand All @@ -241,18 +241,16 @@ UCS_PROFILE_FUNC(ucs_status_ptr_t, ucp_tag_send_nb,
ucp_request_t *req;
size_t length;
ucs_status_ptr_t ret;
ucp_addr_dn_h addr_dn;
uint64_t mask;
unsigned md_index;
ucp_addr_dn_h addr_dn_h;

UCP_THREAD_CS_ENTER_CONDITIONAL(&ep->worker->mt_lock);

ucs_trace_req("send_nb buffer %p count %zu tag %"PRIx64" to %s cb %p",
buffer, count, tag, ucp_ep_peer_name(ep), cb);

ucp_addr_domain_detect_mds(ep->worker->context, (void *)buffer, &addr_dn);
ucp_addr_domain_detect_mds(ep->worker->context, (void *)buffer, &addr_dn_h);

if (!addr_dn.mask && ucs_likely(UCP_DT_IS_CONTIG(datatype))) {
if (addr_dn_h->id == UCT_MD_ADDR_DOMAIN_DEFAULT && ucs_likely(UCP_DT_IS_CONTIG(datatype))) {
length = ucp_contig_dt_length(datatype, count);
if (ucs_likely((ssize_t)length <= ucp_ep_config(ep)->tag.eager.max_short)) {
status = UCS_PROFILE_CALL(ucp_tag_send_eager_short, ep, tag, buffer,
Expand All @@ -271,17 +269,11 @@ UCS_PROFILE_FUNC(ucs_status_ptr_t, ucp_tag_send_nb,
goto out;
}

ucp_tag_send_req_init(req, ep, buffer, datatype, tag, 0, addr_dn);

mask = addr_dn.mask;
md_index = 0;
if (addr_dn.mask) {
CONVERT_BITMASK_TO_INDEX(mask, md_index);
}
ucp_tag_send_req_init(req, ep, buffer, datatype, tag, 0, addr_dn_h);

ret = ucp_tag_send_req(req, count,
addr_dn.mask ? ucp_ep_config(ep)->dn[md_index].tag.eager.max_short : ucp_ep_config(ep)->tag.eager.max_short,
addr_dn.mask ? ucp_ep_config(ep)->dn[md_index].tag.eager.zcopy_thresh : ucp_ep_config(ep)->tag.eager.zcopy_thresh,
addr_dn_h->id != UCT_MD_ADDR_DOMAIN_DEFAULT ? ucp_ep_config(ep)->dn[addr_dn_h->id].tag.eager.max_short : ucp_ep_config(ep)->tag.eager.max_short,
addr_dn_h->id != UCT_MD_ADDR_DOMAIN_DEFAULT ? ucp_ep_config(ep)->dn[addr_dn_h->id].tag.eager.zcopy_thresh : ucp_ep_config(ep)->tag.eager.zcopy_thresh,
ucp_ep_config(ep)->tag.rndv.rma_thresh,
ucp_ep_config(ep)->tag.rndv.am_thresh,
cb, ucp_ep_config(ep)->tag.proto);
Expand All @@ -297,7 +289,6 @@ UCS_PROFILE_FUNC(ucs_status_ptr_t, ucp_tag_send_sync_nb,
{
ucp_request_t *req;
ucs_status_ptr_t ret;
ucp_addr_dn_h addr_dn;

UCP_THREAD_CS_ENTER_CONDITIONAL(&ep->worker->mt_lock);

Expand All @@ -318,10 +309,9 @@ UCS_PROFILE_FUNC(ucs_status_ptr_t, ucp_tag_send_sync_nb,
/* Remote side needs to send reply, so have it connect to us */
ucp_ep_connect_remote(ep);

ucp_addr_domain_detect_mds(ep->worker->context, (void *)buffer, &addr_dn);

ucp_tag_send_req_init(req, ep, buffer, datatype, tag, UCP_REQUEST_FLAG_SYNC,
addr_dn);
&ucp_addr_dn_dummy_handle);

ret = ucp_tag_send_req(req, count,
-1, /* disable short method */
Expand Down
17 changes: 10 additions & 7 deletions src/uct/api/uct.h
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,12 @@ enum {
* @ingroup UCT_MD
* @brief Memory addr domoins.
*/
enum {
UCT_MD_ADDR_DOMAIN_CUDA = UCS_BIT(0) /**< NVIDIA CUDA domain */
};
typedef enum {
UCT_MD_ADDR_DOMAIN_CUDA = 0, /**< NVIDIA CUDA domain */
UCT_MD_ADDR_DOMAIN_DEFAULT, /**< Default system domain */
UCT_MD_ADDR_DOMAIN_LAST = UCT_MD_ADDR_DOMAIN_DEFAULT

} uct_addr_domain_t;


/**
Expand Down Expand Up @@ -533,7 +536,7 @@ struct uct_md_attr {
size_t max_alloc; /**< Maximal allocation size */
size_t max_reg; /**< Maximal registration size */
uint64_t flags; /**< UCT_MD_FLAG_xx */
uint64_t addr_dn_mask; /**< Supported addr domains */
uct_addr_domain_t addr_dn; /**< Supported addr domain */
struct {
size_t max_short;
} eager;
Expand Down Expand Up @@ -1278,13 +1281,13 @@ ucs_status_t uct_md_mem_dereg(uct_md_h md, uct_mem_h memh);
* @ingroup UCT_MD
* @brief Detect memory on the memory domain.
*
* Detect memory on the memory domain. Return memory domain in domain mask.
* Detect memory on the memory domain.
* Return UCS_OK if address belongs to MDs address domain
*
* @param [in] md Memory domain to register memory on.
* @param [in] address Memory address to detect.
* @param [out] dn_mask Filled with memory domain mask.
*/
ucs_status_t uct_md_mem_detect(uct_md_h md, void *addr, uint64_t *dn_mask);
ucs_status_t uct_md_mem_detect(uct_md_h md, void *addr);

/**
* @ingroup UCT_MD
Expand Down
4 changes: 2 additions & 2 deletions src/uct/base/uct_md.c
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ ucs_status_t uct_md_mem_dereg(uct_md_h md, uct_mem_h memh)
return md->ops->mem_dereg(md, memh);
}

ucs_status_t uct_md_mem_detect(uct_md_h md, void *addr, uint64_t *dn_mask)
ucs_status_t uct_md_mem_detect(uct_md_h md, void *addr)
{
return md->ops->mem_detect(md, addr, dn_mask);
return md->ops->mem_detect(md, addr);
}
Loading

0 comments on commit 5007e2c

Please sign in to comment.