Don't use the mm_struct linked list or the vma->vm_next in prep for removal Signed-off-by: Liam R. Howlett <Liam.Howlett@xxxxxxxxxx> --- fs/userfaultfd.c | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c index 000b457ad0876..b0eac6fb3e8e2 100644 --- a/fs/userfaultfd.c +++ b/fs/userfaultfd.c @@ -598,14 +598,16 @@ static void userfaultfd_event_wait_completion(struct userfaultfd_ctx *ctx, if (release_new_ctx) { struct vm_area_struct *vma; struct mm_struct *mm = release_new_ctx->mm; + MA_STATE(mas, &mm->mm_mt, 0, 0); /* the various vma->vm_userfaultfd_ctx still points to it */ mmap_write_lock(mm); - for (vma = mm->mmap; vma; vma = vma->vm_next) + mas_for_each(&mas, vma, ULONG_MAX) { if (vma->vm_userfaultfd_ctx.ctx == release_new_ctx) { vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX; vma->vm_flags &= ~(VM_UFFD_WP | VM_UFFD_MISSING); } + } mmap_write_unlock(mm); userfaultfd_ctx_put(release_new_ctx); @@ -790,7 +792,9 @@ int userfaultfd_unmap_prep(struct vm_area_struct *vma, unsigned long start, unsigned long end, struct list_head *unmaps) { - for ( ; vma && vma->vm_start < end; vma = vma->vm_next) { + MA_STATE(mas, &vma->vm_mm->mm_mt, vma->vm_start, vma->vm_start); + + mas_for_each(&mas, vma, end) { struct userfaultfd_unmap_ctx *unmap_ctx; struct userfaultfd_ctx *ctx = vma->vm_userfaultfd_ctx.ctx; @@ -840,6 +844,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file) /* len == 0 means wake all */ struct userfaultfd_wake_range range = { .len = 0, }; unsigned long new_flags; + MA_STATE(mas, &mm->mm_mt, 0, 0); WRITE_ONCE(ctx->released, true); @@ -856,7 +861,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file) */ mmap_write_lock(mm); prev = NULL; - for (vma = mm->mmap; vma; vma = vma->vm_next) { + mas_for_each(&mas, vma, ULONG_MAX) { cond_resched(); BUG_ON(!!vma->vm_userfaultfd_ctx.ctx ^ !!(vma->vm_flags & (VM_UFFD_MISSING | VM_UFFD_WP))); @@ -1270,6 +1275,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, bool found; bool basic_ioctls; unsigned long start, end, vma_end; + MA_STATE(mas, &mm->mm_mt, 0, 0); user_uffdio_register = (struct uffdio_register __user *) arg; @@ -1328,7 +1334,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, */ found = false; basic_ioctls = false; - for (cur = vma; cur && cur->vm_start < end; cur = cur->vm_next) { + mas_set(&mas, vma->vm_start); + mas_for_each(&mas, cur, end) { cond_resched(); BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^ @@ -1444,7 +1451,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, skip: prev = vma; start = vma->vm_end; - vma = vma->vm_next; + vma = vma_next(mm, vma); } while (vma && vma->vm_start < end); out_unlock: mmap_write_unlock(mm); @@ -1485,6 +1492,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, bool found; unsigned long start, end, vma_end; const void __user *buf = (void __user *)arg; + MA_STATE(mas, &mm->mm_mt, 0, 0); ret = -EFAULT; if (copy_from_user(&uffdio_unregister, buf, sizeof(uffdio_unregister))) @@ -1528,7 +1536,9 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, */ found = false; ret = -EINVAL; - for (cur = vma; cur && cur->vm_start < end; cur = cur->vm_next) { + mas_set(&mas, vma->vm_start); + + mas_for_each(&mas, cur, end) { cond_resched(); BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^ @@ -1614,7 +1624,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, skip: prev = vma; start = vma->vm_end; - vma = vma->vm_next; + vma = vma_next(mm, vma); } while (vma && vma->vm_start < end); out_unlock: mmap_write_unlock(mm); -- 2.28.0