* Lokesh Gidra <lokeshgidra@xxxxxxxxxx> [240212 19:19]: > All userfaultfd operations, except write-protect, opportunistically use > per-vma locks to lock vmas. On failure, attempt again inside mmap_lock > critical section. > > Write-protect operation requires mmap_lock as it iterates over multiple > vmas. > > Signed-off-by: Lokesh Gidra <lokeshgidra@xxxxxxxxxx> > --- > fs/userfaultfd.c | 13 +- > include/linux/userfaultfd_k.h | 5 +- > mm/userfaultfd.c | 392 ++++++++++++++++++++++++++-------- > 3 files changed, 312 insertions(+), 98 deletions(-) > ... > + > +static __always_inline > +struct vm_area_struct *find_vma_and_prepare_anon(struct mm_struct *mm, > + unsigned long addr) > +{ > + struct vm_area_struct *vma; > + > + mmap_assert_locked(mm); > + vma = vma_lookup(mm, addr); > + if (!vma) > + vma = ERR_PTR(-ENOENT); > + else if (!(vma->vm_flags & VM_SHARED) && anon_vma_prepare(vma)) > + vma = ERR_PTR(-ENOMEM); Nit: I just noticed that the code below says anon_vma_prepare() is unlikely. ... > +static struct vm_area_struct *lock_mm_and_find_dst_vma(struct mm_struct *dst_mm, > + unsigned long dst_start, > + unsigned long len) > +{ > + struct vm_area_struct *dst_vma; > + int err; > + > + mmap_read_lock(dst_mm); > + dst_vma = find_vma_and_prepare_anon(dst_mm, dst_start); > + if (IS_ERR(dst_vma)) { > + err = PTR_ERR(dst_vma); It's sort of odd you decode then re-encode this error, but it's correct the way you have it written. You could just encode ENOENT instead? > + goto out_unlock; > + } > + > + if (validate_dst_vma(dst_vma, dst_start + len)) > + return dst_vma; > + > + err = -ENOENT; > +out_unlock: > + mmap_read_unlock(dst_mm); > + return ERR_PTR(err); > } > +#endif > ... > +static __always_inline > +long find_vmas_mm_locked(struct mm_struct *mm, int would probably do? > + unsigned long dst_start, > + unsigned long src_start, > + struct vm_area_struct **dst_vmap, > + struct vm_area_struct **src_vmap) > +{ > + struct vm_area_struct *vma; > + > + mmap_assert_locked(mm); > + vma = find_vma_and_prepare_anon(mm, dst_start); > + if (IS_ERR(vma)) > + return PTR_ERR(vma); > + > + *dst_vmap = vma; > + /* Skip finding src_vma if src_start is in dst_vma */ > + if (src_start >= vma->vm_start && src_start < vma->vm_end) > + goto out_success; > + > + vma = vma_lookup(mm, src_start); > + if (!vma) > + return -ENOENT; > +out_success: > + *src_vmap = vma; > + return 0; > +} > + > +#ifdef CONFIG_PER_VMA_LOCK > +static long find_and_lock_vmas(struct mm_struct *mm, This could also be an int return type, I must be missing something? ... > + *src_vmap = lock_vma_under_rcu(mm, src_start); > + if (likely(*src_vmap)) > + return 0; > + > + /* Undo any locking and retry in mmap_lock critical section */ > + vma_end_read(*dst_vmap); > + > + mmap_read_lock(mm); > + err = find_vmas_mm_locked(mm, dst_start, src_start, dst_vmap, src_vmap); > + if (!err) { > + /* > + * See comment in lock_vma() as to why not using > + * vma_start_read() here. > + */ > + down_read(&(*dst_vmap)->vm_lock->lock); > + if (*dst_vmap != *src_vmap) > + down_read(&(*src_vmap)->vm_lock->lock); > + } > + mmap_read_unlock(mm); > + return err; > +} > +#else > +static long lock_mm_and_find_vmas(struct mm_struct *mm, > + unsigned long dst_start, > + unsigned long src_start, > + struct vm_area_struct **dst_vmap, > + struct vm_area_struct **src_vmap) > +{ > + long err; > + > + mmap_read_lock(mm); > + err = find_vmas_mm_locked(mm, dst_start, src_start, dst_vmap, src_vmap); > + if (err) > + mmap_read_unlock(mm); > + return err; > } > +#endif This section is much easier to understand. Thanks. Thanks, Liam