Skip to content

Commit

Permalink
Merge pull request open-mpi#10 from shizhibao/huawei
Browse files Browse the repository at this point in the history
Support allreduce non-contiguous datatype
  • Loading branch information
nsosnsos authored Dec 10, 2020
2 parents 9a64c16 + 3409554 commit e5c8471
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 19 deletions.
11 changes: 11 additions & 0 deletions ompi/mca/coll/ucx/coll_ucx_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,16 @@ static int mca_coll_ucg_datatype_convert(ompi_datatype_t *mpi_dt,
return 0;
}

static ptrdiff_t coll_ucx_datatype_span(void *dt_ext, int count, ptrdiff_t *gap)
{
struct ompi_datatype_t *dtype = (struct ompi_datatype_t *)dt_ext;
ptrdiff_t dsize, gp= 0;

dsize = opal_datatype_span(&dtype->super, count, &gp);
*gap = gp;
return dsize;
}

static void mca_coll_ucg_init_group_param(struct ompi_communicator_t *comm, ucg_group_params_t *args)
{
args->member_count = ompi_comm_size(comm);
Expand All @@ -341,6 +351,7 @@ static void mca_coll_ucg_init_group_param(struct ompi_communicator_t *comm, ucg_
args->cb_group_obj = comm;
args->op_is_commute_f = ompi_op_is_commute;
args->mpi_dt_convert = mca_coll_ucg_datatype_convert;
args->mpi_datatype_span = coll_ucx_datatype_span;
}

static void mca_coll_ucg_arg_free(struct ompi_communicator_t *comm, ucg_group_params_t *args)
Expand Down
56 changes: 37 additions & 19 deletions ompi/mca/coll/ucx/coll_ucx_op.c
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,36 @@ int mca_coll_ucx_start(size_t count, ompi_request_t** requests)
((char *)alloca(mca_coll_ucx_component.request_size) + \
mca_coll_ucx_component.request_size);

static int coll_ucx_allreduce_pre_init(struct ompi_datatype_t *dtype, int count, void *sbuf,
void *rbuf, char **inplace_buff, ptrdiff_t *gap)
{
ptrdiff_t dsize, gp, lb = 0;
char *inpbuf = NULL;
int err;

ompi_datatype_type_lb(dtype, &lb);
if ((dtype->super.flags & OPAL_DATATYPE_FLAG_CONTIGUOUS) &&
(dtype->super.flags & OPAL_DATATYPE_FLAG_NO_GAPS) &&
(lb == 0)) {
return UCS_OK;
}

dsize = opal_datatype_span(&dtype->super, count, &gp);
if (sbuf == MPI_IN_PLACE && dsize != 0) {
inpbuf = (char *)malloc(dsize);
if (inpbuf == NULL) {
return UCS_ERR_NO_MEMORY;
}
*inplace_buff = inpbuf;
*gap = gp;
err = ompi_datatype_copy_content_same_ddt(dtype, count, inpbuf - gp, (char *)rbuf);
} else {
err = ompi_datatype_copy_content_same_ddt(dtype, count, (char *)rbuf, (char *)sbuf);
}

return (err == MPI_SUCCESS) ? UCS_OK : UCS_ERR_INVALID_PARAM;
}

int mca_coll_ucx_allreduce(const void *sbuf, void *rbuf, int count,
struct ompi_datatype_t *dtype, struct ompi_op_t *op,
struct ompi_communicator_t *comm,
Expand All @@ -104,8 +134,8 @@ int mca_coll_ucx_allreduce(const void *sbuf, void *rbuf, int count,
mca_coll_ucx_module_t *ucx_module = (mca_coll_ucx_module_t *)module;
char *inplace_buff = NULL;
ucg_coll_h coll = NULL;
ptrdiff_t extent, dsize, gap = 0;
int err;
ptrdiff_t extent, gap = 0;
char *sbuf_rel = NULL;

ompi_datatype_type_extent(dtype, &extent);
ucs_status_t ret = mca_coll_ucx_check_total_data_size((size_t)extent, count);
Expand All @@ -114,29 +144,17 @@ int mca_coll_ucx_allreduce(const void *sbuf, void *rbuf, int count,
return OMPI_ERROR;
}

dsize = opal_datatype_span(&dtype->super, count, &gap);
if (sbuf == MPI_IN_PLACE && dsize != 0) {
inplace_buff = (char *)malloc(dsize);
if (inplace_buff == NULL) {
return OMPI_ERR_OUT_OF_RESOURCE;
}
sbuf = inplace_buff - gap;
err = ompi_datatype_copy_content_same_ddt(dtype, count, (char *)sbuf, (char *)rbuf);
} else {
err = ompi_datatype_copy_content_same_ddt(dtype, count, (char *)rbuf, (char *)sbuf);
}
if (err != MPI_SUCCESS) {
if (inplace_buff != NULL) {
free(inplace_buff);
}
return err;
ret = coll_ucx_allreduce_pre_init(dtype, count, sbuf, rbuf, &inplace_buff, &gap);
if (ret != UCS_OK) {
goto exit;
}
sbuf_rel = (inplace_buff == NULL) ? sbuf : inplace_buff - gap;

COLL_UCX_TRACE("%s", sbuf, rbuf, count, dtype, comm, "allreduce START");

ucs_status_ptr_t req = COLL_UCX_REQ_ALLOCA(ucx_module);

ret = ucg_coll_allreduce_init(sbuf, rbuf, count, (size_t)extent, dtype, ucx_module->ucg_group, 0,
ret = ucg_coll_allreduce_init(sbuf_rel, rbuf, count, (size_t)extent, dtype, ucx_module->ucg_group, 0,
op, 0, 0, &coll);
if (OPAL_UNLIKELY(ret != UCS_OK)) {
COLL_UCX_ERROR("ucx allreduce init failed: %s", ucs_status_string(ret));
Expand Down

0 comments on commit e5c8471

Please sign in to comment.