To make mempolicy fetchable by external tasks, we must first change the callstack to take a task as an argument. Modify the following functions to require a task argument: do_get_mempolicy kernel_get_mempolicy The way the task->mm is acquired must change slightly to enable this change. Originally, do_get_mempolicy would acquire the task->mm directly via (current->mm). This is unsafe to do in a non-current context. However, utilizing get_task_mm would break the original functionality of do_get_mempolicy due to the following check in get_task_mm: if (mm) { if (task->flags & PF_KTHREAD) mm = NULL; else mmget(mm); } To retain the original behavior, if (task == current) we access the task->mm directly, but if (task != current) we will utilize get_task_mm to safely access the mm. We simplify the get/put mechanics by always taking a reference to the mm, even if we are in the context of (task == current). Additionally, since the mempolicy will become externally modifiable, we need to take the task lock to acquire task->mempolicy safely, regardless of whether we are operating on current or not. Signed-off-by: Gregory Price <gregory.price@xxxxxxxxxxxx> --- mm/mempolicy.c | 43 +++++++++++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/mm/mempolicy.c b/mm/mempolicy.c index 9ea3e1bfc002..4519f39b1a07 100644 --- a/mm/mempolicy.c +++ b/mm/mempolicy.c @@ -899,8 +899,9 @@ static int lookup_node(struct mm_struct *mm, unsigned long addr) } /* Retrieve NUMA policy */ -static long do_get_mempolicy(int *policy, nodemask_t *nmask, - unsigned long addr, unsigned long flags) +static long do_get_mempolicy(struct task_struct *task, int *policy, + nodemask_t *nmask, unsigned long addr, + unsigned long flags) { int err; struct mm_struct *mm; @@ -915,9 +916,9 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask, if (flags & (MPOL_F_NODE|MPOL_F_ADDR)) return -EINVAL; *policy = 0; /* just so it's initialized */ - task_lock(current); - *nmask = cpuset_current_mems_allowed; - task_unlock(current); + task_lock(task); + *nmask = task->mems_allowed; + task_unlock(task); return 0; } @@ -928,7 +929,16 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask, * vma/shared policy at addr is NULL. We * want to return MPOL_DEFAULT in this case. */ - mm = current->mm; + if (task == current) { + /* + * original behavior allows a kernel task changing its + * own policy to avoid the condition in get_task_mm, + * so we'll directly access + */ + mm = task->mm; + mmget(mm); + } else + mm = get_task_mm(task); mmap_read_lock(mm); vma = vma_lookup(mm, addr); if (!vma) { @@ -947,8 +957,10 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask, return -EINVAL; else { /* take a reference of the task policy now */ - pol = current->mempolicy; + task_lock(task); + pol = task->mempolicy; mpol_get(pol); + task_unlock(task); } if (!pol) { @@ -962,12 +974,13 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask, vma = NULL; mmap_read_unlock(mm); err = lookup_node(mm, addr); + mmput(mm); if (err < 0) goto out; *policy = err; - } else if (pol == current->mempolicy && + } else if (pol == task->mempolicy && pol->mode == MPOL_INTERLEAVE) { - *policy = next_node_in(current->il_prev, pol->nodes); + *policy = next_node_in(task->il_prev, pol->nodes); } else { err = -EINVAL; goto out; @@ -987,9 +1000,9 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask, if (mpol_store_user_nodemask(pol)) { *nmask = pol->w.user_nodemask; } else { - task_lock(current); + task_lock(task); get_policy_nodemask(pol, nmask); - task_unlock(current); + task_unlock(task); } } @@ -1704,7 +1717,8 @@ SYSCALL_DEFINE4(migrate_pages, pid_t, pid, unsigned long, maxnode, } /* Retrieve NUMA policy */ -static int kernel_get_mempolicy(int __user *policy, +static int kernel_get_mempolicy(struct task_struct *task, + int __user *policy, unsigned long __user *nmask, unsigned long maxnode, unsigned long addr, @@ -1719,7 +1733,7 @@ static int kernel_get_mempolicy(int __user *policy, addr = untagged_addr(addr); - err = do_get_mempolicy(&pval, &nodes, addr, flags); + err = do_get_mempolicy(task, &pval, &nodes, addr, flags); if (err) return err; @@ -1737,7 +1751,8 @@ SYSCALL_DEFINE5(get_mempolicy, int __user *, policy, unsigned long __user *, nmask, unsigned long, maxnode, unsigned long, addr, unsigned long, flags) { - return kernel_get_mempolicy(policy, nmask, maxnode, addr, flags); + return kernel_get_mempolicy(current, policy, nmask, maxnode, addr, + flags); } bool vma_migratable(struct vm_area_struct *vma) -- 2.39.1