Now when we have the udata passed to all the ib_xxx object creation APIs and the additional function 'rdma_get_ucontext' to get the ib_ucontext from ib_udata stored in uverbs_attr_bundle, we can finally start to remove the dependency of the drivers in the ib_xxx->uobject->context. Signed-off-by: Shamir Rabinovitch <shamir.rabinovitch@xxxxxxxxxx> --- drivers/infiniband/hw/bnxt_re/ib_verbs.c | 33 +++++-- drivers/infiniband/hw/cxgb3/iwch_provider.c | 4 +- drivers/infiniband/hw/cxgb4/qp.c | 8 +- drivers/infiniband/hw/hns/hns_roce_qp.c | 15 +++- drivers/infiniband/hw/i40iw/i40iw_verbs.c | 15 +++- drivers/infiniband/hw/mlx4/doorbell.c | 6 ++ drivers/infiniband/hw/mlx4/mr.c | 9 +- drivers/infiniband/hw/mlx4/qp.c | 92 +++++++++++++------- drivers/infiniband/hw/mlx4/srq.c | 13 ++- drivers/infiniband/hw/mlx5/qp.c | 68 +++++++++++---- drivers/infiniband/hw/mlx5/srq.c | 13 ++- drivers/infiniband/hw/mthca/mthca_provider.c | 28 ++++-- drivers/infiniband/hw/mthca/mthca_qp.c | 19 ++-- drivers/infiniband/hw/mthca/mthca_srq.c | 25 ++++-- drivers/infiniband/hw/nes/nes_verbs.c | 13 ++- drivers/infiniband/hw/qedr/verbs.c | 8 +- drivers/infiniband/hw/usnic/usnic_ib_verbs.c | 7 +- drivers/infiniband/sw/rdmavt/qp.c | 10 ++- drivers/infiniband/sw/rdmavt/srq.c | 10 ++- drivers/infiniband/sw/rxe/rxe_qp.c | 5 +- drivers/infiniband/sw/rxe/rxe_verbs.c | 5 +- 21 files changed, 287 insertions(+), 119 deletions(-) diff --git a/drivers/infiniband/hw/bnxt_re/ib_verbs.c b/drivers/infiniband/hw/bnxt_re/ib_verbs.c index 9bc637e49faa..d811efe49e77 100644 --- a/drivers/infiniband/hw/bnxt_re/ib_verbs.c +++ b/drivers/infiniband/hw/bnxt_re/ib_verbs.c @@ -733,11 +733,16 @@ struct ib_ah *bnxt_re_create_ah(struct ib_pd *ib_pd, /* Write AVID to shared page. */ if (udata) { - struct ib_ucontext *ib_uctx = ib_pd->uobject->context; + struct ib_ucontext *ib_uctx; struct bnxt_re_ucontext *uctx; unsigned long flag; u32 *wrptr; + ib_uctx = rdma_get_ucontext(udata); + if (IS_ERR(ib_uctx)) { + rc = PTR_ERR(ib_uctx); + goto fail; + } uctx = container_of(ib_uctx, struct bnxt_re_ucontext, ib_uctx); spin_lock_irqsave(&uctx->sh_lock, flag); wrptr = (u32 *)(uctx->shpg + BNXT_RE_AVID_OFFT); @@ -883,10 +888,15 @@ static int bnxt_re_init_user_qp(struct bnxt_re_dev *rdev, struct bnxt_re_pd *pd, struct bnxt_qplib_qp *qplib_qp = &qp->qplib_qp; struct ib_umem *umem; int bytes = 0; - struct ib_ucontext *context = pd->ib_pd.uobject->context; - struct bnxt_re_ucontext *cntx = container_of(context, - struct bnxt_re_ucontext, - ib_uctx); + struct bnxt_re_ucontext *cntx; + struct ib_ucontext *context; + + context = rdma_get_ucontext(udata); + if (IS_ERR(context)) + return PTR_ERR(context); + + cntx = container_of(context, struct bnxt_re_ucontext, ib_uctx); + if (ib_copy_from_udata(&ureq, udata, sizeof(ureq))) return -EFAULT; @@ -1360,10 +1370,15 @@ static int bnxt_re_init_user_srq(struct bnxt_re_dev *rdev, struct bnxt_qplib_srq *qplib_srq = &srq->qplib_srq; struct ib_umem *umem; int bytes = 0; - struct ib_ucontext *context = pd->ib_pd.uobject->context; - struct bnxt_re_ucontext *cntx = container_of(context, - struct bnxt_re_ucontext, - ib_uctx); + struct bnxt_re_ucontext *cntx; + struct ib_ucontext *context; + + context = rdma_get_ucontext(udata); + if (IS_ERR(context)) + return PTR_ERR(context); + + cntx = container_of(context, struct bnxt_re_ucontext, ib_uctx); + if (ib_copy_from_udata(&ureq, udata, sizeof(ureq))) return -EFAULT; diff --git a/drivers/infiniband/hw/cxgb3/iwch_provider.c b/drivers/infiniband/hw/cxgb3/iwch_provider.c index 07c20cd07f33..b692cd2bbe92 100644 --- a/drivers/infiniband/hw/cxgb3/iwch_provider.c +++ b/drivers/infiniband/hw/cxgb3/iwch_provider.c @@ -796,6 +796,7 @@ static struct ib_qp *iwch_create_qp(struct ib_pd *pd, struct iwch_cq *schp; struct iwch_cq *rchp; struct iwch_create_qp_resp uresp; + struct ib_ucontext *ib_ucontext; int wqsize, sqsize, rqsize; struct iwch_ucontext *ucontext; @@ -836,7 +837,8 @@ static struct ib_qp *iwch_create_qp(struct ib_pd *pd, * Kernel users need more wq space for fastreg WRs which can take * 2 WR fragments. */ - ucontext = udata ? to_iwch_ucontext(pd->uobject->context) : NULL; + ib_ucontext = rdma_get_ucontext(udata); + ucontext = IS_ERR(ib_ucontext) ? NULL : to_iwch_ucontext(ib_ucontext); if (!ucontext && wqsize < (rqsize + (2 * sqsize))) wqsize = roundup_pow_of_two(rqsize + roundup_pow_of_two(attrs->cap.max_send_wr * 2)); diff --git a/drivers/infiniband/hw/cxgb4/qp.c b/drivers/infiniband/hw/cxgb4/qp.c index 03f4c66c2659..d2c452b4282c 100644 --- a/drivers/infiniband/hw/cxgb4/qp.c +++ b/drivers/infiniband/hw/cxgb4/qp.c @@ -2137,6 +2137,7 @@ struct ib_qp *c4iw_create_qp(struct ib_pd *pd, struct ib_qp_init_attr *attrs, struct c4iw_create_qp_resp uresp; unsigned int sqsize, rqsize = 0; struct c4iw_ucontext *ucontext; + struct ib_ucontext *ib_ucontext; int ret; struct c4iw_mm_entry *sq_key_mm, *rq_key_mm = NULL, *sq_db_key_mm; struct c4iw_mm_entry *rq_db_key_mm = NULL, *ma_sync_key_mm = NULL; @@ -2170,7 +2171,8 @@ struct ib_qp *c4iw_create_qp(struct ib_pd *pd, struct ib_qp_init_attr *attrs, if (sqsize < 8) sqsize = 8; - ucontext = udata ? to_c4iw_ucontext(pd->uobject->context) : NULL; + ib_ucontext = rdma_get_ucontext(udata); + ucontext = IS_ERR(ib_ucontext) ? NULL : to_c4iw_ucontext(ib_ucontext); qhp = kzalloc(sizeof(*qhp), GFP_KERNEL); if (!qhp) @@ -2697,6 +2699,7 @@ struct ib_srq *c4iw_create_srq(struct ib_pd *pd, struct ib_srq_init_attr *attrs, struct c4iw_create_srq_resp uresp; struct c4iw_ucontext *ucontext; struct c4iw_mm_entry *srq_key_mm, *srq_db_key_mm; + struct ib_ucontext *ib_ucontext; int rqsize; int ret; int wr_len; @@ -2719,7 +2722,8 @@ struct ib_srq *c4iw_create_srq(struct ib_pd *pd, struct ib_srq_init_attr *attrs, rqsize = attrs->attr.max_wr + 1; rqsize = roundup_pow_of_two(max_t(u16, rqsize, 16)); - ucontext = udata ? to_c4iw_ucontext(pd->uobject->context) : NULL; + ib_ucontext = rdma_get_ucontext(udata); + ucontext = IS_ERR(ib_ucontext) ? NULL : to_c4iw_ucontext(ib_ucontext); srq = kzalloc(sizeof(*srq), GFP_KERNEL); if (!srq) diff --git a/drivers/infiniband/hw/hns/hns_roce_qp.c b/drivers/infiniband/hw/hns/hns_roce_qp.c index accf9ce1507d..b9c25eaf2d75 100644 --- a/drivers/infiniband/hw/hns/hns_roce_qp.c +++ b/drivers/infiniband/hw/hns/hns_roce_qp.c @@ -542,12 +542,19 @@ static int hns_roce_create_qp_common(struct hns_roce_dev *hr_dev, struct device *dev = hr_dev->dev; struct hns_roce_ib_create_qp ucmd; struct hns_roce_ib_create_qp_resp resp = {}; + struct ib_ucontext *ucontext; unsigned long qpn = 0; int ret = 0; u32 page_shift; u32 npages; int i; + ucontext = rdma_get_ucontext(udata); + if (IS_ERR(ucontext)) { + ret = PTR_ERR(ucontext); + goto err_out; + } + mutex_init(&hr_qp->mutex); spin_lock_init(&hr_qp->sq.lock); spin_lock_init(&hr_qp->rq.lock); @@ -653,7 +660,7 @@ static int hns_roce_create_qp_common(struct hns_roce_dev *hr_dev, (udata->outlen >= sizeof(resp)) && hns_roce_qp_has_sq(init_attr)) { ret = hns_roce_db_map_user( - to_hr_ucontext(ib_pd->uobject->context), udata, + to_hr_ucontext(ucontext), udata, ucmd.sdb_addr, &hr_qp->sdb); if (ret) { dev_err(dev, "sq record doorbell map failed!\n"); @@ -669,7 +676,7 @@ static int hns_roce_create_qp_common(struct hns_roce_dev *hr_dev, (udata->outlen >= sizeof(resp)) && hns_roce_qp_has_rq(init_attr)) { ret = hns_roce_db_map_user( - to_hr_ucontext(ib_pd->uobject->context), udata, + to_hr_ucontext(ucontext), udata, ucmd.db_addr, &hr_qp->rdb); if (ret) { dev_err(dev, "rq record doorbell map failed!\n"); @@ -815,7 +822,7 @@ static int hns_roce_create_qp_common(struct hns_roce_dev *hr_dev, (udata->outlen >= sizeof(resp)) && hns_roce_qp_has_rq(init_attr)) hns_roce_db_unmap_user( - to_hr_ucontext(ib_pd->uobject->context), + to_hr_ucontext(ucontext), &hr_qp->rdb); } else { kfree(hr_qp->sq.wrid); @@ -829,7 +836,7 @@ static int hns_roce_create_qp_common(struct hns_roce_dev *hr_dev, (udata->outlen >= sizeof(resp)) && hns_roce_qp_has_sq(init_attr)) hns_roce_db_unmap_user( - to_hr_ucontext(ib_pd->uobject->context), + to_hr_ucontext(ucontext), &hr_qp->sdb); err_mtt: diff --git a/drivers/infiniband/hw/i40iw/i40iw_verbs.c b/drivers/infiniband/hw/i40iw/i40iw_verbs.c index 12b31a8440be..194cd911c9de 100644 --- a/drivers/infiniband/hw/i40iw/i40iw_verbs.c +++ b/drivers/infiniband/hw/i40iw/i40iw_verbs.c @@ -580,11 +580,16 @@ static struct ib_qp *i40iw_create_qp(struct ib_pd *ibpd, struct i40iw_create_qp_info *qp_info; struct i40iw_cqp_request *cqp_request; struct cqp_commands_info *cqp_info; + struct ib_ucontext *ib_ucontext; struct i40iw_qp_host_ctx_info *ctx_info; struct i40iwarp_offload_info *iwarp_info; unsigned long flags; + ib_ucontext = rdma_get_ucontext(udata); + if (udata && IS_ERR(ib_ucontext)) + return ERR_CAST(ib_ucontext); + if (iwdev->closing) return ERR_PTR(-ENODEV); @@ -674,7 +679,7 @@ static struct ib_qp *i40iw_create_qp(struct ib_pd *ibpd, } iwqp->ctx_info.qp_compl_ctx = req.user_compl_ctx; iwqp->user_mode = 1; - ucontext = to_ucontext(ibpd->uobject->context); + ucontext = to_ucontext(ib_ucontext); if (req.user_wqe_buffers) { struct i40iw_pbl *iwpbl; @@ -1832,6 +1837,7 @@ static struct ib_mr *i40iw_reg_user_mr(struct ib_pd *pd, struct i40iw_pd *iwpd = to_iwpd(pd); struct i40iw_device *iwdev = to_iwdev(pd->device); struct i40iw_ucontext *ucontext; + struct ib_ucontext *ib_ucontext; struct i40iw_pble_alloc *palloc; struct i40iw_pbl *iwpbl; struct i40iw_mr *iwmr; @@ -1847,6 +1853,12 @@ static struct ib_mr *i40iw_reg_user_mr(struct ib_pd *pd, int ret; int pg_shift; + ib_ucontext = rdma_get_ucontext(udata); + if (IS_ERR(ib_ucontext)) + return ERR_CAST(ib_ucontext); + + ucontext = to_ucontext(ib_ucontext); + if (iwdev->closing) return ERR_PTR(-ENODEV); @@ -1872,7 +1884,6 @@ static struct ib_mr *i40iw_reg_user_mr(struct ib_pd *pd, iwmr->region = region; iwmr->ibmr.pd = pd; iwmr->ibmr.device = pd->device; - ucontext = to_ucontext(pd->uobject->context); iwmr->page_size = PAGE_SIZE; iwmr->page_msk = PAGE_MASK; diff --git a/drivers/infiniband/hw/mlx4/doorbell.c b/drivers/infiniband/hw/mlx4/doorbell.c index 3aab71b29ce8..1a4c6d3f8078 100644 --- a/drivers/infiniband/hw/mlx4/doorbell.c +++ b/drivers/infiniband/hw/mlx4/doorbell.c @@ -48,6 +48,9 @@ int mlx4_ib_db_map_user(struct mlx4_ib_ucontext *context, struct mlx4_ib_user_db_page *page; int err = 0; + if (!context) + return -EINVAL; + mutex_lock(&context->db_page_mutex); list_for_each_entry(page, &context->db_page_list, list) @@ -84,6 +87,9 @@ int mlx4_ib_db_map_user(struct mlx4_ib_ucontext *context, void mlx4_ib_db_unmap_user(struct mlx4_ib_ucontext *context, struct mlx4_db *db) { + if (WARN_ON(!context)) + return; + mutex_lock(&context->db_page_mutex); if (!--db->u.user_page->refcnt) { diff --git a/drivers/infiniband/hw/mlx4/mr.c b/drivers/infiniband/hw/mlx4/mr.c index 56639ecd53ad..17744bb1b7a0 100644 --- a/drivers/infiniband/hw/mlx4/mr.c +++ b/drivers/infiniband/hw/mlx4/mr.c @@ -367,8 +367,7 @@ int mlx4_ib_umem_calc_optimal_mtt_size(struct ib_umem *umem, u64 start_va, return block_shift; } -static struct ib_umem *mlx4_get_umem_mr(struct ib_ucontext *context, - struct ib_udata *udata, u64 start, +static struct ib_umem *mlx4_get_umem_mr(struct ib_udata *udata, u64 start, u64 length, u64 virt_addr, int access_flags) { @@ -416,7 +415,7 @@ struct ib_mr *mlx4_ib_reg_user_mr(struct ib_pd *pd, u64 start, u64 length, if (!mr) return ERR_PTR(-ENOMEM); - mr->umem = mlx4_get_umem_mr(pd->uobject->context, udata, start, length, + mr->umem = mlx4_get_umem_mr(udata, start, length, virt_addr, access_flags); if (IS_ERR(mr->umem)) { err = PTR_ERR(mr->umem); @@ -507,8 +506,8 @@ int mlx4_ib_rereg_user_mr(struct ib_mr *mr, int flags, mlx4_mr_rereg_mem_cleanup(dev->dev, &mmr->mmr); ib_umem_release(mmr->umem); mmr->umem = - mlx4_get_umem_mr(mr->uobject->context, udata, start, - length, virt_addr, mr_access_flags); + mlx4_get_umem_mr(udata, start, length, virt_addr, + mr_access_flags); if (IS_ERR(mmr->umem)) { err = PTR_ERR(mmr->umem); /* Prevent mlx4_ib_dereg_mr from free'ing invalid pointer */ diff --git a/drivers/infiniband/hw/mlx4/qp.c b/drivers/infiniband/hw/mlx4/qp.c index e38bab50cecf..91bb55dd93af 100644 --- a/drivers/infiniband/hw/mlx4/qp.c +++ b/drivers/infiniband/hw/mlx4/qp.c @@ -52,7 +52,8 @@ static void mlx4_ib_lock_cqs(struct mlx4_ib_cq *send_cq, struct mlx4_ib_cq *recv_cq); static void mlx4_ib_unlock_cqs(struct mlx4_ib_cq *send_cq, struct mlx4_ib_cq *recv_cq); -static int _mlx4_ib_modify_wq(struct ib_wq *ibwq, enum ib_wq_state new_state); +static int _mlx4_ib_modify_wq(struct ib_wq *ibwq, enum ib_wq_state new_statem, + struct ib_udata *udata); enum { MLX4_IB_ACK_REQ_FREQ = 8, @@ -778,10 +779,15 @@ static struct ib_qp *_mlx4_ib_create_qp_rss(struct ib_pd *pd, static int mlx4_ib_alloc_wqn(struct mlx4_ib_ucontext *context, struct mlx4_ib_qp *qp, int range_size, int *wqn) { - struct mlx4_ib_dev *dev = to_mdev(context->ibucontext.device); struct mlx4_wqn_range *range; + struct mlx4_ib_dev *dev; int err = 0; + if (!context) + return -EINVAL; + + dev = to_mdev(context->ibucontext.device); + mutex_lock(&context->wqn_ranges_mutex); range = list_first_entry_or_null(&context->wqn_ranges_list, @@ -828,8 +834,13 @@ static int mlx4_ib_alloc_wqn(struct mlx4_ib_ucontext *context, static void mlx4_ib_release_wqn(struct mlx4_ib_ucontext *context, struct mlx4_ib_qp *qp, bool dirty_release) { - struct mlx4_ib_dev *dev = to_mdev(context->ibucontext.device); struct mlx4_wqn_range *range; + struct mlx4_ib_dev *dev; + + if (!WARN_ON(context)) + return; + + dev = to_mdev(context->ibucontext.device); mutex_lock(&context->wqn_ranges_mutex); @@ -867,6 +878,11 @@ static int create_qp_common(struct mlx4_ib_dev *dev, struct ib_pd *pd, struct mlx4_ib_cq *mcq; unsigned long flags; int range_size = 0; + struct mlx4_ib_ucontext *context; + struct ib_ucontext *ib_ucontext; + + ib_ucontext = rdma_get_ucontext(udata); + context = IS_ERR(ib_ucontext) ? NULL : to_mucontext(ib_ucontext); /* When tunneling special qps, we use a plain UD qp */ if (sqpn) { @@ -1038,7 +1054,7 @@ static int create_qp_common(struct mlx4_ib_dev *dev, struct ib_pd *pd, if (qp_has_rq(init_attr)) { err = mlx4_ib_db_map_user( - to_mucontext(pd->uobject->context), udata, + context, udata, (src == MLX4_IB_QP_SRC) ? ucmd.qp.db_addr : ucmd.wq.db_addr, &qp->db); @@ -1112,8 +1128,7 @@ static int create_qp_common(struct mlx4_ib_dev *dev, struct ib_pd *pd, } } } else if (src == MLX4_IB_RWQ_SRC) { - err = mlx4_ib_alloc_wqn(to_mucontext(pd->uobject->context), qp, - range_size, &qpn); + err = mlx4_ib_alloc_wqn(context, qp, range_size, &qpn); if (err) goto err_wrid; } else { @@ -1184,8 +1199,7 @@ static int create_qp_common(struct mlx4_ib_dev *dev, struct ib_pd *pd, if (qp->flags & MLX4_IB_QP_NETIF) mlx4_ib_steer_qp_free(dev, qpn, 1); else if (src == MLX4_IB_RWQ_SRC) - mlx4_ib_release_wqn(to_mucontext(pd->uobject->context), - qp, 0); + mlx4_ib_release_wqn(context, qp, 0); else mlx4_qp_release_range(dev->dev, qpn, 1); } @@ -1195,7 +1209,7 @@ static int create_qp_common(struct mlx4_ib_dev *dev, struct ib_pd *pd, err_wrid: if (udata) { if (qp_has_rq(init_attr)) - mlx4_ib_db_unmap_user(to_mucontext(pd->uobject->context), &qp->db); + mlx4_ib_db_unmap_user(context, &qp->db); } else { kvfree(qp->sq.wrid); kvfree(qp->rq.wrid); @@ -1942,7 +1956,8 @@ static u8 gid_type_to_qpc(enum ib_gid_type gid_type) * Go over all RSS QP's childes (WQs) and apply their HW state according to * their logic state if the RSS QP is the first RSS QP associated for the WQ. */ -static int bringup_rss_rwqs(struct ib_rwq_ind_table *ind_tbl, u8 port_num) +static int bringup_rss_rwqs(struct ib_rwq_ind_table *ind_tbl, u8 port_num, + struct ib_udata *udata) { int err = 0; int i; @@ -1966,7 +1981,7 @@ static int bringup_rss_rwqs(struct ib_rwq_ind_table *ind_tbl, u8 port_num) } wq->port = port_num; if ((wq->rss_usecnt == 0) && (ibwq->state == IB_WQS_RDY)) { - err = _mlx4_ib_modify_wq(ibwq, IB_WQS_RDY); + err = _mlx4_ib_modify_wq(ibwq, IB_WQS_RDY, udata); if (err) { mutex_unlock(&wq->mutex); break; @@ -1988,7 +2003,8 @@ static int bringup_rss_rwqs(struct ib_rwq_ind_table *ind_tbl, u8 port_num) if ((wq->rss_usecnt == 1) && (ibwq->state == IB_WQS_RDY)) - if (_mlx4_ib_modify_wq(ibwq, IB_WQS_RESET)) + if (_mlx4_ib_modify_wq(ibwq, IB_WQS_RESET, + udata)) pr_warn("failed to reverse WQN=0x%06x\n", ibwq->wq_num); wq->rss_usecnt--; @@ -2000,7 +2016,8 @@ static int bringup_rss_rwqs(struct ib_rwq_ind_table *ind_tbl, u8 port_num) return err; } -static void bring_down_rss_rwqs(struct ib_rwq_ind_table *ind_tbl) +static void bring_down_rss_rwqs(struct ib_rwq_ind_table *ind_tbl, + struct ib_udata *udata) { int i; @@ -2011,7 +2028,7 @@ static void bring_down_rss_rwqs(struct ib_rwq_ind_table *ind_tbl) mutex_lock(&wq->mutex); if ((wq->rss_usecnt == 1) && (ibwq->state == IB_WQS_RDY)) - if (_mlx4_ib_modify_wq(ibwq, IB_WQS_RESET)) + if (_mlx4_ib_modify_wq(ibwq, IB_WQS_RESET, udata)) pr_warn("failed to reverse WQN=%x\n", ibwq->wq_num); wq->rss_usecnt--; @@ -2043,12 +2060,14 @@ static void fill_qp_rss_context(struct mlx4_qp_context *context, static int __mlx4_ib_modify_qp(void *src, enum mlx4_ib_source_type src_type, const struct ib_qp_attr *attr, int attr_mask, - enum ib_qp_state cur_state, enum ib_qp_state new_state) + enum ib_qp_state cur_state, + enum ib_qp_state new_state, + struct ib_udata *udata) { - struct ib_uobject *ibuobject; struct ib_srq *ibsrq; const struct ib_gid_attr *gid_attr = NULL; struct ib_rwq_ind_table *rwq_ind_tbl; + struct ib_ucontext *ibucontext; enum ib_qp_type qp_type; struct mlx4_ib_dev *dev; struct mlx4_ib_qp *qp; @@ -2065,7 +2084,6 @@ static int __mlx4_ib_modify_qp(void *src, enum mlx4_ib_source_type src_type, struct ib_wq *ibwq; ibwq = (struct ib_wq *)src; - ibuobject = ibwq->uobject; ibsrq = NULL; rwq_ind_tbl = NULL; qp_type = IB_QPT_RAW_PACKET; @@ -2076,7 +2094,6 @@ static int __mlx4_ib_modify_qp(void *src, enum mlx4_ib_source_type src_type, struct ib_qp *ibqp; ibqp = (struct ib_qp *)src; - ibuobject = ibqp->uobject; ibsrq = ibqp->srq; rwq_ind_tbl = ibqp->rwq_ind_tbl; qp_type = ibqp->qp_type; @@ -2161,12 +2178,17 @@ static int __mlx4_ib_modify_qp(void *src, enum mlx4_ib_source_type src_type, context->param3 |= cpu_to_be32(1 << 30); } - if (ibuobject) + if (udata) { + ibucontext = rdma_get_ucontext(udata); + if (IS_ERR(ibucontext)) { + err = PTR_ERR(ibucontext); + goto out; + } context->usr_page = cpu_to_be32( mlx4_to_hw_uar_index(dev->dev, - to_mucontext(ibuobject->context) + to_mucontext(ibucontext) ->uar.index)); - else + } else context->usr_page = cpu_to_be32( mlx4_to_hw_uar_index(dev->dev, dev->priv_uar.index)); @@ -2297,7 +2319,7 @@ static int __mlx4_ib_modify_qp(void *src, enum mlx4_ib_source_type src_type, context->cqn_recv = cpu_to_be32(recv_cq->mcq.cqn); /* Set "fast registration enabled" for all kernel QPs */ - if (!ibuobject) + if (!udata) context->params1 |= cpu_to_be32(1 << 11); if (attr_mask & IB_QP_RNR_RETRY) { @@ -2434,7 +2456,7 @@ static int __mlx4_ib_modify_qp(void *src, enum mlx4_ib_source_type src_type, else sqd_event = 0; - if (!ibuobject && + if (!udata && cur_state == IB_QPS_RESET && new_state == IB_QPS_INIT) context->rlkey_roce_mode |= (1 << 4); @@ -2445,7 +2467,7 @@ static int __mlx4_ib_modify_qp(void *src, enum mlx4_ib_source_type src_type, * headroom is stamped so that the hardware doesn't start * processing stale work requests. */ - if (!ibuobject && + if (!udata && cur_state == IB_QPS_RESET && new_state == IB_QPS_INIT) { struct mlx4_wqe_ctrl_seg *ctrl; @@ -2509,7 +2531,7 @@ static int __mlx4_ib_modify_qp(void *src, enum mlx4_ib_source_type src_type, * entries and reinitialize the QP. */ if (new_state == IB_QPS_RESET) { - if (!ibuobject) { + if (!udata) { mlx4_ib_cq_clean(recv_cq, qp->mqp.qpn, ibsrq ? to_msrq(ibsrq) : NULL); if (send_cq != recv_cq) @@ -2735,16 +2757,17 @@ static int _mlx4_ib_modify_qp(struct ib_qp *ibqp, struct ib_qp_attr *attr, } if (ibqp->rwq_ind_tbl && (new_state == IB_QPS_INIT)) { - err = bringup_rss_rwqs(ibqp->rwq_ind_tbl, attr->port_num); + err = bringup_rss_rwqs(ibqp->rwq_ind_tbl, attr->port_num, + udata); if (err) goto out; } err = __mlx4_ib_modify_qp(ibqp, MLX4_IB_QP_SRC, attr, attr_mask, - cur_state, new_state); + cur_state, new_state, udata); if (ibqp->rwq_ind_tbl && err) - bring_down_rss_rwqs(ibqp->rwq_ind_tbl); + bring_down_rss_rwqs(ibqp->rwq_ind_tbl, udata); if (mlx4_is_bonded(dev->dev) && (attr_mask & IB_QP_PORT)) attr->port_num = 1; @@ -4122,7 +4145,8 @@ static int ib_wq2qp_state(enum ib_wq_state state) } } -static int _mlx4_ib_modify_wq(struct ib_wq *ibwq, enum ib_wq_state new_state) +static int _mlx4_ib_modify_wq(struct ib_wq *ibwq, enum ib_wq_state new_state, + struct ib_udata *udata) { struct mlx4_ib_qp *qp = to_mqp((struct ib_qp *)ibwq); enum ib_qp_state qp_cur_state; @@ -4146,7 +4170,8 @@ static int _mlx4_ib_modify_wq(struct ib_wq *ibwq, enum ib_wq_state new_state) attr_mask = IB_QP_PORT; err = __mlx4_ib_modify_qp(ibwq, MLX4_IB_RWQ_SRC, &attr, - attr_mask, IB_QPS_RESET, IB_QPS_INIT); + attr_mask, IB_QPS_RESET, IB_QPS_INIT, + udata); if (err) { pr_debug("WQN=0x%06x failed to apply RST->INIT on the HW QP\n", ibwq->wq_num); @@ -4158,12 +4183,13 @@ static int _mlx4_ib_modify_wq(struct ib_wq *ibwq, enum ib_wq_state new_state) attr_mask = 0; err = __mlx4_ib_modify_qp(ibwq, MLX4_IB_RWQ_SRC, NULL, attr_mask, - qp_cur_state, qp_new_state); + qp_cur_state, qp_new_state, udata); if (err && (qp_cur_state == IB_QPS_INIT)) { qp_new_state = IB_QPS_RESET; if (__mlx4_ib_modify_qp(ibwq, MLX4_IB_RWQ_SRC, NULL, - attr_mask, IB_QPS_INIT, IB_QPS_RESET)) { + attr_mask, IB_QPS_INIT, IB_QPS_RESET, + udata)) { pr_warn("WQN=0x%06x failed with reverting HW's resources failure\n", ibwq->wq_num); qp_new_state = IB_QPS_INIT; @@ -4226,7 +4252,7 @@ int mlx4_ib_modify_wq(struct ib_wq *ibwq, struct ib_wq_attr *wq_attr, * WQ, so we can apply its port on the WQ. */ if (qp->rss_usecnt) - err = _mlx4_ib_modify_wq(ibwq, new_state); + err = _mlx4_ib_modify_wq(ibwq, new_state, udata); if (!err) ibwq->state = new_state; diff --git a/drivers/infiniband/hw/mlx4/srq.c b/drivers/infiniband/hw/mlx4/srq.c index 498588eac051..0551e6732d22 100644 --- a/drivers/infiniband/hw/mlx4/srq.c +++ b/drivers/infiniband/hw/mlx4/srq.c @@ -76,6 +76,7 @@ struct ib_srq *mlx4_ib_create_srq(struct ib_pd *pd, struct mlx4_ib_srq *srq; struct mlx4_wqe_srq_next_seg *next; struct mlx4_wqe_data_seg *scatter; + struct ib_ucontext *ib_ucontext; u32 cqn; u16 xrcdn; int desc_size; @@ -108,6 +109,12 @@ struct ib_srq *mlx4_ib_create_srq(struct ib_pd *pd, if (udata) { struct mlx4_ib_create_srq ucmd; + ib_ucontext = rdma_get_ucontext(udata); + if (IS_ERR(ib_ucontext)) { + err = PTR_ERR(ib_ucontext); + goto err_srq; + } + if (ib_copy_from_udata(&ucmd, udata, sizeof ucmd)) { err = -EFAULT; goto err_srq; @@ -128,8 +135,8 @@ struct ib_srq *mlx4_ib_create_srq(struct ib_pd *pd, if (err) goto err_mtt; - err = mlx4_ib_db_map_user(to_mucontext(pd->uobject->context), - udata, ucmd.db_addr, &srq->db); + err = mlx4_ib_db_map_user(to_mucontext(ib_ucontext), udata, + ucmd.db_addr, &srq->db); if (err) goto err_mtt; } else { @@ -202,7 +209,7 @@ struct ib_srq *mlx4_ib_create_srq(struct ib_pd *pd, err_wrid: if (udata) - mlx4_ib_db_unmap_user(to_mucontext(pd->uobject->context), &srq->db); + mlx4_ib_db_unmap_user(to_mucontext(ib_ucontext), &srq->db); else kvfree(srq->wrid); diff --git a/drivers/infiniband/hw/mlx5/qp.c b/drivers/infiniband/hw/mlx5/qp.c index 529e76f67cb6..ccaa88a4eb9f 100644 --- a/drivers/infiniband/hw/mlx5/qp.c +++ b/drivers/infiniband/hw/mlx5/qp.c @@ -696,12 +696,17 @@ static int create_user_rq(struct mlx5_ib_dev *dev, struct ib_pd *pd, struct ib_udata *udata, struct mlx5_ib_rwq *rwq, struct mlx5_ib_create_wq *ucmd) { + struct ib_ucontext *ib_ucontext; int page_shift = 0; int npages; u32 offset = 0; int ncont = 0; int err; + ib_ucontext = rdma_get_ucontext(udata); + if (IS_ERR(ib_ucontext)) + return PTR_ERR(ib_ucontext); + if (!ucmd->buf_addr) return -EINVAL; @@ -730,7 +735,7 @@ static int create_user_rq(struct mlx5_ib_dev *dev, struct ib_pd *pd, (unsigned long long)ucmd->buf_addr, rwq->buf_size, npages, page_shift, ncont, offset); - err = mlx5_ib_db_map_user(to_mucontext(pd->uobject->context), udata, + err = mlx5_ib_db_map_user(to_mucontext(ib_ucontext), udata, ucmd->db_addr, &rwq->db); if (err) { mlx5_ib_dbg(dev, "map failed\n"); @@ -759,6 +764,7 @@ static int create_user_qp(struct mlx5_ib_dev *dev, struct ib_pd *pd, struct mlx5_ib_create_qp_resp *resp, int *inlen, struct mlx5_ib_qp_base *base) { + struct ib_ucontext *ib_ucontext; struct mlx5_ib_ucontext *context; struct mlx5_ib_create_qp ucmd; struct mlx5_ib_ubuffer *ubuffer = &base->ubuffer; @@ -773,13 +779,17 @@ static int create_user_qp(struct mlx5_ib_dev *dev, struct ib_pd *pd, int err; u16 uid; + ib_ucontext = rdma_get_ucontext(udata); + if (IS_ERR(ib_ucontext)) + return PTR_ERR(ib_ucontext); + err = ib_copy_from_udata(&ucmd, udata, sizeof(ucmd)); if (err) { mlx5_ib_dbg(dev, "copy failed\n"); return err; } - context = to_mucontext(pd->uobject->context); + context = to_mucontext(ib_ucontext); if (ucmd.flags & MLX5_QP_FLAG_BFREG_INDEX) { uar_index = bfregn_to_uar_index(dev, &context->bfregi, ucmd.bfreg_index, true); @@ -1818,6 +1828,7 @@ static int create_qp_common(struct mlx5_ib_dev *dev, struct ib_pd *pd, int inlen = MLX5_ST_SZ_BYTES(create_qp_in); struct mlx5_core_dev *mdev = dev->mdev; struct mlx5_ib_create_qp_resp resp = {}; + struct ib_ucontext *context; struct mlx5_ib_cq *send_cq; struct mlx5_ib_cq *recv_cq; unsigned long flags; @@ -1918,8 +1929,12 @@ static int create_qp_common(struct mlx5_ib_dev *dev, struct ib_pd *pd, MLX5_QP_FLAG_PACKET_BASED_CREDIT_MODE)) return -EINVAL; - err = get_qp_user_index(to_mucontext(pd->uobject->context), - &ucmd, udata->inlen, &uidx); + context = rdma_get_ucontext(udata); + if (IS_ERR(context)) + return PTR_ERR(context); + + err = get_qp_user_index(to_mucontext(context), &ucmd, + udata->inlen, &uidx); if (err) return err; @@ -2403,8 +2418,10 @@ static const char *ib_qp_type_str(enum ib_qp_type type) static struct ib_qp *mlx5_ib_create_dct(struct ib_pd *pd, struct ib_qp_init_attr *attr, - struct mlx5_ib_create_qp *ucmd) + struct mlx5_ib_create_qp *ucmd, + struct ib_udata *udata) { + struct ib_ucontext *context; struct mlx5_ib_qp *qp; int err = 0; u32 uidx = MLX5_IB_DEFAULT_UIDX; @@ -2413,8 +2430,12 @@ static struct ib_qp *mlx5_ib_create_dct(struct ib_pd *pd, if (!attr->srq || !attr->recv_cq) return ERR_PTR(-EINVAL); - err = get_qp_user_index(to_mucontext(pd->uobject->context), - ucmd, sizeof(*ucmd), &uidx); + context = rdma_get_ucontext(udata); + if (IS_ERR(context)) + return ERR_CAST(context); + + err = get_qp_user_index(to_mucontext(context), ucmd, sizeof(*ucmd), + &uidx); if (err) return ERR_PTR(err); @@ -2490,6 +2511,7 @@ struct ib_qp *mlx5_ib_create_qp(struct ib_pd *pd, struct ib_qp_init_attr *verbs_init_attr, struct ib_udata *udata) { + struct ib_ucontext *context; struct mlx5_ib_dev *dev; struct mlx5_ib_qp *qp; u16 xrcdn = 0; @@ -2504,9 +2526,16 @@ struct ib_qp *mlx5_ib_create_qp(struct ib_pd *pd, if (!udata) { mlx5_ib_dbg(dev, "Raw Packet QP is not supported for kernel consumers\n"); return ERR_PTR(-EINVAL); - } else if (!to_mucontext(pd->uobject->context)->cqe_version) { - mlx5_ib_dbg(dev, "Raw Packet QP is only supported for CQE version > 0\n"); - return ERR_PTR(-EINVAL); + } else { + context = rdma_get_ucontext(udata); + if (IS_ERR(context)) + return ERR_CAST(context); + + if (!to_mucontext(context)-> + cqe_version) { + mlx5_ib_dbg(dev, "Raw Packet QP is only supported for CQE version > 0\n"); + return ERR_PTR(-EINVAL); + } } } } else { @@ -2536,7 +2565,7 @@ struct ib_qp *mlx5_ib_create_qp(struct ib_pd *pd, return ERR_PTR(-EINVAL); } } else { - return mlx5_ib_create_dct(pd, init_attr, &ucmd); + return mlx5_ib_create_dct(pd, init_attr, &ucmd, udata); } } @@ -3174,13 +3203,16 @@ static int modify_raw_packet_qp(struct mlx5_ib_dev *dev, struct mlx5_ib_qp *qp, static unsigned int get_tx_affinity(struct mlx5_ib_dev *dev, struct mlx5_ib_pd *pd, struct mlx5_ib_qp_base *qp_base, - u8 port_num) + u8 port_num, + struct ib_udata *udata) { struct mlx5_ib_ucontext *ucontext = NULL; unsigned int tx_port_affinity; + struct ib_ucontext *context; - if (pd && pd->ibpd.uobject && pd->ibpd.uobject->context) - ucontext = to_mucontext(pd->ibpd.uobject->context); + context = rdma_get_ucontext(udata); + if (!IS_ERR(context)) + ucontext = to_mucontext(context); if (ucontext) { tx_port_affinity = (unsigned int)atomic_add_return( @@ -3205,7 +3237,8 @@ static unsigned int get_tx_affinity(struct mlx5_ib_dev *dev, static int __mlx5_ib_modify_qp(struct ib_qp *ibqp, const struct ib_qp_attr *attr, int attr_mask, enum ib_qp_state cur_state, enum ib_qp_state new_state, - const struct mlx5_ib_modify_qp *ucmd) + const struct mlx5_ib_modify_qp *ucmd, + struct ib_udata *udata) { static const u16 optab[MLX5_QP_NUM_STATE][MLX5_QP_NUM_STATE] = { [MLX5_QP_STATE_RST] = { @@ -3296,7 +3329,8 @@ static int __mlx5_ib_modify_qp(struct ib_qp *ibqp, (ibqp->qp_type == IB_QPT_XRC_TGT)) { if (dev->lag_active) { u8 p = mlx5_core_native_port_num(dev->mdev); - tx_affinity = get_tx_affinity(dev, pd, base, p); + tx_affinity = get_tx_affinity(dev, pd, base, p, + udata); context->flags |= cpu_to_be32(tx_affinity << 24); } } @@ -3779,7 +3813,7 @@ int mlx5_ib_modify_qp(struct ib_qp *ibqp, struct ib_qp_attr *attr, } err = __mlx5_ib_modify_qp(ibqp, attr, attr_mask, cur_state, - new_state, &ucmd); + new_state, &ucmd, udata); out: mutex_unlock(&qp->mutex); diff --git a/drivers/infiniband/hw/mlx5/srq.c b/drivers/infiniband/hw/mlx5/srq.c index 22bd774e0b4e..827e58c729a6 100644 --- a/drivers/infiniband/hw/mlx5/srq.c +++ b/drivers/infiniband/hw/mlx5/srq.c @@ -47,6 +47,7 @@ static int create_srq_user(struct ib_pd *pd, struct mlx5_ib_srq *srq, { struct mlx5_ib_dev *dev = to_mdev(pd->device); struct mlx5_ib_create_srq ucmd = {}; + struct ib_ucontext *context; size_t ucmdlen; int err; int npages; @@ -55,6 +56,10 @@ static int create_srq_user(struct ib_pd *pd, struct mlx5_ib_srq *srq, u32 offset; u32 uidx = MLX5_IB_DEFAULT_UIDX; + context = rdma_get_ucontext(udata); + if (IS_ERR(context)) + return PTR_ERR(context); + ucmdlen = min(udata->inlen, sizeof(ucmd)); if (ib_copy_from_udata(&ucmd, udata, ucmdlen)) { @@ -71,8 +76,8 @@ static int create_srq_user(struct ib_pd *pd, struct mlx5_ib_srq *srq, return -EINVAL; if (in->type != IB_SRQT_BASIC) { - err = get_srq_user_index(to_mucontext(pd->uobject->context), - &ucmd, udata->inlen, &uidx); + err = get_srq_user_index(to_mucontext(context), &ucmd, + udata->inlen, &uidx); if (err) return err; } @@ -103,8 +108,8 @@ static int create_srq_user(struct ib_pd *pd, struct mlx5_ib_srq *srq, mlx5_ib_populate_pas(dev, srq->umem, page_shift, in->pas, 0); - err = mlx5_ib_db_map_user(to_mucontext(pd->uobject->context), udata, - ucmd.db_addr, &srq->db); + err = mlx5_ib_db_map_user(to_mucontext(context), udata, ucmd.db_addr, + &srq->db); if (err) { mlx5_ib_dbg(dev, "map doorbell failed\n"); goto err_in; diff --git a/drivers/infiniband/hw/mthca/mthca_provider.c b/drivers/infiniband/hw/mthca/mthca_provider.c index 63003b4d2485..4d6a17bda8b4 100644 --- a/drivers/infiniband/hw/mthca/mthca_provider.c +++ b/drivers/infiniband/hw/mthca/mthca_provider.c @@ -446,6 +446,7 @@ static struct ib_srq *mthca_create_srq(struct ib_pd *pd, { struct mthca_create_srq ucmd; struct mthca_ucontext *context = NULL; + struct ib_ucontext *ib_ucontext; struct mthca_srq *srq; int err; @@ -457,7 +458,12 @@ static struct ib_srq *mthca_create_srq(struct ib_pd *pd, return ERR_PTR(-ENOMEM); if (udata) { - context = to_mucontext(pd->uobject->context); + ib_ucontext = rdma_get_ucontext(udata); + if (IS_ERR(ib_ucontext)) { + err = PTR_ERR(ib_ucontext); + goto err_free; + } + context = to_mucontext(ib_ucontext); if (ib_copy_from_udata(&ucmd, udata, sizeof ucmd)) { err = -EFAULT; @@ -520,6 +526,7 @@ static struct ib_qp *mthca_create_qp(struct ib_pd *pd, struct ib_qp_init_attr *init_attr, struct ib_udata *udata) { + struct ib_ucontext *ib_ucontext; struct mthca_create_qp ucmd; struct mthca_qp *qp; int err; @@ -532,14 +539,17 @@ static struct ib_qp *mthca_create_qp(struct ib_pd *pd, case IB_QPT_UC: case IB_QPT_UD: { - struct mthca_ucontext *context; + struct mthca_ucontext *context = NULL; qp = kmalloc(sizeof *qp, GFP_KERNEL); if (!qp) return ERR_PTR(-ENOMEM); if (udata) { - context = to_mucontext(pd->uobject->context); + ib_ucontext = rdma_get_ucontext(udata); + if (IS_ERR(ib_ucontext)) + return ERR_CAST(ib_ucontext); + context = to_mucontext(ib_ucontext); if (ib_copy_from_udata(&ucmd, udata, sizeof ucmd)) { kfree(qp); @@ -577,9 +587,7 @@ static struct ib_qp *mthca_create_qp(struct ib_pd *pd, init_attr->qp_type, init_attr->sq_sig_type, &init_attr->cap, qp, udata); - if (err && udata) { - context = to_mucontext(pd->uobject->context); - + if (err && context) { mthca_unmap_user_db(to_mdev(pd->device), &context->uar, context->db_tab, @@ -907,6 +915,7 @@ static struct ib_mr *mthca_reg_user_mr(struct ib_pd *pd, u64 start, u64 length, u64 virt, int acc, struct ib_udata *udata) { struct mthca_dev *dev = to_mdev(pd->device); + struct ib_ucontext *context; struct scatterlist *sg; struct mthca_mr *mr; struct mthca_reg_mr ucmd; @@ -917,12 +926,15 @@ static struct ib_mr *mthca_reg_user_mr(struct ib_pd *pd, u64 start, u64 length, int write_mtt_size; if (udata->inlen < sizeof ucmd) { - if (!to_mucontext(pd->uobject->context)->reg_mr_warned) { + context = rdma_get_ucontext(udata); + if (IS_ERR(context)) + return ERR_CAST(context); + if (!to_mucontext(context)->reg_mr_warned) { mthca_warn(dev, "Process '%s' did not pass in MR attrs.\n", current->comm); mthca_warn(dev, " Update libmthca to fix this.\n"); } - ++to_mucontext(pd->uobject->context)->reg_mr_warned; + ++to_mucontext(context)->reg_mr_warned; ucmd.mr_attrs = 0; } else if (ib_copy_from_udata(&ucmd, udata, sizeof ucmd)) return ERR_PTR(-EFAULT); diff --git a/drivers/infiniband/hw/mthca/mthca_qp.c b/drivers/infiniband/hw/mthca/mthca_qp.c index 4e5b5cc17f1d..ea0ee6b5572c 100644 --- a/drivers/infiniband/hw/mthca/mthca_qp.c +++ b/drivers/infiniband/hw/mthca/mthca_qp.c @@ -554,7 +554,9 @@ static int mthca_path_set(struct mthca_dev *dev, const struct rdma_ah_attr *ah, static int __mthca_modify_qp(struct ib_qp *ibqp, const struct ib_qp_attr *attr, int attr_mask, - enum ib_qp_state cur_state, enum ib_qp_state new_state) + enum ib_qp_state cur_state, + enum ib_qp_state new_state, + struct ib_udata *udata) { struct mthca_dev *dev = to_mdev(ibqp->device); struct mthca_qp *qp = to_mqp(ibqp); @@ -563,6 +565,7 @@ static int __mthca_modify_qp(struct ib_qp *ibqp, struct mthca_qp_context *qp_context; u32 sqd_event = 0; int err = -EINVAL; + struct ib_ucontext *context; mailbox = mthca_alloc_mailbox(dev, GFP_KERNEL); if (IS_ERR(mailbox)) { @@ -618,10 +621,15 @@ static int __mthca_modify_qp(struct ib_qp *ibqp, /* leave arbel_sched_queue as 0 */ - if (qp->ibqp.uobject) + if (udata) { + context = rdma_get_ucontext(udata); + if (IS_ERR(context)) { + err = PTR_ERR(context); + goto out_mailbox; + } qp_context->usr_page = - cpu_to_be32(to_mucontext(qp->ibqp.uobject->context)->uar.index); - else + cpu_to_be32(to_mucontext(context)->uar.index); + } else qp_context->usr_page = cpu_to_be32(dev->driver_uar.index); qp_context->local_qpn = cpu_to_be32(qp->qpn); if (attr_mask & IB_QP_DEST_QPN) { @@ -913,7 +921,8 @@ int mthca_modify_qp(struct ib_qp *ibqp, struct ib_qp_attr *attr, int attr_mask, goto out; } - err = __mthca_modify_qp(ibqp, attr, attr_mask, cur_state, new_state); + err = __mthca_modify_qp(ibqp, attr, attr_mask, cur_state, new_state, + udata); out: mutex_unlock(&qp->mutex); diff --git a/drivers/infiniband/hw/mthca/mthca_srq.c b/drivers/infiniband/hw/mthca/mthca_srq.c index b8333c79e3fa..04979b0ee7fd 100644 --- a/drivers/infiniband/hw/mthca/mthca_srq.c +++ b/drivers/infiniband/hw/mthca/mthca_srq.c @@ -96,18 +96,23 @@ static void mthca_tavor_init_srq_context(struct mthca_dev *dev, struct mthca_pd *pd, struct mthca_srq *srq, struct mthca_tavor_srq_context *context, - bool is_user) + struct ib_udata *udata) { + struct ib_ucontext *ib_ucontext; + memset(context, 0, sizeof *context); context->wqe_base_ds = cpu_to_be64(1 << (srq->wqe_shift - 4)); context->state_pd = cpu_to_be32(pd->pd_num); context->lkey = cpu_to_be32(srq->mr.ibmr.lkey); - if (is_user) + if (udata) { + ib_ucontext = rdma_get_ucontext(udata); + if (WARN_ON(IS_ERR(ib_ucontext))) + return; context->uar = - cpu_to_be32(to_mucontext(pd->ibpd.uobject->context)->uar.index); - else + cpu_to_be32(to_mucontext(ib_ucontext)->uar.index); + } else context->uar = cpu_to_be32(dev->driver_uar.index); } @@ -115,8 +120,9 @@ static void mthca_arbel_init_srq_context(struct mthca_dev *dev, struct mthca_pd *pd, struct mthca_srq *srq, struct mthca_arbel_srq_context *context, - bool is_user) + struct ib_udata *udata) { + struct ib_ucontext *ib_ucontext; int logsize, max; memset(context, 0, sizeof *context); @@ -131,10 +137,13 @@ static void mthca_arbel_init_srq_context(struct mthca_dev *dev, context->lkey = cpu_to_be32(srq->mr.ibmr.lkey); context->db_index = cpu_to_be32(srq->db_index); context->logstride_usrpage = cpu_to_be32((srq->wqe_shift - 4) << 29); - if (is_user) + if (udata) { + ib_ucontext = rdma_get_ucontext(udata); + if (WARN_ON(IS_ERR(ib_ucontext))) + return; context->logstride_usrpage |= - cpu_to_be32(to_mucontext(pd->ibpd.uobject->context)->uar.index); - else + cpu_to_be32(to_mucontext(ib_ucontext)->uar.index); + } else context->logstride_usrpage |= cpu_to_be32(dev->driver_uar.index); context->eq_pd = cpu_to_be32(MTHCA_EQ_ASYNC << 24 | pd->pd_num); } diff --git a/drivers/infiniband/hw/nes/nes_verbs.c b/drivers/infiniband/hw/nes/nes_verbs.c index 034156f7e9ed..a33b3fdae682 100644 --- a/drivers/infiniband/hw/nes/nes_verbs.c +++ b/drivers/infiniband/hw/nes/nes_verbs.c @@ -983,6 +983,7 @@ static struct ib_qp *nes_create_qp(struct ib_pd *ibpd, struct nes_vnic *nesvnic = to_nesvnic(ibpd->device); struct nes_device *nesdev = nesvnic->nesdev; struct nes_adapter *nesadapter = nesdev->nesadapter; + struct ib_ucontext *context; struct nes_qp *nesqp; struct nes_cq *nescq; struct nes_ucontext *nes_ucontext; @@ -1066,9 +1067,10 @@ static struct ib_qp *nes_create_qp(struct ib_pd *ibpd, } if (req.user_qp_buffer) nesqp->nesuqp_addr = req.user_qp_buffer; - if (udata && (ibpd->uobject->context)) { + context = rdma_get_ucontext(udata); + if (!IS_ERR(context)) { nesqp->user_mode = 1; - nes_ucontext = to_nesucontext(ibpd->uobject->context); + nes_ucontext = to_nesucontext(context); if (virt_wqs) { err = 1; list_for_each_entry(nespbl, &nes_ucontext->qp_reg_mem_list, list) { @@ -1089,7 +1091,6 @@ static struct ib_qp *nes_create_qp(struct ib_pd *ibpd, } } - nes_ucontext = to_nesucontext(ibpd->uobject->context); nesqp->mmap_sq_db_index = find_next_zero_bit(nes_ucontext->allocated_wqs, NES_MAX_USER_WQ_REGIONS, nes_ucontext->first_free_wq); @@ -2111,6 +2112,7 @@ static struct ib_mr *nes_reg_user_mr(struct ib_pd *pd, u64 start, u64 length, struct ib_mr *ibmr = ERR_PTR(-EINVAL); struct scatterlist *sg; struct nes_ucontext *nes_ucontext; + struct ib_ucontext *context; struct nes_pbl *nespbl; struct nes_mr *nesmr; struct ib_umem *region; @@ -2382,8 +2384,11 @@ static struct ib_mr *nes_reg_user_mr(struct ib_pd *pd, u64 start, u64 length, kfree(nespbl); return ERR_PTR(-ENOMEM); } + context = rdma_get_ucontext(udata); + if (IS_ERR(context)) + return ERR_CAST(context); nesmr->region = region; - nes_ucontext = to_nesucontext(pd->uobject->context); + nes_ucontext = to_nesucontext(context); pbl_depth = region->length >> 12; pbl_depth += (region->length & (4096-1)) ? 1 : 0; nespbl->pbl_size = pbl_depth*sizeof(u64); diff --git a/drivers/infiniband/hw/qedr/verbs.c b/drivers/infiniband/hw/qedr/verbs.c index b27ff9408507..e676a36b6b5d 100644 --- a/drivers/infiniband/hw/qedr/verbs.c +++ b/drivers/infiniband/hw/qedr/verbs.c @@ -1433,7 +1433,6 @@ struct ib_srq *qedr_create_srq(struct ib_pd *ibpd, struct qedr_pd *pd = get_qedr_pd(ibpd); struct qedr_create_srq_ureq ureq = {}; u64 pbl_base_addr, phy_prod_pair_addr; - struct ib_ucontext *ib_ctx = NULL; struct qedr_srq_hwq_info *hw_srq; u32 page_cnt, page_size; struct qedr_srq *srq; @@ -1458,9 +1457,7 @@ struct ib_srq *qedr_create_srq(struct ib_pd *ibpd, hw_srq->max_wr = init_attr->attr.max_wr; hw_srq->max_sges = init_attr->attr.max_sge; - if (udata && ibpd->uobject && ibpd->uobject->context) { - ib_ctx = ibpd->uobject->context; - + if (udata && !IS_ERR(rdma_get_ucontext(udata))) { if (ib_copy_from_udata(&ureq, udata, sizeof(ureq))) { DP_ERR(dev, "create srq: problem copying data from user space\n"); @@ -1698,13 +1695,10 @@ static int qedr_create_user_qp(struct qedr_dev *dev, struct qed_rdma_create_qp_in_params in_params; struct qed_rdma_create_qp_out_params out_params; struct qedr_pd *pd = get_qedr_pd(ibpd); - struct ib_ucontext *ib_ctx = NULL; struct qedr_create_qp_ureq ureq; int alloc_and_init = rdma_protocol_roce(&dev->ibdev, 1); int rc = -EINVAL; - ib_ctx = ibpd->uobject->context; - memset(&ureq, 0, sizeof(ureq)); rc = ib_copy_from_udata(&ureq, udata, sizeof(ureq)); if (rc) { diff --git a/drivers/infiniband/hw/usnic/usnic_ib_verbs.c b/drivers/infiniband/hw/usnic/usnic_ib_verbs.c index 432e6f6599fa..f471e7f270c0 100644 --- a/drivers/infiniband/hw/usnic/usnic_ib_verbs.c +++ b/drivers/infiniband/hw/usnic/usnic_ib_verbs.c @@ -501,10 +501,15 @@ struct ib_qp *usnic_ib_create_qp(struct ib_pd *pd, struct usnic_vnic_res_spec res_spec; struct usnic_ib_create_qp_cmd cmd; struct usnic_transport_spec trans_spec; + struct ib_ucontext *context; usnic_dbg("\n"); - ucontext = to_uucontext(pd->uobject->context); + context = rdma_get_ucontext(udata); + if (IS_ERR(context)) + return ERR_CAST(context); + + ucontext = to_uucontext(context); us_ibdev = to_usdev(pd->device); if (init_attr->create_flags) diff --git a/drivers/infiniband/sw/rdmavt/qp.c b/drivers/infiniband/sw/rdmavt/qp.c index a1bd8cfc2c25..ace8e640aee1 100644 --- a/drivers/infiniband/sw/rdmavt/qp.c +++ b/drivers/infiniband/sw/rdmavt/qp.c @@ -948,6 +948,7 @@ struct ib_qp *rvt_create_qp(struct ib_pd *ibpd, struct ib_qp_init_attr *init_attr, struct ib_udata *udata) { + struct ib_ucontext *context; struct rvt_qp *qp; int err; struct rvt_swqe *swq = NULL; @@ -1127,8 +1128,13 @@ struct ib_qp *rvt_create_qp(struct ib_pd *ibpd, } else { u32 s = sizeof(struct rvt_rwq) + qp->r_rq.size * sz; - qp->ip = rvt_create_mmap_info(rdi, s, - ibpd->uobject->context, + context = rdma_get_ucontext(udata); + if (IS_ERR(context)) { + ret = PTR_ERR(context); + goto bail_qpn; + } + + qp->ip = rvt_create_mmap_info(rdi, s, context, qp->r_rq.wq); if (!qp->ip) { ret = ERR_PTR(-ENOMEM); diff --git a/drivers/infiniband/sw/rdmavt/srq.c b/drivers/infiniband/sw/rdmavt/srq.c index 78e06fc456c5..fa4c6bc32ec6 100644 --- a/drivers/infiniband/sw/rdmavt/srq.c +++ b/drivers/infiniband/sw/rdmavt/srq.c @@ -77,6 +77,7 @@ struct ib_srq *rvt_create_srq(struct ib_pd *ibpd, struct ib_udata *udata) { struct rvt_dev_info *dev = ib_to_rvt(ibpd->device); + struct ib_ucontext *context; struct rvt_srq *srq; u32 sz; struct ib_srq *ret; @@ -118,9 +119,14 @@ struct ib_srq *rvt_create_srq(struct ib_pd *ibpd, int err; u32 s = sizeof(struct rvt_rwq) + srq->rq.size * sz; + context = rdma_get_ucontext(udata); + if (IS_ERR(context)) { + ret = ERR_CAST(context); + goto bail_wq; + } + srq->ip = - rvt_create_mmap_info(dev, s, ibpd->uobject->context, - srq->rq.wq); + rvt_create_mmap_info(dev, s, context, srq->rq.wq); if (!srq->ip) { ret = ERR_PTR(-ENOMEM); goto bail_wq; diff --git a/drivers/infiniband/sw/rxe/rxe_qp.c b/drivers/infiniband/sw/rxe/rxe_qp.c index fd86fd2fbb26..a6a5f223ffb3 100644 --- a/drivers/infiniband/sw/rxe/rxe_qp.c +++ b/drivers/infiniband/sw/rxe/rxe_qp.c @@ -343,7 +343,10 @@ int rxe_qp_from_init(struct rxe_dev *rxe, struct rxe_qp *qp, struct rxe_pd *pd, struct rxe_cq *rcq = to_rcq(init->recv_cq); struct rxe_cq *scq = to_rcq(init->send_cq); struct rxe_srq *srq = init->srq ? to_rsrq(init->srq) : NULL; - struct ib_ucontext *context = udata ? ibpd->uobject->context : NULL; + struct ib_ucontext *context = rdma_get_ucontext(udata); + + if (IS_ERR(context)) + context = NULL; rxe_add_ref(pd); rxe_add_ref(rcq); diff --git a/drivers/infiniband/sw/rxe/rxe_verbs.c b/drivers/infiniband/sw/rxe/rxe_verbs.c index 3d01247a28db..0d6e5af21797 100644 --- a/drivers/infiniband/sw/rxe/rxe_verbs.c +++ b/drivers/infiniband/sw/rxe/rxe_verbs.c @@ -331,9 +331,12 @@ static struct ib_srq *rxe_create_srq(struct ib_pd *ibpd, struct rxe_dev *rxe = to_rdev(ibpd->device); struct rxe_pd *pd = to_rpd(ibpd); struct rxe_srq *srq; - struct ib_ucontext *context = udata ? ibpd->uobject->context : NULL; + struct ib_ucontext *context = rdma_get_ucontext(udata); struct rxe_create_srq_resp __user *uresp = NULL; + if (IS_ERR(context)) + context = NULL; + if (udata) { if (udata->outlen < sizeof(*uresp)) return ERR_PTR(-EINVAL); -- 2.17.2