From 3409554f939fe919a221993ee0c41308bcce782b Mon Sep 17 00:00:00 2001 From: shizhibao Date: Thu, 10 Dec 2020 10:52:20 +0800 Subject: [PATCH] Support allreduce non-contiguous datatype --- ompi/mca/coll/ucx/coll_ucx_module.c | 11 ++++++ ompi/mca/coll/ucx/coll_ucx_op.c | 56 +++++++++++++++++++---------- 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/ompi/mca/coll/ucx/coll_ucx_module.c b/ompi/mca/coll/ucx/coll_ucx_module.c index ca11f7ae585..09c996b1974 100644 --- a/ompi/mca/coll/ucx/coll_ucx_module.c +++ b/ompi/mca/coll/ucx/coll_ucx_module.c @@ -331,6 +331,16 @@ static int mca_coll_ucg_datatype_convert(ompi_datatype_t *mpi_dt, return 0; } +static ptrdiff_t coll_ucx_datatype_span(void *dt_ext, int count, ptrdiff_t *gap) +{ + struct ompi_datatype_t *dtype = (struct ompi_datatype_t *)dt_ext; + ptrdiff_t dsize, gp= 0; + + dsize = opal_datatype_span(&dtype->super, count, &gp); + *gap = gp; + return dsize; +} + static void mca_coll_ucg_init_group_param(struct ompi_communicator_t *comm, ucg_group_params_t *args) { args->member_count = ompi_comm_size(comm); @@ -341,6 +351,7 @@ static void mca_coll_ucg_init_group_param(struct ompi_communicator_t *comm, ucg_ args->cb_group_obj = comm; args->op_is_commute_f = ompi_op_is_commute; args->mpi_dt_convert = mca_coll_ucg_datatype_convert; + args->mpi_datatype_span = coll_ucx_datatype_span; } static void mca_coll_ucg_arg_free(struct ompi_communicator_t *comm, ucg_group_params_t *args) diff --git a/ompi/mca/coll/ucx/coll_ucx_op.c b/ompi/mca/coll/ucx/coll_ucx_op.c index a6d2ab4cfb3..f14c4481f34 100644 --- a/ompi/mca/coll/ucx/coll_ucx_op.c +++ b/ompi/mca/coll/ucx/coll_ucx_op.c @@ -96,6 +96,36 @@ int mca_coll_ucx_start(size_t count, ompi_request_t** requests) ((char *)alloca(mca_coll_ucx_component.request_size) + \ mca_coll_ucx_component.request_size); +static int coll_ucx_allreduce_pre_init(struct ompi_datatype_t *dtype, int count, void *sbuf, + void *rbuf, char **inplace_buff, ptrdiff_t *gap) +{ + ptrdiff_t dsize, gp, lb = 0; + char *inpbuf = NULL; + int err; + + ompi_datatype_type_lb(dtype, &lb); + if ((dtype->super.flags & OPAL_DATATYPE_FLAG_CONTIGUOUS) && + (dtype->super.flags & OPAL_DATATYPE_FLAG_NO_GAPS) && + (lb == 0)) { + return UCS_OK; + } + + dsize = opal_datatype_span(&dtype->super, count, &gp); + if (sbuf == MPI_IN_PLACE && dsize != 0) { + inpbuf = (char *)malloc(dsize); + if (inpbuf == NULL) { + return UCS_ERR_NO_MEMORY; + } + *inplace_buff = inpbuf; + *gap = gp; + err = ompi_datatype_copy_content_same_ddt(dtype, count, inpbuf - gp, (char *)rbuf); + } else { + err = ompi_datatype_copy_content_same_ddt(dtype, count, (char *)rbuf, (char *)sbuf); + } + + return (err == MPI_SUCCESS) ? UCS_OK : UCS_ERR_INVALID_PARAM; +} + int mca_coll_ucx_allreduce(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, @@ -104,8 +134,8 @@ int mca_coll_ucx_allreduce(const void *sbuf, void *rbuf, int count, mca_coll_ucx_module_t *ucx_module = (mca_coll_ucx_module_t *)module; char *inplace_buff = NULL; ucg_coll_h coll = NULL; - ptrdiff_t extent, dsize, gap = 0; - int err; + ptrdiff_t extent, gap = 0; + char *sbuf_rel = NULL; ompi_datatype_type_extent(dtype, &extent); ucs_status_t ret = mca_coll_ucx_check_total_data_size((size_t)extent, count); @@ -114,29 +144,17 @@ int mca_coll_ucx_allreduce(const void *sbuf, void *rbuf, int count, return OMPI_ERROR; } - dsize = opal_datatype_span(&dtype->super, count, &gap); - if (sbuf == MPI_IN_PLACE && dsize != 0) { - inplace_buff = (char *)malloc(dsize); - if (inplace_buff == NULL) { - return OMPI_ERR_OUT_OF_RESOURCE; - } - sbuf = inplace_buff - gap; - err = ompi_datatype_copy_content_same_ddt(dtype, count, (char *)sbuf, (char *)rbuf); - } else { - err = ompi_datatype_copy_content_same_ddt(dtype, count, (char *)rbuf, (char *)sbuf); - } - if (err != MPI_SUCCESS) { - if (inplace_buff != NULL) { - free(inplace_buff); - } - return err; + ret = coll_ucx_allreduce_pre_init(dtype, count, sbuf, rbuf, &inplace_buff, &gap); + if (ret != UCS_OK) { + goto exit; } + sbuf_rel = (inplace_buff == NULL) ? sbuf : inplace_buff - gap; COLL_UCX_TRACE("%s", sbuf, rbuf, count, dtype, comm, "allreduce START"); ucs_status_ptr_t req = COLL_UCX_REQ_ALLOCA(ucx_module); - ret = ucg_coll_allreduce_init(sbuf, rbuf, count, (size_t)extent, dtype, ucx_module->ucg_group, 0, + ret = ucg_coll_allreduce_init(sbuf_rel, rbuf, count, (size_t)extent, dtype, ucx_module->ucg_group, 0, op, 0, 0, &coll); if (OPAL_UNLIKELY(ret != UCS_OK)) { COLL_UCX_ERROR("ucx allreduce init failed: %s", ucs_status_string(ret));