3 files changed, 16 insertions(+), 14 deletions(-)
diff --git a/drivers/infiniband/hw/mlx4/cq.c b/drivers/infiniband/hw/mlx4/cq.c
index 4cd738aae53c..b12713fdde99 100644
--- a/drivers/infiniband/hw/mlx4/cq.c
+++ b/drivers/infiniband/hw/mlx4/cq.c
@@ -180,7 +180,8 @@ int mlx4_ib_create_cq(struct ib_cq *ibcq, const struct ib_cq_init_attr *attr,
struct mlx4_ib_dev *dev = to_mdev(ibdev);
struct mlx4_ib_cq *cq = to_mcq(ibcq);
struct mlx4_uar *uar;
- void *buf_addr;
+ void __user *ubuf_addr;
+ void *kbuf_addr;
int err;
struct mlx4_ib_ucontext *context = rdma_udata_to_drv_context(
udata, struct mlx4_ib_ucontext, ibucontext);
@@ -209,7 +210,8 @@ int mlx4_ib_create_cq(struct ib_cq *ibcq, const struct ib_cq_init_attr *attr,
goto err_cq;
}
- buf_addr = (void *)(unsigned long)ucmd.buf_addr;
+ ubuf_addr = u64_to_user_ptr(ucmd.buf_addr);
+ kbuf_addr = NULL;
err = mlx4_ib_get_cq_umem(dev, &cq->buf, &cq->umem,
ucmd.buf_addr, entries);
if (err)
@@ -235,7 +237,8 @@ int mlx4_ib_create_cq(struct ib_cq *ibcq, const struct ib_cq_init_attr *attr,
if (err)
goto err_db;
- buf_addr = &cq->buf.buf;
+ ubuf_addr = NULL;
+ kbuf_addr = &cq->buf.buf;
uar = &dev->priv_uar;
cq->mcq.usage = MLX4_RES_USAGE_DRIVER;
@@ -248,7 +251,7 @@ int mlx4_ib_create_cq(struct ib_cq *ibcq, const struct ib_cq_init_attr *attr,
&cq->mcq, vector, 0,
!!(cq->create_flags &
IB_UVERBS_CQ_FLAGS_TIMESTAMP_COMPLETION),
- buf_addr, !!udata);
+ ubuf_addr, kbuf_addr);
if (err)
goto err_dbmap;
diff --git a/drivers/net/ethernet/mellanox/mlx4/cq.c b/drivers/net/ethernet/mellanox/mlx4/cq.c
index 020cb8e2883f..22216f4e409b 100644
--- a/drivers/net/ethernet/mellanox/mlx4/cq.c
+++ b/drivers/net/ethernet/mellanox/mlx4/cq.c
@@ -287,7 +287,7 @@ static void mlx4_cq_free_icm(struct mlx4_dev *dev, int cqn)
__mlx4_cq_free_icm(dev, cqn);
}
-static int mlx4_init_user_cqes(void *buf, int entries, int cqe_size)
+static int mlx4_init_user_cqes(void __user *buf, int entries, int cqe_size)
{
int entries_per_copy = PAGE_SIZE / cqe_size;
size_t copy_size = array_size(entries, cqe_size);
@@ -307,7 +307,7 @@ static int mlx4_init_user_cqes(void *buf, int entries, int cqe_size)
if (copy_size > PAGE_SIZE) {
for (i = 0; i < entries / entries_per_copy; i++) {
- err = copy_to_user((void __user *)buf, init_ents, PAGE_SIZE) ?
+ err = copy_to_user(buf, init_ents, PAGE_SIZE) ?
-EFAULT : 0;
if (err)
goto out;
@@ -315,8 +315,7 @@ static int mlx4_init_user_cqes(void *buf, int entries, int cqe_size)
buf += PAGE_SIZE;
}
} else {
- err = copy_to_user((void __user *)buf, init_ents,
- copy_size) ?
+ err = copy_to_user(buf, init_ents, copy_size) ?
-EFAULT : 0;
}
@@ -343,7 +342,7 @@ static void mlx4_init_kernel_cqes(struct mlx4_buf *buf,
int mlx4_cq_alloc(struct mlx4_dev *dev, int nent,
struct mlx4_mtt *mtt, struct mlx4_uar *uar, u64 db_rec,
struct mlx4_cq *cq, unsigned vector, int collapsed,
- int timestamp_en, void *buf_addr, bool user_cq)
+ int timestamp_en, void __user *ubuf_addr, void *kbuf_addr)
{
bool sw_cq_init = dev->caps.flags2 & MLX4_DEV_CAP_FLAG2_SW_CQ_INIT;
struct mlx4_priv *priv = mlx4_priv(dev);
@@ -391,13 +390,13 @@ int mlx4_cq_alloc(struct mlx4_dev *dev, int nent,
cq_context->db_rec_addr = cpu_to_be64(db_rec);
if (sw_cq_init) {
- if (user_cq) {
- err = mlx4_init_user_cqes(buf_addr, nent,
+ if (ubuf_addr) {
+ err = mlx4_init_user_cqes(ubuf_addr, nent,
dev->caps.cqe_size);
if (err)
sw_cq_init = false;
- } else {
- mlx4_init_kernel_cqes(buf_addr, nent,
+ } else if (kbuf_addr) {
+ mlx4_init_kernel_cqes(kbuf_addr, nent,
dev->caps.cqe_size);
}
}
diff --git a/include/linux/mlx4/device.h b/include/linux/mlx4/device.h
index 6646634a0b9d..dd8f3396dcba 100644
--- a/include/linux/mlx4/device.h
+++ b/include/linux/mlx4/device.h
@@ -1126,7 +1126,7 @@ void mlx4_free_hwq_res(struct mlx4_dev *mdev, struct mlx4_hwq_resources *wqres,
int mlx4_cq_alloc(struct mlx4_dev *dev, int nent, struct mlx4_mtt *mtt,
struct mlx4_uar *uar, u64 db_rec, struct mlx4_cq *cq,
unsigned int vector, int collapsed, int timestamp_en,
- void *buf_addr, bool user_cq);
+ void __user *ubuf_addr, void *kbuf_addr);
void mlx4_cq_free(struct mlx4_dev *dev, struct mlx4_cq *cq);
int mlx4_qp_reserve_range(struct mlx4_dev *dev, int cnt, int align,
int *base, u8 flags, u8 usage);