Re: [PATCH 1/1] userfaultfd: allow get_mempolicy(MPOL_F_NODE|MPOL_F_ADDR) to trigger userfaults

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



On Fri, Aug 31, 2018 at 05:48:48PM -0400, Andrea Arcangeli wrote:
> get_mempolicy(MPOL_F_NODE|MPOL_F_ADDR) called a get_user_pages that
> would not be waiting for userfaults before failing and it would hit on
> a SIGBUS instead. Using get_user_pages_locked/unlocked instead will
> allow get_mempolicy to allow userfaults to resolve the fault and fill
> the hole, before grabbing the node id of the page.
> 
> Reported-by: Maxime Coquelin <maxime.coquelin@xxxxxxxxxx>
> Tested-by: Dr. David Alan Gilbert <dgilbert@xxxxxxxxxx>
> Signed-off-by: Andrea Arcangeli <aarcange@xxxxxxxxxx>

Reviewed-by: Mike Rapoport <rppt@xxxxxxxxxxxxxxxxxx>

> ---
>  mm/mempolicy.c | 24 +++++++++++++++++++-----
>  1 file changed, 19 insertions(+), 5 deletions(-)
> 
> diff --git a/mm/mempolicy.c b/mm/mempolicy.c
> index 01f1a14facc4..a7f7f5415936 100644
> --- a/mm/mempolicy.c
> +++ b/mm/mempolicy.c
> @@ -797,16 +797,19 @@ static void get_policy_nodemask(struct mempolicy *p, nodemask_t *nodes)
>  	}
>  }
> 
> -static int lookup_node(unsigned long addr)
> +static int lookup_node(struct mm_struct *mm, unsigned long addr)
>  {
>  	struct page *p;
>  	int err;
> 
> -	err = get_user_pages(addr & PAGE_MASK, 1, 0, &p, NULL);
> +	int locked = 1;
> +	err = get_user_pages_locked(addr & PAGE_MASK, 1, 0, &p, &locked);
>  	if (err >= 0) {
>  		err = page_to_nid(p);
>  		put_page(p);
>  	}
> +	if (locked)
> +		up_read(&mm->mmap_sem);
>  	return err;
>  }
> 
> @@ -817,7 +820,7 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask,
>  	int err;
>  	struct mm_struct *mm = current->mm;
>  	struct vm_area_struct *vma = NULL;
> -	struct mempolicy *pol = current->mempolicy;
> +	struct mempolicy *pol = current->mempolicy, *pol_refcount = NULL;
> 
>  	if (flags &
>  		~(unsigned long)(MPOL_F_NODE|MPOL_F_ADDR|MPOL_F_MEMS_ALLOWED))
> @@ -857,7 +860,16 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask,
> 
>  	if (flags & MPOL_F_NODE) {
>  		if (flags & MPOL_F_ADDR) {
> -			err = lookup_node(addr);
> +			/*
> +			 * Take a refcount on the mpol, lookup_node()
> +			 * wil drop the mmap_sem, so after calling
> +			 * lookup_node() only "pol" remains valid, "vma"
> +			 * is stale.
> +			 */
> +			pol_refcount = pol;
> +			vma = NULL;
> +			mpol_get(pol);
> +			err = lookup_node(mm, addr);
>  			if (err < 0)
>  				goto out;
>  			*policy = err;
> @@ -892,7 +904,9 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask,
>   out:
>  	mpol_cond_put(pol);
>  	if (vma)
> -		up_read(&current->mm->mmap_sem);
> +		up_read(&mm->mmap_sem);
> +	if (pol_refcount)
> +		mpol_put(pol_refcount);
>  	return err;
>  }
> 
> 

-- 
Sincerely yours,
Mike.




[Index of Archives]     [Linux ARM Kernel]     [Linux ARM]     [Linux Omap]     [Fedora ARM]     [IETF Annouce]     [Bugtraq]     [Linux OMAP]     [Linux MIPS]     [eCos]     [Asterisk Internet PBX]     [Linux API]

  Powered by Linux