On Thu, 14 Sep 2023 19:54:56 -0400 Gregory Price <gourry.memverge@xxxxxxxxx> wrote: > sys_set_mempolicy is limited by its current argument structure > (mode, nodes, flags) to implementing policies that can be described > in that manner. > > Implement set/get_mempolicy2 with a new mempolicy_args structure > which encapsulates the old behavior, and allows for new mempolicies > which may require additional information. > > Signed-off-by: Gregory Price <gregory.price@xxxxxxxxxxxx> Some random comments inline. Jonathan > --- > arch/x86/entry/syscalls/syscall_32.tbl | 2 + > arch/x86/entry/syscalls/syscall_64.tbl | 2 + > include/linux/syscalls.h | 2 + > include/uapi/asm-generic/unistd.h | 10 +- > include/uapi/linux/mempolicy.h | 32 ++++ > mm/mempolicy.c | 215 ++++++++++++++++++++++++- > 6 files changed, 261 insertions(+), 2 deletions(-) > > diff --git a/arch/x86/entry/syscalls/syscall_32.tbl b/arch/x86/entry/syscalls/syscall_32.tbl > index 2d0b1bd866ea..a72ef588a704 100644 > --- a/arch/x86/entry/syscalls/syscall_32.tbl > +++ b/arch/x86/entry/syscalls/syscall_32.tbl > @@ -457,3 +457,5 @@ > 450 i386 set_mempolicy_home_node sys_set_mempolicy_home_node > 451 i386 cachestat sys_cachestat > 452 i386 fchmodat2 sys_fchmodat2 > +454 i386 set_mempolicy2 sys_set_mempolicy2 > +455 i386 get_mempolicy2 sys_get_mempolicy2 > diff --git a/arch/x86/entry/syscalls/syscall_64.tbl b/arch/x86/entry/syscalls/syscall_64.tbl > index 1d6eee30eceb..ec54064de8b3 100644 > --- a/arch/x86/entry/syscalls/syscall_64.tbl > +++ b/arch/x86/entry/syscalls/syscall_64.tbl > @@ -375,6 +375,8 @@ > 451 common cachestat sys_cachestat > 452 common fchmodat2 sys_fchmodat2 > 453 64 map_shadow_stack sys_map_shadow_stack > +454 common set_mempolicy2 sys_set_mempolicy2 > +455 common get_mempolicy2 sys_get_mempolicy2 > > # > # Due to a historical design error, certain syscalls are numbered differently > diff --git a/include/linux/syscalls.h b/include/linux/syscalls.h > index 22bc6bc147f8..d50a452954ae 100644 > --- a/include/linux/syscalls.h > +++ b/include/linux/syscalls.h > @@ -813,6 +813,8 @@ asmlinkage long sys_get_mempolicy(int __user *policy, > unsigned long addr, unsigned long flags); > asmlinkage long sys_set_mempolicy(int mode, const unsigned long __user *nmask, > unsigned long maxnode); > +asmlinkage long sys_get_mempolicy2(struct mempolicy_args __user *args); > +asmlinkage long sys_set_mempolicy2(struct mempolicy_args __user *args); > asmlinkage long sys_migrate_pages(pid_t pid, unsigned long maxnode, > const unsigned long __user *from, > const unsigned long __user *to); > diff --git a/include/uapi/asm-generic/unistd.h b/include/uapi/asm-generic/unistd.h > index abe087c53b4b..397dcf804941 100644 > --- a/include/uapi/asm-generic/unistd.h > +++ b/include/uapi/asm-generic/unistd.h > @@ -823,8 +823,16 @@ __SYSCALL(__NR_cachestat, sys_cachestat) > #define __NR_fchmodat2 452 > __SYSCALL(__NR_fchmodat2, sys_fchmodat2) > > +/* CONFIG_MMU only */ > +#ifndef __ARCH_NOMMU > +#define __NR_set_mempolicy 454 > +__SYSCALL(__NR_set_mempolicy2, sys_set_mempolicy2) > +#define __NR_set_mempolicy 455 > +__SYSCALL(__NR_get_mempolicy2, sys_get_mempolicy2) > +#endif > + > #undef __NR_syscalls > -#define __NR_syscalls 453 > +#define __NR_syscalls 456 +3 for 2 additions? > > /* > * 32 bit systems traditionally used different > diff --git a/include/uapi/linux/mempolicy.h b/include/uapi/linux/mempolicy.h > index 046d0ccba4cd..53650f69db2b 100644 > --- a/include/uapi/linux/mempolicy.h > +++ b/include/uapi/linux/mempolicy.h > @@ -23,9 +23,41 @@ enum { > MPOL_INTERLEAVE, > MPOL_LOCAL, > MPOL_PREFERRED_MANY, > + MPOL_LEGACY, /* set_mempolicy limited to above modes */ > MPOL_MAX, /* always last member of enum */ > }; > > +struct mempolicy_args { > + int err; > + unsigned short mode; > + unsigned long *nodemask; > + unsigned long maxnode; > + unsigned short flags; > + struct { > + /* Memory allowed */ > + struct { > + int err; > + unsigned long maxnode; > + unsigned long *nodemask; > + } allowed; > + /* Address information */ > + struct { > + int err; > + unsigned long addr; > + unsigned long node; > + unsigned short mode; > + unsigned short flags; > + } addr; > + /* Interleave */ > + } get; > + /* Mode specific settings */ > + union { > + struct { > + unsigned long next_node; /* get only */ > + } interleave; > + }; > +}; > + > /* Flags for set_mempolicy */ > #define MPOL_F_STATIC_NODES (1 << 15) > #define MPOL_F_RELATIVE_NODES (1 << 14) > diff --git a/mm/mempolicy.c b/mm/mempolicy.c > index f49337f6f300..1cf7709400f1 100644 > --- a/mm/mempolicy.c > +++ b/mm/mempolicy.c > @@ -1483,7 +1483,7 @@ static inline int sanitize_mpol_flags(int *mode, unsigned short *flags) > *flags = *mode & MPOL_MODE_FLAGS; > *mode &= ~MPOL_MODE_FLAGS; > > - if ((unsigned int)(*mode) >= MPOL_MAX) > + if ((unsigned int)(*mode) >= MPOL_LEGACY) > return -EINVAL; > if ((*flags & MPOL_F_STATIC_NODES) && (*flags & MPOL_F_RELATIVE_NODES)) > return -EINVAL; > @@ -1614,6 +1614,219 @@ SYSCALL_DEFINE3(set_mempolicy, int, mode, const unsigned long __user *, nmask, > return kernel_set_mempolicy(mode, nmask, maxnode); > } > > +static long do_set_mempolicy2(struct mempolicy_args *args) > +{ > + struct mempolicy *new = NULL; > + nodemask_t nodes; > + int err; > + > + if (args->mode <= MPOL_LEGACY) > + return -EINVAL; > + > + if (args->mode >= MPOL_MAX) > + return -EINVAL; > + > + err = get_nodes(&nodes, args->nodemask, args->maxnode); > + if (err) > + return err; > + > + new = mpol_new(args->mode, args->flags, &nodes); > + if (IS_ERR(new)) { > + err = PTR_ERR(new); > + goto out; I'd expect mpol_new() to be side effect free on error, so return PTR_ERR(new); should be fine? > + } > + > + switch (args->mode) { > + default: > + BUG(); > + } > + > + if (err) > + goto out; > + > + err = swap_mempolicy(new, &nodes); > +out: > + if (err && new) as IS_ERR(new) is true, I think this puts the node even if mpol_new returned an error. That seems unwise. I'd push this block below a return 0 anyway, so as to avoid error handling in the good path. > + mpol_put(new); > + return err; > +}; > + > +static bool mempolicy2_args_valid(struct mempolicy_args *kargs) > +{ > + /* Legacy modes are routed through the legacy interface */ > + if (kargs->mode <= MPOL_LEGACY) > + return false; > + > + if (kargs->mode >= MPOL_MAX) > + return false; > + > + return true; This is a range check, so I think equally clear (and shorter) as.. /* Legacy modes are routed through the legacy interface */ return kargs->mode > MPOL_LEGACY && kargs->mode < MPOL_MAX; > +} > + > +static long kernel_set_mempolicy2(const struct mempolicy_args __user *uargs, > + size_t usize) > +{ > + struct mempolicy_args kargs; > + int err; > + > + if (usize != sizeof(kargs)) As below, maybe allow for bigger with assumption we'll ignore what is in the extra space. > + return -EINVAL; > + > + err = copy_struct_from_user(&kargs, sizeof(kargs), uargs, usize); > + if (err) > + return err; > + > + /* If the mode is legacy, use the legacy path */ > + if (kargs.mode < MPOL_LEGACY) { > + int legacy_mode = kargs.mode | kargs.flags; > + const unsigned long __user *lnmask = kargs.nodemask; > + unsigned long maxnode = kargs.maxnode; > + > + return kernel_set_mempolicy(legacy_mode, lnmask, maxnode); > + } > + > + if (!mempolicy2_args_valid(&kargs)) > + return -EINVAL; > + > + return do_set_mempolicy2(&kargs); > +} > + > +SYSCALL_DEFINE2(set_mempolicy2, const struct mempolicy_args __user *, args, > + size_t, size) > +{ > + return kernel_set_mempolicy2(args, size); > +} > + > +/* Gets extended mempolicy information */ > +static long do_get_mempolicy2(struct mempolicy_args *kargs) > +{ > + struct mempolicy *pol = current->mempolicy; > + nodemask_t knodes; > + int err = 0; > + > + kargs->err = 0; > + kargs->mode = pol->mode; > + /* Mask off internal flags */ > + kargs->flags = (pol->flags & MPOL_MODE_FLAGS); Excessive brackets. > + > + if (kargs->nodemask) { > + if (mpol_store_user_nodemask(pol)) { > + knodes = pol->w.user_nodemask; > + } else { > + task_lock(current); > + get_policy_nodemask(pol, &knodes); > + task_unlock(current); > + } > + err = copy_nodes_to_user(kargs->nodemask, > + kargs->maxnode, > + &knodes); Can wrap this less. > + if (err) return err ? > + return -EINVAL; > + } > + > + > + if (kargs->get.allowed.nodemask) { > + kargs->get.allowed.err = 0; > + task_lock(current); > + knodes = cpuset_current_mems_allowed; > + task_unlock(current); > + err = copy_nodes_to_user(kargs->get.allowed.nodemask, > + kargs->get.allowed.maxnode, > + &knodes); > + kargs->get.allowed.err = err ? err : 0; > + kargs->err |= err ? err : 1; if (err) { kargs->get.allowed.err = err; kargs->err |= err; } else { kargs->get.allowed.err = 0; kargs->err = 1; Not particularly obvious why 1 and if you get an error later it's going to be messy as will 1 |= err_code } > + } > + > + if (kargs->get.addr.addr) { > + struct mempolicy *addr_pol = NULL; Why init here - I think it's always set before use. > + struct vm_area_struct *vma = NULL; Why init here? > + struct mm_struct *mm = current->mm; > + unsigned long addr = kargs->get.addr.addr; > + > + kargs->get.addr.err = 0; I'd set this only in the good path. You overwrite it in the bad paths anyway, so just move it down below the error checks. > + > + /* > + * Do NOT fall back to task policy if the > + * vma/shared policy at addr is NULL. We > + * want to return MPOL_DEFAULT in this case. > + */ > + mmap_read_lock(mm); > + vma = vma_lookup(mm, addr); > + if (!vma) { > + mmap_read_unlock(mm); > + kargs->get.addr.err = -EFAULT; > + kargs->err |= err ? err : 2; > + goto mode_info; > + } > + if (vma->vm_ops && vma->vm_ops->get_policy) > + addr_pol = vma->vm_ops->get_policy(vma, addr); > + else > + addr_pol = vma->vm_policy; > + > + kargs->get.addr.mode = addr_pol->mode; > + /* Mask off internal flags */ > + kargs->get.addr.flags = (pol->flags & MPOL_MODE_FLAGS); > + > + /* > + * Take a refcount on the mpol, because we are about to > + * drop the mmap_lock, after which only "pol" remains > + * valid, "vma" is stale. > + */ > + vma = NULL; > + mpol_get(addr_pol); > + mmap_read_unlock(mm); > + err = lookup_node(mm, addr); > + mpol_put(addr_pol); > + if (err < 0) { > + kargs->get.addr.err = err; > + kargs->err |= err ? err : 4; > + goto mode_info; > + } > + kargs->get.addr.node = err; Confusing to call something that isn't an error, err. I'd use a different local variable for this and set err = rc in error path only. Could set the get.addr.err = 0; down here as this is only way it remains 0 if you set it earlier. > + } > + > +mode_info: > + switch (kargs->mode) { > + case MPOL_INTERLEAVE: > + kargs->interleave.next_node = next_node_in(current->il_prev, > + pol->nodes); > + break; > + default: > + break; > + } > + > + return err; > +} > + > +static long kernel_get_mempolicy2(struct mempolicy_args __user *uargs, > + size_t usize) > +{ > + struct mempolicy_args kargs; > + int err; > + > + if (usize != sizeof(struct mempolicy_args)) sizeof(kargs) for same reason as below. I'm not sure on convention here but is it wise to leave option for a newer userspace to send a larger struct, knowing that fields in it might be ignored by an older kernel? > + return -EINVAL; > + > + err = copy_struct_from_user(&kargs, sizeof(kargs), uargs, usize); > + if (err) > + return err; > + > + /* Get the extended memory policy information (kargs.ext) */ > + err = do_get_mempolicy2(&kargs); > + if (err) > + return err; > + > + err = copy_to_user(uargs, &kargs, sizeof(struct mempolicy_args)); > + > + return err; return copy_to_user(uargs, &kargs, sizeof(kargs)); You are inconsistent on the sizeof. Better to pick one style, and given both are used, I'd go with using the sizeof(thing) rather than sizeof(type) option + shorter lines ;) > +} > + > +SYSCALL_DEFINE2(get_mempolicy2, struct mempolicy_args __user *, policy, > + size_t, size) > +{ > + return kernel_get_mempolicy2(policy, size); > +} > + > static int kernel_migrate_pages(pid_t pid, unsigned long maxnode, > const unsigned long __user *old_nodes, > const unsigned long __user *new_nodes)