To make mempolicy fetchable by external tasks, we must first change the callstack to take a task as an argument. Modify the following functions to require a task argument: set_mempolicy_home_node First we refactor set_mempolicy_home_node to __set_mempolicy_home_node which accepts a task argument, and change the syscall definition to pass in (current). The only functional change in this patch is related to the way task->mm is acquired. Originally, set_mempolicy_home_node would acquire task->mm directly via (current->mm). This is unsafe to do in a non-current context. However, utilizing get_task_mm would break the original functionality of do_get_mempolicy due to the following check in get_task_mm: if (mm) { if (task->flags & PF_KTHREAD) mm = NULL; else mmget(mm); } To retain the original behavior, if (task == current) we access the task->mm directly, but if (task != current) we will utilize get_task_mm to safely access the mm. We always take a reference to the mm to keep the cleanup semantics simple. Signed-off-by: Gregory Price <gregory.price@xxxxxxxxxxxx> --- mm/mempolicy.c | 62 +++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 49 insertions(+), 13 deletions(-) diff --git a/mm/mempolicy.c b/mm/mempolicy.c index 4519f39b1a07..540163f5d349 100644 --- a/mm/mempolicy.c +++ b/mm/mempolicy.c @@ -1521,39 +1521,67 @@ static long kernel_mbind(unsigned long start, unsigned long len, return do_mbind(start, len, lmode, mode_flags, &nodes, flags); } -SYSCALL_DEFINE4(set_mempolicy_home_node, unsigned long, start, unsigned long, len, - unsigned long, home_node, unsigned long, flags) +static long __set_mempolicy_home_node(struct task_struct *task, + unsigned long start, + unsigned long len, + unsigned long home_node, + unsigned long flags) { - struct mm_struct *mm = current->mm; + struct mm_struct *mm; struct vm_area_struct *vma, *prev; struct mempolicy *new, *old; unsigned long end; int err = -ENOENT; + + /* + * Behavior when task == current allows a task modifying itself + * to bypass the check in get_task_mm and acquire the mm directly + */ + if (task == current) { + mm = task->mm; + mmget(mm); + } else + mm = get_task_mm(task); + + if (!mm) + return -ENODEV; + VMA_ITERATOR(vmi, mm, start); start = untagged_addr(start); - if (start & ~PAGE_MASK) - return -EINVAL; + if (start & ~PAGE_MASK) { + err = -EINVAL; + goto mm_out; + } /* * flags is used for future extension if any. */ - if (flags != 0) - return -EINVAL; + if (flags != 0) { + err = -EINVAL; + goto mm_out; + } /* * Check home_node is online to avoid accessing uninitialized * NODE_DATA. */ - if (home_node >= MAX_NUMNODES || !node_online(home_node)) - return -EINVAL; + if (home_node >= MAX_NUMNODES || !node_online(home_node)) { + err = -EINVAL; + goto mm_out; + } len = PAGE_ALIGN(len); end = start + len; - if (end < start) - return -EINVAL; - if (end == start) - return 0; + if (end < start) { + err = -EINVAL; + goto mm_out; + } + if (end == start) { + err = 0; + goto mm_out; + } + mmap_write_lock(mm); prev = vma_prev(&vmi); for_each_vma_range(vmi, vma, end) { @@ -1585,9 +1613,17 @@ SYSCALL_DEFINE4(set_mempolicy_home_node, unsigned long, start, unsigned long, le break; } mmap_write_unlock(mm); +mm_out: + mmput(mm); return err; } +SYSCALL_DEFINE4(set_mempolicy_home_node, unsigned long, start, unsigned long, len, + unsigned long, home_node, unsigned long, flags) +{ + return __set_mempolicy_home_node(current, start, len, home_node, flags); +} + SYSCALL_DEFINE6(mbind, unsigned long, start, unsigned long, len, unsigned long, mode, const unsigned long __user *, nmask, unsigned long, maxnode, unsigned int, flags) -- 2.39.1