Re: [PATCH v2 2/5] mm: abstract the vma_merge()/split_vma() pattern for mprotect() et al.

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



* Lorenzo Stoakes <lstoakes@xxxxxxxxx> [231009 16:53]:
> mprotect() and other functions which change VMA parameters over a range
> each employ a pattern of:-
> 
> 1. Attempt to merge the range with adjacent VMAs.
> 2. If this fails, and the range spans a subset of the VMA, split it
>    accordingly.
> 
> This is open-coded and duplicated in each case. Also in each case most of
> the parameters passed to vma_merge() remain the same.
> 
> Create a new function, vma_modify(), which abstracts this operation,
> accepting only those parameters which can be changed.
> 
> To avoid the mess of invoking each function call with unnecessary
> parameters, create inline wrapper functions for each of the modify
> operations, parameterised only by what is required to perform the action.
> 
> Note that the userfaultfd_release() case works even though it does not
> split VMAs - since start is set to vma->vm_start and end is set to
> vma->vm_end, the split logic does not trigger.
> 
> In addition, since we calculate pgoff to be equal to vma->vm_pgoff + (start
> - vma->vm_start) >> PAGE_SHIFT, and start - vma->vm_start will be 0 in this
> instance, this invocation will remain unchanged.
> 
> Signed-off-by: Lorenzo Stoakes <lstoakes@xxxxxxxxx>
> ---
>  fs/userfaultfd.c   | 69 +++++++++++++++-------------------------------
>  include/linux/mm.h | 60 ++++++++++++++++++++++++++++++++++++++++
>  mm/madvise.c       | 32 ++++++---------------
>  mm/mempolicy.c     | 22 +++------------
>  mm/mlock.c         | 27 +++++-------------
>  mm/mmap.c          | 45 ++++++++++++++++++++++++++++++
>  mm/mprotect.c      | 35 +++++++----------------
>  7 files changed, 157 insertions(+), 133 deletions(-)
> 
> diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c
> index a7c6ef764e63..ba44a67a0a34 100644
> --- a/fs/userfaultfd.c
> +++ b/fs/userfaultfd.c
> @@ -927,11 +927,10 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
>  			continue;
>  		}
>  		new_flags = vma->vm_flags & ~__VM_UFFD_FLAGS;
> -		prev = vma_merge(&vmi, mm, prev, vma->vm_start, vma->vm_end,
> -				 new_flags, vma->anon_vma,
> -				 vma->vm_file, vma->vm_pgoff,
> -				 vma_policy(vma),
> -				 NULL_VM_UFFD_CTX, anon_vma_name(vma));
> +		prev = vma_modify_flags_uffd(&vmi, prev, vma, vma->vm_start,
> +					     vma->vm_end, new_flags,
> +					     NULL_VM_UFFD_CTX);
> +
>  		if (prev) {
>  			vma = prev;
>  		} else {
> @@ -1331,7 +1330,6 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
>  	unsigned long start, end, vma_end;
>  	struct vma_iterator vmi;
>  	bool wp_async = userfaultfd_wp_async_ctx(ctx);
> -	pgoff_t pgoff;
>  
>  	user_uffdio_register = (struct uffdio_register __user *) arg;
>  
> @@ -1484,28 +1482,17 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
>  		vma_end = min(end, vma->vm_end);
>  
>  		new_flags = (vma->vm_flags & ~__VM_UFFD_FLAGS) | vm_flags;
> -		pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
> -		prev = vma_merge(&vmi, mm, prev, start, vma_end, new_flags,
> -				 vma->anon_vma, vma->vm_file, pgoff,
> -				 vma_policy(vma),
> -				 ((struct vm_userfaultfd_ctx){ ctx }),
> -				 anon_vma_name(vma));
> -		if (prev) {
> -			/* vma_merge() invalidated the mas */
> -			vma = prev;
> -			goto next;
> -		}
> -		if (vma->vm_start < start) {
> -			ret = split_vma(&vmi, vma, start, 1);
> -			if (ret)
> -				break;
> -		}
> -		if (vma->vm_end > end) {
> -			ret = split_vma(&vmi, vma, end, 0);
> -			if (ret)
> -				break;
> +		prev = vma_modify_flags_uffd(&vmi, prev, vma, start, vma_end,
> +					     new_flags,
> +					     (struct vm_userfaultfd_ctx){ctx});
> +		if (IS_ERR(prev)) {
> +			ret = PTR_ERR(prev);
> +			break;
>  		}
> -	next:
> +
> +		if (prev)
> +			vma = prev; /* vma_merge() invalidated the mas */

This is a stale comment.  The maple state is in the vma iterator, which
is passed through.  I missed this on the vma iterator conversion.

> +
>  		/*
>  		 * In the vma_merge() successful mprotect-like case 8:
>  		 * the next vma was merged into the current one and
> @@ -1568,7 +1555,6 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
>  	const void __user *buf = (void __user *)arg;
>  	struct vma_iterator vmi;
>  	bool wp_async = userfaultfd_wp_async_ctx(ctx);
> -	pgoff_t pgoff;
>  
>  	ret = -EFAULT;
>  	if (copy_from_user(&uffdio_unregister, buf, sizeof(uffdio_unregister)))
> @@ -1671,26 +1657,15 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
>  			uffd_wp_range(vma, start, vma_end - start, false);
>  
>  		new_flags = vma->vm_flags & ~__VM_UFFD_FLAGS;
> -		pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
> -		prev = vma_merge(&vmi, mm, prev, start, vma_end, new_flags,
> -				 vma->anon_vma, vma->vm_file, pgoff,
> -				 vma_policy(vma),
> -				 NULL_VM_UFFD_CTX, anon_vma_name(vma));
> -		if (prev) {
> -			vma = prev;
> -			goto next;
> -		}
> -		if (vma->vm_start < start) {
> -			ret = split_vma(&vmi, vma, start, 1);
> -			if (ret)
> -				break;
> -		}
> -		if (vma->vm_end > end) {
> -			ret = split_vma(&vmi, vma, end, 0);
> -			if (ret)
> -				break;
> +		prev = vma_modify_flags_uffd(&vmi, prev, vma, start, vma_end,
> +					     new_flags, NULL_VM_UFFD_CTX);
> +		if (IS_ERR(prev)) {
> +			ret = PTR_ERR(prev);
> +			break;
>  		}
> -	next:
> +
> +		if (prev)
> +			vma = prev;
>  		/*
>  		 * In the vma_merge() successful mprotect-like case 8:
>  		 * the next vma was merged into the current one and
> diff --git a/include/linux/mm.h b/include/linux/mm.h
> index a7b667786cde..83ee1f35febe 100644
> --- a/include/linux/mm.h
> +++ b/include/linux/mm.h
> @@ -3253,6 +3253,66 @@ extern struct vm_area_struct *copy_vma(struct vm_area_struct **,
>  	unsigned long addr, unsigned long len, pgoff_t pgoff,
>  	bool *need_rmap_locks);
>  extern void exit_mmap(struct mm_struct *);
> +struct vm_area_struct *vma_modify(struct vma_iterator *vmi,
> +				  struct vm_area_struct *prev,
> +				  struct vm_area_struct *vma,
> +				  unsigned long start, unsigned long end,
> +				  unsigned long vm_flags,
> +				  struct mempolicy *policy,
> +				  struct vm_userfaultfd_ctx uffd_ctx,
> +				  struct anon_vma_name *anon_name);
> +
> +/* We are about to modify the VMA's flags. */
> +static inline struct vm_area_struct
> +*vma_modify_flags(struct vma_iterator *vmi,
> +		  struct vm_area_struct *prev,
> +		  struct vm_area_struct *vma,
> +		  unsigned long start, unsigned long end,
> +		  unsigned long new_flags)
> +{
> +	return vma_modify(vmi, prev, vma, start, end, new_flags,
> +			  vma_policy(vma), vma->vm_userfaultfd_ctx,
> +			  anon_vma_name(vma));
> +}
> +
> +/* We are about to modify the VMA's flags and/or anon_name. */
> +static inline struct vm_area_struct
> +*vma_modify_flags_name(struct vma_iterator *vmi,
> +		       struct vm_area_struct *prev,
> +		       struct vm_area_struct *vma,
> +		       unsigned long start,
> +		       unsigned long end,
> +		       unsigned long new_flags,
> +		       struct anon_vma_name *new_name)
> +{
> +	return vma_modify(vmi, prev, vma, start, end, new_flags,
> +			  vma_policy(vma), vma->vm_userfaultfd_ctx, new_name);
> +}
> +
> +/* We are about to modify the VMA's memory policy. */
> +static inline struct vm_area_struct
> +*vma_modify_policy(struct vma_iterator *vmi,
> +		   struct vm_area_struct *prev,
> +		   struct vm_area_struct *vma,
> +		   unsigned long start, unsigned long end,
> +		   struct mempolicy *new_pol)
> +{
> +	return vma_modify(vmi, prev, vma, start, end, vma->vm_flags,
> +			  new_pol, vma->vm_userfaultfd_ctx, anon_vma_name(vma));
> +}
> +
> +/* We are about to modify the VMA's flags and/or uffd context. */
> +static inline struct vm_area_struct
> +*vma_modify_flags_uffd(struct vma_iterator *vmi,
> +		       struct vm_area_struct *prev,
> +		       struct vm_area_struct *vma,
> +		       unsigned long start, unsigned long end,
> +		       unsigned long new_flags,
> +		       struct vm_userfaultfd_ctx new_ctx)
> +{
> +	return vma_modify(vmi, prev, vma, start, end, new_flags,
> +			  vma_policy(vma), new_ctx, anon_vma_name(vma));
> +}
>  
>  static inline int check_data_rlimit(unsigned long rlim,
>  				    unsigned long new,
> diff --git a/mm/madvise.c b/mm/madvise.c
> index a4a20de50494..801d3c1bb7b3 100644
> --- a/mm/madvise.c
> +++ b/mm/madvise.c
> @@ -141,7 +141,7 @@ static int madvise_update_vma(struct vm_area_struct *vma,
>  {
>  	struct mm_struct *mm = vma->vm_mm;
>  	int error;
> -	pgoff_t pgoff;
> +	struct vm_area_struct *merged;
>  	VMA_ITERATOR(vmi, mm, start);
>  
>  	if (new_flags == vma->vm_flags && anon_vma_name_eq(anon_vma_name(vma), anon_name)) {
> @@ -149,30 +149,16 @@ static int madvise_update_vma(struct vm_area_struct *vma,
>  		return 0;
>  	}
>  
> -	pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
> -	*prev = vma_merge(&vmi, mm, *prev, start, end, new_flags,
> -			  vma->anon_vma, vma->vm_file, pgoff, vma_policy(vma),
> -			  vma->vm_userfaultfd_ctx, anon_name);
> -	if (*prev) {
> -		vma = *prev;
> -		goto success;
> -	}
> -
> -	*prev = vma;
> -
> -	if (start != vma->vm_start) {
> -		error = split_vma(&vmi, vma, start, 1);
> -		if (error)
> -			return error;
> -	}
> +	merged = vma_modify_flags_name(&vmi, *prev, vma, start, end, new_flags,
> +				       anon_name);
> +	if (IS_ERR(merged))
> +		return PTR_ERR(merged);
>  
> -	if (end != vma->vm_end) {
> -		error = split_vma(&vmi, vma, end, 0);
> -		if (error)
> -			return error;
> -	}
> +	if (merged)
> +		vma = *prev = merged;
> +	else
> +		*prev = vma;
>  
> -success:
>  	/* vm_flags is protected by the mmap_lock held in write mode. */
>  	vma_start_write(vma);
>  	vm_flags_reset(vma, new_flags);
> diff --git a/mm/mempolicy.c b/mm/mempolicy.c
> index b01922e88548..6b2e99db6dd5 100644
> --- a/mm/mempolicy.c
> +++ b/mm/mempolicy.c
> @@ -786,8 +786,6 @@ static int mbind_range(struct vma_iterator *vmi, struct vm_area_struct *vma,
>  {
>  	struct vm_area_struct *merged;
>  	unsigned long vmstart, vmend;
> -	pgoff_t pgoff;
> -	int err;
>  
>  	vmend = min(end, vma->vm_end);
>  	if (start > vma->vm_start) {
> @@ -802,27 +800,15 @@ static int mbind_range(struct vma_iterator *vmi, struct vm_area_struct *vma,
>  		return 0;
>  	}
>  
> -	pgoff = vma->vm_pgoff + ((vmstart - vma->vm_start) >> PAGE_SHIFT);
> -	merged = vma_merge(vmi, vma->vm_mm, *prev, vmstart, vmend, vma->vm_flags,
> -			 vma->anon_vma, vma->vm_file, pgoff, new_pol,
> -			 vma->vm_userfaultfd_ctx, anon_vma_name(vma));
> +	merged =  vma_modify_policy(vmi, *prev, vma, vmstart, vmend, new_pol);
> +	if (IS_ERR(merged))
> +		return PTR_ERR(merged);
> +
>  	if (merged) {
>  		*prev = merged;
>  		return vma_replace_policy(merged, new_pol);
>  	}
>  
> -	if (vma->vm_start != vmstart) {
> -		err = split_vma(vmi, vma, vmstart, 1);
> -		if (err)
> -			return err;
> -	}
> -
> -	if (vma->vm_end != vmend) {
> -		err = split_vma(vmi, vma, vmend, 0);
> -		if (err)
> -			return err;
> -	}
> -
>  	*prev = vma;
>  	return vma_replace_policy(vma, new_pol);
>  }
> diff --git a/mm/mlock.c b/mm/mlock.c
> index 42b6865f8f82..ae83a33c387e 100644
> --- a/mm/mlock.c
> +++ b/mm/mlock.c
> @@ -476,10 +476,10 @@ static int mlock_fixup(struct vma_iterator *vmi, struct vm_area_struct *vma,
>  	       unsigned long end, vm_flags_t newflags)
>  {
>  	struct mm_struct *mm = vma->vm_mm;
> -	pgoff_t pgoff;
>  	int nr_pages;
>  	int ret = 0;
>  	vm_flags_t oldflags = vma->vm_flags;
> +	struct vm_area_struct *merged;
>  
>  	if (newflags == oldflags || (oldflags & VM_SPECIAL) ||
>  	    is_vm_hugetlb_page(vma) || vma == get_gate_vma(current->mm) ||
> @@ -487,28 +487,15 @@ static int mlock_fixup(struct vma_iterator *vmi, struct vm_area_struct *vma,
>  		/* don't set VM_LOCKED or VM_LOCKONFAULT and don't count */
>  		goto out;
>  
> -	pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
> -	*prev = vma_merge(vmi, mm, *prev, start, end, newflags,
> -			vma->anon_vma, vma->vm_file, pgoff, vma_policy(vma),
> -			vma->vm_userfaultfd_ctx, anon_vma_name(vma));
> -	if (*prev) {
> -		vma = *prev;
> -		goto success;
> -	}
> -
> -	if (start != vma->vm_start) {
> -		ret = split_vma(vmi, vma, start, 1);
> -		if (ret)
> -			goto out;
> +	merged = vma_modify_flags(vmi, *prev, vma, start, end, newflags);
> +	if (IS_ERR(merged)) {
> +		ret = PTR_ERR(merged);
> +		goto out;
>  	}
>  
> -	if (end != vma->vm_end) {
> -		ret = split_vma(vmi, vma, end, 0);
> -		if (ret)
> -			goto out;
> -	}
> +	if (merged)
> +		vma = *prev = merged;
>  
> -success:
>  	/*
>  	 * Keep track of amount of locked VM.
>  	 */
> diff --git a/mm/mmap.c b/mm/mmap.c
> index 673429ee8a9e..22d968affc07 100644
> --- a/mm/mmap.c
> +++ b/mm/mmap.c
> @@ -2437,6 +2437,51 @@ int split_vma(struct vma_iterator *vmi, struct vm_area_struct *vma,
>  	return __split_vma(vmi, vma, addr, new_below);
>  }
>  
> +/*
> + * We are about to modify one or multiple of a VMA's flags, policy, userfaultfd
> + * context and anonymous VMA name within the range [start, end).
> + *
> + * As a result, we might be able to merge the newly modified VMA range with an
> + * adjacent VMA with identical properties.
> + *
> + * If no merge is possible and the range does not span the entirety of the VMA,
> + * we then need to split the VMA to accommodate the change.
> + */
> +struct vm_area_struct *vma_modify(struct vma_iterator *vmi,
> +				  struct vm_area_struct *prev,
> +				  struct vm_area_struct *vma,
> +				  unsigned long start, unsigned long end,
> +				  unsigned long vm_flags,
> +				  struct mempolicy *policy,
> +				  struct vm_userfaultfd_ctx uffd_ctx,
> +				  struct anon_vma_name *anon_name)
> +{
> +	pgoff_t pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
> +	struct vm_area_struct *merged;
> +
> +	merged = vma_merge(vmi, vma->vm_mm, prev, start, end, vm_flags,
> +			   vma->anon_vma, vma->vm_file, pgoff, policy,
> +			   uffd_ctx, anon_name);
> +	if (merged)
> +		return merged;
> +
> +	if (vma->vm_start < start) {
> +		int err = split_vma(vmi, vma, start, 1);
> +
> +		if (err)
> +			return ERR_PTR(err);
> +	}
> +
> +	if (vma->vm_end > end) {
> +		int err = split_vma(vmi, vma, end, 0);
> +
> +		if (err)
> +			return ERR_PTR(err);
> +	}
> +
> +	return NULL;
> +}
> +
>  /*
>   * do_vmi_align_munmap() - munmap the aligned region from @start to @end.
>   * @vmi: The vma iterator
> diff --git a/mm/mprotect.c b/mm/mprotect.c
> index b94fbb45d5c7..6f85d99682ab 100644
> --- a/mm/mprotect.c
> +++ b/mm/mprotect.c
> @@ -581,7 +581,7 @@ mprotect_fixup(struct vma_iterator *vmi, struct mmu_gather *tlb,
>  	long nrpages = (end - start) >> PAGE_SHIFT;
>  	unsigned int mm_cp_flags = 0;
>  	unsigned long charged = 0;
> -	pgoff_t pgoff;
> +	struct vm_area_struct *merged;
>  	int error;
>  
>  	if (newflags == oldflags) {
> @@ -625,34 +625,19 @@ mprotect_fixup(struct vma_iterator *vmi, struct mmu_gather *tlb,
>  		}
>  	}
>  
> -	/*
> -	 * First try to merge with previous and/or next vma.
> -	 */
> -	pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
> -	*pprev = vma_merge(vmi, mm, *pprev, start, end, newflags,
> -			   vma->anon_vma, vma->vm_file, pgoff, vma_policy(vma),
> -			   vma->vm_userfaultfd_ctx, anon_vma_name(vma));
> -	if (*pprev) {
> -		vma = *pprev;
> -		VM_WARN_ON((vma->vm_flags ^ newflags) & ~VM_SOFTDIRTY);
> -		goto success;
> +	merged = vma_modify_flags(vmi, *pprev, vma, start, end, newflags);
> +	if (IS_ERR(merged)) {
> +		error = PTR_ERR(merged);
> +		goto fail;
>  	}
>  
> -	*pprev = vma;
> -
> -	if (start != vma->vm_start) {
> -		error = split_vma(vmi, vma, start, 1);
> -		if (error)
> -			goto fail;
> -	}
> -
> -	if (end != vma->vm_end) {
> -		error = split_vma(vmi, vma, end, 0);
> -		if (error)
> -			goto fail;
> +	if (merged) {
> +		vma = *pprev = merged;
> +		VM_WARN_ON((vma->vm_flags ^ newflags) & ~VM_SOFTDIRTY);
> +	} else {
> +		*pprev = vma;
>  	}
>  
> -success:
>  	/*
>  	 * vm_flags and vm_page_prot are protected by the mmap_lock
>  	 * held in write mode.
> -- 
> 2.42.0
> 




[Index of Archives]     [Linux ARM Kernel]     [Linux ARM]     [Linux Omap]     [Fedora ARM]     [IETF Annouce]     [Bugtraq]     [Linux OMAP]     [Linux MIPS]     [eCos]     [Asterisk Internet PBX]     [Linux API]

  Powered by Linux