To make mbind applied mempolicies modifable by external tasks, we must first change the do_mbind callstack to take a task as an argument. This patch includes changes to the following functions: do_mbind kernel_mbind get_vma_policy And adds the following: get_task_vma_policy get_vma_policy is changed into a wrapper of get_task_vma_policy which passes current as an argument to retain the existing behavior for callers of get_vma_policy. do_mbind is modified as followed: 1) the way task->mm is acquired is changed to be safe for non-current tasks, but the original behavior of (task == current) is retained. 2) we take a reference to the mm so that the task lock can be dropped. 3) the task lock must now be acquired on call to get_task_policy to ensure we acquire and reference the policy safely. 4) get_task_vma_policy is called instead of get_vma_policy. This requires taking the task_lock because of the new semantics. Change to acquiring task->mm: When (task == curent), if we use get_task_mm, it would prevent a kernel task from making modifications or accessing information about its own vma's. So in this scenario, we simply access and reference the mm directly, since the mempolicy information is being accessed in the context of the task itself. if (mm) { if (task->flags & PF_KTHREAD) mm = NULL; else mmget(mm); } The retains the existing behavior. Change to get_task_vma_policy locking behavior: Since task->policy is no longer guaranteed to be stable, any time we seek to acquire a policy via get_task_vma_policy, we must use the task_lock and reference it accordingly, regardless of where it ultimately came from. A similar behvior can be seen in do_get_mempolicy, where a reference is taken and a conditional release is made to handle the situation where a shared policy is acquired. In the case of do_mbind, we don't actually need to take a reference to the policy, as we only call get_task_vma_policy to find the ilx. In this case, we only need to call mpol_cond_put immediately to ensure that if get_task_vma_policy returns a shared policy we decrement the reference count since a shared mpol will return already referenced. Signed-off-by: Gregory Price <gregory.price@xxxxxxxxxxxx> --- mm/mempolicy.c | 92 +++++++++++++++++++++++++++++++++++++------------- 1 file changed, 69 insertions(+), 23 deletions(-) diff --git a/mm/mempolicy.c b/mm/mempolicy.c index 540163f5d349..3d2171ac4098 100644 --- a/mm/mempolicy.c +++ b/mm/mempolicy.c @@ -422,6 +422,10 @@ static bool migrate_folio_add(struct folio *folio, struct list_head *foliolist, unsigned long flags); static nodemask_t *policy_nodemask(gfp_t gfp, struct mempolicy *pol, pgoff_t ilx, int *nid); +static struct mempolicy *get_task_vma_policy(struct task_struct *task, + struct vm_area_struct *vma, + unsigned long addr, int order, + pgoff_t *ilx); static bool strictly_unmovable(unsigned long flags) { @@ -1250,11 +1254,12 @@ static struct folio *alloc_migration_target_by_mpol(struct folio *src, } #endif -static long do_mbind(unsigned long start, unsigned long len, - unsigned short mode, unsigned short mode_flags, - nodemask_t *nmask, unsigned long flags) +static long do_mbind(struct task_struct *task, unsigned long start, + unsigned long len, unsigned short mode, + unsigned short mode_flags, nodemask_t *nmask, + unsigned long flags) { - struct mm_struct *mm = current->mm; + struct mm_struct *mm; struct vm_area_struct *vma, *prev; struct vma_iterator vmi; struct migration_mpol mmpol; @@ -1287,6 +1292,21 @@ static long do_mbind(unsigned long start, unsigned long len, if (IS_ERR(new)) return PTR_ERR(new); + /* + * original behavior allows a kernel task modifying itself to bypass + * check in get_task_mm, so directly acquire mm in this case + */ + if (task == current) { + mm = task->mm; + mmget(mm); + } else + mm = get_task_mm(task); + + if (!mm) { + err = -ENODEV; + goto mpol_out; + } + /* * If we are using the default policy then operation * on discontinuous address spaces is okay after all @@ -1300,7 +1320,9 @@ static long do_mbind(unsigned long start, unsigned long len, NODEMASK_SCRATCH(scratch); if (scratch) { mmap_write_lock(mm); - err = mpol_set_nodemask(current, new, nmask, scratch); + task_lock(task); + err = mpol_set_nodemask(task, new, nmask, scratch); + task_unlock(task); if (err) mmap_write_unlock(mm); } else @@ -1308,7 +1330,7 @@ static long do_mbind(unsigned long start, unsigned long len, NODEMASK_SCRATCH_FREE(scratch); } if (err) - goto mpol_out; + goto mm_out; /* * Lock the VMAs before scanning for pages to migrate, @@ -1333,8 +1355,10 @@ static long do_mbind(unsigned long start, unsigned long len, if (!err && !list_empty(&pagelist)) { /* Convert MPOL_DEFAULT's NULL to task or default policy */ if (!new) { - new = get_task_policy(current); + task_lock(task); + new = get_task_policy(task); mpol_get(new); + task_unlock(task); } mmpol.pol = new; mmpol.ilx = 0; @@ -1365,8 +1389,11 @@ static long do_mbind(unsigned long start, unsigned long len, if (addr != -EFAULT) { order = compound_order(page); /* We already know the pol, but not the ilx */ - mpol_cond_put(get_vma_policy(vma, addr, order, - &mmpol.ilx)); + task_lock(task); + mpol_cond_put(get_task_vma_policy(task, vma, + addr, order, + &mmpol.ilx)); + task_unlock(task); /* Set base from which to increment by index */ mmpol.ilx -= page->index >> order; } @@ -1386,6 +1413,8 @@ static long do_mbind(unsigned long start, unsigned long len, err = -EIO; if (!list_empty(&pagelist)) putback_movable_pages(&pagelist); +mm_out: + mmput(mm); mpol_out: mpol_put(new); if (flags & (MPOL_MF_MOVE | MPOL_MF_MOVE_ALL)) @@ -1500,8 +1529,9 @@ static inline int sanitize_mpol_flags(int *mode, unsigned short *flags) return 0; } -static long kernel_mbind(unsigned long start, unsigned long len, - unsigned long mode, const unsigned long __user *nmask, +static long kernel_mbind(struct task_struct *task, unsigned long start, + unsigned long len, unsigned long mode, + const unsigned long __user *nmask, unsigned long maxnode, unsigned int flags) { unsigned short mode_flags; @@ -1518,7 +1548,7 @@ static long kernel_mbind(unsigned long start, unsigned long len, if (err) return err; - return do_mbind(start, len, lmode, mode_flags, &nodes, flags); + return do_mbind(task, start, len, lmode, mode_flags, &nodes, flags); } static long __set_mempolicy_home_node(struct task_struct *task, @@ -1628,7 +1658,7 @@ SYSCALL_DEFINE6(mbind, unsigned long, start, unsigned long, len, unsigned long, mode, const unsigned long __user *, nmask, unsigned long, maxnode, unsigned int, flags) { - return kernel_mbind(start, len, mode, nmask, maxnode, flags); + return kernel_mbind(current, start, len, mode, nmask, maxnode, flags); } /* Set the process memory policy */ @@ -1827,6 +1857,31 @@ struct mempolicy *__get_vma_policy(struct vm_area_struct *vma, vma->vm_ops->get_policy(vma, addr, ilx) : vma->vm_policy; } +/* + * Task variant of get_vma_policy for use internally. Returns the task + * policy if the vma does not have a policy of its own. get_vma_policy + * will return current->mempolicy as a result. + * + * Like get_vma_policy and get_task_policy, must hold alloc/task_lock + * while calling this. + */ +static struct mempolicy *get_task_vma_policy(struct task_struct *task, + struct vm_area_struct *vma, + unsigned long addr, int order, + pgoff_t *ilx) +{ + struct mempolicy *pol; + + pol = __get_vma_policy(vma, addr, ilx); + if (!pol) + pol = get_task_policy(task); + if (pol->mode == MPOL_INTERLEAVE) { + *ilx += vma->vm_pgoff >> order; + *ilx += (addr - vma->vm_start) >> (PAGE_SHIFT + order); + } + return pol; +} + /* * get_vma_policy(@vma, @addr, @order, @ilx) * @vma: virtual memory area whose policy is sought @@ -1844,16 +1899,7 @@ struct mempolicy *__get_vma_policy(struct vm_area_struct *vma, struct mempolicy *get_vma_policy(struct vm_area_struct *vma, unsigned long addr, int order, pgoff_t *ilx) { - struct mempolicy *pol; - - pol = __get_vma_policy(vma, addr, ilx); - if (!pol) - pol = get_task_policy(current); - if (pol->mode == MPOL_INTERLEAVE) { - *ilx += vma->vm_pgoff >> order; - *ilx += (addr - vma->vm_start) >> (PAGE_SHIFT + order); - } - return pol; + return get_task_vma_policy(current, vma, addr, order, ilx); } bool vma_policy_mof(struct vm_area_struct *vma) -- 2.39.1