On 12/1/21 15:30, Liam Howlett wrote: > From: "Liam R. Howlett" <Liam.Howlett@xxxxxxxxxx> > > Don't use the mm_struct linked list or the vma->vm_next in prep for removal > > Signed-off-by: Liam R. Howlett <Liam.Howlett@xxxxxxxxxx> > --- > fs/userfaultfd.c | 49 ++++++++++++++++++++++------------- > include/linux/userfaultfd_k.h | 7 +++-- > mm/mmap.c | 12 ++++----- > 3 files changed, 40 insertions(+), 28 deletions(-) > > diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c > index 22bf14ab2d16..2880025598c7 100644 > --- a/fs/userfaultfd.c > +++ b/fs/userfaultfd.c > @@ -606,14 +606,16 @@ static void userfaultfd_event_wait_completion(struct userfaultfd_ctx *ctx, > if (release_new_ctx) { > struct vm_area_struct *vma; > struct mm_struct *mm = release_new_ctx->mm; > + VMA_ITERATOR(vmi, mm, 0); > > /* the various vma->vm_userfaultfd_ctx still points to it */ > mmap_write_lock(mm); > - for (vma = mm->mmap; vma; vma = vma->vm_next) > + for_each_vma(vmi, vma) { > if (vma->vm_userfaultfd_ctx.ctx == release_new_ctx) { > vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX; > vma->vm_flags &= ~__VM_UFFD_FLAGS; > } > + } > mmap_write_unlock(mm); > > userfaultfd_ctx_put(release_new_ctx); > @@ -794,11 +796,13 @@ static bool has_unmap_ctx(struct userfaultfd_ctx *ctx, struct list_head *unmaps, > return false; > } > > -int userfaultfd_unmap_prep(struct vm_area_struct *vma, > - unsigned long start, unsigned long end, > - struct list_head *unmaps) > +int userfaultfd_unmap_prep(struct mm_struct *mm, unsigned long start, > + unsigned long end, struct list_head *unmaps) > { > - for ( ; vma && vma->vm_start < end; vma = vma->vm_next) { > + VMA_ITERATOR(vmi, mm, start); > + struct vm_area_struct *vma; > + > + for_each_vma_range(vmi, vma, end) { > struct userfaultfd_unmap_ctx *unmap_ctx; > struct userfaultfd_ctx *ctx = vma->vm_userfaultfd_ctx.ctx; > > @@ -848,6 +852,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file) > /* len == 0 means wake all */ > struct userfaultfd_wake_range range = { .len = 0, }; > unsigned long new_flags; > + MA_STATE(mas, &mm->mm_mt, 0, 0); Again, it looks like this could also be VMA_ITERATOR, consistent with the one above? > > WRITE_ONCE(ctx->released, true); > > @@ -864,7 +869,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file) > */ > mmap_write_lock(mm); > prev = NULL; > - for (vma = mm->mmap; vma; vma = vma->vm_next) { > + mas_for_each(&mas, vma, ULONG_MAX) { > cond_resched(); > BUG_ON(!!vma->vm_userfaultfd_ctx.ctx ^ > !!(vma->vm_flags & __VM_UFFD_FLAGS)); > @@ -1281,6 +1286,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, > bool found; > bool basic_ioctls; > unsigned long start, end, vma_end; > + MA_STATE(mas, &mm->mm_mt, 0, 0); > > user_uffdio_register = (struct uffdio_register __user *) arg; > > @@ -1323,7 +1329,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, > goto out; > > mmap_write_lock(mm); > - vma = find_vma_prev(mm, start, &prev); > + mas_set(&mas, start); > + vma = mas_find(&mas, ULONG_MAX); > if (!vma) > goto out_unlock; > > @@ -1348,7 +1355,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, > */ > found = false; > basic_ioctls = false; > - for (cur = vma; cur && cur->vm_start < end; cur = cur->vm_next) { > + for (cur = vma; cur; cur = mas_next(&mas, end - 1)) { > cond_resched(); > > BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^ > @@ -1408,8 +1415,10 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, > } > BUG_ON(!found); > > - if (vma->vm_start < start) > - prev = vma; > + mas_set(&mas, start); > + prev = mas_prev(&mas, 0); > + if (prev != vma) > + mas_next(&mas, ULONG_MAX); Hmm non-commented tricky stuff... > > ret = 0; > do { > @@ -1466,8 +1475,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, > skip: > prev = vma; > start = vma->vm_end; > - vma = vma->vm_next; > - } while (vma && vma->vm_start < end); > + vma = mas_next(&mas, end - 1); > + } while (vma); > out_unlock: > mmap_write_unlock(mm); > mmput(mm); > @@ -1511,6 +1520,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, > bool found; > unsigned long start, end, vma_end; > const void __user *buf = (void __user *)arg; > + MA_STATE(mas, &mm->mm_mt, 0, 0); > > ret = -EFAULT; > if (copy_from_user(&uffdio_unregister, buf, sizeof(uffdio_unregister))) > @@ -1529,7 +1539,8 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, > goto out; > > mmap_write_lock(mm); > - vma = find_vma_prev(mm, start, &prev); > + mas_set(&mas, start); > + vma = mas_find(&mas, ULONG_MAX); > if (!vma) > goto out_unlock; > > @@ -1554,7 +1565,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, > */ > found = false; > ret = -EINVAL; > - for (cur = vma; cur && cur->vm_start < end; cur = cur->vm_next) { > + for (cur = vma; cur; cur = mas_next(&mas, end - 1)) { > cond_resched(); > > BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^ > @@ -1574,8 +1585,10 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, > } > BUG_ON(!found); > > - if (vma->vm_start < start) > - prev = vma; > + mas_set(&mas, start); > + prev = mas_prev(&mas, 0); > + if (prev != vma) > + mas_next(&mas, ULONG_MAX); Same here. > > ret = 0; > do { > @@ -1640,8 +1653,8 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, > skip: > prev = vma; > start = vma->vm_end; > - vma = vma->vm_next; > - } while (vma && vma->vm_start < end); > + vma = mas_next(&mas, end - 1); > + } while (vma); > out_unlock: > mmap_write_unlock(mm); > mmput(mm); > diff --git a/include/linux/userfaultfd_k.h b/include/linux/userfaultfd_k.h > index 33cea484d1ad..e0b2ec2c20f2 100644 > --- a/include/linux/userfaultfd_k.h > +++ b/include/linux/userfaultfd_k.h > @@ -139,9 +139,8 @@ extern bool userfaultfd_remove(struct vm_area_struct *vma, > unsigned long start, > unsigned long end); > > -extern int userfaultfd_unmap_prep(struct vm_area_struct *vma, > - unsigned long start, unsigned long end, > - struct list_head *uf); > +extern int userfaultfd_unmap_prep(struct mm_struct *mm, unsigned long start, > + unsigned long end, struct list_head *uf); > extern void userfaultfd_unmap_complete(struct mm_struct *mm, > struct list_head *uf); > > @@ -222,7 +221,7 @@ static inline bool userfaultfd_remove(struct vm_area_struct *vma, > return true; > } > > -static inline int userfaultfd_unmap_prep(struct vm_area_struct *vma, > +static inline int userfaultfd_unmap_prep(struct mm_struct *mm, > unsigned long start, unsigned long end, > struct list_head *uf) > { > diff --git a/mm/mmap.c b/mm/mmap.c > index 79b8494d83c6..dde74e0b195d 100644 > --- a/mm/mmap.c > +++ b/mm/mmap.c > @@ -2449,7 +2449,7 @@ do_mas_align_munmap(struct ma_state *mas, struct vm_area_struct *vma, > * split, despite we could. This is unlikely enough > * failure that it's not worth optimizing it for. > */ > - int error = userfaultfd_unmap_prep(vma, start, end, uf); > + int error = userfaultfd_unmap_prep(mm, start, end, uf); > > if (error) > return error; > @@ -2938,10 +2938,7 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma, > goto munmap_full_vma; > } > > - vma_init(&unmap, mm); > - unmap.vm_start = newbrk; > - unmap.vm_end = oldbrk; > - ret = userfaultfd_unmap_prep(&unmap, newbrk, oldbrk, uf); > + ret = userfaultfd_unmap_prep(mm, newbrk, oldbrk, uf); > if (ret) > return ret; > ret = 1; > @@ -2954,6 +2951,9 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma, > } > > vma->vm_end = newbrk; > + vma_init(&unmap, mm); > + unmap.vm_start = newbrk; > + unmap.vm_end = oldbrk; > if (vma_mas_remove(&unmap, mas)) > goto mas_store_fail; > > @@ -2963,7 +2963,7 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma, > } > > unmap_pages = vma_pages(&unmap); > - if (unmap.vm_flags & VM_LOCKED) { > + if (vma->vm_flags & VM_LOCKED) { Hmm is this an unrelated bug fix? As unmap didn't have any vm_flags set even before this patch, right? > mm->locked_vm -= unmap_pages; > munlock_vma_pages_range(&unmap, newbrk, oldbrk); > }