From: "Liam R. Howlett" <Liam.Howlett@xxxxxxxxxx> Don't use the mm_struct linked list or the vma->vm_next in prep for removal Signed-off-by: Liam R. Howlett <Liam.Howlett@xxxxxxxxxx> --- fs/userfaultfd.c | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c index 003f0d31743e..bd181f922999 100644 --- a/fs/userfaultfd.c +++ b/fs/userfaultfd.c @@ -606,14 +606,18 @@ static void userfaultfd_event_wait_completion(struct userfaultfd_ctx *ctx, if (release_new_ctx) { struct vm_area_struct *vma; struct mm_struct *mm = release_new_ctx->mm; + MA_STATE(mas, &mm->mm_mt, 0, 0); /* the various vma->vm_userfaultfd_ctx still points to it */ mmap_write_lock(mm); - for (vma = mm->mmap; vma; vma = vma->vm_next) + mas_lock(&mas); + mas_for_each(&mas, vma, ULONG_MAX) { if (vma->vm_userfaultfd_ctx.ctx == release_new_ctx) { vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX; vma->vm_flags &= ~__VM_UFFD_FLAGS; } + } + mas_unlock(&mas); mmap_write_unlock(mm); userfaultfd_ctx_put(release_new_ctx); @@ -798,7 +802,10 @@ int userfaultfd_unmap_prep(struct vm_area_struct *vma, unsigned long start, unsigned long end, struct list_head *unmaps) { - for ( ; vma && vma->vm_start < end; vma = vma->vm_next) { + MA_STATE(mas, &vma->vm_mm->mm_mt, vma->vm_start, vma->vm_start); + + rcu_read_lock(); + mas_for_each(&mas, vma, end) { struct userfaultfd_unmap_ctx *unmap_ctx; struct userfaultfd_ctx *ctx = vma->vm_userfaultfd_ctx.ctx; @@ -817,6 +824,7 @@ int userfaultfd_unmap_prep(struct vm_area_struct *vma, unmap_ctx->end = end; list_add_tail(&unmap_ctx->list, unmaps); } + rcu_read_unlock(); return 0; } @@ -848,6 +856,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file) /* len == 0 means wake all */ struct userfaultfd_wake_range range = { .len = 0, }; unsigned long new_flags; + MA_STATE(mas, &mm->mm_mt, 0, 0); WRITE_ONCE(ctx->released, true); @@ -863,9 +872,11 @@ static int userfaultfd_release(struct inode *inode, struct file *file) * taking the mmap_lock for writing. */ mmap_write_lock(mm); + mas_lock(&mas); prev = NULL; - for (vma = mm->mmap; vma; vma = vma->vm_next) { + mas_for_each(&mas, vma, ULONG_MAX) { cond_resched(); + BUG_ON(!!vma->vm_userfaultfd_ctx.ctx ^ !!(vma->vm_flags & __VM_UFFD_FLAGS)); if (vma->vm_userfaultfd_ctx.ctx != ctx) { @@ -885,6 +896,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file) vma->vm_flags = new_flags; vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX; } + mas_unlock(&mas); mmap_write_unlock(mm); mmput(mm); wakeup: @@ -1281,6 +1293,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, bool found; bool basic_ioctls; unsigned long start, end, vma_end; + MA_STATE(mas, &mm->mm_mt, 0, 0); user_uffdio_register = (struct uffdio_register __user *) arg; @@ -1323,6 +1336,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, goto out; mmap_write_lock(mm); + mas_lock(&mas); vma = find_vma_prev(mm, start, &prev); if (!vma) goto out_unlock; @@ -1348,7 +1362,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, */ found = false; basic_ioctls = false; - for (cur = vma; cur && cur->vm_start < end; cur = cur->vm_next) { + mas_set(&mas, vma->vm_start); + mas_for_each(&mas, cur, end) { cond_resched(); BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^ @@ -1466,9 +1481,10 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, skip: prev = vma; start = vma->vm_end; - vma = vma->vm_next; + vma = vma_next(mm, vma); } while (vma && vma->vm_start < end); out_unlock: + mas_unlock(&mas); mmap_write_unlock(mm); mmput(mm); if (!ret) { @@ -1511,6 +1527,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, bool found; unsigned long start, end, vma_end; const void __user *buf = (void __user *)arg; + MA_STATE(mas, &mm->mm_mt, 0, 0); ret = -EFAULT; if (copy_from_user(&uffdio_unregister, buf, sizeof(uffdio_unregister))) @@ -1529,6 +1546,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, goto out; mmap_write_lock(mm); + mas_lock(&mas); vma = find_vma_prev(mm, start, &prev); if (!vma) goto out_unlock; @@ -1554,7 +1572,8 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, */ found = false; ret = -EINVAL; - for (cur = vma; cur && cur->vm_start < end; cur = cur->vm_next) { + mas_set(&mas, vma->vm_start); + mas_for_each(&mas, cur, end) { cond_resched(); BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^ @@ -1640,9 +1659,10 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, skip: prev = vma; start = vma->vm_end; - vma = vma->vm_next; + vma = vma_next(mm, vma); } while (vma && vma->vm_start < end); out_unlock: + mas_unlock(&mas); mmap_write_unlock(mm); mmput(mm); out: -- 2.30.2