Implement UMCG server/worker API. This is an early RFC patch - the code seems working, but more testing is needed. Gaps I plan to address before this is ready for a detailed review: - preemption/interrupt handling; - better documentation/comments; - tracing; - additional testing; - corner cases like abnormal process/task termination; - in some cases where I kill the task (umcg_segv), returning an error may be more appropriate. All in all, please focus more on the high-level approach and less on things like variable names, (doc) comments, or indentation. Signed-off-by: Peter Oskolkov <posk@xxxxxxxxxx> --- include/linux/mm_types.h | 5 + include/linux/syscalls.h | 5 + kernel/fork.c | 11 + kernel/sched/core.c | 11 + kernel/sched/umcg.c | 764 ++++++++++++++++++++++++++++++++++++++- kernel/sched/umcg.h | 54 +++ mm/init-mm.c | 4 + 7 files changed, 845 insertions(+), 9 deletions(-) diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h index 6613b26a8894..5ca7b7d55775 100644 --- a/include/linux/mm_types.h +++ b/include/linux/mm_types.h @@ -562,6 +562,11 @@ struct mm_struct { #ifdef CONFIG_IOMMU_SUPPORT u32 pasid; #endif + +#ifdef CONFIG_UMCG + spinlock_t umcg_lock; + struct list_head umcg_groups; +#endif } __randomize_layout; /* diff --git a/include/linux/syscalls.h b/include/linux/syscalls.h index 15de3e34ccee..2781659daaf1 100644 --- a/include/linux/syscalls.h +++ b/include/linux/syscalls.h @@ -1059,6 +1059,11 @@ asmlinkage long umcg_wait(u32 flags, const struct __kernel_timespec __user *time asmlinkage long umcg_wake(u32 flags, u32 next_tid); asmlinkage long umcg_swap(u32 wake_flags, u32 next_tid, u32 wait_flags, const struct __kernel_timespec __user *timeout); +asmlinkage long umcg_create_group(u32 api_version, u64, flags); +asmlinkage long umcg_destroy_group(u32 group_id); +asmlinkage long umcg_poll_worker(u32 flags, struct umcg_task __user **ut); +asmlinkage long umcg_run_worker(u32 flags, u32 worker_tid, + struct umcg_task __user **ut); /* * Architecture-specific system calls diff --git a/kernel/fork.c b/kernel/fork.c index ace4631b5b54..3a2a7950df8e 100644 --- a/kernel/fork.c +++ b/kernel/fork.c @@ -1026,6 +1026,10 @@ static struct mm_struct *mm_init(struct mm_struct *mm, struct task_struct *p, seqcount_init(&mm->write_protect_seq); mmap_init_lock(mm); INIT_LIST_HEAD(&mm->mmlist); +#ifdef CONFIG_UMCG + spin_lock_init(&mm->umcg_lock); + INIT_LIST_HEAD(&mm->umcg_groups); +#endif mm->core_state = NULL; mm_pgtables_bytes_init(mm); mm->map_count = 0; @@ -1102,6 +1106,13 @@ static inline void __mmput(struct mm_struct *mm) list_del(&mm->mmlist); spin_unlock(&mmlist_lock); } +#ifdef CONFIG_UMCG + if (!list_empty(&mm->umcg_groups)) { + spin_lock(&mm->umcg_lock); + list_del(&mm->umcg_groups); + spin_unlock(&mm->umcg_lock); + } +#endif if (mm->binfmt) module_put(mm->binfmt->module); mmdrop(mm); diff --git a/kernel/sched/core.c b/kernel/sched/core.c index 462104f13c28..e657a35655b1 100644 --- a/kernel/sched/core.c +++ b/kernel/sched/core.c @@ -26,6 +26,7 @@ #include "pelt.h" #include "smp.h" +#include "umcg.h" /* * Export tracepoints that act as a bare tracehook (ie: have no trace event @@ -6012,10 +6013,20 @@ static inline void sched_submit_work(struct task_struct *tsk) */ if (blk_needs_flush_plug(tsk)) blk_schedule_flush_plug(tsk); + +#ifdef CONFIG_UMCG + if (rcu_access_pointer(tsk->umcg_task_data)) + umcg_on_block(); +#endif } static void sched_update_worker(struct task_struct *tsk) { +#ifdef CONFIG_UMCG + if (rcu_access_pointer(tsk->umcg_task_data)) + umcg_on_wake(); +#endif + if (tsk->flags & (PF_WQ_WORKER | PF_IO_WORKER)) { if (tsk->flags & PF_WQ_WORKER) wq_worker_running(tsk); diff --git a/kernel/sched/umcg.c b/kernel/sched/umcg.c index 2d718433c773..38cba772322d 100644 --- a/kernel/sched/umcg.c +++ b/kernel/sched/umcg.c @@ -21,6 +21,12 @@ static int __api_version(u32 requested) return 1; } +static int umcg_segv(int res) +{ + force_sig(SIGSEGV); + return res; +} + /** * sys_umcg_api_version - query UMCG API versions that are supported. * @api_version: Requested API version. @@ -54,6 +60,78 @@ static int put_state(struct umcg_task __user *ut, u32 state) return put_user(state, (u32 __user *)ut); } +static void umcg_lock_pair(struct task_struct *server, + struct task_struct *worker) +{ + spin_lock(&server->alloc_lock); + spin_lock_nested(&worker->alloc_lock, SINGLE_DEPTH_NESTING); +} + +static void umcg_unlock_pair(struct task_struct *server, + struct task_struct *worker) +{ + spin_unlock(&worker->alloc_lock); + spin_unlock(&server->alloc_lock); +} + +static void umcg_detach_peer(void) +{ + struct task_struct *server, *worker; + struct umcg_task_data *utd; + + rcu_read_lock(); + task_lock(current); + utd = rcu_dereference(current->umcg_task_data); + + if (!utd || !rcu_dereference(utd->peer)) { + task_unlock(current); + goto out; + } + + switch (utd->task_type) { + case UMCG_TT_SERVER: + server = current; + worker = rcu_dereference(utd->peer); + break; + + case UMCG_TT_WORKER: + worker = current; + server = rcu_dereference(utd->peer); + break; + + default: + task_unlock(current); + printk(KERN_WARNING "umcg_detach_peer: unexpected task type"); + umcg_segv(0); + goto out; + } + task_unlock(current); + + if (!server || !worker) + goto out; + + umcg_lock_pair(server, worker); + + utd = rcu_dereference(server->umcg_task_data); + if (WARN_ON(!utd)) { + umcg_segv(0); + goto out_pair; + } + rcu_assign_pointer(utd->peer, NULL); + + utd = rcu_dereference(worker->umcg_task_data); + if (WARN_ON(!utd)) { + umcg_segv(0); + goto out_pair; + } + rcu_assign_pointer(utd->peer, NULL); + +out_pair: + umcg_unlock_pair(server, worker); +out: + rcu_read_unlock(); +} + static int register_core_task(u32 api_version, struct umcg_task __user *umcg_task) { struct umcg_task_data *utd; @@ -73,6 +151,7 @@ static int register_core_task(u32 api_version, struct umcg_task __user *umcg_tas utd->umcg_task = umcg_task; utd->task_type = UMCG_TT_CORE; utd->api_version = api_version; + RCU_INIT_POINTER(utd->peer, NULL); if (put_state(umcg_task, UMCG_TASK_RUNNING)) { kfree(utd); @@ -86,6 +165,105 @@ static int register_core_task(u32 api_version, struct umcg_task __user *umcg_tas return 0; } +static int add_task_to_group(u32 api_version, u32 group_id, + struct umcg_task __user *umcg_task, + enum umcg_task_type task_type, u32 new_state) +{ + struct mm_struct *mm = current->mm; + struct umcg_task_data *utd = NULL; + struct umcg_group *group = NULL; + struct umcg_group *list_entry; + int ret = -EINVAL; + u32 state; + + if (get_state(umcg_task, &state)) + return -EFAULT; + + if (state != UMCG_TASK_NONE) + return -EINVAL; + + if (put_state(umcg_task, new_state)) + return -EFAULT; + +retry_once: + rcu_read_lock(); + list_for_each_entry_rcu(list_entry, &mm->umcg_groups, list) { + if (list_entry->group_id == group_id) { + group = list_entry; + break; + } + } + + if (!group || group->api_version != api_version) + goto out_rcu; + + spin_lock(&group->lock); + if (group->nr_tasks < 0) /* The groups is being destroyed. */ + goto out_group; + + if (!utd) { + utd = kzalloc(sizeof(struct umcg_task_data), GFP_NOWAIT); + if (!utd) { + spin_unlock(&group->lock); + rcu_read_unlock(); + + utd = kzalloc(sizeof(struct umcg_task_data), GFP_KERNEL); + if (!utd) { + ret = -ENOMEM; + goto out; + } + + goto retry_once; + } + } + + utd->self = current; + utd->group = group; + utd->umcg_task = umcg_task; + utd->task_type = task_type; + utd->api_version = api_version; + RCU_INIT_POINTER(utd->peer, NULL); + + INIT_LIST_HEAD(&utd->list); + group->nr_tasks++; + + task_lock(current); + rcu_assign_pointer(current->umcg_task_data, utd); + task_unlock(current); + + ret = 0; + +out_group: + spin_unlock(&group->lock); + +out_rcu: + rcu_read_unlock(); + if (ret && utd) + kfree(utd); + +out: + if (ret) + put_state(umcg_task, UMCG_TASK_NONE); + else + schedule(); /* Trigger umcg_on_wake(). */ + + return ret; +} + +static int register_worker(u32 api_version, u32 group_id, + struct umcg_task __user *umcg_task) +{ + return add_task_to_group(api_version, group_id, umcg_task, + UMCG_TT_WORKER, UMCG_TASK_UNBLOCKED); +} + +static int register_server(u32 api_version, u32 group_id, + struct umcg_task __user *umcg_task) +{ + return add_task_to_group(api_version, group_id, umcg_task, + UMCG_TT_SERVER, UMCG_TASK_PROCESSING); +} + /** * sys_umcg_register_task - register the current task as a UMCG task. * @api_version: The expected/desired API version of the syscall. @@ -122,6 +300,10 @@ SYSCALL_DEFINE4(umcg_register_task, u32, api_version, u32, flags, u32, group_id, if (group_id != UMCG_NOID) return -EINVAL; return register_core_task(api_version, umcg_task); + case UMCG_REGISTER_WORKER: + return register_worker(api_version, group_id, umcg_task); + case UMCG_REGISTER_SERVER: + return register_server(api_version, group_id, umcg_task); default: return -EINVAL; } @@ -146,9 +328,39 @@ SYSCALL_DEFINE1(umcg_unregister_task, u32, flags) if (!utd || flags) goto out; + if (!utd->group) { + ret = 0; + goto out; + } + + if (utd->task_type == UMCG_TT_WORKER) { + struct task_struct *server = rcu_dereference(utd->peer); + + if (server) { + umcg_detach_peer(); + if (WARN_ON(!wake_up_process(server))) { + umcg_segv(0); + goto out; + } + } + } else { + if (WARN_ON(utd->task_type != UMCG_TT_SERVER)) { + umcg_segv(0); + goto out; + } + + umcg_detach_peer(); + } + + spin_lock(&utd->group->lock); task_lock(current); + rcu_assign_pointer(current->umcg_task_data, NULL); + + --utd->group->nr_tasks; + task_unlock(current); + spin_unlock(&utd->group->lock); ret = 0; @@ -164,6 +376,7 @@ SYSCALL_DEFINE1(umcg_unregister_task, u32, flags) static int do_context_switch(struct task_struct *next) { struct umcg_task_data *utd = rcu_access_pointer(current->umcg_task_data); + bool prev_wait_flag; /* See comment in do_wait() below. */ /* * It is important to set_current_state(TASK_INTERRUPTIBLE) before @@ -173,34 +386,51 @@ static int do_context_switch(struct task_struct *next) */ set_current_state(TASK_INTERRUPTIBLE); - WRITE_ONCE(utd->in_wait, true); - + prev_wait_flag = utd->in_wait; + if (!prev_wait_flag) + WRITE_ONCE(utd->in_wait, true); + if (!try_to_wake_up(next, TASK_NORMAL, WF_CURRENT_CPU)) return -EAGAIN; freezable_schedule(); - WRITE_ONCE(utd->in_wait, false); + if (!prev_wait_flag) + WRITE_ONCE(utd->in_wait, false); if (signal_pending(current)) return -EINTR; + /* TODO: deal with non-fatal interrupts. */ return 0; } static int do_wait(void) { struct umcg_task_data *utd = rcu_access_pointer(current->umcg_task_data); + /* + * freezable_schedule() below can recursively call do_wait() if + * this is a worker that needs a server. As the wait flag is only + * used by the outermost wait/wake (and swap) syscalls, modify it only + * in the outermost do_wait() instead of using a counter. + * + * Note that the nesting level is at most two, as utd->in_workqueue + * is used to prevent further nesting. + */ + bool prev_wait_flag; if (!utd) return -EINVAL; - WRITE_ONCE(utd->in_wait, true); + prev_wait_flag = utd->in_wait; + if (!prev_wait_flag) + WRITE_ONCE(utd->in_wait, true); set_current_state(TASK_INTERRUPTIBLE); freezable_schedule(); - WRITE_ONCE(utd->in_wait, false); + if (!prev_wait_flag) + WRITE_ONCE(utd->in_wait, false); if (signal_pending(current)) return -EINTR; @@ -214,7 +444,7 @@ static int do_wait(void) * @timeout: The absolute timeout of the wait. Not supported yet. * Must be NULL. * - * Sleep until woken, interrupted, or @timeout expires. + * Sleep until woken or @timeout expires. * * Return: * 0 - Ok; @@ -229,6 +459,7 @@ SYSCALL_DEFINE2(umcg_wait, u32, flags, const struct __kernel_timespec __user *, timeout) { struct umcg_task_data *utd; + struct task_struct *server = NULL; if (flags) return -EINVAL; @@ -242,8 +473,14 @@ SYSCALL_DEFINE2(umcg_wait, u32, flags, return -EINVAL; } + if (utd->task_type == UMCG_TT_WORKER) + server = rcu_dereference(utd->peer); + rcu_read_unlock(); + if (server) + return do_context_switch(server); + return do_wait(); } @@ -252,7 +489,7 @@ SYSCALL_DEFINE2(umcg_wait, u32, flags, * @flags: Reserved. * @next_tid: The ID of the task to wake. * - * Wake @next identified by @next_tid. @next must be either a UMCG core + * Wake task next identified by @next_tid. @next must be either a UMCG core * task or a UMCG worker task. * * Return: @@ -265,7 +502,7 @@ SYSCALL_DEFINE2(umcg_wait, u32, flags, SYSCALL_DEFINE2(umcg_wake, u32, flags, u32, next_tid) { struct umcg_task_data *next_utd; - struct task_struct *next; + struct task_struct *next, *next_peer; int ret = -EINVAL; if (!next_tid) @@ -282,11 +519,29 @@ SYSCALL_DEFINE2(umcg_wake, u32, flags, u32, next_tid) if (!next_utd) goto out; + if (next_utd->task_type == UMCG_TT_SERVER) + goto out; + if (!READ_ONCE(next_utd->in_wait)) { ret = -EAGAIN; goto out; } + next_peer = rcu_dereference(next_utd->peer); + if (next_peer) { + if (next_peer == current) + umcg_detach_peer(); + else { + /* + * Waking a worker with an assigned server is not + * permitted, unless the waking is done by the assigned + * server. + */ + umcg_segv(0); + goto out; + } + } + ret = wake_up_process(next); put_task_struct(next); if (ret) @@ -348,7 +603,7 @@ SYSCALL_DEFINE4(umcg_swap, u32, wake_flags, u32, next_tid, u32, wait_flags, } next_utd = rcu_dereference(next->umcg_task_data); - if (!next_utd) { + if (!next_utd || next_utd->group != curr_utd->group) { ret = -EINVAL; goto out; } @@ -358,6 +613,25 @@ SYSCALL_DEFINE4(umcg_swap, u32, wake_flags, u32, next_tid, u32, wait_flags, goto out; } + /* Move the server from curr to next, if appropriate. */ + if (curr_utd->task_type == UMCG_TT_WORKER) { + struct task_struct *server = rcu_dereference(curr_utd->peer); + if (server) { + struct umcg_task_data *server_utd = + rcu_dereference(server->umcg_task_data); + + if (rcu_access_pointer(next_utd->peer)) { + ret = -EAGAIN; + goto out; + } + umcg_detach_peer(); + umcg_lock_pair(server, next); + rcu_assign_pointer(server_utd->peer, next); + rcu_assign_pointer(next_utd->peer, server); + umcg_unlock_pair(server, next); + } + } + rcu_read_unlock(); return do_context_switch(next); @@ -366,3 +640,475 @@ SYSCALL_DEFINE4(umcg_swap, u32, wake_flags, u32, next_tid, u32, wait_flags, rcu_read_unlock(); return ret; } + +/** + * sys_umcg_create_group - create a UMCG group + * @api_version: Requested API version. + * @flags: Reserved. + * + * Return: + * >= 0 - the group ID + * -EOPNOTSUPP - @api_version is not supported + * -EINVAL - @flags is not valid + * -ENOMEM - not enough memory + */ +SYSCALL_DEFINE2(umcg_create_group, u32, api_version, u64, flags) +{ + int ret; + struct umcg_group *group; + struct umcg_group *list_entry; + struct mm_struct *mm = current->mm; + + if (flags) + return -EINVAL; + + if (__api_version(api_version)) + return -EOPNOTSUPP; + + group = kzalloc(sizeof(struct umcg_group), GFP_KERNEL); + if (!group) + return -ENOMEM; + + spin_lock_init(&group->lock); + INIT_LIST_HEAD(&group->list); + INIT_LIST_HEAD(&group->waiters); + group->flags = flags; + group->api_version = api_version; + + spin_lock(&mm->umcg_lock); + + list_for_each_entry_rcu(list_entry, &mm->umcg_groups, list) { + if (list_entry->group_id >= group->group_id) + group->group_id = list_entry->group_id + 1; + } + + list_add_rcu(&mm->umcg_groups, &group->list); + + ret = group->group_id; + spin_unlock(&mm->umcg_lock); + + return ret; +} + +/** + * sys_umcg_destroy_group - destroy a UMCG group + * @group_id: The ID of the group to destroy. + * + * The group must be empty, i.e. have no registered servers or workers. + * + * Return: + * 0 - success; + * -ESRCH - group not found; + * -EBUSY - the group has registered workers or servers. + */ +SYSCALL_DEFINE1(umcg_destroy_group, u32, group_id) +{ + int ret = 0; + struct umcg_group *group = NULL; + struct umcg_group *list_entry; + struct mm_struct *mm = current->mm; + + spin_lock(&mm->umcg_lock); + list_for_each_entry_rcu(list_entry, &mm->umcg_groups, list) { + if (list_entry->group_id == group_id) { + group = list_entry; + break; + } + } + + if (group == NULL) { + ret = -ESRCH; + goto out; + } + + spin_lock(&group->lock); + + if (group->nr_tasks > 0) { + ret = -EBUSY; + spin_unlock(&group->lock); + goto out; + } + + /* Tell group rcu readers that the group is going to be deleted. */ + group->nr_tasks = -1; + + spin_unlock(&group->lock); + + list_del_rcu(&group->list); + kfree_rcu(group, rcu); + +out: + spin_unlock(&mm->umcg_lock); + return ret; +} + +/** + * sys_umcg_poll_worker - poll an UNBLOCKED worker + * @flags: reserved; + * @ut: the control struct umcg_task of the polled worker. + * + * The current task must be a UMCG server in POLLING state; if there are + * UNBLOCKED workers in the server's group, take the earliest queued, + * mark the worker as RUNNABLE.and return. + * + * If there are no unblocked workers, the syscall waits for one to become + * available. + * + * Return: + * 0 - Ok; + * -EINTR - a signal was received; + * -EINVAL - one of the parameters is wrong, or a precondition was not met. + */ +SYSCALL_DEFINE2(umcg_poll_worker, u32, flags, struct umcg_task __user **, ut) +{ + struct umcg_group *group; + struct task_struct *worker; + struct task_struct *server = current; + struct umcg_task __user *result; + struct umcg_task_data *worker_utd, *server_utd; + + if (flags) + return -EINVAL; + + rcu_read_lock(); + + server_utd = rcu_dereference(server->umcg_task_data); + + if (!server_utd || server_utd->task_type != UMCG_TT_SERVER) { + rcu_read_unlock(); + return -EINVAL; + } + + umcg_detach_peer(); + + group = server_utd->group; + + spin_lock(&group->lock); + + if (group->nr_waiting_workers == 0) { /* Queue the server. */ + ++group->nr_waiting_pollers; + list_add_tail(&server_utd->list, &group->waiters); + set_current_state(TASK_INTERRUPTIBLE); + spin_unlock(&group->lock); + rcu_read_unlock(); + + freezable_schedule(); + + rcu_read_lock(); + server_utd = rcu_dereference(server->umcg_task_data); + + if (!list_empty(&server_utd->list)) { + spin_lock(&group->lock); + list_del_init(&server_utd->list); + --group->nr_waiting_pollers; + spin_unlock(&group->lock); + } + + if (signal_pending(current)) { + rcu_read_unlock(); + return -EINTR; + } + + worker = rcu_dereference(server_utd->peer); + if (worker) { + worker_utd = rcu_dereference(worker->umcg_task_data); + result = worker_utd->umcg_task; + } else + result = NULL; + + rcu_read_unlock(); + + if (put_user(result, ut)) + return umcg_segv(-EFAULT); + return 0; + } + + /* Pick up the first worker. */ + worker_utd = list_first_entry(&group->waiters, struct umcg_task_data, + list); + list_del_init(&worker_utd->list); + worker = worker_utd->self; + --group->nr_waiting_workers; + + umcg_lock_pair(server, worker); + spin_unlock(&group->lock); + + if (WARN_ON(rcu_access_pointer(server_utd->peer) || + rcu_access_pointer(worker_utd->peer))) { + /* This is unexpected. */ + rcu_read_unlock(); + return umcg_segv(-EINVAL); + } + rcu_assign_pointer(server_utd->peer, worker); + rcu_assign_pointer(worker_utd->peer, current); + + umcg_unlock_pair(server, worker); + + result = worker_utd->umcg_task; + rcu_read_unlock(); + + if (put_state(result, UMCG_TASK_RUNNABLE)) + return umcg_segv(-EFAULT); + + if (put_user(result, ut)) + return umcg_segv(-EFAULT); + + return 0; +} + +/** + * sys_umcg_run_worker - "run" a RUNNABLE worker as a server + * @flags: reserved; + * @worker_tid: tid of the worker to run; + * @ut: the control struct umcg_task of the worker that blocked + * during this "run". + * + * The worker must be in RUNNABLE state. The server (=current task) + * wakes the worker and blocks; when the worker, or one of the workers + * in umcg_swap chain, blocks, the server is woken and the syscall returns + * with ut indicating the blocked worker. + * + * If the worker exits or unregisters itself, the syscall succeeds with + * ut == NULL. + * + * Return: + * 0 - Ok; + * -EINTR - a signal was received; + * -EINVAL - one of the parameters is wrong, or a precondition was not met. + */ +SYSCALL_DEFINE3(umcg_run_worker, u32, flags, u32, worker_tid, + struct umcg_task __user **, ut) +{ + int ret = -EINVAL; + struct task_struct *worker; + struct task_struct *server = current; + struct umcg_task __user *result = NULL; + struct umcg_task_data *worker_utd; + struct umcg_task_data *server_utd; + struct umcg_task __user *server_ut; + struct umcg_task __user *worker_ut; + + if (!ut) + return -EINVAL; + + rcu_read_lock(); + server_utd = rcu_dereference(server->umcg_task_data); + + if (!server_utd || server_utd->task_type != UMCG_TT_SERVER) + goto out_rcu; + + if (flags) + goto out_rcu; + + worker = find_get_task_by_vpid(worker_tid); + if (!worker) { + ret = -ESRCH; + goto out_rcu; + } + + worker_utd = rcu_dereference(worker->umcg_task_data); + if (!worker_utd) + goto out_rcu; + + if (!READ_ONCE(worker_utd->in_wait)) { + ret = -EAGAIN; + goto out_rcu; + } + + if (server_utd->group != worker_utd->group) + goto out_rcu; + + if (rcu_access_pointer(server_utd->peer) != worker) + umcg_detach_peer(); + + if (!rcu_access_pointer(server_utd->peer)) { + umcg_lock_pair(server, worker); + WARN_ON(worker_utd->peer); + rcu_assign_pointer(server_utd->peer, worker); + rcu_assign_pointer(worker_utd->peer, server); + umcg_unlock_pair(server, worker); + } + + server_ut = server_utd->umcg_task; + worker_ut = server_utd->umcg_task; + + rcu_read_unlock(); + + ret = do_context_switch(worker); + if (ret) + return ret; + + rcu_read_lock(); + worker = rcu_dereference(server_utd->peer); + if (worker) { + worker_utd = rcu_dereference(worker->umcg_task_data); + if (worker_utd) + result = worker_utd->umcg_task; + } + rcu_read_unlock(); + + if (put_user(result, ut)) + return -EFAULT; + return 0; + +out_rcu: + rcu_read_unlock(); + return ret; +} + +void umcg_on_block(void) +{ + struct umcg_task_data *utd = rcu_access_pointer(current->umcg_task_data); + struct umcg_task __user *ut; + struct task_struct *server; + u32 state; + + if (utd->task_type != UMCG_TT_WORKER || utd->in_workqueue) + return; + + ut = utd->umcg_task; + + if (get_user(state, (u32 __user *)ut)) { + if (signal_pending(current)) + return; + umcg_segv(0); + return; + } + + if (state != UMCG_TASK_RUNNING) + return; + + state = UMCG_TASK_BLOCKED; + if (put_user(state, (u32 __user *)ut)) { + umcg_segv(0); + return; + } + + rcu_read_lock(); + server = rcu_dereference(utd->peer); + rcu_read_unlock(); + + if (server) + WARN_ON(!try_to_wake_up(server, TASK_NORMAL, WF_CURRENT_CPU)); +} + +/* Return true to return to the user, false to keep waiting. */ +static bool process_unblocked_worker(void) +{ + struct umcg_task_data *utd; + struct umcg_group *group; + + rcu_read_lock(); + + utd = rcu_dereference(current->umcg_task_data); + group = utd->group; + + spin_lock(&group->lock); + if (!list_empty(&utd->list)) { + /* This was a spurious wakeup or an interrupt, do nothing. */ + spin_unlock(&group->lock); + rcu_read_unlock(); + do_wait(); + return false; + } + + if (group->nr_waiting_pollers > 0) { /* Wake a server. */ + struct task_struct *server; + struct umcg_task_data *server_utd = list_first_entry( + &group->waiters, struct umcg_task_data, list); + + list_del_init(&server_utd->list); + server = server_utd->self; + --group->nr_waiting_pollers; + + umcg_lock_pair(server, current); + spin_unlock(&group->lock); + + if (WARN_ON(server_utd->peer || utd->peer)) { + umcg_segv(0); + return true; + } + rcu_assign_pointer(server_utd->peer, current); + rcu_assign_pointer(utd->peer, server); + + umcg_unlock_pair(server, current); + rcu_read_unlock(); + + if (put_state(utd->umcg_task, UMCG_TASK_RUNNABLE)) { + umcg_segv(0); + return true; + } + + do_context_switch(server); + return false; + } + + /* Add to the queue. */ + ++group->nr_waiting_workers; + list_add_tail(&utd->list, &group->waiters); + spin_unlock(&group->lock); + rcu_read_unlock(); + + do_wait(); + + smp_rmb(); + if (!list_empty(&utd->list)) { + spin_lock(&group->lock); + list_del_init(&utd->list); + --group->nr_waiting_workers; + spin_unlock(&group->lock); + } + + return false; +} + +void umcg_on_wake(void) +{ + struct umcg_task_data *utd; + struct umcg_task __user *ut; + bool should_break = false; + + /* current->umcg_task_data is modified only from current. */ + utd = rcu_access_pointer(current->umcg_task_data); + if (utd->task_type != UMCG_TT_WORKER || utd->in_workqueue) + return; + + do { + u32 state; + + if (fatal_signal_pending(current)) + return; + + if (signal_pending(current)) + return; + + ut = utd->umcg_task; + + if (get_state(ut, &state)) { + if (signal_pending(current)) + return; + goto segv; + } + + if (state == UMCG_TASK_RUNNING && rcu_access_pointer(utd->peer)) + return; + + if (state == UMCG_TASK_BLOCKED || state == UMCG_TASK_RUNNING) { + state = UMCG_TASK_UNBLOCKED; + if (put_state(ut, state)) + goto segv; + } else if (state != UMCG_TASK_UNBLOCKED) { + goto segv; + } + + utd->in_workqueue = true; + should_break = process_unblocked_worker(); + utd->in_workqueue = false; + if (should_break) + return; + + } while (!should_break); + +segv: + umcg_segv(0); +} diff --git a/kernel/sched/umcg.h b/kernel/sched/umcg.h index 6791d570f622..92012a1674ab 100644 --- a/kernel/sched/umcg.h +++ b/kernel/sched/umcg.h @@ -8,6 +8,34 @@ #include <linux/sched.h> #include <linux/umcg.h> +struct umcg_group { + struct list_head list; + u32 group_id; /* Never changes. */ + u32 api_version; /* Never changes. */ + u64 flags; /* Never changes. */ + + spinlock_t lock; + + /* + * One of the counters below is always zero. The non-zero counter + * indicates the number of elements in @waiters below. + */ + int nr_waiting_workers; + int nr_waiting_pollers; + + /* + * The list below either contains UNBLOCKED workers waiting + * for the userspace to poll or run them if nr_waiting_workers > 0, + * or polling servers waiting for unblocked workers if + * nr_waiting_pollers > 0. + */ + struct list_head waiters; + + int nr_tasks; /* The total number of tasks registered. */ + + struct rcu_head rcu; +}; + enum umcg_task_type { UMCG_TT_CORE = 1, UMCG_TT_SERVER = 2, @@ -32,11 +60,37 @@ struct umcg_task_data { */ u32 api_version; + /* NULL for core API tasks. Never changes. */ + struct umcg_group *group; + + /* + * If this is a server task, points to its assigned worker, if any; + * if this is a worker task, points to its assigned server, if any. + * + * Protected by alloc_lock of the task owning this struct. + * + * Always either NULL, or the server and the worker point to each other. + * Locking order: first lock the server, then the worker. + * + * Either the worker or the server should be the current task when + * this field is changed, with the exception of sys_umcg_swap. + */ + struct task_struct __rcu *peer; + + /* Used in umcg_group.waiters. */ + struct list_head list; + + /* Used by curr in umcg_on_block/wake to prevent nesting/recursion. */ + bool in_workqueue; + /* * Used by wait/wake routines to handle races. Written only by current. */ bool in_wait; }; +void umcg_on_block(void); +void umcg_on_wake(void); + #endif /* CONFIG_UMCG */ #endif /* _KERNEL_SCHED_UMCG_H */ diff --git a/mm/init-mm.c b/mm/init-mm.c index 153162669f80..85e4a8ecfd91 100644 --- a/mm/init-mm.c +++ b/mm/init-mm.c @@ -36,6 +36,10 @@ struct mm_struct init_mm = { .page_table_lock = __SPIN_LOCK_UNLOCKED(init_mm.page_table_lock), .arg_lock = __SPIN_LOCK_UNLOCKED(init_mm.arg_lock), .mmlist = LIST_HEAD_INIT(init_mm.mmlist), +#ifdef CONFIG_UMCG + .umcg_lock = __SPIN_LOCK_UNLOCKED(init_mm.umcg_lock), + .umcg_groups = LIST_HEAD_INIT(init_mm.umcg_groups), +#endif .user_ns = &init_user_ns, .cpu_bitmap = CPU_BITS_NONE, INIT_MM_CONTEXT(init_mm) -- 2.31.1.818.g46aad6cb9e-goog