diff --git a/drivers/infiniband/core/ucma.c b/drivers/infiniband/core/ucma.c index 4a42a95be6013..0ef2f6a3b8df3 100644 --- a/drivers/infiniband/core/ucma.c +++ b/drivers/infiniband/core/ucma.c @@ -132,7 +132,7 @@ static inline struct ucma_context *_ucma_find_context(int id, ctx = idr_find(&ctx_idr, id); if (!ctx) ctx = ERR_PTR(-ENOENT); - else if (ctx->file != file) + else if (ctx->file != file || !ctx->cm_id) ctx = ERR_PTR(-EINVAL); return ctx; } @@ -454,6 +454,7 @@ static ssize_t ucma_create_id(struct ucma_file *file, const char __user *inbuf, struct rdma_ucm_create_id cmd; struct rdma_ucm_create_id_resp resp; struct ucma_context *ctx; + struct rdma_cm_id *cm_id; enum ib_qp_type qp_type; int ret; @@ -474,10 +475,10 @@ static ssize_t ucma_create_id(struct ucma_file *file, const char __user *inbuf, return -ENOMEM; ctx->uid = cmd.uid; - ctx->cm_id = rdma_create_id(current->nsproxy->net_ns, - ucma_event_handler, ctx, cmd.ps, qp_type); - if (IS_ERR(ctx->cm_id)) { - ret = PTR_ERR(ctx->cm_id); + cm_id = rdma_create_id(current->nsproxy->net_ns, + ucma_event_handler, ctx, cmd.ps, qp_type); + if (IS_ERR(cm_id)) { + ret = PTR_ERR(cm_id); goto err1; } @@ -487,10 +488,12 @@ static ssize_t ucma_create_id(struct ucma_file *file, const char __user *inbuf, ret = -EFAULT; goto err2; } + + ctx->cm_id = cm_id; return 0; err2: - rdma_destroy_id(ctx->cm_id); + rdma_destroy_id(cm_id); err1: mutex_lock(&mut); idr_remove(&ctx_idr, ctx->id);