* Vlastimil Babka <vbabka@xxxxxxx> [220119 11:26]: > On 12/1/21 15:30, Liam Howlett wrote: > > From: "Liam R. Howlett" <Liam.Howlett@xxxxxxxxxx> > > > > Don't use the mm_struct linked list or the vma->vm_next in prep for removal > > > > Signed-off-by: Liam R. Howlett <Liam.Howlett@xxxxxxxxxx> > > --- > > fs/userfaultfd.c | 49 ++++++++++++++++++++++------------- > > include/linux/userfaultfd_k.h | 7 +++-- > > mm/mmap.c | 12 ++++----- > > 3 files changed, 40 insertions(+), 28 deletions(-) > > > > diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c > > index 22bf14ab2d16..2880025598c7 100644 > > --- a/fs/userfaultfd.c > > +++ b/fs/userfaultfd.c > > @@ -606,14 +606,16 @@ static void userfaultfd_event_wait_completion(struct userfaultfd_ctx *ctx, > > if (release_new_ctx) { > > struct vm_area_struct *vma; > > struct mm_struct *mm = release_new_ctx->mm; > > + VMA_ITERATOR(vmi, mm, 0); > > > > /* the various vma->vm_userfaultfd_ctx still points to it */ > > mmap_write_lock(mm); > > - for (vma = mm->mmap; vma; vma = vma->vm_next) > > + for_each_vma(vmi, vma) { > > if (vma->vm_userfaultfd_ctx.ctx == release_new_ctx) { > > vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX; > > vma->vm_flags &= ~__VM_UFFD_FLAGS; > > } > > + } > > mmap_write_unlock(mm); > > > > userfaultfd_ctx_put(release_new_ctx); > > @@ -794,11 +796,13 @@ static bool has_unmap_ctx(struct userfaultfd_ctx *ctx, struct list_head *unmaps, > > return false; > > } > > > > -int userfaultfd_unmap_prep(struct vm_area_struct *vma, > > - unsigned long start, unsigned long end, > > - struct list_head *unmaps) > > +int userfaultfd_unmap_prep(struct mm_struct *mm, unsigned long start, > > + unsigned long end, struct list_head *unmaps) > > { > > - for ( ; vma && vma->vm_start < end; vma = vma->vm_next) { > > + VMA_ITERATOR(vmi, mm, start); > > + struct vm_area_struct *vma; > > + > > + for_each_vma_range(vmi, vma, end) { > > struct userfaultfd_unmap_ctx *unmap_ctx; > > struct userfaultfd_ctx *ctx = vma->vm_userfaultfd_ctx.ctx; > > > > @@ -848,6 +852,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file) > > /* len == 0 means wake all */ > > struct userfaultfd_wake_range range = { .len = 0, }; > > unsigned long new_flags; > > + MA_STATE(mas, &mm->mm_mt, 0, 0); > > Again, it looks like this could also be VMA_ITERATOR, consistent with the > one above? VMA_ITERATOR is for simple cases, this is not a simple case, but in this change it does appear so. I missed the mas_pause() when the state is invalidated by vma_merge() success in the mas_for_each() loop below. I will fix this. > > > > > WRITE_ONCE(ctx->released, true); > > > > @@ -864,7 +869,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file) > > */ > > mmap_write_lock(mm); > > prev = NULL; > > - for (vma = mm->mmap; vma; vma = vma->vm_next) { > > + mas_for_each(&mas, vma, ULONG_MAX) { > > cond_resched(); > > BUG_ON(!!vma->vm_userfaultfd_ctx.ctx ^ > > !!(vma->vm_flags & __VM_UFFD_FLAGS)); > > @@ -1281,6 +1286,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, > > bool found; > > bool basic_ioctls; > > unsigned long start, end, vma_end; > > + MA_STATE(mas, &mm->mm_mt, 0, 0); > > > > user_uffdio_register = (struct uffdio_register __user *) arg; > > > > @@ -1323,7 +1329,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, > > goto out; > > > > mmap_write_lock(mm); > > - vma = find_vma_prev(mm, start, &prev); > > + mas_set(&mas, start); > > + vma = mas_find(&mas, ULONG_MAX); > > if (!vma) > > goto out_unlock; > > > > @@ -1348,7 +1355,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, > > */ > > found = false; > > basic_ioctls = false; > > - for (cur = vma; cur && cur->vm_start < end; cur = cur->vm_next) { > > + for (cur = vma; cur; cur = mas_next(&mas, end - 1)) { > > cond_resched(); > > > > BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^ > > @@ -1408,8 +1415,10 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, > > } > > BUG_ON(!found); > > > > - if (vma->vm_start < start) > > - prev = vma; > > + mas_set(&mas, start); > > + prev = mas_prev(&mas, 0); > > + if (prev != vma) > > + mas_next(&mas, ULONG_MAX); > > Hmm non-commented tricky stuff... Oh, I did not see this as tricky. I will add a comment. Basically, I am setting the maple state to search for start, then a mas_prev() means it will get the vma before start. > > > > > ret = 0; > > do { > > @@ -1466,8 +1475,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, > > skip: > > prev = vma; > > start = vma->vm_end; > > - vma = vma->vm_next; > > - } while (vma && vma->vm_start < end); > > + vma = mas_next(&mas, end - 1); > > + } while (vma); > > out_unlock: > > mmap_write_unlock(mm); > > mmput(mm); > > @@ -1511,6 +1520,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, > > bool found; > > unsigned long start, end, vma_end; > > const void __user *buf = (void __user *)arg; > > + MA_STATE(mas, &mm->mm_mt, 0, 0); > > > > ret = -EFAULT; > > if (copy_from_user(&uffdio_unregister, buf, sizeof(uffdio_unregister))) > > @@ -1529,7 +1539,8 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, > > goto out; > > > > mmap_write_lock(mm); > > - vma = find_vma_prev(mm, start, &prev); > > + mas_set(&mas, start); > > + vma = mas_find(&mas, ULONG_MAX); > > if (!vma) > > goto out_unlock; > > > > @@ -1554,7 +1565,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, > > */ > > found = false; > > ret = -EINVAL; > > - for (cur = vma; cur && cur->vm_start < end; cur = cur->vm_next) { > > + for (cur = vma; cur; cur = mas_next(&mas, end - 1)) { > > cond_resched(); > > > > BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^ > > @@ -1574,8 +1585,10 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, > > } > > BUG_ON(!found); > > > > - if (vma->vm_start < start) > > - prev = vma; > > + mas_set(&mas, start); > > + prev = mas_prev(&mas, 0); > > + if (prev != vma) > > + mas_next(&mas, ULONG_MAX); > > Same here. I'll add the comment here too. > > > > > ret = 0; > > do { > > @@ -1640,8 +1653,8 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, > > skip: > > prev = vma; > > start = vma->vm_end; > > - vma = vma->vm_next; > > - } while (vma && vma->vm_start < end); > > + vma = mas_next(&mas, end - 1); > > + } while (vma); > > out_unlock: > > mmap_write_unlock(mm); > > mmput(mm); > > diff --git a/include/linux/userfaultfd_k.h b/include/linux/userfaultfd_k.h > > index 33cea484d1ad..e0b2ec2c20f2 100644 > > --- a/include/linux/userfaultfd_k.h > > +++ b/include/linux/userfaultfd_k.h > > @@ -139,9 +139,8 @@ extern bool userfaultfd_remove(struct vm_area_struct *vma, > > unsigned long start, > > unsigned long end); > > > > -extern int userfaultfd_unmap_prep(struct vm_area_struct *vma, > > - unsigned long start, unsigned long end, > > - struct list_head *uf); > > +extern int userfaultfd_unmap_prep(struct mm_struct *mm, unsigned long start, > > + unsigned long end, struct list_head *uf); > > extern void userfaultfd_unmap_complete(struct mm_struct *mm, > > struct list_head *uf); > > > > @@ -222,7 +221,7 @@ static inline bool userfaultfd_remove(struct vm_area_struct *vma, > > return true; > > } > > > > -static inline int userfaultfd_unmap_prep(struct vm_area_struct *vma, > > +static inline int userfaultfd_unmap_prep(struct mm_struct *mm, > > unsigned long start, unsigned long end, > > struct list_head *uf) > > { > > diff --git a/mm/mmap.c b/mm/mmap.c > > index 79b8494d83c6..dde74e0b195d 100644 > > --- a/mm/mmap.c > > +++ b/mm/mmap.c > > @@ -2449,7 +2449,7 @@ do_mas_align_munmap(struct ma_state *mas, struct vm_area_struct *vma, > > * split, despite we could. This is unlikely enough > > * failure that it's not worth optimizing it for. > > */ > > - int error = userfaultfd_unmap_prep(vma, start, end, uf); > > + int error = userfaultfd_unmap_prep(mm, start, end, uf); > > > > if (error) > > return error; > > @@ -2938,10 +2938,7 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma, > > goto munmap_full_vma; > > } > > > > - vma_init(&unmap, mm); > > - unmap.vm_start = newbrk; > > - unmap.vm_end = oldbrk; > > - ret = userfaultfd_unmap_prep(&unmap, newbrk, oldbrk, uf); > > + ret = userfaultfd_unmap_prep(mm, newbrk, oldbrk, uf); > > if (ret) > > return ret; > > ret = 1; > > @@ -2954,6 +2951,9 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma, > > } > > > > vma->vm_end = newbrk; > > + vma_init(&unmap, mm); > > + unmap.vm_start = newbrk; > > + unmap.vm_end = oldbrk; > > if (vma_mas_remove(&unmap, mas)) > > goto mas_store_fail; > > > > @@ -2963,7 +2963,7 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma, > > } > > > > unmap_pages = vma_pages(&unmap); > > - if (unmap.vm_flags & VM_LOCKED) { > > + if (vma->vm_flags & VM_LOCKED) { > > Hmm is this an unrelated bug fix? As unmap didn't have any vm_flags set even > before this patch, right? Yes. Thanks, I must have merged it into the wrong commit. > > > mm->locked_vm -= unmap_pages; > > munlock_vma_pages_range(&unmap, newbrk, oldbrk); > > } >