Propagate the change of adding the owner parameter to several internal core functions, as well as the ib_umem_odp_get() kernel interface function. The mm of the address space that owns the memory region is saved in the per_mm struct, which is then used by ib_umem_odp_map_dma_pages() when resolving a page fault from ODP. Signed-off-by: Joel Nider <joeln@xxxxxxxxxx> --- drivers/infiniband/core/umem.c | 4 +-- drivers/infiniband/core/umem_odp.c | 50 ++++++++++++++++++-------------------- drivers/infiniband/hw/mlx5/odp.c | 6 ++++- include/rdma/ib_umem_odp.h | 6 +++-- 4 files changed, 35 insertions(+), 31 deletions(-) diff --git a/drivers/infiniband/core/umem.c b/drivers/infiniband/core/umem.c index 9646cee..77874e5 100644 --- a/drivers/infiniband/core/umem.c +++ b/drivers/infiniband/core/umem.c @@ -142,7 +142,7 @@ struct ib_umem *ib_umem_get(struct ib_ucontext *context, unsigned long addr, mmgrab(mm); if (access & IB_ACCESS_ON_DEMAND) { - ret = ib_umem_odp_get(to_ib_umem_odp(umem), access); + ret = ib_umem_odp_get(to_ib_umem_odp(umem), access, owner); if (ret) goto umem_kfree; return umem; @@ -200,7 +200,7 @@ struct ib_umem *ib_umem_get(struct ib_ucontext *context, unsigned long addr, mm, cur_base, min_t(unsigned long, npages, PAGE_SIZE / sizeof(struct page *)), - gup_flags, page_list, vma_list, NULL); + gup_flags, page_list, vma_list); if (ret < 0) { up_read(&mm->mmap_sem); goto umem_release; diff --git a/drivers/infiniband/core/umem_odp.c b/drivers/infiniband/core/umem_odp.c index a4ec430..49826070 100644 --- a/drivers/infiniband/core/umem_odp.c +++ b/drivers/infiniband/core/umem_odp.c @@ -227,7 +227,8 @@ static void remove_umem_from_per_mm(struct ib_umem_odp *umem_odp) } static struct ib_ucontext_per_mm *alloc_per_mm(struct ib_ucontext *ctx, - struct mm_struct *mm) + struct mm_struct *mm, + struct pid *owner) { struct ib_ucontext_per_mm *per_mm; int ret; @@ -241,12 +242,8 @@ static struct ib_ucontext_per_mm *alloc_per_mm(struct ib_ucontext *ctx, per_mm->umem_tree = RB_ROOT_CACHED; init_rwsem(&per_mm->umem_rwsem); per_mm->active = ctx->invalidate_range; - - rcu_read_lock(); - per_mm->tgid = get_task_pid(current->group_leader, PIDTYPE_PID); - rcu_read_unlock(); - - WARN_ON(mm != current->mm); + per_mm->tgid = owner; + mmgrab(per_mm->mm); per_mm->mn.ops = &ib_umem_notifiers; ret = mmu_notifier_register(&per_mm->mn, per_mm->mm); @@ -265,7 +262,7 @@ static struct ib_ucontext_per_mm *alloc_per_mm(struct ib_ucontext *ctx, return ERR_PTR(ret); } -static int get_per_mm(struct ib_umem_odp *umem_odp) +static int get_per_mm(struct ib_umem_odp *umem_odp, struct pid *owner) { struct ib_ucontext *ctx = umem_odp->umem.context; struct ib_ucontext_per_mm *per_mm; @@ -280,7 +277,7 @@ static int get_per_mm(struct ib_umem_odp *umem_odp) goto found; } - per_mm = alloc_per_mm(ctx, umem_odp->umem.owning_mm); + per_mm = alloc_per_mm(ctx, umem_odp->umem.owning_mm, owner); if (IS_ERR(per_mm)) { mutex_unlock(&ctx->per_mm_list_lock); return PTR_ERR(per_mm); @@ -333,7 +330,8 @@ void put_per_mm(struct ib_umem_odp *umem_odp) } struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext_per_mm *per_mm, - unsigned long addr, size_t size) + unsigned long addr, size_t size, + struct mm_struct *owner_mm) { struct ib_ucontext *ctx = per_mm->context; struct ib_umem_odp *odp_data; @@ -345,12 +343,14 @@ struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext_per_mm *per_mm, if (!odp_data) return ERR_PTR(-ENOMEM); umem = &odp_data->umem; + umem->context = ctx; umem->length = size; umem->address = addr; umem->page_shift = PAGE_SHIFT; umem->writable = 1; umem->is_odp = 1; + umem->owning_mm = owner_mm; odp_data->per_mm = per_mm; mutex_init(&odp_data->umem_mutex); @@ -389,13 +389,9 @@ struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext_per_mm *per_mm, } EXPORT_SYMBOL(ib_alloc_odp_umem); -int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access) +int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access, struct pid *owner) { struct ib_umem *umem = &umem_odp->umem; - /* - * NOTE: This must called in a process context where umem->owning_mm - * == current->mm - */ struct mm_struct *mm = umem->owning_mm; int ret_val; @@ -437,7 +433,7 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access) } } - ret_val = get_per_mm(umem_odp); + ret_val = get_per_mm(umem_odp, owner); if (ret_val) goto out_dma_list; add_umem_to_per_mm(umem_odp); @@ -574,8 +570,8 @@ static int ib_umem_odp_map_dma_single_page( * the return value. * @access_mask: bit mask of the requested access permissions for the given * range. - * @current_seq: the MMU notifiers sequance value for synchronization with - * invalidations. the sequance number is read from + * @current_seq: the MMU notifiers sequence value for synchronization with + * invalidations. the sequence number is read from * umem_odp->notifiers_seq before calling this function */ int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt, @@ -584,7 +580,7 @@ int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt, { struct ib_umem *umem = &umem_odp->umem; struct task_struct *owning_process = NULL; - struct mm_struct *owning_mm = umem_odp->umem.owning_mm; + struct mm_struct *owning_mm; struct page **local_page_list = NULL; u64 page_mask, off; int j, k, ret = 0, start_idx, npages = 0, page_shift; @@ -609,12 +605,13 @@ int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt, bcnt += off; /* Charge for the first page offset as well. */ /* - * owning_process is allowed to be NULL, this means somehow the mm is - * existing beyond the lifetime of the originating process.. Presumably + * owning_process may be NULL, because the mm can + * exist independently of the originating process. * mmget_not_zero will fail in this case. */ owning_process = get_pid_task(umem_odp->per_mm->tgid, PIDTYPE_PID); - if (WARN_ON(!mmget_not_zero(umem_odp->umem.owning_mm))) { + owning_mm = umem_odp->per_mm->mm; + if (WARN_ON(!mmget_not_zero(owning_mm))) { ret = -EINVAL; goto out_put_task; } @@ -632,15 +629,16 @@ int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt, down_read(&owning_mm->mmap_sem); /* - * Note: this might result in redundent page getting. We can + * Note: this might result in redundant page getting. We can * avoid this by checking dma_list to be 0 before calling - * get_user_pages. However, this make the code much more + * get_user_pages. However, this makes the code much more * complex (and doesn't gain us much performance in most use * cases). */ - npages = get_user_pages_remote(owning_process, owning_mm, + npages = get_user_pages_remote_longterm(owning_process, + owning_mm, user_virt, gup_num_pages, - flags, local_page_list, NULL, NULL); + flags, local_page_list, NULL); up_read(&owning_mm->mmap_sem); if (npages < 0) { diff --git a/drivers/infiniband/hw/mlx5/odp.c b/drivers/infiniband/hw/mlx5/odp.c index c317e18..1abc917 100644 --- a/drivers/infiniband/hw/mlx5/odp.c +++ b/drivers/infiniband/hw/mlx5/odp.c @@ -439,8 +439,12 @@ static struct ib_umem_odp *implicit_mr_get_data(struct mlx5_ib_mr *mr, if (nentries) nentries++; } else { + struct mm_struct *owner_mm = current->mm; + + if (mr->umem->owning_mm) + owner_mm = mr->umem->owning_mm; odp = ib_alloc_odp_umem(odp_mr->per_mm, addr, - MLX5_IMR_MTT_SIZE); + MLX5_IMR_MTT_SIZE, owner_mm); if (IS_ERR(odp)) { mutex_unlock(&odp_mr->umem_mutex); return ERR_CAST(odp); diff --git a/include/rdma/ib_umem_odp.h b/include/rdma/ib_umem_odp.h index 0b1446f..28099e6 100644 --- a/include/rdma/ib_umem_odp.h +++ b/include/rdma/ib_umem_odp.h @@ -102,9 +102,11 @@ struct ib_ucontext_per_mm { struct rcu_head rcu; }; -int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access); +int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access, + struct pid *owner); struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext_per_mm *per_mm, - unsigned long addr, size_t size); + unsigned long addr, size_t size, + struct mm_struct *owner_mm); void ib_umem_odp_release(struct ib_umem_odp *umem_odp); /* -- 2.7.4