diff --git a/ompi/mca/coll/acoll/LICENSE.md b/ompi/mca/coll/acoll/LICENSE.md new file mode 100644 index 00000000000..bc36e3a3a48 --- /dev/null +++ b/ompi/mca/coll/acoll/LICENSE.md @@ -0,0 +1,11 @@ +Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/ompi/mca/coll/acoll/Makefile.am b/ompi/mca/coll/acoll/Makefile.am new file mode 100644 index 00000000000..fdbd7edbbd2 --- /dev/null +++ b/ompi/mca/coll/acoll/Makefile.am @@ -0,0 +1,45 @@ +# +# Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# + +AM_CPPFLAGS = $(coll_acoll_CPPFLAGS) + +sources = \ + coll_acoll.h \ + coll_acoll_utils.h \ + coll_acoll_allgather.c \ + coll_acoll_bcast.c \ + coll_acoll_gather.c \ + coll_acoll_reduce.c \ + coll_acoll_allreduce.c \ + coll_acoll_barrier.c \ + coll_acoll_component.c \ + coll_acoll_module.c + +# Make the output library in this directory, and name it either +# mca__.la (for DSO builds) or libmca__.la +# (for static builds). + +if MCA_BUILD_ompi_coll_acoll_DSO +component_noinst = +component_install = mca_coll_acoll.la +else +component_noinst = libmca_coll_acoll.la +component_install = +endif + +mcacomponentdir = $(ompilibdir) +mcacomponent_LTLIBRARIES = $(component_install) +mca_coll_acoll_la_SOURCES = $(sources) +mca_coll_acoll_la_LDFLAGS = -module -avoid-version $(coll_acoll_LDFLAGS) +mca_coll_acoll_la_LIBADD = $(top_builddir)/ompi/lib@OMPI_LIBMPI_NAME@.la $(coll_acoll_LIBS) + +noinst_LTLIBRARIES = $(component_noinst) +libmca_coll_acoll_la_SOURCES =$(sources) +libmca_coll_acoll_la_LIBADD = $(coll_acoll_LIBS) +libmca_coll_acoll_la_LDFLAGS = -module -avoid-version $(coll_acoll_LDFLAGS) diff --git a/ompi/mca/coll/acoll/README b/ompi/mca/coll/acoll/README new file mode 100644 index 00000000000..d5b5acae8f1 --- /dev/null +++ b/ompi/mca/coll/acoll/README @@ -0,0 +1,16 @@ +Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + +$COPYRIGHT$ + +Additional copyrights may follow + +$HEADER$ + +=========================================================================== + +The collective component, AMD Coll (“acoll”), is a high-performant MPI collective component for the OpenMPI library that is optimized for AMD "Zen"-based processors. “acoll” is optimized for communications within a single node of AMD “Zen”-based processors and provides the following commonly used collective algorithms: boardcast (MPI_Bcast), allreduce (MPI_Allreduce), reduce (MPI_Reduce), gather (MPI_Gather), allgather (MPI_Allgather), and barrier (MPI_Barrier). + +At present, “acoll” has been tested with OpenMPI v5.0.2 and can be built as part of OpenMPI. + +To run an application with acoll, use the following command line parameters +- mpirun --mca coll acoll,tuned,libnbc,basic --mca coll_acoll_priority 40 diff --git a/ompi/mca/coll/acoll/coll_acoll.h b/ompi/mca/coll/acoll/coll_acoll.h new file mode 100644 index 00000000000..ec2da029a6d --- /dev/null +++ b/ompi/mca/coll/acoll/coll_acoll.h @@ -0,0 +1,227 @@ +/* -*- Mode: C; c-basic-offset:4 ; indent-tabs-mode:nil -*- */ +/* + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#ifndef MCA_COLL_ACOLL_EXPORT_H +#define MCA_COLL_ACOLL_EXPORT_H + +#include "ompi_config.h" + +#include "mpi.h" +#include "ompi/communicator/communicator.h" +#include "ompi/mca/coll/base/coll_base_functions.h" +#include "ompi/mca/coll/coll.h" +#include "ompi/mca/mca.h" +#include "ompi/request/request.h" + +#ifdef HAVE_XPMEM_H +#include "opal/mca/rcache/base/base.h" +#include +#endif + +#include "opal/mca/shmem/base/base.h" +#include "opal/mca/shmem/shmem.h" + +BEGIN_C_DECLS + +/* Globally exported variables */ +OMPI_DECLSPEC extern const mca_coll_base_component_2_4_0_t mca_coll_acoll_component; +extern int mca_coll_acoll_priority; +extern int mca_coll_acoll_sg_size; +extern int mca_coll_acoll_sg_scale; +extern int mca_coll_acoll_node_size; +extern int mca_coll_acoll_use_dynamic_rules; +extern int mca_coll_acoll_mnode_enable; +extern int mca_coll_acoll_bcast_lin0; +extern int mca_coll_acoll_bcast_lin1; +extern int mca_coll_acoll_bcast_lin2; +extern int mca_coll_acoll_bcast_nonsg; +extern int mca_coll_acoll_allgather_lin; +extern int mca_coll_acoll_allgather_ring_1; + +/* API functions */ +int mca_coll_acoll_init_query(bool enable_progress_threads, bool enable_mpi_threads); +mca_coll_base_module_t *mca_coll_acoll_comm_query(struct ompi_communicator_t *comm, int *priority); + +int mca_coll_acoll_module_enable(mca_coll_base_module_t *module, struct ompi_communicator_t *comm); + +int mca_coll_acoll_allgather(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, + void *rbuf, int rcount, struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, mca_coll_base_module_t *module); + +int mca_coll_acoll_bcast(void *buff, int count, struct ompi_datatype_t *datatype, int root, + struct ompi_communicator_t *comm, mca_coll_base_module_t *module); + +int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, + void *rbuf, int rcount, struct ompi_datatype_t *rdtype, int root, + struct ompi_communicator_t *comm, mca_coll_base_module_t *module); + +int mca_coll_acoll_reduce_intra(const void *sbuf, void *rbuf, int count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, int root, + struct ompi_communicator_t *comm, mca_coll_base_module_t *module); + +int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, int count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module); + +int mca_coll_acoll_barrier_intra(struct ompi_communicator_t *comm, mca_coll_base_module_t *module); + +int mca_coll_acoll_ft_event(int status); + +END_C_DECLS + +#define MCA_COLL_ACOLL_MAX_CID 100 +#define MCA_COLL_ACOLL_ROOT_CHANGE_THRESH 10 + +typedef enum MCA_COLL_ACOLL_SG_SIZES { + MCA_COLL_ACOLL_SG_SIZE_1 = 8, + MCA_COLL_ACOLL_SG_SIZE_2 = 16 +} MCA_COLL_ACOLL_SG_SIZES; + +typedef enum MCA_COLL_ACOLL_SG_SCALES { + MCA_COLL_ACOLL_SG_SCALE_1 = 1, + MCA_COLL_ACOLL_SG_SCALE_2 = 2, + MCA_COLL_ACOLL_SG_SCALE_3 = 4, + MCA_COLL_ACOLL_SG_SCALE_4 = 8, + MCA_COLL_ACOLL_SG_SCALE_5 = 16 +} MCA_COLL_ACOLL_SG_SCALES; + +typedef enum MCA_COLL_ACOLL_SUBCOMMS { + MCA_COLL_ACOLL_NODE_L = 0, + MCA_COLL_ACOLL_INTRA, + MCA_COLL_ACOLL_SOCK_L, + MCA_COLL_ACOLL_NUMA_L, + MCA_COLL_ACOLL_L3_L, + MCA_COLL_ACOLL_LEAF, + MCA_COLL_ACOLL_NUM_SC +} MCA_COLL_ACOLL_SUBCOMMS; + +typedef enum MCA_COLL_ACOLL_LAYERS { + MCA_COLL_ACOLL_LYR_NODE = 0, + MCA_COLL_ACOLL_LYR_SOCKET, + MCA_COLL_ACOLL_NUM_LAYERS +} MCA_COLL_ACOLL_LAYERS; + +typedef enum MCA_COLL_ACOLL_BASE_LYRS { + MCA_COLL_ACOLL_L3CACHE = 0, + MCA_COLL_ACOLL_NUMA, + MCA_COLL_ACOLL_NUM_BASE_LYRS +} MCA_COLL_ACOLL_BASE_LYRS; + +typedef struct coll_acoll_data { +#ifdef HAVE_XPMEM_H + xpmem_segid_t *allseg_id; + xpmem_apid_t *all_apid; + void **allshm_sbuf; + void **allshm_rbuf; + void **xpmem_saddr; + void **xpmem_raddr; + mca_rcache_base_module_t **rcache; + void *scratch; +#endif + opal_shmem_ds_t *allshmseg_id; + void **allshmmmap_sbuf; + + int comm_size; + int l1_local_rank; + int l2_local_rank; + int l1_gp_size; + int *l1_gp; + int *l2_gp; + int l2_gp_size; + int offset[4]; + int sync[2]; +} coll_acoll_data_t; + +typedef struct coll_acoll_subcomms { + ompi_communicator_t *local_comm; + ompi_communicator_t *local_r_comm; + ompi_communicator_t *leader_comm; + ompi_communicator_t *subgrp_comm; + ompi_communicator_t *numa_comm; + ompi_communicator_t *base_comm[MCA_COLL_ACOLL_NUM_BASE_LYRS][MCA_COLL_ACOLL_NUM_LAYERS]; + ompi_communicator_t *orig_comm; + ompi_communicator_t *socket_comm; + ompi_communicator_t *socket_ldr_comm; + int num_nodes; + int derived_node_size; + int is_root_node; + int is_root_sg; + int is_root_numa; + int is_root_socket; + int local_root[MCA_COLL_ACOLL_NUM_LAYERS]; + int outer_grp_root; + int subgrp_root; + int numa_root; + int socket_ldr_root; + int base_root[MCA_COLL_ACOLL_NUM_BASE_LYRS][MCA_COLL_ACOLL_NUM_LAYERS]; + int base_rank[MCA_COLL_ACOLL_NUM_BASE_LYRS]; + int socket_rank; + int subgrp_size; + int initialized; + int prev_init_root; + int num_root_change; + + ompi_communicator_t *numa_comm_ldrs; + ompi_communicator_t *node_comm; + ompi_communicator_t *inter_comm; + int cid; + coll_acoll_data_t *data; + bool initialized_data; + bool initialized_shm_data; +#ifdef HAVE_XPMEM_H + uint64_t xpmem_buf_size; + int without_xpmem; + int xpmem_use_sr_buf; +#endif + +} coll_acoll_subcomms_t; + +typedef struct coll_acoll_reserve_mem { + void *reserve_mem; + uint64_t reserve_mem_size; + bool reserve_mem_allocate; + bool reserve_mem_in_use; +} coll_acoll_reserve_mem_t; + +struct mca_coll_acoll_module_t { + mca_coll_base_module_t super; + MCA_COLL_ACOLL_SG_SIZES sg_size; + MCA_COLL_ACOLL_SG_SCALES sg_scale; + int sg_cnt; + // Todo: Remove log2 variables + int log2_sg_cnt; + int node_cnt; + int log2_node_cnt; + int use_dyn_rules; + // Todo: Use substructure for every API related ones + int use_mnode; + int use_lin0; + int use_lin1; + int use_lin2; + int mnode_sg_size; + int mnode_log2_sg_size; + int allg_lin; + int allg_ring; + coll_acoll_subcomms_t subc[MCA_COLL_ACOLL_MAX_CID]; + coll_acoll_reserve_mem_t reserve_mem_s; +}; + +#ifdef HAVE_XPMEM_H +struct acoll_xpmem_rcache_reg_t { + mca_rcache_base_registration_t base; + void *xpmem_vaddr; +}; +#endif + +typedef struct mca_coll_acoll_module_t mca_coll_acoll_module_t; +OMPI_DECLSPEC OBJ_CLASS_DECLARATION(mca_coll_acoll_module_t); + +#endif /* MCA_COLL_ACOLL_EXPORT_H */ diff --git a/ompi/mca/coll/acoll/coll_acoll_allgather.c b/ompi/mca/coll/acoll/coll_acoll_allgather.c new file mode 100644 index 00000000000..36f9fe2b1a6 --- /dev/null +++ b/ompi/mca/coll/acoll/coll_acoll_allgather.c @@ -0,0 +1,624 @@ +/* -*- Mode: C; indent-tabs-mode:nil -*- */ +/* + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" + + +#include "mpi.h" +#include "ompi/constants.h" +#include "ompi/datatype/ompi_datatype.h" +#include "ompi/mca/coll/base/coll_base_functions.h" +#include "ompi/mca/coll/base/coll_base_util.h" +#include "ompi/mca/coll/base/coll_tags.h" +#include "ompi/mca/coll/coll.h" +#include "ompi/mca/pml/pml.h" +#include "opal/util/bit_ops.h" +#include "coll_acoll.h" +#include "coll_acoll_utils.h" + +static inline int log_sg_bcast_intra(void *buff, int count, struct ompi_datatype_t *datatype, + int rank, int dim, int size, int sg_size, int cur_base, + int sg_start, struct ompi_communicator_t *comm, + mca_coll_base_module_t *module, ompi_request_t **preq, + int *nreqs) +{ + int msb_pos, sub_rank, peer, err; + int i, mask; + int end_sg, end_peer; + + end_sg = sg_start + sg_size - 1; + if (end_sg >= size) { + end_sg = size - 1; + } + end_peer = (end_sg - cur_base) % sg_size; + sub_rank = (rank - cur_base + sg_size) % sg_size; + + msb_pos = opal_hibit(sub_rank, dim); + --dim; + + /* Receive data from parent in the sg tree. */ + if (sub_rank > 0) { + assert(msb_pos >= 0); + peer = (sub_rank & ~(1 << msb_pos)); + if (peer > end_peer) { + peer = (((peer + cur_base - sg_start) % sg_size) + sg_start); + } else { + peer = peer + cur_base; + } + + err = MCA_PML_CALL( + recv(buff, count, datatype, peer, MCA_COLL_BASE_TAG_ALLGATHER, comm, MPI_STATUS_IGNORE)); + if (MPI_SUCCESS != err) { + return err; + } + } + + for (i = msb_pos + 1, mask = 1 << i; i <= dim; ++i, mask <<= 1) { + peer = sub_rank | mask; + if (peer >= sg_size) { + continue; + } + if (peer >= end_peer) { + peer = (((peer + cur_base - sg_start) % sg_size) + sg_start); + } else { + peer = peer + cur_base; + } + /* Checks to ensure that the sends are limited to the necessary ones. + It also ensures 'preq' not exceeding the max allocated. */ + if ((peer < size) && (peer != rank) && (peer != cur_base)) { + *nreqs = *nreqs + 1; + err = MCA_PML_CALL(isend(buff, count, datatype, peer, MCA_COLL_BASE_TAG_ALLGATHER, + MCA_PML_BASE_SEND_STANDARD, comm, preq++)); + if (MPI_SUCCESS != err) { + return err; + } + } + } + + return err; +} + +static inline int lin_sg_bcast_intra(void *buff, int count, struct ompi_datatype_t *datatype, + int rank, int dim, int size, int sg_size, int cur_base, + int sg_start, struct ompi_communicator_t *comm, + mca_coll_base_module_t *module, ompi_request_t **preq, + int *nreqs) +{ + int peer; + int err; + int sg_end; + + sg_end = sg_start + sg_size - 1; + if (sg_end >= size) { + sg_end = size - 1; + } + + if (rank == cur_base) { + for (peer = sg_start; peer <= sg_end; peer++) { + if (peer == cur_base) { + continue; + } + *nreqs = *nreqs + 1; + err = MCA_PML_CALL(isend(buff, count, datatype, peer, MCA_COLL_BASE_TAG_ALLGATHER, + MCA_PML_BASE_SEND_STANDARD, comm, preq++)); + if (MPI_SUCCESS != err) { + return err; + } + } + } else { + err = MCA_PML_CALL(recv(buff, count, datatype, cur_base, MCA_COLL_BASE_TAG_ALLGATHER, comm, + MPI_STATUS_IGNORE)); + if (MPI_SUCCESS != err) { + return err; + } + } + + return err; +} + +/* + * sg_bcast_intra + * + * Function: broadcast operation within a subgroup + * Accepts: Arguments of MPI_Bcast() plus subgroup params + * Returns: MPI_SUCCESS or error code + * + * Description: O(N) or O(log(N)) algorithm based on count. + * + * Memory: No additional memory requirements beyond user-supplied buffers. + * + */ +static inline int sg_bcast_intra(void *buff, int count, struct ompi_datatype_t *datatype, int rank, + int dim, int size, int sg_size, int cur_base, int sg_start, + struct ompi_communicator_t *comm, mca_coll_base_module_t *module, + ompi_request_t **preq, int *nreqs) +{ + int err; + size_t total_dsize, dsize; + + ompi_datatype_type_size(datatype, &dsize); + total_dsize = dsize * (unsigned long) count; + + if (total_dsize <= 8192) { + err = log_sg_bcast_intra(buff, count, datatype, rank, dim, size, sg_size, cur_base, + sg_start, comm, module, preq, nreqs); + } else { + err = lin_sg_bcast_intra(buff, count, datatype, rank, dim, size, sg_size, cur_base, + sg_start, comm, module, preq, nreqs); + } + return err; +} + +/* + * coll_allgather_decision_fixed + * + * Function: Choose optimal allgather algorithm + * + * Description: Based on no. of processes and message size, chooses whether + * or not to use subgroups. If subgroup based algorithm is not, + * chosen, further decides if [ring|lin] allgather is to be used. + * + */ +static inline void coll_allgather_decision_fixed(int size, size_t total_dsize, int sg_size, + int *use_ring, int *use_lin) +{ + *use_ring = 0; + *use_lin = 0; + if (size <= (sg_size << 1)) { + if (total_dsize >= 1048576) { + *use_lin = 1; + } + } else if (size <= (sg_size << 2)) { + if ((total_dsize >= 4096) && (total_dsize < 32768)) { + *use_ring = 1; + } else if (total_dsize >= 1048576) { + *use_lin = 1; + } + } else if (size <= (sg_size << 3)) { + if ((total_dsize >= 4096) && (total_dsize < 32768)) { + *use_ring = 1; + } + } else { + if (total_dsize >= 4096) { + *use_ring = 1; + } + } +} + +/* + * rd_allgather_sub + * + * Function: Uses recursive doubling based allgather for the group. + * Group can be all ranks in a subgroup or base ranks across + * subgroups. + * + * Description: Implementation logic of recursive doubling reused from + * ompi_coll_base_allgather_intra_recursivedoubling(). + * + */ +static inline int rd_allgather_sub(void *rbuf, struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, int count, int send_blk_loc, + int rank, int virtual_rank, int grp_size, const int across_sg, + int sg_start, int sg_size, ptrdiff_t rext) +{ + int err; + /* At step i, rank r exchanges message with rank (r ^ 2^i) */ + for (int dist = 0x1, i = 0; dist < grp_size; dist <<= 1, i++) { + int remote = virtual_rank ^ dist; + int recv_blk_loc = virtual_rank < remote ? send_blk_loc + dist : send_blk_loc - dist; + int sr_cnt = ((ptrdiff_t) count) << i; + char *tmpsend = (char *) rbuf + (ptrdiff_t) send_blk_loc * (ptrdiff_t) count * rext; + char *tmprecv = (char *) rbuf + (ptrdiff_t) recv_blk_loc * (ptrdiff_t) count * rext; + int peer = across_sg ? remote * sg_size : remote + sg_start; + if (virtual_rank >= remote) { + send_blk_loc -= dist; + } + + /* Sendreceive */ + err = ompi_coll_base_sendrecv(tmpsend, sr_cnt, rdtype, peer, MCA_COLL_BASE_TAG_ALLGATHER, + tmprecv, sr_cnt, rdtype, peer, MCA_COLL_BASE_TAG_ALLGATHER, + comm, MPI_STATUS_IGNORE, rank); + if (MPI_SUCCESS != err) { + return err; + } + } + + return err; +} + +static inline int mca_coll_acoll_allgather_intra(const void *sbuf, int scount, + struct ompi_datatype_t *sdtype, void *rbuf, + int rcount, struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + int i; + int err; + int size; + int rank, adj_rank; + int sg_id, num_sgs, is_pow2_num_sgs; + int sg_start, sg_end; + int sg_size, log2_sg_size; + int subgrp_size, last_subgrp_size; + ptrdiff_t rlb, rext; + char *tmpsend = NULL, *tmprecv = NULL; + int sendto, recvfrom; + int num_data_blks, data_blk_size[2] = {0}, blk_ofst[2] = {0}; + int bcount; + int last_subgrp_rcnt; + int brank, last_brank; + int use_rd_base, use_ring_sg; + int use_ring = 0, use_lin = 0; + int nreqs; + ompi_request_t **preq, **reqs; + size_t dsize; + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + + err = ompi_datatype_get_extent(rdtype, &rlb, &rext); + if (MPI_SUCCESS != err) { + return err; + } + + ompi_datatype_type_size(rdtype, &dsize); + size = ompi_comm_size(comm); + rank = ompi_comm_rank(comm); + sg_size = acoll_module->sg_cnt; + log2_sg_size = acoll_module->log2_sg_cnt; + + /* Handle non MPI_IN_PLACE */ + tmprecv = (char *) rbuf + (ptrdiff_t) rank * (ptrdiff_t) rcount * rext; + if (MPI_IN_PLACE != sbuf) { + tmpsend = (char *) sbuf; + err = ompi_datatype_sndrcv(tmpsend, scount, sdtype, tmprecv, rcount, rdtype); + if (MPI_SUCCESS != err) { + return err; + } + } + + /* Derive subgroup parameters */ + sg_id = rank >> log2_sg_size; + num_sgs = (size + sg_size - 1) >> log2_sg_size; + sg_start = sg_id << log2_sg_size; + sg_end = sg_start + sg_size; + if (sg_end > size) { + sg_end = size; + } + subgrp_size = sg_end - sg_start; + last_subgrp_size = size - ((num_sgs - 1) << log2_sg_size); + last_subgrp_rcnt = rcount * last_subgrp_size; + use_ring_sg = (subgrp_size != sg_size) ? 1 : 0; + bcount = rcount << log2_sg_size; + + /* Override subgroup params based on data size */ + coll_allgather_decision_fixed(size, dsize * (unsigned long) rcount, sg_size, &use_ring, + &use_lin); + + if (use_lin) { + err = ompi_coll_base_allgather_intra_basic_linear(sbuf, scount, sdtype, rbuf, rcount, + rdtype, comm, module); + return err; + } + if (use_ring) { + sg_size = sg_end = subgrp_size = size; + num_sgs = 1; + use_ring_sg = 1; + sg_start = 0; + } + + /* Do ring/recursive doubling based allgather within subgroup */ + adj_rank = rank - sg_start; + if (use_ring_sg) { + recvfrom = ((adj_rank - 1 + subgrp_size) % subgrp_size) + sg_start; + sendto = ((adj_rank + 1) % subgrp_size) + sg_start; + + /* Loop over ranks in subgroup */ + for (i = 0; i < (subgrp_size - 1); i++) { + int recv_peer = ((adj_rank - i - 1 + subgrp_size) % subgrp_size) + sg_start; + int send_peer = ((adj_rank - i + subgrp_size) % subgrp_size) + sg_start; + + tmprecv = (char *) rbuf + (ptrdiff_t) recv_peer * (ptrdiff_t) rcount * rext; + tmpsend = (char *) rbuf + (ptrdiff_t) send_peer * (ptrdiff_t) rcount * rext; + + /* Sendreceive */ + err = ompi_coll_base_sendrecv(tmpsend, rcount, rdtype, sendto, + MCA_COLL_BASE_TAG_ALLGATHER, tmprecv, rcount, rdtype, + recvfrom, MCA_COLL_BASE_TAG_ALLGATHER, comm, + MPI_STATUS_IGNORE, rank); + if (MPI_SUCCESS != err) { + return err; + } + } + } else { + err = rd_allgather_sub(rbuf, rdtype, comm, rcount, rank, rank, adj_rank, sg_size, 0, + sg_start, sg_size, rext); + if (MPI_SUCCESS != err) { + return err; + } + } + + /* Return if all ranks belong to single subgroup */ + if (num_sgs == 1) { + /* All done */ + return err; + } + + /* Do ring/rd based allgather across start ranks of subgroups */ + is_pow2_num_sgs = 0; + if (num_sgs == (1 << opal_hibit(num_sgs, comm->c_cube_dim))) { + is_pow2_num_sgs = 1; + } + use_rd_base = is_pow2_num_sgs ? ((last_subgrp_rcnt == bcount) ? 1 : 0) : 0; + + brank = sg_id; + last_brank = num_sgs - 1; + + /* Use ring for non-power of 2 cases */ + if (!(rank & (sg_size - 1)) && !use_rd_base) { + recvfrom = ((brank - 1 + num_sgs) % num_sgs) << log2_sg_size; + sendto = ((brank + 1) % num_sgs) << log2_sg_size; + + /* Loop over subgroups */ + for (i = 0; i < (num_sgs - 1); i++) { + int recv_peer = ((brank - i - 1 + num_sgs) % num_sgs); + int send_peer = ((brank - i + num_sgs) % num_sgs); + int scnt = (send_peer == last_brank) ? last_subgrp_rcnt : bcount; + int rcnt = (recv_peer == last_brank) ? last_subgrp_rcnt : bcount; + + tmprecv = (char *) rbuf + (ptrdiff_t) recv_peer * (ptrdiff_t) bcount * rext; + tmpsend = (char *) rbuf + (ptrdiff_t) send_peer * (ptrdiff_t) bcount * rext; + + recv_peer <<= log2_sg_size; + send_peer <<= log2_sg_size; + + /* Sendreceive */ + err = ompi_coll_base_sendrecv(tmpsend, scnt, rdtype, sendto, + MCA_COLL_BASE_TAG_ALLGATHER, tmprecv, rcnt, rdtype, + recvfrom, MCA_COLL_BASE_TAG_ALLGATHER, comm, + MPI_STATUS_IGNORE, rank); + if (MPI_SUCCESS != err) { + return err; + } + } + } else if (!(rank & (sg_size - 1))) { + /* Use recursive doubling for power of 2 cases */ + err = rd_allgather_sub(rbuf, rdtype, comm, bcount, brank, rank, brank, num_sgs, 1, sg_start, + sg_size, rext); + if (MPI_SUCCESS != err) { + return err; + } + } + /* Now all base ranks have the full data */ + /* Do broadcast within subgroups from the base ranks for the extra data */ + if (sg_id == 0) { + num_data_blks = 1; + data_blk_size[0] = bcount * (num_sgs - 2) + last_subgrp_rcnt; + blk_ofst[0] = bcount; + } else if (sg_id == num_sgs - 1) { + if (last_subgrp_size < 2) { + return err; + } + num_data_blks = 1; + data_blk_size[0] = bcount * (num_sgs - 1); + blk_ofst[0] = 0; + } else { + num_data_blks = 2; + data_blk_size[0] = bcount * sg_id; + data_blk_size[1] = bcount * (num_sgs - sg_id - 2) + last_subgrp_rcnt; + blk_ofst[0] = 0; + blk_ofst[1] = bcount * (sg_id + 1); + } + reqs = ompi_coll_base_comm_get_reqs(module->base_data, size); + if (NULL == reqs) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + nreqs = 0; + preq = reqs; + /* Loop over data blocks */ + for (i = 0; i < num_data_blks; i++) { + char *buff = (char *) rbuf + (ptrdiff_t) blk_ofst[i] * rext; + int sg_dim = opal_hibit(subgrp_size - 1, comm->c_cube_dim); + if ((1 << sg_dim) < subgrp_size) { + sg_dim++; + } + /* The size parameters to sg_bcast_intra ensures that the no. of send + requests do not exceed the max allocated. */ + err = sg_bcast_intra(buff, data_blk_size[i], rdtype, rank, sg_dim, size, sg_size, sg_start, + sg_start, comm, module, preq, &nreqs); + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + return err; + } + } + /* Start and wait on all requests. */ + if (nreqs > 0) { + err = ompi_request_wait_all(nreqs, reqs, MPI_STATUSES_IGNORE); + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + } + } + + /* All done */ + return err; +} + +/* + * mca_coll_acoll_allgather + * + * Function: Allgather operation using subgroup based algorithm + * Accepts: Same arguments as MPI_Allgather() + * Returns: MPI_SUCCESS or error code + * + * Description: Allgather is performed across and within subgroups. + * Subgroups can be 1 or more based on size and count. + * + * Memory: No additional memory requirements beyond user-supplied buffers. + * + */ +int mca_coll_acoll_allgather(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, + void *rbuf, int rcount, struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, mca_coll_base_module_t *module) +{ + int i; + int err; + int size; + int rank; + int num_nodes, node_start, node_end, node_id; + int node_size, last_node_size; + ptrdiff_t rlb, rext; + char *tmpsend = NULL, *tmprecv = NULL; + int sendto, recvfrom; + int num_data_blks, data_blk_size[2] = {0}, blk_ofst[2] = {0}; + int bcount; + int last_subgrp_rcnt; + int brank, last_brank; + int use_rd_base; + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + coll_acoll_subcomms_t *subc; + int cid = ompi_comm_get_local_cid(comm); + char *local_rbuf; + ompi_communicator_t *intra_comm; + + /* Fallback to ring if cid is beyond supported limit */ + if (cid >= MCA_COLL_ACOLL_MAX_CID) { + return ompi_coll_base_allgather_intra_ring(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, + module); + } + + subc = &acoll_module->subc[cid]; + size = ompi_comm_size(comm); + if (!subc->initialized && size > 2) { + err = mca_coll_acoll_comm_split_init(comm, acoll_module, 0); + if (MPI_SUCCESS != err) { + return err; + } + } + + err = ompi_datatype_get_extent(rdtype, &rlb, &rext); + if (MPI_SUCCESS != err) { + return err; + } + + rank = ompi_comm_rank(comm); + node_size = size > 2 ? subc->derived_node_size : size; + + /* Derive node parameters */ + num_nodes = (size + node_size - 1) / node_size; + node_id = rank / node_size; + node_start = node_id * node_size; + node_end = node_start + node_size; + if (node_end > size) { + node_end = size; + } + last_node_size = size - (num_nodes - 1) * node_size; + + /* Call intra */ + local_rbuf = (char *) rbuf + (ptrdiff_t) node_start * (ptrdiff_t) rcount * rext; + if (size <= 2) { + intra_comm = comm; + } else { + if (num_nodes > 1) { + assert(subc->local_r_comm != NULL); + } + intra_comm = num_nodes == 1 ? comm : subc->local_r_comm; + } + err = mca_coll_acoll_allgather_intra(sbuf, scount, sdtype, local_rbuf, rcount, rdtype, + intra_comm, module); + if (MPI_SUCCESS != err) { + return err; + } + + /* Return if intra-node communicator */ + if ((num_nodes == 1) || (size <= 2)) { + /* All done */ + return err; + } + + /* Handle inter-case by first doing allgather across node leaders */ + bcount = node_size * rcount; + last_subgrp_rcnt = last_node_size * rcount; + + /* Perform allgather across node leaders */ + if (rank == node_start) { + int is_pow2_num_nodes = 0; + if (num_nodes == (1 << opal_hibit(num_nodes, comm->c_cube_dim))) { + is_pow2_num_nodes = 1; + } + use_rd_base = is_pow2_num_nodes ? ((last_node_size == node_size) ? 1 : 0) : 0; + brank = node_id; + last_brank = num_nodes - 1; + + /* Use ring for non-power of 2 cases */ + if (!use_rd_base) { + recvfrom = ((brank - 1 + num_nodes) % num_nodes) * node_size; + sendto = ((brank + 1) % num_nodes) * node_size; + + /* Loop over nodes */ + for (i = 0; i < (num_nodes - 1); i++) { + int recv_peer = ((brank - i - 1 + num_nodes) % num_nodes); + int send_peer = ((brank - i + num_nodes) % num_nodes); + int scnt = (send_peer == last_brank) ? last_subgrp_rcnt : bcount; + int rcnt = (recv_peer == last_brank) ? last_subgrp_rcnt : bcount; + + tmprecv = (char *) rbuf + (ptrdiff_t) recv_peer * (ptrdiff_t) bcount * rext; + tmpsend = (char *) rbuf + (ptrdiff_t) send_peer * (ptrdiff_t) bcount * rext; + recv_peer *= node_size; + send_peer *= node_size; + + /* Sendreceive */ + err = ompi_coll_base_sendrecv(tmpsend, scnt, rdtype, sendto, + MCA_COLL_BASE_TAG_ALLGATHER, tmprecv, rcnt, rdtype, + recvfrom, MCA_COLL_BASE_TAG_ALLGATHER, comm, + MPI_STATUS_IGNORE, rank); + if (MPI_SUCCESS != err) { + return err; + } + } + } else { + /* Use recursive doubling for power of 2 cases */ + err = rd_allgather_sub(rbuf, rdtype, comm, bcount, brank, rank, brank, num_nodes, 1, + node_start, node_size, rext); + if (MPI_SUCCESS != err) { + return err; + } + } + } /* End of if inter leader */ + + /* Do intra node broadcast */ + if (node_id == 0) { + num_data_blks = 1; + data_blk_size[0] = bcount * (num_nodes - 2) + last_subgrp_rcnt; + blk_ofst[0] = bcount; + } else if (node_id == num_nodes - 1) { + if (last_node_size < 2) { + return err; + } + num_data_blks = 1; + data_blk_size[0] = bcount * (num_nodes - 1); + blk_ofst[0] = 0; + } else { + num_data_blks = 2; + data_blk_size[0] = bcount * node_id; + data_blk_size[1] = bcount * (num_nodes - node_id - 2) + last_subgrp_rcnt; + blk_ofst[0] = 0; + blk_ofst[1] = bcount * (node_id + 1); + } + /* Loop over data blocks */ + for (i = 0; i < num_data_blks; i++) { + char *buff = (char *) rbuf + (ptrdiff_t) blk_ofst[i] * rext; + err = (comm)->c_coll->coll_bcast(buff, data_blk_size[i], rdtype, 0, subc->local_r_comm, + module); + if (MPI_SUCCESS != err) { + return err; + } + } + + /* All done */ + return err; +} diff --git a/ompi/mca/coll/acoll/coll_acoll_allreduce.c b/ompi/mca/coll/acoll/coll_acoll_allreduce.c new file mode 100644 index 00000000000..c03e559b0e4 --- /dev/null +++ b/ompi/mca/coll/acoll/coll_acoll_allreduce.c @@ -0,0 +1,560 @@ +/* -*- Mode: C; indent-tabs-mode:nil -*- */ +/* + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" + + +#include "mpi.h" +#include "ompi/communicator/communicator.h" +#include "ompi/constants.h" +#include "ompi/datatype/ompi_datatype.h" +#include "ompi/mca/coll/base/coll_tags.h" +#include "ompi/mca/coll/coll.h" +#include "ompi/mca/pml/pml.h" +#include "ompi/op/op.h" +#include "opal/util/bit_ops.h" +#include "coll_acoll.h" +#include "coll_acoll_utils.h" + + +void mca_coll_acoll_sync(coll_acoll_data_t *data, int offset, int *group, int gp_size, int rank, int up); +int mca_coll_acoll_allreduce_small_msgs_h(const void *sbuf, void *rbuf, int count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module, int intra); + + +static inline int coll_allreduce_decision_fixed(int comm_size, size_t msg_size) +{ + int alg = 3; + if (msg_size <= 256) { + alg = 1; + } else if (msg_size <= 1045876) { + alg = 2; + } else if (msg_size <= 4194304) { + alg = 3; + } else if (msg_size <= 8388608) { + alg = 0; + } else { + alg = 3; + } + return alg; +} + +#ifdef HAVE_XPMEM_H +static inline int mca_coll_acoll_reduce_xpmem_h(const void *sbuf, void *rbuf, int count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + int size; + size_t total_dsize, dsize; + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + + coll_acoll_subcomms_t *subc; + int cid = ompi_comm_get_local_cid(comm); + subc = &acoll_module->subc[cid]; + coll_acoll_init(module, comm, subc->data); + coll_acoll_data_t *data = subc->data; + if (NULL == data) { + return -1; + } + + size = ompi_comm_size(comm); + int rank = ompi_comm_rank(comm); + ompi_datatype_type_size(dtype, &dsize); + total_dsize = dsize * (unsigned long) count; + + int l1_gp_size = data->l1_gp_size; + int *l1_gp = data->l1_gp; + int *l2_gp = data->l2_gp; + int l2_gp_size = data->l2_gp_size; + + int l1_local_rank = data->l1_local_rank; + int l2_local_rank = data->l2_local_rank; + char *tmp_sbuf = NULL; + char *tmp_rbuf = NULL; + if (!subc->xpmem_use_sr_buf) { + tmp_rbuf = (char *) data->scratch; + tmp_sbuf = (char *) data->scratch + (subc->xpmem_buf_size) / 2; + if ((sbuf == MPI_IN_PLACE)) { + memcpy(tmp_sbuf, rbuf, total_dsize); + } else { + memcpy(tmp_sbuf, sbuf, total_dsize); + } + } else { + tmp_sbuf = (char *) sbuf; + tmp_rbuf = (char *) rbuf; + if (sbuf == MPI_IN_PLACE) { + tmp_sbuf = (char *) rbuf; + } + } + void *sbuf_vaddr[1] = {tmp_sbuf}; + void *rbuf_vaddr[1] = {tmp_rbuf}; + int err = MPI_SUCCESS; + + err = comm->c_coll->coll_allgather(sbuf_vaddr, sizeof(void *), MPI_BYTE, data->allshm_sbuf, + sizeof(void *), MPI_BYTE, comm, + comm->c_coll->coll_allgather_module); + if (err != MPI_SUCCESS) { + return err; + } + + err = comm->c_coll->coll_allgather(rbuf_vaddr, sizeof(void *), MPI_BYTE, data->allshm_rbuf, + sizeof(void *), MPI_BYTE, comm, + comm->c_coll->coll_allgather_module); + if (err != MPI_SUCCESS) { + return err; + } + + register_and_cache(size, total_dsize, rank, data); + + /* reduce to the local group leader */ + int chunk = count / l1_gp_size; + int my_count_size = (l1_local_rank == (l1_gp_size - 1)) ? chunk + count % l1_gp_size : chunk; + + if (rank == l1_gp[0]) { + if (sbuf != MPI_IN_PLACE) + memcpy(tmp_rbuf, sbuf, my_count_size * dsize); + + for (int i = 1; i < l1_gp_size; i++) { + ompi_op_reduce(op, (char *) data->xpmem_saddr[l1_gp[i]] + chunk * l1_local_rank * dsize, + (char *) tmp_rbuf + chunk * l1_local_rank * dsize, my_count_size, dtype); + } + } else { + ompi_3buff_op_reduce(op, + (char *) data->xpmem_saddr[l1_gp[0]] + chunk * l1_local_rank * dsize, + (char *) tmp_sbuf + chunk * l1_local_rank * dsize, + (char *) data->xpmem_raddr[l1_gp[0]] + chunk * l1_local_rank * dsize, + my_count_size, dtype); + for (int i = 1; i < l1_gp_size; i++) { + if (i == l1_local_rank) { + continue; + } + ompi_op_reduce(op, (char *) data->xpmem_saddr[l1_gp[i]] + chunk * l1_local_rank * dsize, + (char *) data->xpmem_raddr[l1_gp[0]] + chunk * l1_local_rank * dsize, + my_count_size, dtype); + } + } + err = ompi_coll_base_barrier_intra_tree(comm, module); + if (err != MPI_SUCCESS) { + return err; + } + + /* perform reduce to 0 */ + int local_size = l2_gp_size; + if ((rank == l1_gp[0]) && (local_size > 1)) { + chunk = count / local_size; + + my_count_size = (l2_local_rank == (local_size - 1)) ? chunk + (count % local_size) : chunk; + + if (l2_local_rank == 0) { + for (int i = 1; i < local_size; i++) { + ompi_op_reduce(op, (char *) data->xpmem_raddr[l2_gp[i]], (char *) tmp_rbuf, + my_count_size, dtype); + } + } else { + for (int i = 1; i < local_size; i++) { + if (i == l2_local_rank) { + continue; + } + + ompi_op_reduce(op, + (char *) data->xpmem_raddr[l2_gp[i]] + chunk * l2_local_rank * dsize, + (char *) data->xpmem_raddr[0] + chunk * l2_local_rank * dsize, + my_count_size, dtype); + } + ompi_op_reduce(op, (char *) tmp_rbuf + chunk * l2_local_rank * dsize, + (char *) data->xpmem_raddr[0] + chunk * l2_local_rank * dsize, + my_count_size, dtype); + } + } + + err = ompi_coll_base_barrier_intra_tree(comm, module); + if (!subc->xpmem_use_sr_buf) { + memcpy(rbuf, tmp_rbuf, total_dsize); + } + return err; +} + +static inline int mca_coll_acoll_allreduce_xpmem_f(const void *sbuf, void *rbuf, int count, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + int size; + size_t total_dsize, dsize; + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + + coll_acoll_subcomms_t *subc; + int cid = ompi_comm_get_local_cid(comm); + subc = &acoll_module->subc[cid]; + coll_acoll_init(module, comm, subc->data); + coll_acoll_data_t *data = subc->data; + if (NULL == data) { + return -1; + } + + size = ompi_comm_size(comm); + ompi_datatype_type_size(dtype, &dsize); + total_dsize = dsize * (unsigned long) count; + + char *tmp_sbuf = NULL; + char *tmp_rbuf = NULL; + if (!subc->xpmem_use_sr_buf) { + tmp_rbuf = (char *) data->scratch; + tmp_sbuf = (char *) data->scratch + (subc->xpmem_buf_size) / 2; + if ((sbuf == MPI_IN_PLACE)) { + memcpy(tmp_sbuf, rbuf, total_dsize); + } else { + memcpy(tmp_sbuf, sbuf, total_dsize); + } + } else { + tmp_sbuf = (char *) sbuf; + tmp_rbuf = (char *) rbuf; + if (sbuf == MPI_IN_PLACE) { + tmp_sbuf = (char *) rbuf; + } + } + void *sbuf_vaddr[1] = {tmp_sbuf}; + void *rbuf_vaddr[1] = {tmp_rbuf}; + int err = MPI_SUCCESS; + int rank = ompi_comm_rank(comm); + + err = comm->c_coll->coll_allgather(sbuf_vaddr, sizeof(void *), MPI_BYTE, data->allshm_sbuf, + sizeof(void *), MPI_BYTE, comm, + comm->c_coll->coll_allgather_module); + if (err != MPI_SUCCESS) { + return err; + } + err = comm->c_coll->coll_allgather(rbuf_vaddr, sizeof(void *), MPI_BYTE, data->allshm_rbuf, + sizeof(void *), MPI_BYTE, comm, + comm->c_coll->coll_allgather_module); + + if (err != MPI_SUCCESS) { + return err; + } + + register_and_cache(size, total_dsize, rank, data); + + int chunk = count / size; + int my_count_size = (rank == (size - 1)) ? (count / size) + count % size : count / size; + if (rank == 0) { + if (sbuf != MPI_IN_PLACE) + memcpy(tmp_rbuf, sbuf, my_count_size * dsize); + } else { + ompi_3buff_op_reduce(op, (char *) data->xpmem_saddr[0] + chunk * rank * dsize, + (char *) tmp_sbuf + chunk * rank * dsize, + (char *) tmp_rbuf + chunk * rank * dsize, my_count_size, dtype); + } + + err = ompi_coll_base_barrier_intra_tree(comm, module); + if (err != MPI_SUCCESS) { + return err; + } + + for (int i = 1; i < size; i++) { + if (rank == i) { + continue; + } + ompi_op_reduce(op, (char *) data->xpmem_saddr[i] + chunk * rank * dsize, + (char *) tmp_rbuf + chunk * rank * dsize, my_count_size, dtype); + } + err = ompi_coll_base_barrier_intra_tree(comm, module); + if (err != MPI_SUCCESS) { + return err; + } + + int tmp = chunk * dsize; + for (int i = 0; i < size; i++) { + if (subc->xpmem_use_sr_buf && (rank == i)) { + continue; + } + my_count_size = (i == (size - 1)) ? (count / size) + count % size : count / size; + int tmp1 = i * tmp; + char *dst = (char *) rbuf + tmp1; + char *src = (char *) data->xpmem_raddr[i] + tmp1; + memcpy(dst, src, my_count_size * dsize); + } + + err = ompi_coll_base_barrier_intra_tree(comm, module); + + return err; +} +#endif + +void mca_coll_acoll_sync(coll_acoll_data_t *data, int offset, int *group, int gp_size, int rank, + int up) +{ + volatile int *tmp, tmp0; + tmp = (int *) ((char *) data->allshmmmap_sbuf[group[0]] + offset + + 64 * rank); /* ToDo: 64 should be replace by cacheline size */ + tmp0 = __atomic_load_n((int *) ((char *) data->allshmmmap_sbuf[group[0]] + offset + + 64 * group[0]), + __ATOMIC_RELAXED); /* ToDo: 64 should be replace by cacheline size */ + + opal_atomic_wmb(); + + int val; + if (up == 1) { + val = data->sync[0]; + } else { + val = data->sync[1]; + } + + if (rank == group[0]) { + __atomic_store_n((int *) ((char *) data->allshmmmap_sbuf[group[0]] + offset + + 64 * group[0]), + val, __ATOMIC_RELAXED); + } + + while (tmp0 != val) { + tmp0 = __atomic_load_n((int *) ((char *) data->allshmmmap_sbuf[group[0]] + offset + + 64 * group[0]), + __ATOMIC_RELAXED); + } + + if (rank != group[0]) { + val++; + __atomic_store_n(tmp, val, __ATOMIC_RELAXED); + } + opal_atomic_wmb(); + if (rank == group[0]) { + for (int i = 1; i < gp_size; i++) { + volatile int tmp1 = __atomic_load_n( + (int *) ((char *) data->allshmmmap_sbuf[group[0]] + offset + 64 * group[i]), + __ATOMIC_RELAXED); // ToDo: 64 should be replace by cacheline size + while (tmp1 == val) { + tmp1 = __atomic_load_n((int *) ((char *) data->allshmmmap_sbuf[group[0]] + offset + + 64 * group[i]), + __ATOMIC_RELAXED); + } + opal_atomic_wmb(); + } + ++val; + __atomic_store_n(tmp, val, __ATOMIC_RELAXED); + } else { + volatile int tmp1 = __atomic_load_n( + (int *) ((char *) data->allshmmmap_sbuf[group[0]] + offset + 64 * group[0]), + __ATOMIC_RELAXED); // ToDo: 64 should be replace by cacheline size + while (tmp1 != val) { + tmp1 = __atomic_load_n((int *) ((char *) data->allshmmmap_sbuf[group[0]] + offset + + 64 * group[0]), + __ATOMIC_RELAXED); + } + } + if (up == 1) { + data->sync[0] = val; + } else { + data->sync[1] = val; + } +} + +int mca_coll_acoll_allreduce_small_msgs_h(const void *sbuf, void *rbuf, int count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module, int intra) +{ + size_t dsize; + int err = MPI_SUCCESS; + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + coll_acoll_subcomms_t *subc; + int cid = ompi_comm_get_local_cid(comm); + subc = &acoll_module->subc[cid]; + coll_acoll_init(module, comm, subc->data); + coll_acoll_data_t *data = subc->data; + if (NULL == data) { + return -1; + } + + int rank = ompi_comm_rank(comm); + ompi_datatype_type_size(dtype, &dsize); + + int l1_gp_size = data->l1_gp_size; + int *l1_gp = data->l1_gp; + int *l2_gp = data->l2_gp; + int l2_gp_size = data->l2_gp_size; + + int l1_local_rank = data->l1_local_rank; + int l2_local_rank = data->l2_local_rank; + int comm_id = ompi_comm_get_local_cid(comm); + + int offset1 = data->offset[0]; + int offset2 = data->offset[1]; + int tshm_offset = data->offset[2]; + int shm_offset = data->offset[3]; + + int local_size; + + if (rank == l1_gp[0]) { + if (l2_gp_size > 1) { + mca_coll_acoll_sync(data, offset2, l2_gp, l2_gp_size, rank, 3); + } + } + + if (MPI_IN_PLACE == sbuf) { + memcpy((char *) data->allshmmmap_sbuf[l1_gp[0]] + shm_offset, rbuf, count * dsize); + } else { + memcpy((char *) data->allshmmmap_sbuf[l1_gp[0]] + shm_offset, sbuf, count * dsize); + } + + mca_coll_acoll_sync(data, offset1, l1_gp, l1_gp_size, rank, 1); + + if (rank == l1_gp[0]) { + memcpy((char *) data->allshmmmap_sbuf[l1_gp[l1_local_rank]], + (char *) data->allshmmmap_sbuf[l1_gp[0]] + shm_offset, count * dsize); + for (int i = 1; i < l1_gp_size; i++) { + ompi_op_reduce(op, + (char *) data->allshmmmap_sbuf[l1_gp[0]] + tshm_offset + + l1_gp[i] * 8 * 1024, + (char *) data->allshmmmap_sbuf[l1_gp[l1_local_rank]], count, dtype); + } + memcpy(rbuf, data->allshmmmap_sbuf[l1_gp[l1_local_rank]], count * dsize); + } + + if (rank == l1_gp[0]) { + if (l2_gp_size > 1) { + mca_coll_acoll_sync(data, offset2, l2_gp, l2_gp_size, rank, 3); + } + } + + /* perform allreduce across leaders */ + local_size = l2_gp_size; + if (local_size > 1) { + if (rank == l1_gp[0]) { + for (int i = 0; i < local_size; i++) { + if (i == l2_local_rank) { + continue; + } + ompi_op_reduce(op, (char *) data->allshmmmap_sbuf[l2_gp[i]], (char *) rbuf, count, + dtype); + } + } + } + + if (intra && (ompi_comm_size(acoll_module->subc[comm_id].numa_comm) > 1)) { + err = mca_coll_acoll_bcast(rbuf, count, dtype, 0, acoll_module->subc[comm_id].numa_comm, module); + } + return err; +} + +int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, int count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + struct ompi_communicator_t *comm, mca_coll_base_module_t *module) +{ + int size, alg, err; + int num_nodes; + size_t total_dsize, dsize; + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + size = ompi_comm_size(comm); + ompi_datatype_type_size(dtype, &dsize); + total_dsize = dsize * (unsigned long) count; + + if (size == 1) { + if (MPI_IN_PLACE != sbuf) { + memcpy((char *) rbuf, sbuf, total_dsize); + } + return MPI_SUCCESS; + } + + coll_acoll_subcomms_t *subc; + int cid = ompi_comm_get_local_cid(comm); + subc = &acoll_module->subc[cid]; + + /* Falling back to recursivedoubling for non-commutative operators to be safe */ + if (!ompi_op_is_commute(op)) { + return ompi_coll_base_allreduce_intra_recursivedoubling(sbuf, rbuf, count, dtype, op, comm, + module); + } + + /* Fallback to knomial if cid is beyond supported limit */ + if (cid >= MCA_COLL_ACOLL_MAX_CID) { + return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype, op, comm, + module); + } + + subc = &acoll_module->subc[cid]; + if (!subc->initialized) { + err = mca_coll_acoll_comm_split_init(comm, acoll_module, 0); + if (MPI_SUCCESS != err) + return err; + } + + num_nodes = subc->num_nodes; + + alg = coll_allreduce_decision_fixed(size, total_dsize); + + if (num_nodes == 1) { + if (total_dsize < 32) { + return ompi_coll_base_allreduce_intra_recursivedoubling(sbuf, rbuf, count, dtype, op, + comm, module); + } else if (total_dsize < 512) { + return mca_coll_acoll_allreduce_small_msgs_h(sbuf, rbuf, count, dtype, op, comm, module, + 1); + } else if (total_dsize <= 2048) { + return ompi_coll_base_allreduce_intra_recursivedoubling(sbuf, rbuf, count, dtype, op, + comm, module); + } else if (total_dsize < 65536) { + if (alg == 1) { + return ompi_coll_base_allreduce_intra_recursivedoubling(sbuf, rbuf, count, dtype, + op, comm, module); + } else if (alg == 2) { + return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype, + op, comm, module); + } else { /*alg == 3 */ + return ompi_coll_base_allreduce_intra_ring_segmented(sbuf, rbuf, count, dtype, op, + comm, module, 0); + } + } else if (total_dsize < 4194304) { +#ifdef HAVE_XPMEM_H + if (((subc->xpmem_use_sr_buf != 0) || (subc->xpmem_buf_size > 2 * total_dsize)) && (subc->without_xpmem != 1)) { + return mca_coll_acoll_allreduce_xpmem_f(sbuf, rbuf, count, dtype, op, comm, module); + } else { + return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype, + op, comm, module); + } +#else + return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype, op, + comm, module); +#endif + } else if (total_dsize <= 16777216) { +#ifdef HAVE_XPMEM_H + if (((subc->xpmem_use_sr_buf != 0) || (subc->xpmem_buf_size > 2 * total_dsize)) && (subc->without_xpmem != 1)) { + mca_coll_acoll_reduce_xpmem_h(sbuf, rbuf, count, dtype, op, comm, module); + return mca_coll_acoll_bcast(rbuf, count, dtype, 0, comm, module); + } else { + return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype, + op, comm, module); + } +#else + return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype, op, + comm, module); +#endif + } else { +#ifdef HAVE_XPMEM_H + if (((subc->xpmem_use_sr_buf != 0) || (subc->xpmem_buf_size > 2 * total_dsize)) && (subc->without_xpmem != 1)) { + return mca_coll_acoll_allreduce_xpmem_f(sbuf, rbuf, count, dtype, op, comm, module); + } else { + return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype, + op, comm, module); + } +#else + return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype, op, + comm, module); +#endif + } + + } else { + return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype, op, comm, + module); + } + return MPI_SUCCESS; +} diff --git a/ompi/mca/coll/acoll/coll_acoll_barrier.c b/ompi/mca/coll/acoll/coll_acoll_barrier.c new file mode 100644 index 00000000000..a138027f444 --- /dev/null +++ b/ompi/mca/coll/acoll/coll_acoll_barrier.c @@ -0,0 +1,223 @@ +/* -*- Mode: C; indent-tabs-mode:nil -*- */ +/* + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" + + +#include "mpi.h" +#include "ompi/constants.h" +#include "ompi/datatype/ompi_datatype.h" +#include "ompi/mca/coll/base/coll_base_functions.h" +#include "ompi/mca/coll/base/coll_tags.h" +#include "ompi/mca/coll/coll.h" +#include "ompi/mca/pml/pml.h" +#include "opal/util/bit_ops.h" +#include "coll_acoll.h" +#include "coll_acoll_utils.h" + +static int mca_coll_acoll_barrier_recv_subc(struct ompi_communicator_t *comm, + mca_coll_base_module_t *module, ompi_request_t **reqs, + int *nreqs, int root) +{ + int rank = ompi_comm_rank(comm); + int size = ompi_comm_size(comm); + int err = MPI_SUCCESS; + + if (rank < 0) { + return err; + } + + /* Non-zero ranks receive zero-byte message from rank 0 */ + if (rank != root) { + err = MCA_PML_CALL( + recv(NULL, 0, MPI_BYTE, root, MCA_COLL_BASE_TAG_BARRIER, comm, MPI_STATUS_IGNORE)); + if (MPI_SUCCESS != err) { + return err; + } + } else if (rank == root) { + ompi_request_t **preq = reqs; + *nreqs = 0; + for (int i = 0; i < size; i++) { + if (i == root) { + continue; + } + *nreqs = *nreqs + 1; + err = MCA_PML_CALL(isend(NULL, 0, MPI_BYTE, i, MCA_COLL_BASE_TAG_BARRIER, + MCA_PML_BASE_SEND_STANDARD, comm, preq++)); + if (MPI_SUCCESS != err) { + return err; + } + } + err = ompi_request_wait_all(*nreqs, reqs, MPI_STATUSES_IGNORE); + if (MPI_SUCCESS != err) { + return err; + } + } + + return err; +} + +static int mca_coll_acoll_barrier_send_subc(struct ompi_communicator_t *comm, + mca_coll_base_module_t *module, ompi_request_t **reqs, + int *nreqs, int root) +{ + int rank = ompi_comm_rank(comm); + int size = ompi_comm_size(comm); + int err = MPI_SUCCESS; + + if (rank < 0) { + return err; + } + + /* Non-zero ranks send zero-byte message to rank 0 */ + if (rank != root) { + err = MCA_PML_CALL(send(NULL, 0, MPI_BYTE, root, MCA_COLL_BASE_TAG_BARRIER, + MCA_PML_BASE_SEND_STANDARD, comm)); + if (MPI_SUCCESS != err) { + return err; + } + } else if (rank == root) { + ompi_request_t **preq = reqs; + *nreqs = 0; + for (int i = 0; i < size; i++) { + if (i == root) { + continue; + } + *nreqs = *nreqs + 1; + err = MCA_PML_CALL( + irecv(NULL, 0, MPI_BYTE, i, MCA_COLL_BASE_TAG_BARRIER, comm, preq++)); + if (MPI_SUCCESS != err) { + return err; + } + } + err = ompi_request_wait_all(*nreqs, reqs, MPI_STATUSES_IGNORE); + if (MPI_SUCCESS != err) { + return err; + } + } + + return err; +} + +/* + * mca_coll_acoll_barrier_intra + * + * Function: Barrier operation using subgroup based algorithm + * Accepts: Same arguments as MPI_Barrier() + * Returns: MPI_SUCCESS or error code + * + * Description: Step 1 - All leaf ranks of a subgroup send to base rank. + * Step 2 - All base ranks send to rank 0. + * Step 3 - Base rank sends to leaf ranks. + * + * Limitations: None + * + * Memory: No additional memory requirements beyond user-supplied buffers. + * + */ +int mca_coll_acoll_barrier_intra(struct ompi_communicator_t *comm, mca_coll_base_module_t *module) +{ + int size, ssize, bsize; + int err = MPI_SUCCESS; + int nreqs = 0; + ompi_request_t **reqs; + int num_nodes; + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + coll_acoll_subcomms_t *subc; + int cid = ompi_comm_get_local_cid(comm); + + /* Fallback to linear if cid is beyond supported limit */ + if (cid >= MCA_COLL_ACOLL_MAX_CID) { + return ompi_coll_base_barrier_intra_basic_linear(comm, module); + } + + subc = &acoll_module->subc[cid]; + size = ompi_comm_size(comm); + if (size == 1) { + return err; + } + if (!subc->initialized && size > 1) { + err = mca_coll_acoll_comm_split_init(comm, acoll_module, 0); + if (MPI_SUCCESS != err) { + return err; + } + } + num_nodes = size > 1 ? subc->num_nodes : 1; + + reqs = ompi_coll_base_comm_get_reqs(module->base_data, size); + if (NULL == reqs) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + + ssize = ompi_comm_size(subc->subgrp_comm); + bsize = ompi_comm_size(subc->base_comm[MCA_COLL_ACOLL_L3CACHE][MCA_COLL_ACOLL_LYR_NODE]); + + /* Sends from leaf ranks at subgroup level */ + if (ssize > 1) { + err = mca_coll_acoll_barrier_send_subc(subc->subgrp_comm, module, reqs, &nreqs, + subc->subgrp_root); + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + return err; + } + } + /* Sends from leaf ranks at base rank level */ + if ((bsize > 1) && (subc->base_root[MCA_COLL_ACOLL_L3CACHE][MCA_COLL_ACOLL_LYR_NODE] != -1)) { + err = mca_coll_acoll_barrier_send_subc( + subc->base_comm[MCA_COLL_ACOLL_L3CACHE][MCA_COLL_ACOLL_LYR_NODE], module, reqs, &nreqs, + subc->base_root[MCA_COLL_ACOLL_L3CACHE][MCA_COLL_ACOLL_LYR_NODE]); + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + return err; + } + } + /* Sends from leaf ranks at node leader level */ + if ((num_nodes > 1) && (subc->outer_grp_root != -1)) { + err = mca_coll_acoll_barrier_send_subc(subc->leader_comm, module, reqs, &nreqs, + subc->outer_grp_root); + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + return err; + } + } + + /* Leaf ranks at node leader level receive from root */ + if ((num_nodes > 1) && (subc->outer_grp_root != -1)) { + err = mca_coll_acoll_barrier_recv_subc(subc->leader_comm, module, reqs, &nreqs, + subc->outer_grp_root); + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + return err; + } + } + /* Leaf ranks at base rank level receive from inter leader */ + if ((bsize > 1) && (subc->base_root[MCA_COLL_ACOLL_L3CACHE][MCA_COLL_ACOLL_LYR_NODE] != -1)) { + err = mca_coll_acoll_barrier_recv_subc( + subc->base_comm[MCA_COLL_ACOLL_L3CACHE][MCA_COLL_ACOLL_LYR_NODE], module, reqs, &nreqs, + subc->base_root[MCA_COLL_ACOLL_L3CACHE][MCA_COLL_ACOLL_LYR_NODE]); + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + return err; + } + } + /* Leaf ranks at subgroup level to receive from base ranks */ + if (ssize > 1) { + err = mca_coll_acoll_barrier_recv_subc(subc->subgrp_comm, module, reqs, &nreqs, + subc->subgrp_root); + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + return err; + } + } + + /* All done */ + ompi_coll_base_free_reqs(reqs, nreqs); + return err; +} diff --git a/ompi/mca/coll/acoll/coll_acoll_bcast.c b/ompi/mca/coll/acoll/coll_acoll_bcast.c new file mode 100644 index 00000000000..0fd64ffec69 --- /dev/null +++ b/ompi/mca/coll/acoll/coll_acoll_bcast.c @@ -0,0 +1,540 @@ +/* -*- Mode: C; indent-tabs-mode:nil -*- */ +/* + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" + +#include "mpi.h" +#include "ompi/constants.h" +#include "ompi/datatype/ompi_datatype.h" +#include "ompi/mca/coll/base/coll_base_functions.h" +#include "ompi/mca/coll/base/coll_tags.h" +#include "ompi/mca/coll/coll.h" +#include "ompi/mca/pml/pml.h" +#include "opal/util/bit_ops.h" +#include "coll_acoll.h" +#include "coll_acoll_utils.h" + +typedef int (*bcast_subc_func)(void *buff, int count, struct ompi_datatype_t *datatype, int root, + struct ompi_communicator_t *comm, ompi_request_t **preq, int *nreqs, + int world_rank); + +/* + * bcast_binomial + * + * Function: Broadcast operation using balanced binomial tree + * + * Description: Core logic of implementation is derived from that in + * "basic" component. + */ +static int bcast_binomial(void *buff, int count, struct ompi_datatype_t *datatype, int root, + struct ompi_communicator_t *comm, ompi_request_t **preq, int *nreqs, + int world_rank) +{ + int msb_pos, sub_rank, peer, err = MPI_SUCCESS; + int size, rank, dim; + int i, mask; + + size = ompi_comm_size(comm); + rank = ompi_comm_rank(comm); + dim = comm->c_cube_dim; + sub_rank = (rank - root + size) % size; + + msb_pos = opal_hibit(sub_rank, dim); + --dim; + + /* Receive data from parent in the subgroup tree. */ + if (sub_rank > 0) { + assert(msb_pos >= 0); + peer = ((sub_rank & ~(1 << msb_pos)) + root) % size; + + err = MCA_PML_CALL( + recv(buff, count, datatype, peer, MCA_COLL_BASE_TAG_BCAST, comm, MPI_STATUS_IGNORE)); + if (MPI_SUCCESS != err) { + return err; + } + } + + for (i = msb_pos + 1, mask = 1 << i; i <= dim; ++i, mask <<= 1) { + peer = sub_rank | mask; + if (peer < size) { + peer = (peer + root) % size; + *nreqs = *nreqs + 1; + + err = MCA_PML_CALL(isend(buff, count, datatype, peer, MCA_COLL_BASE_TAG_BCAST, + MCA_PML_BASE_SEND_STANDARD, comm, preq++)); + if (MPI_SUCCESS != err) { + return err; + } + } + } + + return err; +} + +static int bcast_flat_tree(void *buff, int count, struct ompi_datatype_t *datatype, int root, + struct ompi_communicator_t *comm, ompi_request_t **preq, int *nreqs, + int world_rank) +{ + int peer; + int err = MPI_SUCCESS; + int rank = ompi_comm_rank(comm); + int size = ompi_comm_size(comm); + + if (rank == root) { + for (peer = 0; peer < size; peer++) { + if (peer == root) { + continue; + } + *nreqs = *nreqs + 1; + err = MCA_PML_CALL(isend(buff, count, datatype, peer, MCA_COLL_BASE_TAG_BCAST, + MCA_PML_BASE_SEND_STANDARD, comm, preq++)); + if (MPI_SUCCESS != err) { + return err; + } + } + } else { + err = MCA_PML_CALL( + recv(buff, count, datatype, root, MCA_COLL_BASE_TAG_BCAST, comm, MPI_STATUS_IGNORE)); + if (MPI_SUCCESS != err) { + return err; + } + } + + return err; +} + +/* + * coll_bcast_decision_fixed + * + * Function: Choose optimal broadcast algorithm + * + * Description: Based on no. of processes and message size, chooses [log|lin] + * broadcast and subgroup size to be used. + * + */ + +#define SET_BCAST_PARAMS(l0, l1, l2) \ + *lin_0 = l0; \ + *lin_1 = l1; \ + *lin_2 = l2; + +static inline void coll_bcast_decision_fixed(int size, size_t total_dsize, int node_size, + int *sg_cnt, int *use_0, int *use_numa, int *lin_0, + int *lin_1, int *lin_2, + mca_coll_acoll_module_t *acoll_module, + coll_acoll_subcomms_t *subc) +{ + int sg_size = *sg_cnt; + *use_0 = 0; + *lin_0 = 0; + *use_numa = 0; + if (size <= node_size) { + if (size <= sg_size) { + *sg_cnt = sg_size; + if (total_dsize <= 8192) { + SET_BCAST_PARAMS(0, 0, 0) + } else { + SET_BCAST_PARAMS(0, 1, 1) + } + } else if (size <= (sg_size << 1)) { + if (total_dsize <= 1024) { + *sg_cnt = size; + SET_BCAST_PARAMS(0, 0, 0) + } else if (total_dsize <= 8192) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 0, 0) + } else if (total_dsize <= 2097152) { + *sg_cnt = size; + SET_BCAST_PARAMS(0, 1, 1) + } else { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 0, 0) + } + } else if (size <= (sg_size << 2)) { + if (total_dsize <= 1024) { + *sg_cnt = size; + SET_BCAST_PARAMS(0, 0, 0) + } else if (total_dsize <= 8192) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 0, 0) + } else if (total_dsize <= 32768) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 1, 1) + } else if (total_dsize <= 4194304) { + *sg_cnt = size; + SET_BCAST_PARAMS(0, 1, 1) + } else { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 0, 0) + } + } else if (size <= (sg_size << 3)) { + if (total_dsize <= 1024) { + *sg_cnt = size; + SET_BCAST_PARAMS(0, 0, 0) + } else if (total_dsize <= 8192) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 0, 0) + } else if (total_dsize <= 262144) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 1, 1) + } else { + *sg_cnt = size; + SET_BCAST_PARAMS(0, 1, 1) + } + } else if (size <= (sg_size << 4)) { + if (total_dsize <= 512) { + *sg_cnt = size; + SET_BCAST_PARAMS(0, 0, 0) + } else if (total_dsize <= 8192) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 0, 0) + } else if (total_dsize <= 262144) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 1, 1) + } else { + *sg_cnt = size; + SET_BCAST_PARAMS(0, 1, 1) + } + } else { + if (total_dsize <= 512) { + *sg_cnt = size; + SET_BCAST_PARAMS(0, 0, 0) + } else if (total_dsize <= 8192) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 0, 0) + } else if (total_dsize <= 262144) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 1, 1) + } else if (total_dsize <= 16777216) { + *sg_cnt = size; + SET_BCAST_PARAMS(0, 1, 1) + } else { + *sg_cnt = sg_size; + *use_numa = 1; + SET_BCAST_PARAMS(0, 1, 1) + } + } + } else { + if (acoll_module->use_dyn_rules) { + *sg_cnt = acoll_module->mnode_sg_size; + *use_0 = acoll_module->use_mnode; + SET_BCAST_PARAMS(acoll_module->use_lin0, acoll_module->use_lin1, acoll_module->use_lin2) + } else { + int derived_node_size = subc->derived_node_size; + *use_0 = 1; + if (size <= (derived_node_size << 2)) { + size_t dsize_thresh[2][3] = {{512, 8192, 131072}, {128, 8192, 65536}}; + int thr_ind = (size <= (derived_node_size << 1)) ? 0 : 1; + if (total_dsize <= dsize_thresh[thr_ind][0]) { + *sg_cnt = node_size; + SET_BCAST_PARAMS(0, 0, 0) + } else if (total_dsize <= dsize_thresh[thr_ind][1]) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 0, 0) + } else if (total_dsize <= dsize_thresh[thr_ind][2]) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(1, 1, 1) + } else { + *sg_cnt = node_size; + SET_BCAST_PARAMS(1, 1, 1) + } + } else if (size <= (derived_node_size << 3)) { + if (total_dsize <= 1024) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 0, 1) + } else if (total_dsize <= 8192) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(1, 0, 1) + } else if (total_dsize <= 65536) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(1, 1, 1) + } else if (total_dsize <= 2097152) { + *sg_cnt = node_size; + SET_BCAST_PARAMS(0, 1, 1) + } else { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 0, 0) + } + } else if (size <= (derived_node_size << 4)) { + if (total_dsize <= 64) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 1, 1) + } else if (total_dsize <= 8192) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 0, 1) + } else if (total_dsize <= 32768) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(1, 1, 1) + } else if (total_dsize <= 2097152) { + *sg_cnt = node_size; + SET_BCAST_PARAMS(0, 1, 1) + } else { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 0, 0) + } + } else { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 0, 0) + } + } + } +} + +static inline void coll_acoll_bcast_subcomms(struct ompi_communicator_t *comm, + coll_acoll_subcomms_t *subc, + struct ompi_communicator_t **subcomms, int *subc_roots, + int root, int num_nodes, int use_0, int no_sg, + int use_numa) +{ + /* Node leaders */ + if (use_0) { + subcomms[MCA_COLL_ACOLL_NODE_L] = subc->leader_comm; + subc_roots[MCA_COLL_ACOLL_NODE_L] = subc->outer_grp_root; + } + /* Intra comm */ + if ((num_nodes > 1) && use_0) { + subc_roots[MCA_COLL_ACOLL_INTRA] = subc->is_root_node + ? subc->local_root[MCA_COLL_ACOLL_LYR_NODE] + : 0; + subcomms[MCA_COLL_ACOLL_INTRA] = subc->local_comm; + } else { + subc_roots[MCA_COLL_ACOLL_INTRA] = root; + subcomms[MCA_COLL_ACOLL_INTRA] = comm; + } + /* Base ranks comm */ + if (no_sg) { + subcomms[MCA_COLL_ACOLL_L3_L] = subcomms[MCA_COLL_ACOLL_INTRA]; + subc_roots[MCA_COLL_ACOLL_L3_L] = subc_roots[MCA_COLL_ACOLL_INTRA]; + } else { + subcomms[MCA_COLL_ACOLL_L3_L] = subc->base_comm[MCA_COLL_ACOLL_L3CACHE] + [MCA_COLL_ACOLL_LYR_NODE]; + subc_roots[MCA_COLL_ACOLL_L3_L] = subc->base_root[MCA_COLL_ACOLL_L3CACHE] + [MCA_COLL_ACOLL_LYR_NODE]; + } + /* Subgroup comm */ + subcomms[MCA_COLL_ACOLL_LEAF] = subc->subgrp_comm; + subc_roots[MCA_COLL_ACOLL_LEAF] = subc->subgrp_root; + + /* Override with numa when needed */ + if (use_numa) { + subcomms[MCA_COLL_ACOLL_L3_L] = subc->base_comm[MCA_COLL_ACOLL_NUMA] + [MCA_COLL_ACOLL_LYR_NODE]; + subc_roots[MCA_COLL_ACOLL_L3_L] = subc->base_root[MCA_COLL_ACOLL_NUMA] + [MCA_COLL_ACOLL_LYR_NODE]; + subcomms[MCA_COLL_ACOLL_LEAF] = subc->numa_comm; + subc_roots[MCA_COLL_ACOLL_LEAF] = subc->numa_root; + } +} + +static int mca_coll_acoll_bcast_intra_node(void *buff, int count, struct ompi_datatype_t *datatype, + mca_coll_base_module_t *module, + coll_acoll_subcomms_t *subc, + struct ompi_communicator_t **subcomms, int *subc_roots, + int lin_1, int lin_2, int no_sg, int use_numa, + int world_rank) +{ + int size; + int rank; + int err; + int subgrp_size; + int is_base = 0; + int nreqs; + ompi_request_t **preq, **reqs; + struct ompi_communicator_t *comm = subcomms[MCA_COLL_ACOLL_INTRA]; + bcast_subc_func bcast_intra[2] = {&bcast_binomial, &bcast_flat_tree}; + + rank = ompi_comm_rank(comm); + size = ompi_comm_size(comm); + + reqs = ompi_coll_base_comm_get_reqs(module->base_data, size); + if (NULL == reqs) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + nreqs = 0; + preq = reqs; + err = MPI_SUCCESS; + if (no_sg) { + is_base = 1; + } else { + int ind = use_numa ? MCA_COLL_ACOLL_NUMA : MCA_COLL_ACOLL_L3CACHE; + is_base = rank == subc->base_rank[ind] ? 1 : 0; + } + + /* All base ranks receive from root */ + if (is_base) { + err = bcast_intra[lin_1](buff, count, datatype, subc_roots[MCA_COLL_ACOLL_L3_L], + subcomms[MCA_COLL_ACOLL_L3_L], preq, &nreqs, world_rank); + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + return err; + } + } + + /* Start and wait on all requests. */ + if (nreqs > 0) { + err = ompi_request_wait_all(nreqs, reqs, MPI_STATUSES_IGNORE); + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + } + } + + /* If single stage, return */ + if (no_sg) { + ompi_coll_base_free_reqs(reqs, nreqs); + return err; + } + + subgrp_size = use_numa ? ompi_comm_size(subc->numa_comm) : subc->subgrp_size; + /* All leaf ranks receive from the respective base rank */ + if ((subgrp_size > 1) && !no_sg) { + err = bcast_intra[lin_2](buff, count, datatype, subc_roots[MCA_COLL_ACOLL_LEAF], + subcomms[MCA_COLL_ACOLL_LEAF], preq, &nreqs, world_rank); + } + + /* Start and wait on all requests. */ + if (nreqs > 0) { + err = ompi_request_wait_all(nreqs, reqs, MPI_STATUSES_IGNORE); + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + } + } + + /* All done */ + ompi_coll_base_free_reqs(reqs, nreqs); + return err; +} + +/* + * mca_coll_acoll_bcast + * + * Function: Broadcast operation using subgroup based algorithm + * Accepts: Same arguments as MPI_Bcast() + * Returns: MPI_SUCCESS or error code + * + * Description: Broadcast is performed across and within subgroups. + * O(N) or O(log(N)) algorithm within sunbgroup based on count. + * Subgroups can be 1 or more based on size and count. + * + * Limitations: None + * + * Memory: No additional memory requirements beyond user-supplied buffers. + * + */ +int mca_coll_acoll_bcast(void *buff, int count, struct ompi_datatype_t *datatype, int root, + struct ompi_communicator_t *comm, mca_coll_base_module_t *module) +{ + int size; + int rank; + int err; + int nreqs; + ompi_request_t **preq, **reqs; + int sg_cnt, node_size; + int num_nodes; + int use_0 = 0; + int lin_0 = 0, lin_1 = 0, lin_2 = 0; + int use_numa = 0; + int no_sg; + size_t total_dsize, dsize; + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + bcast_subc_func bcast_func[2] = {&bcast_binomial, &bcast_flat_tree}; + coll_acoll_subcomms_t *subc; + struct ompi_communicator_t *subcomms[MCA_COLL_ACOLL_NUM_SC] = {NULL}; + int subc_roots[MCA_COLL_ACOLL_NUM_SC] = {-1}; + int cid = ompi_comm_get_local_cid(comm); + + /* Fallback to knomial if cid is beyond supported limit */ + if (cid >= MCA_COLL_ACOLL_MAX_CID) { + return ompi_coll_base_bcast_intra_knomial(buff, count, datatype, root, comm, module, 0, 4); + } + + subc = &acoll_module->subc[cid]; + /* Fallback to knomial if no. of root changes is beyond a threshold */ + if (subc->num_root_change > MCA_COLL_ACOLL_ROOT_CHANGE_THRESH) { + return ompi_coll_base_bcast_intra_knomial(buff, count, datatype, root, comm, module, 0, 4); + } + size = ompi_comm_size(comm); + if ((!subc->initialized || (root != subc->prev_init_root)) && size > 2) { + err = mca_coll_acoll_comm_split_init(comm, acoll_module, root); + if (MPI_SUCCESS != err) { + return err; + } + } + + ompi_datatype_type_size(datatype, &dsize); + total_dsize = dsize * (unsigned long) count; + rank = ompi_comm_rank(comm); + sg_cnt = acoll_module->sg_cnt; + if (size > 2) { + num_nodes = subc->num_nodes; + node_size = ompi_comm_size(subc->local_comm); + } else { + num_nodes = 1; + node_size = size; + } + + /* Use knomial for nodes 8 and above and non-large messages */ + if ((num_nodes >= 8 && total_dsize <= 65536) + || (num_nodes == 1 && size >= 256 && total_dsize < 16384)) { + return ompi_coll_base_bcast_intra_knomial(buff, count, datatype, root, comm, module, 0, 4); + } + + /* Determine the algorithm to be used based on size and count */ + /* sg_cnt determines subgroup based communication */ + /* lin_1 and lin_2 indicate whether to use linear or log based + sends/receives across and within subgroups respectively. */ + coll_bcast_decision_fixed(size, total_dsize, node_size, &sg_cnt, &use_0, &use_numa, &lin_0, + &lin_1, &lin_2, acoll_module, subc); + no_sg = (sg_cnt == node_size) ? 1 : 0; + if (size <= 2) + no_sg = 1; + + coll_acoll_bcast_subcomms(comm, subc, subcomms, subc_roots, root, num_nodes, use_0, no_sg, + use_numa); + + reqs = ompi_coll_base_comm_get_reqs(module->base_data, size); + if (NULL == reqs) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + nreqs = 0; + preq = reqs; + err = MPI_SUCCESS; + + if (use_0) { + if (subc_roots[MCA_COLL_ACOLL_NODE_L] != -1) { + err = bcast_func[lin_0](buff, count, datatype, subc_roots[MCA_COLL_ACOLL_NODE_L], + subcomms[MCA_COLL_ACOLL_NODE_L], preq, &nreqs, rank); + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + return err; + } + } + } + + /* Start and wait on all requests. */ + if (nreqs > 0) { + err = ompi_request_wait_all(nreqs, reqs, MPI_STATUSES_IGNORE); + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + return err; + } + } + + err = mca_coll_acoll_bcast_intra_node(buff, count, datatype, module, subc, subcomms, subc_roots, + lin_1, lin_2, no_sg, use_numa, rank); + + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + return err; + } + + /* All done */ + ompi_coll_base_free_reqs(reqs, nreqs); + return err; +} diff --git a/ompi/mca/coll/acoll/coll_acoll_component.c b/ompi/mca/coll/acoll/coll_acoll_component.c new file mode 100644 index 00000000000..344af1e31cc --- /dev/null +++ b/ompi/mca/coll/acoll/coll_acoll_component.c @@ -0,0 +1,346 @@ +/* -*- Mode: C; c-acoll-offset:4 ; indent-tabs-mode:nil -*- */ +/* + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + * + */ + +#include "ompi_config.h" + +#include "mpi.h" +#include "ompi/mca/coll/coll.h" +#include "coll_acoll.h" + +/* + * Public string showing the coll ompi_acoll component version number + */ +const char *mca_coll_acoll_component_version_string + = "Open MPI acoll collective MCA component version " OMPI_VERSION; + +/* + * Global variables + */ +int mca_coll_acoll_priority = 40; +int mca_coll_acoll_sg_size = 8; +int mca_coll_acoll_sg_scale = 1; +int mca_coll_acoll_node_size = 128; +int mca_coll_acoll_use_dynamic_rules = 0; +int mca_coll_acoll_mnode_enable = 1; +int mca_coll_acoll_bcast_lin0 = 0; +int mca_coll_acoll_bcast_lin1 = 0; +int mca_coll_acoll_bcast_lin2 = 0; +int mca_coll_acoll_bcast_nonsg = 0; +int mca_coll_acoll_allgather_lin = 0; +int mca_coll_acoll_allgather_ring_1 = 0; +int mca_coll_acoll_reserve_memory_for_algo = 0; +uint64_t mca_coll_acoll_reserve_memory_size_for_algo = 128 * 32768; // 4 MB +uint64_t mca_coll_acoll_xpmem_buffer_size = 128 * 32768; + +/* By default utilize xpmem based algorithms applicable when built with xpmem. */ +int mca_coll_acoll_without_xpmem = 0; +int mca_coll_acoll_xpmem_use_sr_buf = 1; + +/* + * Local function + */ +static int acoll_register(void); + +/* + * Instantiate the public struct with all of our public information + * and pointers to our public functions in it + */ + +const mca_coll_base_component_2_4_0_t mca_coll_acoll_component = { + + /* First, the mca_component_t struct containing meta information + * about the component itself */ + + .collm_version = { + MCA_COLL_BASE_VERSION_2_4_0, + + /* Component name and version */ + .mca_component_name = "acoll", + MCA_BASE_MAKE_VERSION(component, OMPI_MAJOR_VERSION, OMPI_MINOR_VERSION, + OMPI_RELEASE_VERSION), + + /* Component open and close functions */ + .mca_register_component_params = acoll_register, + }, + .collm_data = { + /* The component is checkpoint ready */ + MCA_BASE_METADATA_PARAM_CHECKPOINT + }, + + /* Initialization / querying functions */ + + .collm_init_query = mca_coll_acoll_init_query, + .collm_comm_query = mca_coll_acoll_comm_query, +}; + +static int acoll_register(void) +{ + /* Use a low priority, but allow other components to be lower */ + mca_coll_acoll_priority = 40; + (void) mca_base_component_var_register(&mca_coll_acoll_component.collm_version, "priority", + "Priority of the acoll coll component", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_acoll_priority); + + /* Defaults on topology */ + (void) + mca_base_component_var_register(&mca_coll_acoll_component.collm_version, "sg_size", + "Size of subgroup to be used for subgroup based algorithms", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_acoll_sg_size); + + (void) mca_base_component_var_register( + &mca_coll_acoll_component.collm_version, "sg_scale", + "Scale factor for effective subgroup size for subgroup based algorithms", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, MCA_BASE_VAR_SCOPE_READONLY, + &mca_coll_acoll_sg_scale); + (void) mca_base_component_var_register(&mca_coll_acoll_component.collm_version, "node_size", + "Size of node for multinode cases", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_acoll_node_size); + (void) + mca_base_component_var_register(&mca_coll_acoll_component.collm_version, + "use_dynamic_rules", + "Use dynamic selection of algorithms for multinode cases", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_READONLY, + &mca_coll_acoll_use_dynamic_rules); + (void) mca_base_component_var_register(&mca_coll_acoll_component.collm_version, "mnode_enable", + "Enable separate algorithm for multinode cases", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_READONLY, + &mca_coll_acoll_mnode_enable); + (void) mca_base_component_var_register(&mca_coll_acoll_component.collm_version, "bcast_lin0", + "Use lin/log for stage 0 of multinode algorithm", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_acoll_bcast_lin0); + (void) mca_base_component_var_register(&mca_coll_acoll_component.collm_version, "bcast_lin1", + "Use lin/log for stage 1 of multinode algorithm", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_acoll_bcast_lin1); + (void) mca_base_component_var_register(&mca_coll_acoll_component.collm_version, "bcast_lin2", + "Use lin/log for stage 2 of multinode algorithm", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_acoll_bcast_lin2); + (void) mca_base_component_var_register( + &mca_coll_acoll_component.collm_version, "bcast_nonsg", + "Flag to turn on/off subgroup based algorithms for multinode", MCA_BASE_VAR_TYPE_INT, NULL, + 0, 0, OPAL_INFO_LVL_9, MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_acoll_bcast_nonsg); + (void) mca_base_component_var_register(&mca_coll_acoll_component.collm_version, "allgather_lin", + "Flag to indicate use of linear allgather for multinode", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_READONLY, + &mca_coll_acoll_allgather_lin); + (void) + mca_base_component_var_register(&mca_coll_acoll_component.collm_version, "allgather_ring_1", + "Flag to indicate use of ring/rd allgather for multinode", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_READONLY, + &mca_coll_acoll_allgather_ring_1); + (void) mca_base_component_var_register( + &mca_coll_acoll_component.collm_version, "reserve_memory_for_algo", + "Flag to inform the acoll component to reserve/pre-allocate memory" + " for use inside collective algorithms.", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, MCA_BASE_VAR_SCOPE_READONLY, + &mca_coll_acoll_reserve_memory_for_algo); + (void) mca_base_component_var_register( + &mca_coll_acoll_component.collm_version, "reserve_memory_size_for_algo", + "Size of memory to be allocated by acoll component to use as reserve" + "memory inside collective algorithms.", + MCA_BASE_VAR_TYPE_UINT64_T, NULL, 0, 0, OPAL_INFO_LVL_9, MCA_BASE_VAR_SCOPE_READONLY, + &mca_coll_acoll_reserve_memory_size_for_algo); + (void) mca_base_component_var_register( + &mca_coll_acoll_component.collm_version, "without_xpmem", + "By default, xpmem-based algorithms are used when applicable. " + "When this flag is set to 1, xpmem-based algorithms are disabled.", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, MCA_BASE_VAR_SCOPE_READONLY, + &mca_coll_acoll_without_xpmem); + (void) mca_base_component_var_register( + &mca_coll_acoll_component.collm_version, "xpmem_buffer_size", + "Maximum size of memory that can be used for temporary buffers for " + "xpmem-based algorithms. By default these buffers are not created or " + "used unless xpmem_use_sr_buf is set to 0.", + MCA_BASE_VAR_TYPE_UINT64_T, NULL, 0, 0, OPAL_INFO_LVL_9, MCA_BASE_VAR_SCOPE_READONLY, + &mca_coll_acoll_xpmem_buffer_size); + (void) mca_base_component_var_register( + &mca_coll_acoll_component.collm_version, "xpmem_use_sr_buf", + "Uses application provided send/recv buffers during xpmem registration " + "when set to 1 instead of temporary buffers. The send/recv buffers are " + "assumed to persist for the duration of the application.", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, MCA_BASE_VAR_SCOPE_READONLY, + &mca_coll_acoll_xpmem_use_sr_buf); + + return OMPI_SUCCESS; +} + +/* + * Module constructor + */ +static void mca_coll_acoll_module_construct(mca_coll_acoll_module_t *module) +{ + for (int i = 0; i < MCA_COLL_ACOLL_MAX_CID; i++) { + coll_acoll_subcomms_t *subc = &module->subc[i]; + subc->initialized = 0; + subc->is_root_node = 0; + subc->is_root_sg = 0; + subc->is_root_numa = 0; + subc->outer_grp_root = -1; + subc->subgrp_root = 0; + subc->num_nodes = 1; + subc->prev_init_root = -1; + subc->num_root_change = 0; + subc->numa_root = 0; + subc->socket_ldr_root = -1; + subc->local_comm = NULL; + subc->local_r_comm = NULL; + subc->leader_comm = NULL; + subc->subgrp_comm = NULL; + subc->socket_comm = NULL; + subc->socket_ldr_comm = NULL; + for (int j = 0; j < MCA_COLL_ACOLL_NUM_LAYERS; j++) { + for (int k = 0; k < MCA_COLL_ACOLL_NUM_BASE_LYRS; k++) { + subc->base_comm[k][j] = NULL; + subc->base_root[k][j] = -1; + } + subc->local_root[j] = 0; + } + + subc->numa_comm = NULL; + subc->numa_comm_ldrs = NULL; + subc->node_comm = NULL; + subc->inter_comm = NULL; + subc->cid = -1; + subc->initialized_data = false; + subc->initialized_shm_data = false; + subc->data = NULL; +#ifdef HAVE_XPMEM_H + subc->xpmem_buf_size = mca_coll_acoll_xpmem_buffer_size; + subc->without_xpmem = mca_coll_acoll_without_xpmem; + subc->xpmem_use_sr_buf = mca_coll_acoll_xpmem_use_sr_buf; +#endif + } + + /* Reserve memory init. Lazy allocation of memory when needed. */ + (module->reserve_mem_s).reserve_mem = NULL; + (module->reserve_mem_s).reserve_mem_size = 0; + (module->reserve_mem_s).reserve_mem_allocate = false; + (module->reserve_mem_s).reserve_mem_in_use = false; + if ((0 != mca_coll_acoll_reserve_memory_for_algo) + && (0 < mca_coll_acoll_reserve_memory_size_for_algo) + && (false == ompi_mpi_thread_multiple)) { + (module->reserve_mem_s).reserve_mem_allocate = true; + (module->reserve_mem_s).reserve_mem_size = mca_coll_acoll_reserve_memory_size_for_algo; + } +} + +/* + * Module destructor + */ +static void mca_coll_acoll_module_destruct(mca_coll_acoll_module_t *module) +{ + + for (int i = 0; i < MCA_COLL_ACOLL_MAX_CID; i++) { + coll_acoll_subcomms_t *subc = &module->subc[i]; + if (subc->initialized_data) { + if (subc->initialized_shm_data) { + if (subc->orig_comm != NULL) { + opal_shmem_unlink( + &((subc->data)->allshmseg_id[ompi_comm_rank(subc->orig_comm)])); + opal_shmem_segment_detach( + &((subc->data)->allshmseg_id[ompi_comm_rank(subc->orig_comm)])); + } + } + coll_acoll_data_t *data = subc->data; + if (NULL != data) { +#ifdef HAVE_XPMEM_H + for (int j = 0; j < data->comm_size; j++) { + xpmem_release(data->all_apid[j]); + xpmem_remove(data->allseg_id[j]); + mca_rcache_base_module_destroy(data->rcache[j]); + } + + free(data->allseg_id); + data->allseg_id = NULL; + free(data->all_apid); + data->all_apid = NULL; + free(data->allshm_sbuf); + data->allshm_sbuf = NULL; + free(data->allshm_rbuf); + data->allshm_rbuf = NULL; + free(data->xpmem_saddr); + data->xpmem_saddr = NULL; + free(data->xpmem_raddr); + data->xpmem_raddr = NULL; + free(data->scratch); + data->scratch = NULL; + free(data->rcache); + data->rcache = NULL; +#endif + free(data->allshmseg_id); + data->allshmseg_id = NULL; + free(data->allshmmmap_sbuf); + data->allshmmmap_sbuf = NULL; + free(data->l1_gp); + data->l1_gp = NULL; + free(data->l2_gp); + data->l2_gp = NULL; + free(data); + data = NULL; + } + } + + if (subc->local_comm != NULL) { + ompi_comm_free(&(subc->local_comm)); + subc->local_comm = NULL; + } + + if (subc->local_r_comm != NULL) { + ompi_comm_free(&(subc->local_r_comm)); + subc->local_r_comm = NULL; + } + + if (subc->leader_comm != NULL) { + ompi_comm_free(&(subc->leader_comm)); + subc->leader_comm = NULL; + } + + if (subc->subgrp_comm != NULL) { + ompi_comm_free(&(subc->subgrp_comm)); + subc->subgrp_comm = NULL; + } + if (subc->socket_comm != NULL) { + ompi_comm_free(&(subc->socket_comm)); + subc->socket_comm = NULL; + } + + if (subc->socket_ldr_comm != NULL) { + ompi_comm_free(&(subc->socket_ldr_comm)); + subc->socket_ldr_comm = NULL; + } + for (int k = 0; k < MCA_COLL_ACOLL_NUM_BASE_LYRS; k++) { + for (int j = 0; j < MCA_COLL_ACOLL_NUM_LAYERS; j++) { + if (subc->base_comm[k][j] != NULL) { + ompi_comm_free(&(subc->base_comm[k][j])); + subc->base_comm[k][j] = NULL; + } + } + } + subc->initialized = 0; + } + + if ((true == (module->reserve_mem_s).reserve_mem_allocate) + && (NULL != (module->reserve_mem_s).reserve_mem)) { + free((module->reserve_mem_s).reserve_mem); + } +} + +OBJ_CLASS_INSTANCE(mca_coll_acoll_module_t, mca_coll_base_module_t, mca_coll_acoll_module_construct, + mca_coll_acoll_module_destruct); diff --git a/ompi/mca/coll/acoll/coll_acoll_gather.c b/ompi/mca/coll/acoll/coll_acoll_gather.c new file mode 100644 index 00000000000..91e1e6ff51f --- /dev/null +++ b/ompi/mca/coll/acoll/coll_acoll_gather.c @@ -0,0 +1,217 @@ +/* -*- Mode: C; indent-tabs-mode:nil -*- */ +/* + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" + +#include "mpi.h" +#include "ompi/constants.h" +#include "ompi/datatype/ompi_datatype.h" +#include "ompi/mca/coll/base/coll_base_functions.h" +#include "ompi/mca/coll/base/coll_tags.h" +#include "ompi/mca/coll/coll.h" +#include "ompi/mca/pml/pml.h" +#include "opal/util/bit_ops.h" +#include "coll_acoll.h" +#include "coll_acoll_utils.h" + +/* + * mca_coll_acoll_gather_intra + * + * Function: Gather operation using subgroup based algorithm + * Accepts: Same arguments as MPI_Gather() + * Returns: MPI_SUCCESS or error code + * + * Description: Gather is performed across and within subgroups. + * Subgroups can be 1 or more based on size and count. + * + * Limitations: Current implementation is optimal only for map-by core. + * + * Memory: The base rank of each subgroup may create temporary buffer. + * + */ +int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, + void *rbuf, int rcount, struct ompi_datatype_t *rdtype, int root, + struct ompi_communicator_t *comm, mca_coll_base_module_t *module) +{ + int i, err, rank, size; + char *wkg = NULL, *workbuf = NULL; + MPI_Status status; + MPI_Aint sextent, sgap = 0, ssize; + MPI_Aint rextent = 0; + int total_recv = 0; + int sg_cnt, node_cnt; + int cur_sg, root_sg; + int cur_node, root_node; + int is_base, is_local_root; + int startr, endr, inc; + int startn, endn; + int num_nodes; + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + coll_acoll_reserve_mem_t *reserve_mem_gather = &(acoll_module->reserve_mem_s); + + size = ompi_comm_size(comm); + rank = ompi_comm_rank(comm); + + sg_cnt = acoll_module->sg_cnt; + node_cnt = acoll_module->node_cnt; + num_nodes = (size + node_cnt - 1) / node_cnt; + /* For small messages for nodes 8 and above, fall back to normal */ + if (num_nodes >= 8 && (rcount < 262144)) { + node_cnt = size; + sg_cnt = size; + num_nodes = 1; + } + + /* Setup root for receive */ + if (rank == root) { + ompi_datatype_type_extent(rdtype, &rextent); + /* Just use the recv buffer */ + wkg = (char *) rbuf; + if (sbuf != MPI_IN_PLACE) { + MPI_Aint root_ofst = rextent * (ptrdiff_t) (rcount * root); + err = ompi_datatype_sndrcv((void *) sbuf, scount, sdtype, wkg + (ptrdiff_t) root_ofst, + rcount, rdtype); + if (MPI_SUCCESS != err) { + return err; + } + } + total_recv = rcount; + } + + /* Setup base ranks of non-root subgroups for receive */ + cur_sg = rank / sg_cnt; + root_sg = root / sg_cnt; + is_base = (rank % sg_cnt == 0) && (cur_sg != root_sg); + startr = (rank / sg_cnt) * sg_cnt; + cur_node = rank / node_cnt; + root_node = root / node_cnt; + is_local_root = (rank % node_cnt == 0) && (cur_node != root_node); + startn = (rank / node_cnt) * node_cnt; + + if (is_base) { + int64_t buf_size = is_local_root ? (int64_t) scount * node_cnt : (int64_t) scount * sg_cnt; + ompi_datatype_type_extent(sdtype, &sextent); + ssize = opal_datatype_span(&sdtype->super, buf_size, &sgap); + if (cur_sg != root_sg) { + char *tmprecv = NULL; + workbuf = (char *) coll_acoll_buf_alloc(reserve_mem_gather, ssize + sgap); + if (NULL == workbuf) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + wkg = workbuf - sgap; + tmprecv = wkg + sextent * (ptrdiff_t) (rcount * (rank - startr)); + /* local copy to workbuf */ + err = ompi_datatype_sndrcv((void *) sbuf, scount, sdtype, tmprecv, scount, sdtype); + if (MPI_SUCCESS != err) { + + return err; + } + } + rdtype = sdtype; + rcount = scount; + rextent = sextent; + total_recv = rcount; + } else if (rank != root) { + wkg = (char *) sbuf; + total_recv = scount; + } + + /* All base ranks receive from other ranks in their respective subgroup */ + endr = startr + sg_cnt; + if (endr > size) { + endr = size; + } + inc = (rank == root) ? ((root != 0) ? 0 : 1) : 1; + if (is_base || (rank == root)) { + for (i = startr + inc; i < endr; i++) { + char *tmprecv = NULL; + if (i == root) { + continue; + } + if (rank == root) { + tmprecv = wkg + rextent * (ptrdiff_t) (rcount * i); + } else { + tmprecv = wkg + rextent * (ptrdiff_t) (rcount * (i - startr)); + } + err = MCA_PML_CALL( + recv(tmprecv, rcount, rdtype, i, MCA_COLL_BASE_TAG_GATHER, comm, &status)); + total_recv += rcount; + } + } else { + int peer = (cur_sg == root_sg) ? root : startr; + err = MCA_PML_CALL(send(sbuf, scount, sdtype, peer, MCA_COLL_BASE_TAG_GATHER, + MCA_PML_BASE_SEND_STANDARD, comm)); + return err; + } + + /* All base ranks send to local root */ + endn = startn + node_cnt; + if (endn > size) { + endn = size; + } + if (sg_cnt < size) { + int local_root = (root_node == cur_node) ? root : startn; + for (i = startn; i < endn; i += sg_cnt) { + int i_sg = i / sg_cnt; + if ((rank != local_root) && (rank == i) && is_base) { + err = MCA_PML_CALL(send(workbuf - sgap, total_recv, sdtype, local_root, + MCA_COLL_BASE_TAG_GATHER, MCA_PML_BASE_SEND_STANDARD, + comm)); + } + if ((rank == local_root) && (rank != i) && (i_sg != root_sg)) { + int recv_amt = (i + sg_cnt > size) ? rcount * (size - i) : rcount * sg_cnt; + MPI_Aint rcv_ofst = rextent * (ptrdiff_t) (rcount * (i - startn)); + + err = MCA_PML_CALL(recv(wkg + (ptrdiff_t) rcv_ofst, recv_amt, rdtype, i, + MCA_COLL_BASE_TAG_GATHER, comm, &status)); + total_recv += recv_amt; + } + if (MPI_SUCCESS != err) { + if (NULL != workbuf) { + coll_acoll_buf_free(reserve_mem_gather, workbuf); + } + return err; + } + } + } + + /* All local roots ranks send to root */ + if (node_cnt < size && num_nodes > 1) { + for (i = 0; i < size; i += node_cnt) { + int i_node = i / node_cnt; + if ((rank != root) && (rank == i) && is_base) { + err = MCA_PML_CALL(send(workbuf - sgap, total_recv, sdtype, root, + MCA_COLL_BASE_TAG_GATHER, MCA_PML_BASE_SEND_STANDARD, + comm)); + } + if ((rank == root) && (rank != i) && (i_node != root_node)) { + int recv_amt = (i + node_cnt > size) ? rcount * (size - i) : rcount * node_cnt; + MPI_Aint rcv_ofst = rextent * (ptrdiff_t) (rcount * i); + + err = MCA_PML_CALL(recv((char *) rbuf + (ptrdiff_t) rcv_ofst, recv_amt, rdtype, i, + MCA_COLL_BASE_TAG_GATHER, comm, &status)); + total_recv += recv_amt; + } + if (MPI_SUCCESS != err) { + if (NULL != workbuf) { + coll_acoll_buf_free(reserve_mem_gather, workbuf); + } + return err; + } + } + } + + if (NULL != workbuf) { + coll_acoll_buf_free(reserve_mem_gather, workbuf); + } + + /* All done */ + return MPI_SUCCESS; +} diff --git a/ompi/mca/coll/acoll/coll_acoll_module.c b/ompi/mca/coll/acoll/coll_acoll_module.c new file mode 100644 index 00000000000..b3b2afddc8b --- /dev/null +++ b/ompi/mca/coll/acoll/coll_acoll_module.c @@ -0,0 +1,200 @@ +/* + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" + +#include + +#include "mpi.h" +#include "ompi/mca/coll/base/base.h" +#include "ompi/mca/coll/coll.h" +#include "coll_acoll.h" + + +static int acoll_module_enable(mca_coll_base_module_t *module, struct ompi_communicator_t *comm); +static int acoll_module_disable(mca_coll_base_module_t *module, struct ompi_communicator_t *comm); + +/* + * Initial query function that is invoked during MPI_INIT, allowing + * this component to disqualify itself if it doesn't support the + * required level of thread support. + */ +int mca_coll_acoll_init_query(bool enable_progress_threads, bool enable_mpi_threads) +{ + /* Nothing to do */ + return OMPI_SUCCESS; +} + + + +#define ACOLL_INSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__module->super.coll_##__api) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "acoll"); \ + } \ + } while (0) + +#define ACOLL_UNINSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__comm->c_coll->coll_##__api##_module == &__module->super) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, NULL, NULL, "acoll"); \ + } \ + } while (0) + +/* + * Invoked when there's a new communicator that has been created. + * Look at the communicator and decide which set of functions and + * priority we want to return. + */ +mca_coll_base_module_t *mca_coll_acoll_comm_query(struct ompi_communicator_t *comm, int *priority) +{ + mca_coll_acoll_module_t *acoll_module; + + acoll_module = OBJ_NEW(mca_coll_acoll_module_t); + if (NULL == acoll_module) { + return NULL; + } + + if (OMPI_COMM_IS_INTER(comm)) { + *priority = 0; + return NULL; + } + if (OMPI_COMM_IS_INTRA(comm) && ompi_comm_size(comm) < 2) { + *priority = 0; + return NULL; + } + + *priority = mca_coll_acoll_priority; + + /* Set topology params */ + acoll_module->sg_scale = mca_coll_acoll_sg_scale; + acoll_module->sg_size = mca_coll_acoll_sg_size; + acoll_module->sg_cnt = mca_coll_acoll_sg_size / mca_coll_acoll_sg_scale; + acoll_module->node_cnt = mca_coll_acoll_node_size; + if (mca_coll_acoll_sg_size == MCA_COLL_ACOLL_SG_SIZE_1) { + assert((acoll_module->sg_cnt == 1) || (acoll_module->sg_cnt == 2) + || (acoll_module->sg_cnt == 4) || (acoll_module->sg_cnt == 8)); + } + if (mca_coll_acoll_sg_size == MCA_COLL_ACOLL_SG_SIZE_2) { + assert((acoll_module->sg_cnt == 1) || (acoll_module->sg_cnt == 2) + || (acoll_module->sg_cnt == 4) || (acoll_module->sg_cnt == 8) + || (acoll_module->sg_cnt == 16)); + } + + switch (acoll_module->sg_cnt) { + case 1: + acoll_module->log2_sg_cnt = 0; + break; + case 2: + acoll_module->log2_sg_cnt = 1; + break; + case 4: + acoll_module->log2_sg_cnt = 2; + break; + case 8: + acoll_module->log2_sg_cnt = 3; + break; + case 16: + acoll_module->log2_sg_cnt = 4; + break; + default: + assert(0); + break; + } + + switch (acoll_module->node_cnt) { + case 96: + case 128: + acoll_module->log2_node_cnt = 7; + break; + case 192: + acoll_module->log2_node_cnt = 8; + break; + case 64: + acoll_module->log2_node_cnt = 6; + break; + case 32: + acoll_module->log2_node_cnt = 5; + break; + default: + assert(0); + break; + } + + acoll_module->use_dyn_rules = mca_coll_acoll_use_dynamic_rules; + acoll_module->use_mnode = mca_coll_acoll_mnode_enable; + acoll_module->use_lin0 = mca_coll_acoll_bcast_lin0; + acoll_module->use_lin1 = mca_coll_acoll_bcast_lin1; + acoll_module->use_lin2 = mca_coll_acoll_bcast_lin2; + if (mca_coll_acoll_bcast_nonsg) { + acoll_module->mnode_sg_size = acoll_module->node_cnt; + acoll_module->mnode_log2_sg_size = acoll_module->log2_node_cnt; + } else { + acoll_module->mnode_sg_size = acoll_module->sg_cnt; + acoll_module->mnode_log2_sg_size = acoll_module->log2_sg_cnt; + } + acoll_module->allg_lin = mca_coll_acoll_allgather_lin; + acoll_module->allg_ring = mca_coll_acoll_allgather_ring_1; + + /* Choose whether to use [intra|inter], and [subgroup|normal]-based + * algorithms. */ + acoll_module->super.coll_module_enable = acoll_module_enable; + acoll_module->super.coll_module_disable = acoll_module_disable; + + acoll_module->super.coll_allgather = mca_coll_acoll_allgather; + acoll_module->super.coll_allreduce = mca_coll_acoll_allreduce_intra; + acoll_module->super.coll_barrier = mca_coll_acoll_barrier_intra; + acoll_module->super.coll_bcast = mca_coll_acoll_bcast; + acoll_module->super.coll_gather = mca_coll_acoll_gather_intra; + acoll_module->super.coll_reduce = mca_coll_acoll_reduce_intra; + + return &(acoll_module->super); +} + +/* + * Init module on the communicator + */ +static int acoll_module_enable(mca_coll_base_module_t *module, struct ompi_communicator_t *comm) +{ + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + + /* prepare the placeholder for the array of request* */ + module->base_data = OBJ_NEW(mca_coll_base_comm_t); + if (NULL == module->base_data) { + return OMPI_ERROR; + } + + ACOLL_INSTALL_COLL_API(comm, acoll_module, allgather); + ACOLL_INSTALL_COLL_API(comm, acoll_module, allreduce); + ACOLL_INSTALL_COLL_API(comm, acoll_module, barrier); + ACOLL_INSTALL_COLL_API(comm, acoll_module, bcast); + ACOLL_INSTALL_COLL_API(comm, acoll_module, gather); + ACOLL_INSTALL_COLL_API(comm, acoll_module, reduce); + + /* All done */ + return OMPI_SUCCESS; +} + +static int acoll_module_disable(mca_coll_base_module_t *module, struct ompi_communicator_t *comm) +{ + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + + ACOLL_UNINSTALL_COLL_API(comm, acoll_module, allgather); + ACOLL_UNINSTALL_COLL_API(comm, acoll_module, allreduce); + ACOLL_UNINSTALL_COLL_API(comm, acoll_module, barrier); + ACOLL_UNINSTALL_COLL_API(comm, acoll_module, bcast); + ACOLL_UNINSTALL_COLL_API(comm, acoll_module, gather); + ACOLL_UNINSTALL_COLL_API(comm, acoll_module, reduce); + + return OMPI_SUCCESS; +} diff --git a/ompi/mca/coll/acoll/coll_acoll_reduce.c b/ompi/mca/coll/acoll/coll_acoll_reduce.c new file mode 100644 index 00000000000..bda55f41123 --- /dev/null +++ b/ompi/mca/coll/acoll/coll_acoll_reduce.c @@ -0,0 +1,397 @@ +/* -*- Mode: C; indent-tabs-mode:nil -*- */ +/* + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" + +#include "mpi.h" +#include "ompi/constants.h" +#include "ompi/datatype/ompi_datatype.h" +#include "ompi/mca/coll/base/coll_tags.h" +#include "ompi/mca/coll/coll.h" +#include "ompi/mca/pml/pml.h" +#include "ompi/op/op.h" +#include "opal/util/bit_ops.h" +#include "coll_acoll.h" +#include "coll_acoll_utils.h" + +static inline int coll_reduce_decision_fixed(int comm_size, size_t msg_size) +{ + /* Set default to topology aware algorithm */ + int alg = 0; + if (comm_size <= 8) { + /* Linear */ + alg = 1; + } else if (msg_size <= 8192) { + alg = 0; + } else if (msg_size <= 262144) { + /* Binomial */ + alg = 2; + } else if (msg_size <= 8388608 && comm_size < 64) { + alg = 1; + } else if (msg_size <= 8388608 && comm_size <= 128) { + /* In order binary */ + alg = 3; + } else { + alg = 2; + } + return alg; +} + +static inline int coll_acoll_reduce_topo(const void *sbuf, void *rbuf, int count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + int root, struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + int ret = MPI_SUCCESS, rank, sz; + int cid = ompi_comm_get_local_cid(comm); + + ptrdiff_t dsize, gap = 0; + char *free_buffer = NULL; + char *pml_buffer = NULL; + char *tmp_rbuf = NULL; + char *tmp_sbuf = NULL; + + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + coll_acoll_subcomms_t *subc = &acoll_module->subc[cid]; + coll_acoll_reserve_mem_t *reserve_mem_rbuf_reduce = &(acoll_module->reserve_mem_s); + + rank = ompi_comm_rank(comm); + + tmp_sbuf = (char *) sbuf; + if ((sbuf == MPI_IN_PLACE) && (rank == root)) { + tmp_sbuf = (char *) rbuf; + } + + int i; + int ind1 = MCA_COLL_ACOLL_L3CACHE; + int ind2 = MCA_COLL_ACOLL_LYR_NODE; + int is_base = rank == subc->base_rank[ind1] ? 1 : 0; + int bound = subc->subgrp_size; + + sz = ompi_comm_size(subc->base_comm[ind1][ind2]); + dsize = opal_datatype_span(&dtype->super, count, &gap); + if (rank == root) { + tmp_rbuf = rbuf; + } else if (is_base) { + tmp_rbuf = (char *) coll_acoll_buf_alloc(reserve_mem_rbuf_reduce, dsize); + if (NULL == tmp_rbuf) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + } + + if (is_base) { + ret = ompi_datatype_copy_content_same_ddt(dtype, count, (char *) tmp_rbuf, + (char *) tmp_sbuf); + free_buffer = (char *) malloc(dsize); + if (NULL == free_buffer) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + pml_buffer = free_buffer - gap; + } + + /* if not a local root, send the message to the local root */ + if (!is_base) { + ret = MCA_PML_CALL(send(tmp_sbuf, count, dtype, subc->subgrp_root, MCA_COLL_BASE_TAG_REDUCE, + MCA_PML_BASE_SEND_STANDARD, subc->subgrp_comm)); + } + + /* if local root, receive the message from other ranks within that group */ + if (is_base) { + for (i = 0; i < bound; i++) { + if (i == subc->subgrp_root) { + continue; + } + ret = MCA_PML_CALL(recv(pml_buffer, count, dtype, i, MCA_COLL_BASE_TAG_REDUCE, + subc->subgrp_comm, MPI_STATUS_IGNORE)); + ompi_op_reduce(op, pml_buffer, tmp_rbuf, count, dtype); + } + } + /* perform reduction at root */ + if (is_base && (sz > 1)) { + if (rank != root) { + ret = MCA_PML_CALL(send(tmp_rbuf, count, dtype, subc->base_root[ind1][ind2], + MCA_COLL_BASE_TAG_REDUCE, MCA_PML_BASE_SEND_STANDARD, + subc->base_comm[ind1][ind2])); + if (ret != MPI_SUCCESS) { + free(pml_buffer); + if (NULL != tmp_rbuf) { + coll_acoll_buf_free(reserve_mem_rbuf_reduce, tmp_rbuf); + } + return ret; + } + } + if (rank == root) { + for (i = 0; i < sz; i++) { + if (i == subc->base_root[ind1][ind2]) { + continue; + } + ret = MCA_PML_CALL(recv(pml_buffer, count, dtype, i, MCA_COLL_BASE_TAG_REDUCE, + subc->base_comm[ind1][ind2], MPI_STATUS_IGNORE)); + if (ret != MPI_SUCCESS) { + free(pml_buffer); + return ret; + } + ompi_op_reduce(op, pml_buffer, rbuf, count, dtype); + } + } + } + + /* if local root, reduce at root */ + if (is_base && (sz > 1)) { + free(pml_buffer); + if (rank != root && NULL != tmp_rbuf) { + coll_acoll_buf_free(reserve_mem_rbuf_reduce, tmp_rbuf); + } + } + + return ret; +} + +#ifdef HAVE_XPMEM_H +static inline int mca_coll_acoll_reduce_xpmem(const void *sbuf, void *rbuf, int count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + int root, struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + int size; + size_t total_dsize, dsize; + ptrdiff_t gap = 0; + + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + + coll_acoll_subcomms_t *subc; + int cid = ompi_comm_get_local_cid(comm); + subc = &acoll_module->subc[cid]; + coll_acoll_init(module, comm, subc->data); + coll_acoll_reserve_mem_t *reserve_mem_rbuf_reduce = NULL; + if (subc->xpmem_use_sr_buf != 0) { + reserve_mem_rbuf_reduce = &(acoll_module->reserve_mem_s); + } + coll_acoll_data_t *data = subc->data; + if (NULL == data) { + return -1; + } + + size = ompi_comm_size(comm); + int rank = ompi_comm_rank(comm); + ompi_datatype_type_size(dtype, &dsize); + total_dsize = opal_datatype_span(&dtype->super, count, &gap); + + int l1_gp_size = data->l1_gp_size; + int *l1_gp = data->l1_gp; + int *l2_gp = data->l2_gp; + int l2_gp_size = data->l2_gp_size; + + int l1_local_rank = data->l1_local_rank; + int l2_local_rank = data->l2_local_rank; + + char *tmp_sbuf = NULL; + char *tmp_rbuf = NULL; + + if (subc->xpmem_use_sr_buf == 0) { + tmp_rbuf = (char *) data->scratch; + tmp_sbuf = (char *) data->scratch + (subc->xpmem_buf_size) / 2; + if ((sbuf == MPI_IN_PLACE) && (rank == root)) { + memcpy(tmp_sbuf, rbuf, total_dsize); + } else { + memcpy(tmp_sbuf, sbuf, total_dsize); + } + } else { + tmp_sbuf = (char *) sbuf; + if ((sbuf == MPI_IN_PLACE) && (rank == root)) { + tmp_sbuf = (char *) rbuf; + } + + if (rank == root) { + tmp_rbuf = rbuf; + } else { + tmp_rbuf = (char *) coll_acoll_buf_alloc(reserve_mem_rbuf_reduce, total_dsize); + if (NULL == tmp_rbuf) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + } + } + void *sbuf_vaddr[1] = {tmp_sbuf}; + void *rbuf_vaddr[1] = {tmp_rbuf}; + + int ret; + + ret = comm->c_coll->coll_allgather(sbuf_vaddr, sizeof(void *), MPI_BYTE, data->allshm_sbuf, + sizeof(void *), MPI_BYTE, comm, + comm->c_coll->coll_allgather_module); + if (ret != MPI_SUCCESS) { + return ret; + } + ret = comm->c_coll->coll_allgather(rbuf_vaddr, sizeof(void *), MPI_BYTE, data->allshm_rbuf, + sizeof(void *), MPI_BYTE, comm, + comm->c_coll->coll_allgather_module); + + if (ret != MPI_SUCCESS) { + return ret; + } + + register_and_cache(size, total_dsize, rank, data); + + /* reduce to the group leader */ + int chunk = count / l1_gp_size; + int my_count_size = (l1_local_rank == (l1_gp_size - 1)) ? chunk + count % l1_gp_size : chunk; + + if (rank == l1_gp[0]) { + if (sbuf != MPI_IN_PLACE) + memcpy(tmp_rbuf, sbuf, my_count_size * dsize); + for (int i = 1; i < l1_gp_size; i++) { + ompi_op_reduce(op, (char *) data->xpmem_saddr[l1_gp[i]] + chunk * l1_local_rank * dsize, + (char *) tmp_rbuf + chunk * l1_local_rank * dsize, my_count_size, dtype); + } + } else { + ompi_3buff_op_reduce(op, + (char *) data->xpmem_saddr[l1_gp[0]] + chunk * l1_local_rank * dsize, + (char *) tmp_sbuf + chunk * l1_local_rank * dsize, + (char *) data->xpmem_raddr[l1_gp[0]] + chunk * l1_local_rank * dsize, + my_count_size, dtype); + for (int i = 1; i < l1_gp_size; i++) { + if (i == l1_local_rank) { + continue; + } + ompi_op_reduce(op, (char *) data->xpmem_saddr[l1_gp[i]] + chunk * l1_local_rank * dsize, + (char *) data->xpmem_raddr[l1_gp[0]] + chunk * l1_local_rank * dsize, + my_count_size, dtype); + } + } + ompi_coll_base_barrier_intra_tree(comm, module); + + /* perform reduce to 0 */ + int local_size = l2_gp_size; + if ((rank == l1_gp[0]) && (local_size > 1)) { + chunk = count / local_size; + my_count_size = (l2_local_rank == (local_size - 1)) ? chunk + (count % local_size) : chunk; + + if (l2_local_rank == 0) { + for (int i = 1; i < local_size; i++) { + ompi_op_reduce(op, (char *) data->xpmem_raddr[l2_gp[i]], (char *) tmp_rbuf, + my_count_size, dtype); + } + } else { + for (int i = 1; i < local_size; i++) { + if (i == l2_local_rank) { + continue; + } + ompi_op_reduce(op, + (char *) data->xpmem_raddr[l2_gp[i]] + chunk * l2_local_rank * dsize, + (char *) data->xpmem_raddr[0] + chunk * l2_local_rank * dsize, + my_count_size, dtype); + } + ompi_op_reduce(op, (char *) tmp_rbuf + chunk * l2_local_rank * dsize, + (char *) data->xpmem_raddr[0] + chunk * l2_local_rank * dsize, + my_count_size, dtype); + } + } + ompi_coll_base_barrier_intra_tree(comm, module); + if (subc->xpmem_use_sr_buf == 0) { + if (rank == root) { + memcpy(rbuf, tmp_rbuf, total_dsize); + } + } else { + if ((rank != root) && (subc->xpmem_use_sr_buf != 0)) { + coll_acoll_buf_free(reserve_mem_rbuf_reduce, tmp_rbuf); + } + } + + return MPI_SUCCESS; +} +#endif + +int mca_coll_acoll_reduce_intra(const void *sbuf, void *rbuf, int count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, int root, + struct ompi_communicator_t *comm, mca_coll_base_module_t *module) +{ + int size, alg; + int num_nodes, ret; + size_t total_dsize, dsize; + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + + size = ompi_comm_size(comm); + if (size < 4) + return ompi_coll_base_reduce_intra_basic_linear(sbuf, rbuf, count, dtype, op, root, comm, + module); + + /* Falling back to inorder binary for non-commutative operators to be safe */ + if (!ompi_op_is_commute(op)) { + return ompi_coll_base_reduce_intra_in_order_binary(sbuf, rbuf, count, dtype, op, root, comm, + module, 0, 0); + } + if (root != 0) { // ToDo: support non-zero root + return ompi_coll_base_reduce_intra_binomial(sbuf, rbuf, count, dtype, op, root, comm, + module, 0, 0); + } + + ompi_datatype_type_size(dtype, &dsize); + total_dsize = dsize * (unsigned long) count; + + alg = coll_reduce_decision_fixed(size, total_dsize); + + coll_acoll_subcomms_t *subc; + int cid = ompi_comm_get_local_cid(comm); + subc = &acoll_module->subc[cid]; + + /* Fallback to knomial if cid is beyond supported limit */ + if (cid >= MCA_COLL_ACOLL_MAX_CID) { + return ompi_coll_base_reduce_intra_binomial(sbuf, rbuf, count, dtype, op, root, comm, + module, 0, 0); + } + + subc = &acoll_module->subc[cid]; + if (!subc->initialized || (root != subc->prev_init_root)) { + ret = mca_coll_acoll_comm_split_init(comm, acoll_module, 0); + if (MPI_SUCCESS != ret) { + return ret; + } + } + + num_nodes = subc->num_nodes; + + if (num_nodes == 1) { + if (total_dsize < 262144) { + if (alg == -1 /* interaction with xpmem implementation causing issues 0*/) { + return coll_acoll_reduce_topo(sbuf, rbuf, count, dtype, op, root, comm, module); + } else if (alg == 1) { + return ompi_coll_base_reduce_intra_basic_linear(sbuf, rbuf, count, dtype, op, root, + comm, module); + } else if (alg == 2) { + return ompi_coll_base_reduce_intra_binomial(sbuf, rbuf, count, dtype, op, root, + comm, module, 0, 0); + } else { /*(alg == 3)*/ + return ompi_coll_base_reduce_intra_in_order_binary(sbuf, rbuf, count, dtype, op, + root, comm, module, 0, 0); + } + } else { +#ifdef HAVE_XPMEM_H + if ((((subc->xpmem_use_sr_buf != 0) + && (acoll_module->reserve_mem_s).reserve_mem_allocate + && ((acoll_module->reserve_mem_s).reserve_mem_size >= total_dsize)) + || ((subc->xpmem_use_sr_buf == 0) && (subc->xpmem_buf_size > 2 * total_dsize))) + && (subc->without_xpmem != 1)) { + return mca_coll_acoll_reduce_xpmem(sbuf, rbuf, count, dtype, op, root, comm, + module); + } else { + return ompi_coll_base_reduce_intra_binomial(sbuf, rbuf, count, dtype, op, + root, comm, module, 0, 0); + } +#else + return ompi_coll_base_reduce_intra_binomial(sbuf, rbuf, count, dtype, op, root, + comm, module, 0, 0); +#endif + } + } else { + return ompi_coll_base_reduce_intra_binomial(sbuf, rbuf, count, dtype, op, root, comm, + module, 0, 0); + } + return MPI_SUCCESS; +} diff --git a/ompi/mca/coll/acoll/coll_acoll_utils.h b/ompi/mca/coll/acoll/coll_acoll_utils.h new file mode 100644 index 00000000000..e69c6dfa859 --- /dev/null +++ b/ompi/mca/coll/acoll/coll_acoll_utils.h @@ -0,0 +1,769 @@ +/* -*- Mode: C; indent-tabs-mode:nil -*- */ +/* + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" + +#include "mpi.h" +#include "ompi/communicator/communicator.h" +#include "ompi/mca/coll/base/coll_base_functions.h" +#include "opal/include/opal/align.h" + +/* Function to allocate scratch buffer */ +static inline void *coll_acoll_buf_alloc(coll_acoll_reserve_mem_t *reserve_mem_ptr, uint64_t size) +{ + void *temp_ptr = NULL; + /* If requested size is within the pre-allocated range, use the + pre-allocated buffer if not in use. */ + if ((true == reserve_mem_ptr->reserve_mem_allocate) + && (size <= reserve_mem_ptr->reserve_mem_size) + && (false == reserve_mem_ptr->reserve_mem_in_use)) { + if (NULL == reserve_mem_ptr->reserve_mem) { + reserve_mem_ptr->reserve_mem = malloc(reserve_mem_ptr->reserve_mem_size); + } + temp_ptr = reserve_mem_ptr->reserve_mem; + + /* Mark the buffer as "in use" */ + if (NULL != temp_ptr) { + reserve_mem_ptr->reserve_mem_in_use = true; + } + } else { + /* If requested size if greater than that of the pre-allocated + buffer or if the pre-allocated buffer is in use, create new buffer */ + temp_ptr = malloc(size); + } + + return temp_ptr; +} + +/* Function to free scratch buffer */ +static inline void coll_acoll_buf_free(coll_acoll_reserve_mem_t *reserve_mem_ptr, void *ptr) +{ + /* Free the buffer only if it is not the reserved (pre-allocated) one */ + if ((false == reserve_mem_ptr->reserve_mem_allocate) + || (false == reserve_mem_ptr->reserve_mem_in_use)) { + if (NULL != ptr) { + free(ptr); + } + } else if (reserve_mem_ptr->reserve_mem == ptr) { + /* Mark the reserved buffer as free to be used */ + reserve_mem_ptr->reserve_mem_in_use = false; + } +} + +/* Function to compare integer elements */ +static int compare_values(const void *ptra, const void *ptrb) +{ + int a = *((int *) ptra); + int b = *((int *) ptrb); + + if (a < b) { + return -1; + } else if (a > b) { + return 1; + } + + return 0; +} + +/* Function to map ranks from parent communicator to sub-communicator */ +static inline int comm_grp_ranks_local(ompi_communicator_t *comm, ompi_communicator_t *local_comm, + int *is_root_node, int *local_root, int **ranks_buf, + int root) +{ + ompi_group_t *local_grp, *grp; + int local_size = ompi_comm_size(local_comm); + int *ranks = malloc(local_size * sizeof(int)); + int *local_ranks = malloc(local_size * sizeof(int)); + int i, err; + + /* Create parent (comm) and sub-comm (local_comm) groups */ + err = ompi_comm_group(comm, &grp); + err = ompi_comm_group(local_comm, &local_grp); + /* Initialize ranks for sub-communicator (local_comm) */ + for (i = 0; i < local_size; i++) { + local_ranks[i] = i; + } + + /* Translate the ranks among the 2 communicators */ + err = ompi_group_translate_ranks(local_grp, local_size, local_ranks, grp, ranks); + if (ranks_buf != NULL) { + *ranks_buf = malloc(local_size * sizeof(int)); + memcpy(*ranks_buf, ranks, local_size * sizeof(int)); + } + + /* Derive the 'local_root' which is the equivalent rank for 'root' of + 'comm' in 'local_comm' */ + for (i = 0; i < local_size; i++) { + if (ranks[i] == root) { + *is_root_node = 1; + *local_root = i; + break; + } + } + + err = ompi_group_free(&grp); + err = ompi_group_free(&local_grp); + free(ranks); + free(local_ranks); + + return err; +} + +static inline int mca_coll_acoll_create_base_comm(ompi_communicator_t **parent_comm, + coll_acoll_subcomms_t *subc, int color, int rank, + int *root, int base_lyr) +{ + int i; + int err; + + for (i = 0; i < MCA_COLL_ACOLL_NUM_LAYERS; i++) { + int is_root_node = 0; + + /* Create base comm */ + err = ompi_comm_split(parent_comm[i], color, rank, &subc->base_comm[base_lyr][i], false); + if (MPI_SUCCESS != err) + return err; + + /* Find out local rank of root in base comm */ + err = comm_grp_ranks_local(parent_comm[i], subc->base_comm[base_lyr][i], &is_root_node, + &subc->base_root[base_lyr][i], NULL, root[i]); + } + return err; +} + +static inline int mca_coll_acoll_comm_split_init(ompi_communicator_t *comm, + mca_coll_acoll_module_t *acoll_module, int root) +{ + opal_info_t comm_info; + mca_coll_base_module_allreduce_fn_t coll_allreduce_org = (comm)->c_coll->coll_allreduce; + mca_coll_base_module_allgather_fn_t coll_allgather_org = (comm)->c_coll->coll_allgather; + mca_coll_base_module_bcast_fn_t coll_bcast_org = (comm)->c_coll->coll_bcast; + mca_coll_base_module_allreduce_fn_t coll_allreduce_loc, coll_allreduce_soc; + mca_coll_base_module_allgather_fn_t coll_allgather_loc, coll_allgather_soc; + mca_coll_base_module_bcast_fn_t coll_bcast_loc, coll_bcast_soc; + coll_acoll_subcomms_t *subc; + int err; + int size = ompi_comm_size(comm); + int rank = ompi_comm_rank(comm); + int cid = ompi_comm_get_local_cid(comm); + if (cid >= MCA_COLL_ACOLL_MAX_CID) { + return MPI_SUCCESS; + } + + /* Derive subcomm structure */ + subc = &acoll_module->subc[cid]; + subc->cid = cid; + subc->orig_comm = comm; + + (comm)->c_coll->coll_allgather = ompi_coll_base_allgather_intra_ring; + (comm)->c_coll->coll_allreduce = ompi_coll_base_allreduce_intra_recursivedoubling; + (comm)->c_coll->coll_bcast = ompi_coll_base_bcast_intra_basic_linear; + if (!subc->initialized) { + OBJ_CONSTRUCT(&comm_info, opal_info_t); + opal_info_set(&comm_info, "ompi_comm_coll_preference", "libnbc,basic,^acoll"); + /* Create node-level subcommunicator */ + err = ompi_comm_split_type(comm, MPI_COMM_TYPE_SHARED, 0, &comm_info, &(subc->local_comm)); + if (MPI_SUCCESS != err) { + return err; + } + /* Create socket-level subcommunicator */ + err = ompi_comm_split_type(comm, OMPI_COMM_TYPE_SOCKET, 0, &comm_info, + &(subc->socket_comm)); + if (MPI_SUCCESS != err) { + return err; + } + OBJ_DESTRUCT(&comm_info); + OBJ_CONSTRUCT(&comm_info, opal_info_t); + opal_info_set(&comm_info, "ompi_comm_coll_preference", "libnbc,basic,^acoll"); + /* Create subgroup-level subcommunicator */ + err = ompi_comm_split_type(comm, OMPI_COMM_TYPE_L3CACHE, 0, &comm_info, + &(subc->subgrp_comm)); + if (MPI_SUCCESS != err) { + return err; + } + err = ompi_comm_split_type(comm, OMPI_COMM_TYPE_NUMA, 0, &comm_info, &(subc->numa_comm)); + if (MPI_SUCCESS != err) { + return err; + } + subc->subgrp_size = ompi_comm_size(subc->subgrp_comm); + OBJ_DESTRUCT(&comm_info); + + /* Derive the no. of nodes */ + if (size == ompi_comm_size(subc->local_comm)) { + subc->num_nodes = 1; + } else { + int *size_list_buf = (int *) malloc(size * sizeof(int)); + int num_nodes = 0; + int local_size = ompi_comm_size(subc->local_comm); + /* Perform allgather so that all ranks know the sizes of the nodes + to which all other ranks belong */ + err = (comm)->c_coll->coll_allgather(&local_size, 1, MPI_INT, size_list_buf, 1, MPI_INT, + comm, &acoll_module->super); + if (MPI_SUCCESS != err) { + free(size_list_buf); + return err; + } + /* Find the no. of nodes by counting each node only once. + * E.g., if there are 3 nodes with 2, 3 and 4 ranks on each node, + * first sort the size array so that the array elements are + * {2,2,3,3,3,4,4,4,4}. Read the value at the start of the array, + * offset the array by the read value, increment the counter, + * and repeat the process till end of array is reached. */ + qsort(size_list_buf, size, sizeof(int), compare_values); + for (int i = 0; i < size;) { + int ofst = size_list_buf[i]; + num_nodes++; + i += ofst; + } + subc->num_nodes = num_nodes; + free(size_list_buf); + } + } + /* Common initializations */ + { + subc->outer_grp_root = -1; + subc->subgrp_root = 0; + subc->is_root_sg = 0; + subc->is_root_numa = 0; + subc->numa_root = 0; + subc->is_root_socket = 0; + subc->socket_ldr_root = -1; + + if (subc->initialized) { + if (subc->num_nodes > 1) { + ompi_comm_free(&(subc->leader_comm)); + subc->leader_comm = NULL; + } + ompi_comm_free(&(subc->socket_ldr_comm)); + subc->socket_ldr_comm = NULL; + } + for (int i = 0; i < MCA_COLL_ACOLL_NUM_LAYERS; i++) { + if (subc->initialized) { + ompi_comm_free(&(subc->base_comm[MCA_COLL_ACOLL_L3CACHE][i])); + subc->base_comm[MCA_COLL_ACOLL_L3CACHE][i] = NULL; + ompi_comm_free(&(subc->base_comm[MCA_COLL_ACOLL_NUMA][i])); + subc->base_comm[MCA_COLL_ACOLL_NUMA][i] = NULL; + } + subc->base_root[MCA_COLL_ACOLL_L3CACHE][i] = -1; + subc->base_root[MCA_COLL_ACOLL_NUMA][i] = -1; + } + /* Store original collectives for local and socket comms */ + coll_allreduce_loc = (subc->local_comm)->c_coll->coll_allreduce; + coll_allgather_loc = (subc->local_comm)->c_coll->coll_allgather; + coll_bcast_loc = (subc->local_comm)->c_coll->coll_bcast; + (subc->local_comm)->c_coll->coll_allgather = ompi_coll_base_allgather_intra_ring; + (subc->local_comm)->c_coll->coll_allreduce + = ompi_coll_base_allreduce_intra_recursivedoubling; + (subc->local_comm)->c_coll->coll_bcast = ompi_coll_base_bcast_intra_basic_linear; + coll_allreduce_soc = (subc->socket_comm)->c_coll->coll_allreduce; + coll_allgather_soc = (subc->socket_comm)->c_coll->coll_allgather; + coll_bcast_soc = (subc->socket_comm)->c_coll->coll_bcast; + (subc->socket_comm)->c_coll->coll_allgather = ompi_coll_base_allgather_intra_ring; + (subc->socket_comm)->c_coll->coll_allreduce + = ompi_coll_base_allreduce_intra_recursivedoubling; + (subc->socket_comm)->c_coll->coll_bcast = ompi_coll_base_bcast_intra_basic_linear; + } + + /* Further subcommunicators based on root */ + if (subc->num_nodes > 1) { + int local_rank = ompi_comm_rank(subc->local_comm); + int color = MPI_UNDEFINED; + int is_root_node = 0, is_root_socket = 0; + int local_root = 0; + int *subgrp_ranks = NULL, *numa_ranks = NULL, *socket_ranks = NULL; + ompi_communicator_t *parent_comm[MCA_COLL_ACOLL_NUM_LAYERS]; + + /* Initializations */ + subc->local_root[MCA_COLL_ACOLL_LYR_NODE] = 0; + subc->local_root[MCA_COLL_ACOLL_LYR_SOCKET] = 0; + + /* Find out the local rank of root */ + err = comm_grp_ranks_local(comm, subc->local_comm, &subc->is_root_node, + &subc->local_root[MCA_COLL_ACOLL_LYR_NODE], NULL, root); + + /* Create subcommunicator with leader ranks */ + color = 1; + if (!subc->is_root_node && (local_rank == 0)) { + color = 0; + } + if (rank == root) { + color = 0; + } + err = ompi_comm_split(comm, color, rank, &subc->leader_comm, false); + if (MPI_SUCCESS != err) { + return err; + } + + /* Find out local rank of root in leader comm */ + err = comm_grp_ranks_local(comm, subc->leader_comm, &is_root_node, &subc->outer_grp_root, + NULL, root); + + /* Find out local rank of root in socket comm */ + if (subc->is_root_node) { + local_root = subc->local_root[MCA_COLL_ACOLL_LYR_NODE]; + } + err = comm_grp_ranks_local(subc->local_comm, subc->socket_comm, &subc->is_root_socket, + &subc->local_root[MCA_COLL_ACOLL_LYR_SOCKET], &socket_ranks, + local_root); + + /* Create subcommunicator with socket leaders */ + subc->socket_rank = subc->is_root_socket == 1 ? local_root : socket_ranks[0]; + color = local_rank == subc->socket_rank ? 0 : 1; + err = ompi_comm_split(subc->local_comm, color, local_rank, &subc->socket_ldr_comm, false); + if (MPI_SUCCESS != err) + return err; + + /* Find out local rank of root in socket leader comm */ + err = comm_grp_ranks_local(subc->local_comm, subc->socket_ldr_comm, &is_root_socket, + &subc->socket_ldr_root, NULL, local_root); + + /* Find out local rank of root in subgroup comm */ + err = comm_grp_ranks_local(subc->local_comm, subc->subgrp_comm, &subc->is_root_sg, + &subc->subgrp_root, &subgrp_ranks, local_root); + + /* Create subcommunicator with base ranks */ + subc->base_rank[MCA_COLL_ACOLL_L3CACHE] = subc->is_root_sg == 1 ? local_root + : subgrp_ranks[0]; + color = local_rank == subc->base_rank[MCA_COLL_ACOLL_L3CACHE] ? 0 : 1; + parent_comm[MCA_COLL_ACOLL_LYR_NODE] = subc->local_comm; + parent_comm[MCA_COLL_ACOLL_LYR_SOCKET] = subc->socket_comm; + err = mca_coll_acoll_create_base_comm(parent_comm, subc, color, local_rank, + subc->local_root, MCA_COLL_ACOLL_L3CACHE); + + /* Find out local rank of root in numa comm */ + err = comm_grp_ranks_local(subc->local_comm, subc->numa_comm, &subc->is_root_numa, + &subc->numa_root, &numa_ranks, local_root); + + subc->base_rank[MCA_COLL_ACOLL_NUMA] = subc->is_root_numa == 1 ? local_root : numa_ranks[0]; + color = local_rank == subc->base_rank[MCA_COLL_ACOLL_NUMA] ? 0 : 1; + err = mca_coll_acoll_create_base_comm(parent_comm, subc, color, local_rank, + subc->local_root, MCA_COLL_ACOLL_NUMA); + + if (socket_ranks != NULL) { + free(socket_ranks); + socket_ranks = NULL; + } + if (subgrp_ranks != NULL) { + free(subgrp_ranks); + subgrp_ranks = NULL; + } + if (numa_ranks != NULL) { + free(numa_ranks); + numa_ranks = NULL; + } + } else { + /* Intra node case */ + int color; + int is_root_socket = 0; + int *subgrp_ranks = NULL, *numa_ranks = NULL, *socket_ranks = NULL; + ompi_communicator_t *parent_comm[MCA_COLL_ACOLL_NUM_LAYERS]; + + /* Initializations */ + subc->local_root[MCA_COLL_ACOLL_LYR_NODE] = root; + subc->local_root[MCA_COLL_ACOLL_LYR_SOCKET] = 0; + + /* Find out local rank of root in socket comm */ + err = comm_grp_ranks_local(comm, subc->socket_comm, &subc->is_root_socket, + &subc->local_root[MCA_COLL_ACOLL_LYR_SOCKET], &socket_ranks, + root); + + /* Create subcommunicator with socket leaders */ + subc->socket_rank = subc->is_root_socket == 1 ? root : socket_ranks[0]; + color = rank == subc->socket_rank ? 0 : 1; + err = ompi_comm_split(comm, color, rank, &subc->socket_ldr_comm, false); + if (MPI_SUCCESS != err) { + return err; + } + + /* Find out local rank of root in socket leader comm */ + err = comm_grp_ranks_local(comm, subc->socket_ldr_comm, &is_root_socket, + &subc->socket_ldr_root, NULL, root); + + /* Find out local rank of root in subgroup comm */ + err = comm_grp_ranks_local(comm, subc->subgrp_comm, &subc->is_root_sg, &subc->subgrp_root, + &subgrp_ranks, root); + + /* Create subcommunicator with base ranks */ + subc->base_rank[MCA_COLL_ACOLL_L3CACHE] = subc->is_root_sg == 1 ? root : subgrp_ranks[0]; + color = rank == subc->base_rank[MCA_COLL_ACOLL_L3CACHE] ? 0 : 1; + parent_comm[MCA_COLL_ACOLL_LYR_NODE] = subc->local_comm; + parent_comm[MCA_COLL_ACOLL_LYR_SOCKET] = subc->socket_comm; + err = mca_coll_acoll_create_base_comm(parent_comm, subc, color, rank, subc->local_root, + MCA_COLL_ACOLL_L3CACHE); + + int numa_rank; + numa_rank = ompi_comm_rank(subc->numa_comm); + color = (numa_rank == 0) ? 0 : 1; + err = ompi_comm_split(subc->local_comm, color, rank, &subc->numa_comm_ldrs, false); + + /* Find out local rank of root in numa comm */ + err = comm_grp_ranks_local(comm, subc->numa_comm, &subc->is_root_numa, &subc->numa_root, + &numa_ranks, root); + + subc->base_rank[MCA_COLL_ACOLL_NUMA] = subc->is_root_numa == 1 ? root : numa_ranks[0]; + color = rank == subc->base_rank[MCA_COLL_ACOLL_NUMA] ? 0 : 1; + err = mca_coll_acoll_create_base_comm(parent_comm, subc, color, rank, subc->local_root, + MCA_COLL_ACOLL_NUMA); + + if (socket_ranks != NULL) { + free(socket_ranks); + socket_ranks = NULL; + } + if (subgrp_ranks != NULL) { + free(subgrp_ranks); + subgrp_ranks = NULL; + } + if (numa_ranks != NULL) { + free(numa_ranks); + numa_ranks = NULL; + } + } + + /* Restore originals for local and socket comms */ + (subc->local_comm)->c_coll->coll_allreduce = coll_allreduce_loc; + (subc->local_comm)->c_coll->coll_allgather = coll_allgather_loc; + (subc->local_comm)->c_coll->coll_bcast = coll_bcast_loc; + (subc->socket_comm)->c_coll->coll_allreduce = coll_allreduce_soc; + (subc->socket_comm)->c_coll->coll_allgather = coll_allgather_soc; + (subc->socket_comm)->c_coll->coll_bcast = coll_bcast_soc; + + /* For collectives where order is important (like gather, allgather), + * split based on ranks. This is optimal for global communicators with + * equal split among nodes, but suboptimal for other cases. + */ + if (!subc->initialized) { + if (subc->num_nodes > 1) { + int node_size = (size + subc->num_nodes - 1) / subc->num_nodes; + int color = rank / node_size; + err = ompi_comm_split(comm, color, rank, &subc->local_r_comm, false); + if (MPI_SUCCESS != err) { + return err; + } + } + subc->derived_node_size = (size + subc->num_nodes - 1) / subc->num_nodes; + } + + /* Restore originals */ + (comm)->c_coll->coll_allreduce = coll_allreduce_org; + (comm)->c_coll->coll_allgather = coll_allgather_org; + (comm)->c_coll->coll_bcast = coll_bcast_org; + + /* Init done */ + subc->initialized = 1; + if (root != subc->prev_init_root) { + subc->num_root_change++; + } + subc->prev_init_root = root; + + return err; +} + +#ifdef HAVE_XPMEM_H +static inline int mca_coll_acoll_xpmem_register(void *xpmem_apid, void *base, size_t size, + mca_rcache_base_registration_t *reg) +{ + struct xpmem_addr xpmem_addr; + xpmem_addr.apid = *((xpmem_apid_t *) xpmem_apid); + xpmem_addr.offset = (uintptr_t) base; + struct acoll_xpmem_rcache_reg_t *xpmem_reg = (struct acoll_xpmem_rcache_reg_t *) reg; + xpmem_reg->xpmem_vaddr = xpmem_attach(xpmem_addr, size, NULL); + + if ((void *) -1 == xpmem_reg->xpmem_vaddr) { + return -1; + } + return 0; +} + +static inline int mca_coll_acoll_xpmem_deregister(void *xpmem_apid, + mca_rcache_base_registration_t *reg) +{ + int status = xpmem_detach(((struct acoll_xpmem_rcache_reg_t *) reg)->xpmem_vaddr); + return status; +} +#endif + +static inline int coll_acoll_init(mca_coll_base_module_t *module, ompi_communicator_t *comm, + coll_acoll_data_t *data) +{ + int size, ret = 0, rank, line; + + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + coll_acoll_subcomms_t *subc; + int cid = ompi_comm_get_local_cid(comm); + subc = &acoll_module->subc[cid]; + if (subc->initialized_data) { + return ret; + } + subc->cid = cid; + data = (coll_acoll_data_t *) malloc(sizeof(coll_acoll_data_t)); + if (NULL == data) { + line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; + goto error_hndl; + } + size = ompi_comm_size(comm); + rank = ompi_comm_rank(comm); + data->comm_size = size; + +#ifdef HAVE_XPMEM_H + if (subc->xpmem_use_sr_buf == 0) { + data->scratch = (char *) malloc(subc->xpmem_buf_size); + if (NULL == data->scratch) { + line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; + goto error_hndl; + } + } else { + data->scratch = NULL; + } + + xpmem_segid_t seg_id; + data->allseg_id = (xpmem_segid_t *) malloc(sizeof(xpmem_segid_t) * size); + if (NULL == data->allseg_id) { + line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; + goto error_hndl; + } + data->all_apid = (xpmem_apid_t *) malloc(sizeof(xpmem_apid_t) * size); + if (NULL == data->all_apid) { + line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; + goto error_hndl; + } + data->allshm_sbuf = (void **) malloc(sizeof(void *) * size); + if (NULL == data->allshm_sbuf) { + line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; + goto error_hndl; + } + data->allshm_rbuf = (void **) malloc(sizeof(void *) * size); + if (NULL == data->allshm_rbuf) { + line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; + goto error_hndl; + } + data->xpmem_saddr = (void **) malloc(sizeof(void *) * size); + if (NULL == data->xpmem_saddr) { + line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; + goto error_hndl; + } + data->xpmem_raddr = (void **) malloc(sizeof(void *) * size); + if (NULL == data->xpmem_raddr) { + line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; + goto error_hndl; + } + data->rcache = (mca_rcache_base_module_t **) malloc(sizeof(mca_rcache_base_module_t *) * size); + if (NULL == data->rcache) { + line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; + goto error_hndl; + } + seg_id = xpmem_make(0, XPMEM_MAXADDR_SIZE, XPMEM_PERMIT_MODE, (void *) 0666); + if (seg_id == -1) { + line = __LINE__; + ret = -1; + goto error_hndl; + } + + ret = comm->c_coll->coll_allgather(&seg_id, sizeof(xpmem_segid_t), MPI_BYTE, data->allseg_id, + sizeof(xpmem_segid_t), MPI_BYTE, comm, + comm->c_coll->coll_allgather_module); + + /* Assuming the length of rcache name is less than 50 characters */ + char rc_name[50]; + for (int i = 0; i < size; i++) { + if (rank != i) { + data->all_apid[i] = xpmem_get(data->allseg_id[i], XPMEM_RDWR, XPMEM_PERMIT_MODE, + (void *) 0666); + if (data->all_apid[i] == -1) { + line = __LINE__; + ret = -1; + goto error_hndl; + } + if (data->all_apid[i] == -1) { + line = __LINE__; + ret = -1; + goto error_hndl; + } + sprintf(rc_name, "acoll_%d_%d_%d", cid, rank, i); + mca_rcache_base_resources_t rcache_element + = {.cache_name = rc_name, + .reg_data = &data->all_apid[i], + .sizeof_reg = sizeof(struct acoll_xpmem_rcache_reg_t), + .register_mem = mca_coll_acoll_xpmem_register, + .deregister_mem = mca_coll_acoll_xpmem_deregister}; + + data->rcache[i] = mca_rcache_base_module_create("grdma", NULL, &rcache_element); + if (data->rcache[i] == NULL) { + ret = -1; + line = __LINE__; + goto error_hndl; + } + } + } +#endif + + /* temporary variables */ + int tmp1, tmp2, tmp3 = 0; + comm_grp_ranks_local(comm, subc->numa_comm, &tmp1, &tmp2, &data->l1_gp, tmp3); + data->l1_gp_size = ompi_comm_size(subc->numa_comm); + data->l1_local_rank = ompi_comm_rank(subc->numa_comm); + + comm_grp_ranks_local(comm, subc->numa_comm_ldrs, &tmp1, &tmp2, &data->l2_gp, tmp3); + data->l2_gp_size = ompi_comm_size(subc->numa_comm_ldrs); + data->l2_local_rank = ompi_comm_rank(subc->numa_comm_ldrs); + data->offset[0] = 16 * 1024; + data->offset[1] = data->offset[0] + size * 64; + data->offset[2] = data->offset[1] + size * 64; + data->offset[3] = data->offset[2] + rank * 8 * 1024; + data->allshmseg_id = (opal_shmem_ds_t *) malloc(sizeof(opal_shmem_ds_t) * size); + data->allshmmmap_sbuf = (void **) malloc(sizeof(void *) * size); + data->sync[0] = 0; + data->sync[1] = 0; + char *shfn; + + /* Only the leaders need to allocate shared memory */ + /* remaining ranks move their data into their leader's shm */ + if (data->l1_gp[0] == rank) { + subc->initialized_shm_data = true; + ret = asprintf(&shfn, "/dev/shm/acoll_coll_shmem_seg.%u.%x.%d:%d-%d", geteuid(), + OPAL_PROC_MY_NAME.jobid, ompi_comm_rank(MPI_COMM_WORLD), + ompi_comm_get_local_cid(comm), ompi_comm_size(comm)); + } + + if (ret < 0) { + line = __LINE__; + goto error_hndl; + } + + opal_shmem_ds_t seg_ds; + if (data->l1_gp[0] == rank) { + /* Assuming cacheline size is 64 */ + long memsize + = (16 * 1024 /* scratch leader */ + 64 * size /* sync variables l1 group*/ + + 64 * size /* sync variables l2 group*/ + 8 * 1024 * size /*data from ranks*/); + ret = opal_shmem_segment_create(&seg_ds, shfn, memsize); + free(shfn); + } + + if (ret != OPAL_SUCCESS) { + opal_output_verbose(MCA_BASE_VERBOSE_ERROR, ompi_coll_base_framework.framework_output, + "coll:acoll: Error: Could not create shared memory segment"); + line = __LINE__; + goto error_hndl; + } + + ret = comm->c_coll->coll_allgather(&seg_ds, sizeof(opal_shmem_ds_t), MPI_BYTE, + data->allshmseg_id, sizeof(opal_shmem_ds_t), MPI_BYTE, comm, + comm->c_coll->coll_allgather_module); + + if (data->l1_gp[0] != rank) { + data->allshmmmap_sbuf[data->l1_gp[0]] = opal_shmem_segment_attach( + &data->allshmseg_id[data->l1_gp[0]]); + } else { + for (int i = 0; i < data->l2_gp_size; i++) { + data->allshmmmap_sbuf[data->l2_gp[i]] = opal_shmem_segment_attach( + &data->allshmseg_id[data->l2_gp[i]]); + } + } + + int offset = 16 * 1024; + memset(((char *) data->allshmmmap_sbuf[data->l1_gp[0]]) + offset + 64 * rank, 0, 64); + if (data->l1_gp[0] == rank) { + memset(((char *) data->allshmmmap_sbuf[data->l2_gp[0]]) + (offset + 64 * size) + 64 * rank, + 0, 64); + } + + subc->initialized_data = true; + subc->data = data; + ompi_coll_base_barrier_intra_tree(comm, module); + + return MPI_SUCCESS; +error_hndl: + (void) line; + if (NULL != data) { +#ifdef HAVE_XPMEM_H + free(data->allseg_id); + data->allseg_id = NULL; + free(data->all_apid); + data->all_apid = NULL; + free(data->allshm_sbuf); + data->allshm_sbuf = NULL; + free(data->allshm_rbuf); + data->allshm_rbuf = NULL; + free(data->xpmem_saddr); + data->xpmem_saddr = NULL; + free(data->xpmem_raddr); + data->xpmem_raddr = NULL; + free(data->rcache); + data->rcache = NULL; + free(data->scratch); + data->scratch = NULL; +#endif + free(data->allshmseg_id); + data->allshmseg_id = NULL; + free(data->allshmmmap_sbuf); + data->allshmmmap_sbuf = NULL; + free(data->l1_gp); + data->l1_gp = NULL; + free(data->l2_gp); + data->l2_gp = NULL; + free(data); + data = NULL; + } + return ret; +} + +#ifdef HAVE_XPMEM_H +static inline void register_and_cache(int size, size_t total_dsize, int rank, + coll_acoll_data_t *data) +{ + uintptr_t base, bound; + for (int i = 0; i < size; i++) { + if (rank != i) { + mca_rcache_base_module_t *rcache_i = data->rcache[i]; + int access_flags = 0; + struct acoll_xpmem_rcache_reg_t *sbuf_reg = NULL, *rbuf_reg = NULL; + base = OPAL_DOWN_ALIGN((uintptr_t) data->allshm_sbuf[i], 4096, uintptr_t); + bound = OPAL_ALIGN((uintptr_t) data->allshm_sbuf[i] + total_dsize, 4096, uintptr_t); + int ret = rcache_i->rcache_register(rcache_i, (void *) base, bound - base, access_flags, + MCA_RCACHE_ACCESS_ANY, + (mca_rcache_base_registration_t **) &sbuf_reg); + + if (ret != 0) { + sbuf_reg = NULL; + return; + } + data->xpmem_saddr[i] = (void *) ((uintptr_t) sbuf_reg->xpmem_vaddr + + ((uintptr_t) data->allshm_sbuf[i] + - (uintptr_t) sbuf_reg->base.base)); + + base = OPAL_DOWN_ALIGN((uintptr_t) data->allshm_rbuf[i], 4096, uintptr_t); + bound = OPAL_ALIGN((uintptr_t) data->allshm_rbuf[i] + total_dsize, 4096, uintptr_t); + ret = rcache_i->rcache_register(rcache_i, (void *) base, bound - base, access_flags, + MCA_RCACHE_ACCESS_ANY, + (mca_rcache_base_registration_t **) &rbuf_reg); + + if (ret != 0) { + rbuf_reg = NULL; + return; + } + data->xpmem_raddr[i] = (void *) ((uintptr_t) rbuf_reg->xpmem_vaddr + + ((uintptr_t) data->allshm_rbuf[i] + - (uintptr_t) rbuf_reg->base.base)); + } else { + data->xpmem_saddr[i] = data->allshm_sbuf[i]; + data->xpmem_raddr[i] = data->allshm_rbuf[i]; + } + } +} +#endif diff --git a/ompi/mca/coll/acoll/configure.m4 b/ompi/mca/coll/acoll/configure.m4 new file mode 100644 index 00000000000..339b34c567c --- /dev/null +++ b/ompi/mca/coll/acoll/configure.m4 @@ -0,0 +1,18 @@ +# +# Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# + +AC_DEFUN([MCA_ompi_coll_acoll_CONFIG],[ + AC_CONFIG_FILES([ompi/mca/coll/acoll/Makefile]) + + OPAL_CHECK_XPMEM([coll_acoll], [should_build=1], [should_build=1]) + + AC_SUBST([coll_acoll_CPPFLAGS]) + AC_SUBST([coll_acoll_LDFLAGS]) + AC_SUBST([coll_acoll_LIBS]) +])dnl