diff --git a/include/ofi_coll.h b/include/ofi_coll.h index d6532f0dd7d..f5f8d103f86 100644 --- a/include/ofi_coll.h +++ b/include/ofi_coll.h @@ -162,4 +162,8 @@ struct util_coll_operation { uint64_t flags; }; +int coll_cq_init(struct fid_domain *domain, struct fi_cq_attr *attr, + struct fid_cq **cq_fid, ofi_cq_progress_func progress, + void *context); + #endif // _OFI_COLL_H_ diff --git a/prov/coll/src/coll_cq.c b/prov/coll/src/coll_cq.c index 3c279b5113e..cadf5783cf9 100644 --- a/prov/coll/src/coll_cq.c +++ b/prov/coll/src/coll_cq.c @@ -68,18 +68,32 @@ static struct fi_ops_cq coll_cq_ops = { int coll_cq_open(struct fid_domain *domain, struct fi_cq_attr *attr, struct fid_cq **cq_fid, void *context) +{ + return coll_cq_init(domain, attr, cq_fid, &ofi_cq_progress, context); +} + +int coll_cq_init(struct fid_domain *domain, + struct fi_cq_attr *attr, struct fid_cq **cq_fid, + ofi_cq_progress_func progress, void *context) { struct coll_cq *cq; struct fi_peer_cq_context *peer_context = context; int ret; + const struct coll_domain *coll_domain; + const struct fi_provider* provider; + + coll_domain = container_of(domain, struct coll_domain, util_domain.domain_fid.fid); + provider = coll_domain->util_domain.fabric->prov; + + if (!attr || !(attr->flags & FI_PEER)) { - FI_WARN(&coll_prov, FI_LOG_CORE, "FI_PEER flag required\n"); + FI_WARN(provider, FI_LOG_CORE, "FI_PEER flag required\n"); return -EINVAL; } if (!peer_context || peer_context->size < sizeof(*peer_context)) { - FI_WARN(&coll_prov, FI_LOG_CORE, "invalid peer CQ context\n"); + FI_WARN(provider, FI_LOG_CORE, "invalid peer CQ context\n"); return -EINVAL; } @@ -89,7 +103,7 @@ int coll_cq_open(struct fid_domain *domain, struct fi_cq_attr *attr, cq->peer_cq = peer_context->cq; - ret = ofi_cq_init(&coll_prov, domain, attr, &cq->util_cq, &ofi_cq_progress, + ret = ofi_cq_init(provider, domain, attr, &cq->util_cq, &ofi_cq_progress, context); if (ret) goto err;