From: Jason Gunthorpe <jgg@xxxxxxxxxxxx> Since ODP had a single struct mmu_notifier located in the ucontext it could only handle a single MM at a time, and this prevented it from using the new owning_mm system. With the prior rework it is now simple to let ODP track multiple MMs per ucontext, finish the job so that the per_mm is allocated on a mm by mm basis, and freed when the last umem is dropped from the ucontext. As a side effect the new saner locking removes the lockdep splat about nesting the umem_rwsem between mmu_notifier_unregister and ib_umem_odp_release. It also makes ODP work with multiple processes, across, fork, etc. Signed-off-by: Jason Gunthorpe <jgg@xxxxxxxxxxxx> Signed-off-by: Leon Romanovsky <leonro@xxxxxxxxxxxx> --- drivers/infiniband/core/umem_odp.c | 301 +++++++++++++++++++---------------- drivers/infiniband/core/uverbs_cmd.c | 8 +- drivers/infiniband/hw/mlx5/main.c | 7 + drivers/infiniband/hw/mlx5/odp.c | 2 +- include/rdma/ib_umem_odp.h | 20 ++- include/rdma/ib_verbs.h | 22 +-- 6 files changed, 191 insertions(+), 169 deletions(-) diff --git a/drivers/infiniband/core/umem_odp.c b/drivers/infiniband/core/umem_odp.c index 6bf3fc0c12a1..0577f9ff600f 100644 --- a/drivers/infiniband/core/umem_odp.c +++ b/drivers/infiniband/core/umem_odp.c @@ -278,10 +278,135 @@ static const struct mmu_notifier_ops ib_umem_notifiers = { .invalidate_range_end = ib_umem_notifier_invalidate_range_end, }; -struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context, - unsigned long addr, size_t size) +static void add_umem_to_per_mm(struct ib_umem_odp *umem_odp) +{ + struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm; + struct ib_umem *umem = &umem_odp->umem; + + down_write(&per_mm->umem_rwsem); + if (likely(ib_umem_start(umem) != ib_umem_end(umem))) + rbt_ib_umem_insert(&umem_odp->interval_tree, + &per_mm->umem_tree); + + if (likely(!atomic_read(&per_mm->notifier_count))) + umem_odp->mn_counters_active = true; + else + list_add(&umem_odp->no_private_counters, + &per_mm->no_private_counters); + up_write(&per_mm->umem_rwsem); +} + +static void remove_umem_from_per_mm(struct ib_umem_odp *umem_odp) +{ + struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm; + struct ib_umem *umem = &umem_odp->umem; + + down_write(&per_mm->umem_rwsem); + if (likely(ib_umem_start(umem) != ib_umem_end(umem))) + rbt_ib_umem_remove(&umem_odp->interval_tree, + &per_mm->umem_tree); + if (!umem_odp->mn_counters_active) { + list_del(&umem_odp->no_private_counters); + complete_all(&umem_odp->notifier_completion); + } + + up_write(&per_mm->umem_rwsem); +} + +static struct ib_ucontext_per_mm *alloc_per_mm(struct ib_ucontext *ctx, + struct mm_struct *mm) { struct ib_ucontext_per_mm *per_mm; + int ret; + + per_mm = kzalloc(sizeof(*per_mm), GFP_KERNEL); + if (!per_mm) + return ERR_PTR(-ENOMEM); + + per_mm->context = ctx; + per_mm->mm = mm; + per_mm->umem_tree = RB_ROOT_CACHED; + init_rwsem(&per_mm->umem_rwsem); + INIT_LIST_HEAD(&per_mm->no_private_counters); + + rcu_read_lock(); + per_mm->tgid = get_task_pid(current->group_leader, PIDTYPE_PID); + rcu_read_unlock(); + + WARN_ON(mm != current->mm); + + per_mm->mn.ops = &ib_umem_notifiers; + ret = mmu_notifier_register(&per_mm->mn, per_mm->mm); + if (ret) { + dev_err(&ctx->device->dev, + "Failed to register mmu_notifier %d\n", ret); + goto out_pid; + } + + list_add(&per_mm->ucontext_list, &ctx->per_mm_list); + return per_mm; + +out_pid: + put_pid(per_mm->tgid); + kfree(per_mm); + return ERR_PTR(ret); +} + +static int get_per_mm(struct ib_umem_odp *umem_odp) +{ + struct ib_ucontext *ctx = umem_odp->umem.context; + struct ib_ucontext_per_mm *per_mm; + + /* + * Generally speaking we expect only one or two per_mm in this list, + * so no reason to optimize this search today. + */ + mutex_lock(&ctx->per_mm_list_lock); + list_for_each_entry(per_mm, &ctx->per_mm_list, ucontext_list) { + if (per_mm->mm == umem_odp->umem.owning_mm) + goto found; + } + + per_mm = alloc_per_mm(ctx, umem_odp->umem.owning_mm); + if (IS_ERR(per_mm)) { + mutex_unlock(&ctx->per_mm_list_lock); + return PTR_ERR(per_mm); + } + +found: + umem_odp->per_mm = per_mm; + per_mm->odp_mrs_count++; + mutex_unlock(&ctx->per_mm_list_lock); + + return 0; +} + +void put_per_mm(struct ib_umem_odp *umem_odp) +{ + struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm; + struct ib_ucontext *ctx = umem_odp->umem.context; + bool need_free; + + mutex_lock(&ctx->per_mm_list_lock); + umem_odp->per_mm = NULL; + per_mm->odp_mrs_count--; + need_free = per_mm->odp_mrs_count == 0; + if (need_free) + list_del(&per_mm->ucontext_list); + mutex_unlock(&ctx->per_mm_list_lock); + + if (!need_free) + return; + + mmu_notifier_unregister(&per_mm->mn, per_mm->mm); + put_pid(per_mm->tgid); + kfree(per_mm); +} + +struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext_per_mm *per_mm, + unsigned long addr, size_t size) +{ + struct ib_ucontext *ctx = per_mm->context; struct ib_umem_odp *odp_data; struct ib_umem *umem; int pages = size >> PAGE_SHIFT; @@ -291,13 +416,13 @@ struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context, if (!odp_data) return ERR_PTR(-ENOMEM); umem = &odp_data->umem; - umem->context = context; + umem->context = ctx; umem->length = size; umem->address = addr; umem->page_shift = PAGE_SHIFT; umem->writable = 1; umem->is_odp = 1; - odp_data->per_mm = per_mm = &context->per_mm; + odp_data->per_mm = per_mm; mutex_init(&odp_data->umem_mutex); init_completion(&odp_data->notifier_completion); @@ -316,15 +441,14 @@ struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context, goto out_page_list; } - down_write(&per_mm->umem_rwsem); + /* + * Caller must ensure that the umem_odp that the per_mm came from + * cannot be freed during the call to ib_alloc_odp_umem. + */ + mutex_lock(&ctx->per_mm_list_lock); per_mm->odp_mrs_count++; - rbt_ib_umem_insert(&odp_data->interval_tree, &per_mm->umem_tree); - if (likely(!atomic_read(&per_mm->notifier_count))) - odp_data->mn_counters_active = true; - else - list_add(&odp_data->no_private_counters, - &per_mm->no_private_counters); - up_write(&per_mm->umem_rwsem); + mutex_unlock(&ctx->per_mm_list_lock); + add_umem_to_per_mm(odp_data); return odp_data; @@ -338,15 +462,13 @@ EXPORT_SYMBOL(ib_alloc_odp_umem); int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access) { - struct ib_ucontext *context = umem_odp->umem.context; struct ib_umem *umem = &umem_odp->umem; - struct ib_ucontext_per_mm *per_mm; + /* + * NOTE: This must called in a process context where umem->owning_mm + * == current->mm + */ + struct mm_struct *mm = umem->owning_mm; int ret_val; - struct pid *our_pid; - struct mm_struct *mm = get_task_mm(current); - - if (!mm) - return -EINVAL; if (access & IB_ACCESS_HUGETLB) { struct vm_area_struct *vma; @@ -366,16 +488,6 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access) umem->hugetlb = 0; } - /* Prevent creating ODP MRs in child processes */ - rcu_read_lock(); - our_pid = get_task_pid(current->group_leader, PIDTYPE_PID); - rcu_read_unlock(); - put_pid(our_pid); - if (context->tgid != our_pid) { - ret_val = -EINVAL; - goto out_mm; - } - mutex_init(&umem_odp->umem_mutex); init_completion(&umem_odp->notifier_completion); @@ -384,10 +496,8 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access) umem_odp->page_list = vzalloc(array_size(sizeof(*umem_odp->page_list), ib_umem_num_pages(umem))); - if (!umem_odp->page_list) { - ret_val = -ENOMEM; - goto out_mm; - } + if (!umem_odp->page_list) + return -ENOMEM; umem_odp->dma_list = vzalloc(array_size(sizeof(*umem_odp->dma_list), @@ -398,67 +508,23 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access) } } - /* - * When using MMU notifiers, we will get a - * notification before the "current" task (and MM) is - * destroyed. We use the umem_rwsem semaphore to synchronize. - */ - umem_odp->per_mm = per_mm = &context->per_mm; - - down_write(&per_mm->umem_rwsem); - per_mm->odp_mrs_count++; - if (likely(ib_umem_start(umem) != ib_umem_end(umem))) - rbt_ib_umem_insert(&umem_odp->interval_tree, - &per_mm->umem_tree); - if (likely(!atomic_read(&per_mm->notifier_count)) || - per_mm->odp_mrs_count == 1) - umem_odp->mn_counters_active = true; - else - list_add(&umem_odp->no_private_counters, - &per_mm->no_private_counters); - downgrade_write(&per_mm->umem_rwsem); + ret_val = get_per_mm(umem_odp); + if (ret_val) + goto out_dma_list; + add_umem_to_per_mm(umem_odp); - if (per_mm->odp_mrs_count == 1) { - /* - * Note that at this point, no MMU notifier is running - * for this per_mm! - */ - atomic_set(&per_mm->notifier_count, 0); - INIT_HLIST_NODE(&per_mm->mn.hlist); - per_mm->mn.ops = &ib_umem_notifiers; - ret_val = mmu_notifier_register(&per_mm->mn, mm); - if (ret_val) { - pr_err("Failed to register mmu_notifier %d\n", ret_val); - ret_val = -EBUSY; - goto out_mutex; - } - } - - up_read(&per_mm->umem_rwsem); - - /* - * Note that doing an mmput can cause a notifier for the relevant mm. - * If the notifier is called while we hold the umem_rwsem, this will - * cause a deadlock. Therefore, we release the reference only after we - * released the semaphore. - */ - mmput(mm); return 0; -out_mutex: - up_read(&per_mm->umem_rwsem); +out_dma_list: vfree(umem_odp->dma_list); out_page_list: vfree(umem_odp->page_list); -out_mm: - mmput(mm); return ret_val; } void ib_umem_odp_release(struct ib_umem_odp *umem_odp) { struct ib_umem *umem = &umem_odp->umem; - struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm; /* * Ensure that no more pages are mapped in the umem. @@ -469,54 +535,8 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp) ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem), ib_umem_end(umem)); - down_write(&per_mm->umem_rwsem); - if (likely(ib_umem_start(umem) != ib_umem_end(umem))) - rbt_ib_umem_remove(&umem_odp->interval_tree, - &per_mm->umem_tree); - per_mm->odp_mrs_count--; - if (!umem_odp->mn_counters_active) { - list_del(&umem_odp->no_private_counters); - complete_all(&umem_odp->notifier_completion); - } - - /* - * Downgrade the lock to a read lock. This ensures that the notifiers - * (who lock the mutex for reading) will be able to finish, and we - * will be able to enventually obtain the mmu notifiers SRCU. Note - * that since we are doing it atomically, no other user could register - * and unregister while we do the check. - */ - downgrade_write(&per_mm->umem_rwsem); - if (!per_mm->odp_mrs_count) { - struct task_struct *owning_process = NULL; - struct mm_struct *owning_mm = NULL; - - owning_process = - get_pid_task(umem_odp->umem.context->tgid, PIDTYPE_PID); - if (owning_process == NULL) - /* - * The process is already dead, notifier were removed - * already. - */ - goto out; - - owning_mm = get_task_mm(owning_process); - if (owning_mm == NULL) - /* - * The process' mm is already dead, notifier were - * removed already. - */ - goto out_put_task; - mmu_notifier_unregister(&per_mm->mn, owning_mm); - - mmput(owning_mm); - -out_put_task: - put_task_struct(owning_process); - } -out: - up_read(&per_mm->umem_rwsem); - + remove_umem_from_per_mm(umem_odp); + put_per_mm(umem_odp); vfree(umem_odp->dma_list); vfree(umem_odp->page_list); } @@ -634,7 +654,7 @@ int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt, { struct ib_umem *umem = &umem_odp->umem; struct task_struct *owning_process = NULL; - struct mm_struct *owning_mm = NULL; + struct mm_struct *owning_mm = umem_odp->umem.owning_mm; struct page **local_page_list = NULL; u64 page_mask, off; int j, k, ret = 0, start_idx, npages = 0, page_shift; @@ -658,15 +678,14 @@ int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt, user_virt = user_virt & page_mask; bcnt += off; /* Charge for the first page offset as well. */ - owning_process = get_pid_task(umem->context->tgid, PIDTYPE_PID); - if (owning_process == NULL) { + /* + * owning_process is allowed to be NULL, this means somehow the mm is + * existing beyond the lifetime of the originating process.. Presumably + * mmget_not_zero will fail in this case. + */ + owning_process = get_pid_task(umem_odp->per_mm->tgid, PIDTYPE_PID); + if (WARN_ON(!mmget_not_zero(umem_odp->umem.owning_mm))) { ret = -EINVAL; - goto out_no_task; - } - - owning_mm = get_task_mm(owning_process); - if (owning_mm == NULL) { - ret = -ENOENT; goto out_put_task; } @@ -738,8 +757,8 @@ int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt, mmput(owning_mm); out_put_task: - put_task_struct(owning_process); -out_no_task: + if (owning_process) + put_task_struct(owning_process); free_page((unsigned long)local_page_list); return ret; } diff --git a/drivers/infiniband/core/uverbs_cmd.c b/drivers/infiniband/core/uverbs_cmd.c index ce678e1008a4..d77b0b9793c7 100644 --- a/drivers/infiniband/core/uverbs_cmd.c +++ b/drivers/infiniband/core/uverbs_cmd.c @@ -124,12 +124,8 @@ ssize_t ib_uverbs_get_context(struct ib_uverbs_file *file, ucontext->cleanup_retryable = false; #ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING - ucontext->per_mm.umem_tree = RB_ROOT_CACHED; - init_rwsem(&ucontext->per_mm.umem_rwsem); - ucontext->per_mm.odp_mrs_count = 0; - INIT_LIST_HEAD(&ucontext->per_mm.no_private_counters); - ucontext->per_mm.context = ucontext; - + mutex_init(&ucontext->per_mm_list_lock); + INIT_LIST_HEAD(&ucontext->per_mm_list); if (!(ib_dev->attrs.device_cap_flags & IB_DEVICE_ON_DEMAND_PAGING)) ucontext->invalidate_range = NULL; diff --git a/drivers/infiniband/hw/mlx5/main.c b/drivers/infiniband/hw/mlx5/main.c index 82fe859f485c..73782100750d 100644 --- a/drivers/infiniband/hw/mlx5/main.c +++ b/drivers/infiniband/hw/mlx5/main.c @@ -1888,6 +1888,13 @@ static int mlx5_ib_dealloc_ucontext(struct ib_ucontext *ibcontext) struct mlx5_ib_dev *dev = to_mdev(ibcontext->device); struct mlx5_bfreg_info *bfregi; +#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING + /* All umem's must be destroyed before destroying the ucontext. */ + mutex_lock(&ibcontext->per_mm_list_lock); + WARN_ON(!list_empty(&ibcontext->per_mm_list)); + mutex_unlock(&ibcontext->per_mm_list_lock); +#endif + bfregi = &context->bfregi; mlx5_ib_dealloc_transport_domain(dev, context->tdn, context->devx_uid); diff --git a/drivers/infiniband/hw/mlx5/odp.c b/drivers/infiniband/hw/mlx5/odp.c index 9982b5f4e598..b04eb6775326 100644 --- a/drivers/infiniband/hw/mlx5/odp.c +++ b/drivers/infiniband/hw/mlx5/odp.c @@ -393,7 +393,7 @@ static struct ib_umem_odp *implicit_mr_get_data(struct mlx5_ib_mr *mr, if (nentries) nentries++; } else { - odp = ib_alloc_odp_umem(odp_mr->umem.context, addr, + odp = ib_alloc_odp_umem(odp_mr->per_mm, addr, MLX5_IMR_MTT_SIZE); if (IS_ERR(odp)) { mutex_unlock(&odp_mr->umem_mutex); diff --git a/include/rdma/ib_umem_odp.h b/include/rdma/ib_umem_odp.h index 394ea6b68db7..259eb08dfc9e 100644 --- a/include/rdma/ib_umem_odp.h +++ b/include/rdma/ib_umem_odp.h @@ -91,8 +91,26 @@ static inline struct ib_umem_odp *to_ib_umem_odp(struct ib_umem *umem) #ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING +struct ib_ucontext_per_mm { + struct ib_ucontext *context; + struct mm_struct *mm; + struct pid *tgid; + + struct rb_root_cached umem_tree; + /* Protects umem_tree */ + struct rw_semaphore umem_rwsem; + atomic_t notifier_count; + + struct mmu_notifier mn; + /* A list of umems that don't have private mmu notifier counters yet. */ + struct list_head no_private_counters; + unsigned int odp_mrs_count; + + struct list_head ucontext_list; +}; + int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access); -struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context, +struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext_per_mm *per_mm, unsigned long addr, size_t size); void ib_umem_odp_release(struct ib_umem_odp *umem_odp); diff --git a/include/rdma/ib_verbs.h b/include/rdma/ib_verbs.h index 87296695dfee..1b2338b86d1b 100644 --- a/include/rdma/ib_verbs.h +++ b/include/rdma/ib_verbs.h @@ -1482,25 +1482,6 @@ struct ib_rdmacg_object { #endif }; -#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING -struct ib_ucontext_per_mm { - struct ib_ucontext *context; - - struct rb_root_cached umem_tree; - /* - * Protects .umem_rbroot and tree, as well as odp_mrs_count and - * mmu notifiers registration. - */ - struct rw_semaphore umem_rwsem; - - struct mmu_notifier mn; - atomic_t notifier_count; - /* A list of umems that don't have private mmu notifier counters yet. */ - struct list_head no_private_counters; - unsigned int odp_mrs_count; -}; -#endif - struct ib_ucontext { struct ib_device *device; struct ib_uverbs_file *ufile; @@ -1517,7 +1498,8 @@ struct ib_ucontext { #ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING void (*invalidate_range)(struct ib_umem_odp *umem_odp, unsigned long start, unsigned long end); - struct ib_ucontext_per_mm per_mm; + struct mutex per_mm_list_lock; + struct list_head per_mm_list; #endif struct ib_rdmacg_object cg_obj; -- 2.14.4