From: Jason Gunthorpe <jgg@xxxxxxxxxxxx> This is the first step to make ODP use the owning_mm that is now part of struct ib_umem. Each ODP umem is linked to a single per_mm structure, which in turn, is linked to a single mm, via the embedded mmu_notifier. This first patch introduces the structure and reworks eveything to use it. This also needs to introduce tgid into the ib_ucontext_per_mm, as get_user_pages_remote() requires the originating task for statistics tracking. Signed-off-by: Jason Gunthorpe <jgg@xxxxxxxxxxxx> Signed-off-by: Leon Romanovsky <leonro@xxxxxxxxxxxx> --- drivers/infiniband/core/umem_odp.c | 127 +++++++++++++++++++---------------- drivers/infiniband/core/uverbs_cmd.c | 9 +-- drivers/infiniband/hw/mlx5/odp.c | 43 +++++++----- include/rdma/ib_umem_odp.h | 2 + include/rdma/ib_verbs.h | 32 +++++---- 5 files changed, 120 insertions(+), 93 deletions(-) diff --git a/drivers/infiniband/core/umem_odp.c b/drivers/infiniband/core/umem_odp.c index 42272b2bf595..6bf3fc0c12a1 100644 --- a/drivers/infiniband/core/umem_odp.c +++ b/drivers/infiniband/core/umem_odp.c @@ -115,34 +115,35 @@ static void ib_umem_notifier_end_account(struct ib_umem_odp *umem_odp) } /* Account for a new mmu notifier in an ib_ucontext. */ -static void ib_ucontext_notifier_start_account(struct ib_ucontext *context) +static void +ib_ucontext_notifier_start_account(struct ib_ucontext_per_mm *per_mm) { - atomic_inc(&context->notifier_count); + atomic_inc(&per_mm->notifier_count); } /* Account for a terminating mmu notifier in an ib_ucontext. * * Must be called with the ib_ucontext->umem_rwsem semaphore unlocked, since * the function takes the semaphore itself. */ -static void ib_ucontext_notifier_end_account(struct ib_ucontext *context) +static void ib_ucontext_notifier_end_account(struct ib_ucontext_per_mm *per_mm) { - int zero_notifiers = atomic_dec_and_test(&context->notifier_count); + int zero_notifiers = atomic_dec_and_test(&per_mm->notifier_count); if (zero_notifiers && - !list_empty(&context->no_private_counters)) { + !list_empty(&per_mm->no_private_counters)) { /* No currently running mmu notifiers. Now is the chance to * add private accounting to all previously added umems. */ struct ib_umem_odp *odp_data, *next; /* Prevent concurrent mmu notifiers from working on the * no_private_counters list. */ - down_write(&context->umem_rwsem); + down_write(&per_mm->umem_rwsem); /* Read the notifier_count again, with the umem_rwsem * semaphore taken for write. */ - if (!atomic_read(&context->notifier_count)) { + if (!atomic_read(&per_mm->notifier_count)) { list_for_each_entry_safe(odp_data, next, - &context->no_private_counters, + &per_mm->no_private_counters, no_private_counters) { mutex_lock(&odp_data->umem_mutex); odp_data->mn_counters_active = true; @@ -152,7 +153,7 @@ static void ib_ucontext_notifier_end_account(struct ib_ucontext *context) } } - up_write(&context->umem_rwsem); + up_write(&per_mm->umem_rwsem); } } @@ -179,19 +180,20 @@ static int ib_umem_notifier_release_trampoline(struct ib_umem_odp *umem_odp, static void ib_umem_notifier_release(struct mmu_notifier *mn, struct mm_struct *mm) { - struct ib_ucontext *context = container_of(mn, struct ib_ucontext, mn); + struct ib_ucontext_per_mm *per_mm = + container_of(mn, struct ib_ucontext_per_mm, mn); - if (!context->invalidate_range) + if (!per_mm->context->invalidate_range) return; - ib_ucontext_notifier_start_account(context); - down_read(&context->umem_rwsem); - rbt_ib_umem_for_each_in_range(&context->umem_tree, 0, + ib_ucontext_notifier_start_account(per_mm); + down_read(&per_mm->umem_rwsem); + rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, 0, ULLONG_MAX, ib_umem_notifier_release_trampoline, true, NULL); - up_read(&context->umem_rwsem); + up_read(&per_mm->umem_rwsem); } static int invalidate_page_trampoline(struct ib_umem_odp *item, u64 start, @@ -217,23 +219,24 @@ static int ib_umem_notifier_invalidate_range_start(struct mmu_notifier *mn, unsigned long end, bool blockable) { - struct ib_ucontext *context = container_of(mn, struct ib_ucontext, mn); + struct ib_ucontext_per_mm *per_mm = + container_of(mn, struct ib_ucontext_per_mm, mn); int ret; - if (!context->invalidate_range) + if (!per_mm->context->invalidate_range) return 0; if (blockable) - down_read(&context->umem_rwsem); - else if (!down_read_trylock(&context->umem_rwsem)) + down_read(&per_mm->umem_rwsem); + else if (!down_read_trylock(&per_mm->umem_rwsem)) return -EAGAIN; - ib_ucontext_notifier_start_account(context); - ret = rbt_ib_umem_for_each_in_range(&context->umem_tree, start, + ib_ucontext_notifier_start_account(per_mm); + ret = rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, start, end, invalidate_range_start_trampoline, blockable, NULL); - up_read(&context->umem_rwsem); + up_read(&per_mm->umem_rwsem); return ret; } @@ -250,9 +253,10 @@ static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn, unsigned long start, unsigned long end) { - struct ib_ucontext *context = container_of(mn, struct ib_ucontext, mn); + struct ib_ucontext_per_mm *per_mm = + container_of(mn, struct ib_ucontext_per_mm, mn); - if (!context->invalidate_range) + if (!per_mm->context->invalidate_range) return; /* @@ -260,12 +264,12 @@ static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn, * in ib_umem_notifier_invalidate_range_start so we shouldn't really block * here. But this is ugly and fragile. */ - down_read(&context->umem_rwsem); - rbt_ib_umem_for_each_in_range(&context->umem_tree, start, + down_read(&per_mm->umem_rwsem); + rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, start, end, invalidate_range_end_trampoline, true, NULL); - up_read(&context->umem_rwsem); - ib_ucontext_notifier_end_account(context); + up_read(&per_mm->umem_rwsem); + ib_ucontext_notifier_end_account(per_mm); } static const struct mmu_notifier_ops ib_umem_notifiers = { @@ -277,6 +281,7 @@ static const struct mmu_notifier_ops ib_umem_notifiers = { struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context, unsigned long addr, size_t size) { + struct ib_ucontext_per_mm *per_mm; struct ib_umem_odp *odp_data; struct ib_umem *umem; int pages = size >> PAGE_SHIFT; @@ -292,6 +297,7 @@ struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context, umem->page_shift = PAGE_SHIFT; umem->writable = 1; umem->is_odp = 1; + odp_data->per_mm = per_mm = &context->per_mm; mutex_init(&odp_data->umem_mutex); init_completion(&odp_data->notifier_completion); @@ -310,15 +316,15 @@ struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context, goto out_page_list; } - down_write(&context->umem_rwsem); - context->odp_mrs_count++; - rbt_ib_umem_insert(&odp_data->interval_tree, &context->umem_tree); - if (likely(!atomic_read(&context->notifier_count))) + down_write(&per_mm->umem_rwsem); + per_mm->odp_mrs_count++; + rbt_ib_umem_insert(&odp_data->interval_tree, &per_mm->umem_tree); + if (likely(!atomic_read(&per_mm->notifier_count))) odp_data->mn_counters_active = true; else list_add(&odp_data->no_private_counters, - &context->no_private_counters); - up_write(&context->umem_rwsem); + &per_mm->no_private_counters); + up_write(&per_mm->umem_rwsem); return odp_data; @@ -334,6 +340,7 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access) { struct ib_ucontext *context = umem_odp->umem.context; struct ib_umem *umem = &umem_odp->umem; + struct ib_ucontext_per_mm *per_mm; int ret_val; struct pid *our_pid; struct mm_struct *mm = get_task_mm(current); @@ -396,28 +403,30 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access) * notification before the "current" task (and MM) is * destroyed. We use the umem_rwsem semaphore to synchronize. */ - down_write(&context->umem_rwsem); - context->odp_mrs_count++; + umem_odp->per_mm = per_mm = &context->per_mm; + + down_write(&per_mm->umem_rwsem); + per_mm->odp_mrs_count++; if (likely(ib_umem_start(umem) != ib_umem_end(umem))) rbt_ib_umem_insert(&umem_odp->interval_tree, - &context->umem_tree); - if (likely(!atomic_read(&context->notifier_count)) || - context->odp_mrs_count == 1) + &per_mm->umem_tree); + if (likely(!atomic_read(&per_mm->notifier_count)) || + per_mm->odp_mrs_count == 1) umem_odp->mn_counters_active = true; else list_add(&umem_odp->no_private_counters, - &context->no_private_counters); - downgrade_write(&context->umem_rwsem); + &per_mm->no_private_counters); + downgrade_write(&per_mm->umem_rwsem); - if (context->odp_mrs_count == 1) { + if (per_mm->odp_mrs_count == 1) { /* * Note that at this point, no MMU notifier is running - * for this context! + * for this per_mm! */ - atomic_set(&context->notifier_count, 0); - INIT_HLIST_NODE(&context->mn.hlist); - context->mn.ops = &ib_umem_notifiers; - ret_val = mmu_notifier_register(&context->mn, mm); + atomic_set(&per_mm->notifier_count, 0); + INIT_HLIST_NODE(&per_mm->mn.hlist); + per_mm->mn.ops = &ib_umem_notifiers; + ret_val = mmu_notifier_register(&per_mm->mn, mm); if (ret_val) { pr_err("Failed to register mmu_notifier %d\n", ret_val); ret_val = -EBUSY; @@ -425,7 +434,7 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access) } } - up_read(&context->umem_rwsem); + up_read(&per_mm->umem_rwsem); /* * Note that doing an mmput can cause a notifier for the relevant mm. @@ -437,7 +446,7 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access) return 0; out_mutex: - up_read(&context->umem_rwsem); + up_read(&per_mm->umem_rwsem); vfree(umem_odp->dma_list); out_page_list: vfree(umem_odp->page_list); @@ -449,7 +458,7 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access) void ib_umem_odp_release(struct ib_umem_odp *umem_odp) { struct ib_umem *umem = &umem_odp->umem; - struct ib_ucontext *context = umem->context; + struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm; /* * Ensure that no more pages are mapped in the umem. @@ -460,11 +469,11 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp) ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem), ib_umem_end(umem)); - down_write(&context->umem_rwsem); + down_write(&per_mm->umem_rwsem); if (likely(ib_umem_start(umem) != ib_umem_end(umem))) rbt_ib_umem_remove(&umem_odp->interval_tree, - &context->umem_tree); - context->odp_mrs_count--; + &per_mm->umem_tree); + per_mm->odp_mrs_count--; if (!umem_odp->mn_counters_active) { list_del(&umem_odp->no_private_counters); complete_all(&umem_odp->notifier_completion); @@ -477,13 +486,13 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp) * that since we are doing it atomically, no other user could register * and unregister while we do the check. */ - downgrade_write(&context->umem_rwsem); - if (!context->odp_mrs_count) { + downgrade_write(&per_mm->umem_rwsem); + if (!per_mm->odp_mrs_count) { struct task_struct *owning_process = NULL; struct mm_struct *owning_mm = NULL; - owning_process = get_pid_task(context->tgid, - PIDTYPE_PID); + owning_process = + get_pid_task(umem_odp->umem.context->tgid, PIDTYPE_PID); if (owning_process == NULL) /* * The process is already dead, notifier were removed @@ -498,7 +507,7 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp) * removed already. */ goto out_put_task; - mmu_notifier_unregister(&context->mn, owning_mm); + mmu_notifier_unregister(&per_mm->mn, owning_mm); mmput(owning_mm); @@ -506,7 +515,7 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp) put_task_struct(owning_process); } out: - up_read(&context->umem_rwsem); + up_read(&per_mm->umem_rwsem); vfree(umem_odp->dma_list); vfree(umem_odp->page_list); diff --git a/drivers/infiniband/core/uverbs_cmd.c b/drivers/infiniband/core/uverbs_cmd.c index 9c87c98a0f19..ce678e1008a4 100644 --- a/drivers/infiniband/core/uverbs_cmd.c +++ b/drivers/infiniband/core/uverbs_cmd.c @@ -124,10 +124,11 @@ ssize_t ib_uverbs_get_context(struct ib_uverbs_file *file, ucontext->cleanup_retryable = false; #ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING - ucontext->umem_tree = RB_ROOT_CACHED; - init_rwsem(&ucontext->umem_rwsem); - ucontext->odp_mrs_count = 0; - INIT_LIST_HEAD(&ucontext->no_private_counters); + ucontext->per_mm.umem_tree = RB_ROOT_CACHED; + init_rwsem(&ucontext->per_mm.umem_rwsem); + ucontext->per_mm.odp_mrs_count = 0; + INIT_LIST_HEAD(&ucontext->per_mm.no_private_counters); + ucontext->per_mm.context = ucontext; if (!(ib_dev->attrs.device_cap_flags & IB_DEVICE_ON_DEMAND_PAGING)) ucontext->invalidate_range = NULL; diff --git a/drivers/infiniband/hw/mlx5/odp.c b/drivers/infiniband/hw/mlx5/odp.c index d4780bded74a..9982b5f4e598 100644 --- a/drivers/infiniband/hw/mlx5/odp.c +++ b/drivers/infiniband/hw/mlx5/odp.c @@ -61,13 +61,21 @@ static int check_parent(struct ib_umem_odp *odp, return mr && mr->parent == parent && !odp->dying; } +struct ib_ucontext_per_mm *mr_to_per_mm(struct mlx5_ib_mr *mr) +{ + if (WARN_ON(!mr || !mr->umem || !mr->umem->is_odp)) + return NULL; + + return to_ib_umem_odp(mr->umem)->per_mm; +} + static struct ib_umem_odp *odp_next(struct ib_umem_odp *odp) { struct mlx5_ib_mr *mr = odp->private, *parent = mr->parent; - struct ib_ucontext *ctx = odp->umem.context; + struct ib_ucontext_per_mm *per_mm = odp->per_mm; struct rb_node *rb; - down_read(&ctx->umem_rwsem); + down_read(&per_mm->umem_rwsem); while (1) { rb = rb_next(&odp->interval_tree.rb); if (!rb) @@ -79,19 +87,19 @@ static struct ib_umem_odp *odp_next(struct ib_umem_odp *odp) not_found: odp = NULL; end: - up_read(&ctx->umem_rwsem); + up_read(&per_mm->umem_rwsem); return odp; } -static struct ib_umem_odp *odp_lookup(struct ib_ucontext *ctx, - u64 start, u64 length, +static struct ib_umem_odp *odp_lookup(u64 start, u64 length, struct mlx5_ib_mr *parent) { + struct ib_ucontext_per_mm *per_mm = mr_to_per_mm(parent); struct ib_umem_odp *odp; struct rb_node *rb; - down_read(&ctx->umem_rwsem); - odp = rbt_ib_umem_lookup(&ctx->umem_tree, start, length); + down_read(&per_mm->umem_rwsem); + odp = rbt_ib_umem_lookup(&per_mm->umem_tree, start, length); if (!odp) goto end; @@ -108,7 +116,7 @@ static struct ib_umem_odp *odp_lookup(struct ib_ucontext *ctx, not_found: odp = NULL; end: - up_read(&ctx->umem_rwsem); + up_read(&per_mm->umem_rwsem); return odp; } @@ -116,7 +124,6 @@ void mlx5_odp_populate_klm(struct mlx5_klm *pklm, size_t offset, size_t nentries, struct mlx5_ib_mr *mr, int flags) { struct ib_pd *pd = mr->ibmr.pd; - struct ib_ucontext *ctx = pd->uobject->context; struct mlx5_ib_dev *dev = to_mdev(pd->device); struct ib_umem_odp *odp; unsigned long va; @@ -131,8 +138,8 @@ void mlx5_odp_populate_klm(struct mlx5_klm *pklm, size_t offset, return; } - odp = odp_lookup(ctx, offset * MLX5_IMR_MTT_SIZE, - nentries * MLX5_IMR_MTT_SIZE, mr); + odp = odp_lookup(offset * MLX5_IMR_MTT_SIZE, + nentries * MLX5_IMR_MTT_SIZE, mr); for (i = 0; i < nentries; i++, pklm++) { pklm->bcount = cpu_to_be32(MLX5_IMR_MTT_SIZE); @@ -368,7 +375,6 @@ static struct mlx5_ib_mr *implicit_mr_alloc(struct ib_pd *pd, static struct ib_umem_odp *implicit_mr_get_data(struct mlx5_ib_mr *mr, u64 io_virt, size_t bcnt) { - struct ib_ucontext *ctx = mr->ibmr.pd->uobject->context; struct mlx5_ib_dev *dev = to_mdev(mr->ibmr.pd->device); struct ib_umem_odp *odp, *result = NULL; struct ib_umem_odp *odp_mr = to_ib_umem_odp(mr->umem); @@ -377,7 +383,7 @@ static struct ib_umem_odp *implicit_mr_get_data(struct mlx5_ib_mr *mr, struct mlx5_ib_mr *mtt; mutex_lock(&odp_mr->umem_mutex); - odp = odp_lookup(ctx, addr, 1, mr); + odp = odp_lookup(addr, 1, mr); mlx5_ib_dbg(dev, "io_virt:%llx bcnt:%zx addr:%llx odp:%p\n", io_virt, bcnt, addr, odp); @@ -387,7 +393,8 @@ static struct ib_umem_odp *implicit_mr_get_data(struct mlx5_ib_mr *mr, if (nentries) nentries++; } else { - odp = ib_alloc_odp_umem(ctx, addr, MLX5_IMR_MTT_SIZE); + odp = ib_alloc_odp_umem(odp_mr->umem.context, addr, + MLX5_IMR_MTT_SIZE); if (IS_ERR(odp)) { mutex_unlock(&odp_mr->umem_mutex); return ERR_CAST(odp); @@ -486,12 +493,12 @@ static int mr_leaf_free(struct ib_umem_odp *umem_odp, u64 start, u64 end, void mlx5_ib_free_implicit_mr(struct mlx5_ib_mr *imr) { - struct ib_ucontext *ctx = imr->ibmr.pd->uobject->context; + struct ib_ucontext_per_mm *per_mm = mr_to_per_mm(imr); - down_read(&ctx->umem_rwsem); - rbt_ib_umem_for_each_in_range(&ctx->umem_tree, 0, ULLONG_MAX, + down_read(&per_mm->umem_rwsem); + rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, 0, ULLONG_MAX, mr_leaf_free, true, imr); - up_read(&ctx->umem_rwsem); + up_read(&per_mm->umem_rwsem); wait_event(imr->q_leaf_free, !atomic_read(&imr->num_leaf_free)); } diff --git a/include/rdma/ib_umem_odp.h b/include/rdma/ib_umem_odp.h index 4519ea663df5..394ea6b68db7 100644 --- a/include/rdma/ib_umem_odp.h +++ b/include/rdma/ib_umem_odp.h @@ -44,6 +44,8 @@ struct umem_odp_node { struct ib_umem_odp { struct ib_umem umem; + struct ib_ucontext_per_mm *per_mm; + /* * An array of the pages included in the on-demand paging umem. * Indices of pages that are currently not mapped into the device will diff --git a/include/rdma/ib_verbs.h b/include/rdma/ib_verbs.h index 458e66877893..87296695dfee 100644 --- a/include/rdma/ib_verbs.h +++ b/include/rdma/ib_verbs.h @@ -1482,6 +1482,25 @@ struct ib_rdmacg_object { #endif }; +#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING +struct ib_ucontext_per_mm { + struct ib_ucontext *context; + + struct rb_root_cached umem_tree; + /* + * Protects .umem_rbroot and tree, as well as odp_mrs_count and + * mmu notifiers registration. + */ + struct rw_semaphore umem_rwsem; + + struct mmu_notifier mn; + atomic_t notifier_count; + /* A list of umems that don't have private mmu notifier counters yet. */ + struct list_head no_private_counters; + unsigned int odp_mrs_count; +}; +#endif + struct ib_ucontext { struct ib_device *device; struct ib_uverbs_file *ufile; @@ -1496,20 +1515,9 @@ struct ib_ucontext { struct pid *tgid; #ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING - struct rb_root_cached umem_tree; - /* - * Protects .umem_rbroot and tree, as well as odp_mrs_count and - * mmu notifiers registration. - */ - struct rw_semaphore umem_rwsem; void (*invalidate_range)(struct ib_umem_odp *umem_odp, unsigned long start, unsigned long end); - - struct mmu_notifier mn; - atomic_t notifier_count; - /* A list of umems that don't have private mmu notifier counters yet. */ - struct list_head no_private_counters; - int odp_mrs_count; + struct ib_ucontext_per_mm per_mm; #endif struct ib_rdmacg_object cg_obj; -- 2.14.4