[RFC PATCH v0.1 7/9] sched/umcg: add UMCG server/worker API (early RFC)

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



Implement UMCG server/worker API.

This is an early RFC patch - the code seems working, but
more testing is needed. Gaps I plan to address before this
is ready for a detailed review:

- preemption/interrupt handling;
- better documentation/comments;
- tracing;
- additional testing;
- corner cases like abnormal process/task termination;
- in some cases where I kill the task (umcg_segv), returning
an error may be more appropriate.

All in all, please focus more on the high-level approach
and less on things like variable names, (doc) comments, or indentation.

Signed-off-by: Peter Oskolkov <posk@xxxxxxxxxx>
---
 include/linux/mm_types.h |   5 +
 include/linux/syscalls.h |   5 +
 kernel/fork.c            |  11 +
 kernel/sched/core.c      |  11 +
 kernel/sched/umcg.c      | 764 ++++++++++++++++++++++++++++++++++++++-
 kernel/sched/umcg.h      |  54 +++
 mm/init-mm.c             |   4 +
 7 files changed, 845 insertions(+), 9 deletions(-)

diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h
index 6613b26a8894..5ca7b7d55775 100644
--- a/include/linux/mm_types.h
+++ b/include/linux/mm_types.h
@@ -562,6 +562,11 @@ struct mm_struct {
 #ifdef CONFIG_IOMMU_SUPPORT
 		u32 pasid;
 #endif
+
+#ifdef CONFIG_UMCG
+	spinlock_t umcg_lock;
+	struct list_head umcg_groups;
+#endif
 	} __randomize_layout;
 
 	/*
diff --git a/include/linux/syscalls.h b/include/linux/syscalls.h
index 15de3e34ccee..2781659daaf1 100644
--- a/include/linux/syscalls.h
+++ b/include/linux/syscalls.h
@@ -1059,6 +1059,11 @@ asmlinkage long umcg_wait(u32 flags, const struct __kernel_timespec __user *time
 asmlinkage long umcg_wake(u32 flags, u32 next_tid);
 asmlinkage long umcg_swap(u32 wake_flags, u32 next_tid, u32 wait_flags,
 				const struct __kernel_timespec __user *timeout);
+asmlinkage long umcg_create_group(u32 api_version, u64, flags);
+asmlinkage long umcg_destroy_group(u32 group_id);
+asmlinkage long umcg_poll_worker(u32 flags, struct umcg_task __user **ut);
+asmlinkage long umcg_run_worker(u32 flags, u32 worker_tid,
+		struct umcg_task __user **ut);
 
 /*
  * Architecture-specific system calls
diff --git a/kernel/fork.c b/kernel/fork.c
index ace4631b5b54..3a2a7950df8e 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -1026,6 +1026,10 @@ static struct mm_struct *mm_init(struct mm_struct *mm, struct task_struct *p,
 	seqcount_init(&mm->write_protect_seq);
 	mmap_init_lock(mm);
 	INIT_LIST_HEAD(&mm->mmlist);
+#ifdef CONFIG_UMCG
+	spin_lock_init(&mm->umcg_lock);
+	INIT_LIST_HEAD(&mm->umcg_groups);
+#endif
 	mm->core_state = NULL;
 	mm_pgtables_bytes_init(mm);
 	mm->map_count = 0;
@@ -1102,6 +1106,13 @@ static inline void __mmput(struct mm_struct *mm)
 		list_del(&mm->mmlist);
 		spin_unlock(&mmlist_lock);
 	}
+#ifdef CONFIG_UMCG
+	if (!list_empty(&mm->umcg_groups)) {
+		spin_lock(&mm->umcg_lock);
+		list_del(&mm->umcg_groups);
+		spin_unlock(&mm->umcg_lock);
+	}
+#endif
 	if (mm->binfmt)
 		module_put(mm->binfmt->module);
 	mmdrop(mm);
diff --git a/kernel/sched/core.c b/kernel/sched/core.c
index 462104f13c28..e657a35655b1 100644
--- a/kernel/sched/core.c
+++ b/kernel/sched/core.c
@@ -26,6 +26,7 @@
 
 #include "pelt.h"
 #include "smp.h"
+#include "umcg.h"
 
 /*
  * Export tracepoints that act as a bare tracehook (ie: have no trace event
@@ -6012,10 +6013,20 @@ static inline void sched_submit_work(struct task_struct *tsk)
 	 */
 	if (blk_needs_flush_plug(tsk))
 		blk_schedule_flush_plug(tsk);
+
+#ifdef CONFIG_UMCG
+	if (rcu_access_pointer(tsk->umcg_task_data))
+		umcg_on_block();
+#endif
 }
 
 static void sched_update_worker(struct task_struct *tsk)
 {
+#ifdef CONFIG_UMCG
+	if (rcu_access_pointer(tsk->umcg_task_data))
+		umcg_on_wake();
+#endif
+
 	if (tsk->flags & (PF_WQ_WORKER | PF_IO_WORKER)) {
 		if (tsk->flags & PF_WQ_WORKER)
 			wq_worker_running(tsk);
diff --git a/kernel/sched/umcg.c b/kernel/sched/umcg.c
index 2d718433c773..38cba772322d 100644
--- a/kernel/sched/umcg.c
+++ b/kernel/sched/umcg.c
@@ -21,6 +21,12 @@ static int __api_version(u32 requested)
 	return 1;
 }
 
+static int umcg_segv(int res)
+{
+	force_sig(SIGSEGV);
+	return res;
+}
+
 /**
  * sys_umcg_api_version - query UMCG API versions that are supported.
  * @api_version:          Requested API version.
@@ -54,6 +60,78 @@ static int put_state(struct umcg_task __user *ut, u32 state)
 	return put_user(state, (u32 __user *)ut);
 }
 
+static void umcg_lock_pair(struct task_struct *server,
+		struct task_struct *worker)
+{
+	spin_lock(&server->alloc_lock);
+	spin_lock_nested(&worker->alloc_lock, SINGLE_DEPTH_NESTING);
+}
+
+static void umcg_unlock_pair(struct task_struct *server,
+		struct task_struct *worker)
+{
+	spin_unlock(&worker->alloc_lock);
+	spin_unlock(&server->alloc_lock);
+}
+
+static void umcg_detach_peer(void)
+{
+	struct task_struct *server, *worker;
+	struct umcg_task_data *utd;
+
+	rcu_read_lock();
+	task_lock(current);
+	utd = rcu_dereference(current->umcg_task_data);
+
+	if (!utd || !rcu_dereference(utd->peer)) {
+		task_unlock(current);
+		goto out;
+	}
+
+	switch (utd->task_type) {
+	case UMCG_TT_SERVER:
+		server = current;
+		worker = rcu_dereference(utd->peer);
+		break;
+
+	case UMCG_TT_WORKER:
+		worker = current;
+		server = rcu_dereference(utd->peer);
+		break;
+
+	default:
+		task_unlock(current);
+		printk(KERN_WARNING "umcg_detach_peer: unexpected task type");
+		umcg_segv(0);
+		goto out;
+	}
+	task_unlock(current);
+
+	if (!server || !worker)
+		goto out;
+
+	umcg_lock_pair(server, worker);
+
+	utd = rcu_dereference(server->umcg_task_data);
+	if (WARN_ON(!utd)) {
+		umcg_segv(0);
+		goto out_pair;
+	}
+	rcu_assign_pointer(utd->peer, NULL);
+
+	utd = rcu_dereference(worker->umcg_task_data);
+	if (WARN_ON(!utd)) {
+		umcg_segv(0);
+		goto out_pair;
+	}
+	rcu_assign_pointer(utd->peer, NULL);
+
+out_pair:
+	umcg_unlock_pair(server, worker);
+out:
+	rcu_read_unlock();
+}
+
 static int register_core_task(u32 api_version, struct umcg_task __user *umcg_task)
 {
 	struct umcg_task_data *utd;
@@ -73,6 +151,7 @@ static int register_core_task(u32 api_version, struct umcg_task __user *umcg_tas
 	utd->umcg_task = umcg_task;
 	utd->task_type = UMCG_TT_CORE;
 	utd->api_version = api_version;
+	RCU_INIT_POINTER(utd->peer, NULL);
 
 	if (put_state(umcg_task, UMCG_TASK_RUNNING)) {
 		kfree(utd);
@@ -86,6 +165,105 @@ static int register_core_task(u32 api_version, struct umcg_task __user *umcg_tas
 	return 0;
 }
 
+static int add_task_to_group(u32 api_version, u32 group_id,
+		struct umcg_task __user *umcg_task,
+		enum umcg_task_type task_type, u32 new_state)
+{
+	struct mm_struct *mm = current->mm;
+	struct umcg_task_data *utd = NULL;
+	struct umcg_group *group = NULL;
+	struct umcg_group *list_entry;
+	int ret = -EINVAL;
+	u32 state;
+
+	if (get_state(umcg_task, &state))
+		return -EFAULT;
+
+	if (state != UMCG_TASK_NONE)
+		return -EINVAL;
+
+	if (put_state(umcg_task, new_state))
+		return -EFAULT;
+
+retry_once:
+	rcu_read_lock();
+	list_for_each_entry_rcu(list_entry, &mm->umcg_groups, list) {
+		if (list_entry->group_id == group_id) {
+			group = list_entry;
+			break;
+		}
+	}
+
+	if (!group || group->api_version != api_version)
+		goto out_rcu;
+
+	spin_lock(&group->lock);
+	if (group->nr_tasks < 0)  /* The groups is being destroyed. */
+		goto out_group;
+
+	if (!utd) {
+		utd = kzalloc(sizeof(struct umcg_task_data), GFP_NOWAIT);
+		if (!utd) {
+			spin_unlock(&group->lock);
+			rcu_read_unlock();
+
+			utd = kzalloc(sizeof(struct umcg_task_data), GFP_KERNEL);
+			if (!utd) {
+				ret = -ENOMEM;
+				goto out;
+			}
+
+			goto retry_once;
+		}
+	}
+
+	utd->self = current;
+	utd->group = group;
+	utd->umcg_task = umcg_task;
+	utd->task_type = task_type;
+	utd->api_version = api_version;
+	RCU_INIT_POINTER(utd->peer, NULL);
+
+	INIT_LIST_HEAD(&utd->list);
+	group->nr_tasks++;
+
+	task_lock(current);
+	rcu_assign_pointer(current->umcg_task_data, utd);
+	task_unlock(current);
+
+	ret = 0;
+
+out_group:
+	spin_unlock(&group->lock);
+
+out_rcu:
+	rcu_read_unlock();
+	if (ret && utd)
+		kfree(utd);
+
+out:
+	if (ret)
+		put_state(umcg_task, UMCG_TASK_NONE);
+	else
+		schedule();  /* Trigger umcg_on_wake(). */
+
+	return ret;
+}
+
+static int register_worker(u32 api_version, u32 group_id,
+		struct umcg_task __user *umcg_task)
+{
+	return add_task_to_group(api_version, group_id, umcg_task,
+				UMCG_TT_WORKER, UMCG_TASK_UNBLOCKED);
+}
+
+static int register_server(u32 api_version, u32 group_id,
+		struct umcg_task __user *umcg_task)
+{
+	return add_task_to_group(api_version, group_id, umcg_task,
+				UMCG_TT_SERVER, UMCG_TASK_PROCESSING);
+}
+
 /**
  * sys_umcg_register_task - register the current task as a UMCG task.
  * @api_version:       The expected/desired API version of the syscall.
@@ -122,6 +300,10 @@ SYSCALL_DEFINE4(umcg_register_task, u32, api_version, u32, flags, u32, group_id,
 		if (group_id != UMCG_NOID)
 			return -EINVAL;
 		return register_core_task(api_version, umcg_task);
+	case UMCG_REGISTER_WORKER:
+		return register_worker(api_version, group_id, umcg_task);
+	case UMCG_REGISTER_SERVER:
+		return register_server(api_version, group_id, umcg_task);
 	default:
 		return -EINVAL;
 	}
@@ -146,9 +328,39 @@ SYSCALL_DEFINE1(umcg_unregister_task, u32, flags)
 	if (!utd || flags)
 		goto out;
 
+	if (!utd->group) {
+		ret = 0;
+		goto out;
+	}
+
+	if (utd->task_type == UMCG_TT_WORKER) {
+		struct task_struct *server = rcu_dereference(utd->peer);
+
+		if (server) {
+			umcg_detach_peer();
+			if (WARN_ON(!wake_up_process(server))) {
+				umcg_segv(0);
+				goto out;
+			}
+		}
+	} else {
+		if (WARN_ON(utd->task_type != UMCG_TT_SERVER)) {
+			umcg_segv(0);
+			goto out;
+		}
+
+		umcg_detach_peer();
+	}
+
+	spin_lock(&utd->group->lock);
 	task_lock(current);
+
 	rcu_assign_pointer(current->umcg_task_data, NULL);
+
+	--utd->group->nr_tasks;
+
 	task_unlock(current);
+	spin_unlock(&utd->group->lock);
 
 	ret = 0;
 
@@ -164,6 +376,7 @@ SYSCALL_DEFINE1(umcg_unregister_task, u32, flags)
 static int do_context_switch(struct task_struct *next)
 {
 	struct umcg_task_data *utd = rcu_access_pointer(current->umcg_task_data);
+	bool prev_wait_flag;  /* See comment in do_wait() below. */
 
 	/*
 	 * It is important to set_current_state(TASK_INTERRUPTIBLE) before
@@ -173,34 +386,51 @@ static int do_context_switch(struct task_struct *next)
 	 */
 	set_current_state(TASK_INTERRUPTIBLE);
 
-	WRITE_ONCE(utd->in_wait, true);
-
+	prev_wait_flag = utd->in_wait;
+	if (!prev_wait_flag)
+		WRITE_ONCE(utd->in_wait, true);
+	
 	if (!try_to_wake_up(next, TASK_NORMAL, WF_CURRENT_CPU))
 		return -EAGAIN;
 
 	freezable_schedule();
 
-	WRITE_ONCE(utd->in_wait, false);
+	if (!prev_wait_flag)
+		WRITE_ONCE(utd->in_wait, false);
 
 	if (signal_pending(current))
 		return -EINTR;
 
+	/* TODO: deal with non-fatal interrupts. */
 	return 0;
 }
 
 static int do_wait(void)
 {
 	struct umcg_task_data *utd = rcu_access_pointer(current->umcg_task_data);
+	/*
+	 * freezable_schedule() below can recursively call do_wait() if
+	 * this is a worker that needs a server. As the wait flag is only
+	 * used by the outermost wait/wake (and swap) syscalls, modify it only
+	 * in the outermost do_wait() instead of using a counter.
+	 *
+	 * Note that the nesting level is at most two, as utd->in_workqueue
+	 * is used to prevent further nesting.
+	 */
+	bool prev_wait_flag;
 
 	if (!utd)
 		return -EINVAL;
 
-	WRITE_ONCE(utd->in_wait, true);
+	prev_wait_flag = utd->in_wait;
+	if (!prev_wait_flag)
+		WRITE_ONCE(utd->in_wait, true);
 
 	set_current_state(TASK_INTERRUPTIBLE);
 	freezable_schedule();
 
-	WRITE_ONCE(utd->in_wait, false);
+	if (!prev_wait_flag)
+		WRITE_ONCE(utd->in_wait, false);
 
 	if (signal_pending(current))
 		return -EINTR;
@@ -214,7 +444,7 @@ static int do_wait(void)
  * @timeout:       The absolute timeout of the wait. Not supported yet.
  *                 Must be NULL.
  *
- * Sleep until woken, interrupted, or @timeout expires.
+ * Sleep until woken or @timeout expires.
  *
  * Return:
  * 0           - Ok;
@@ -229,6 +459,7 @@ SYSCALL_DEFINE2(umcg_wait, u32, flags,
 		const struct __kernel_timespec __user *, timeout)
 {
 	struct umcg_task_data *utd;
+	struct task_struct *server = NULL;
 
 	if (flags)
 		return -EINVAL;
@@ -242,8 +473,14 @@ SYSCALL_DEFINE2(umcg_wait, u32, flags,
 		return -EINVAL;
 	}
 
+	if (utd->task_type == UMCG_TT_WORKER)
+		server = rcu_dereference(utd->peer);
+
 	rcu_read_unlock();
 
+	if (server)
+		return do_context_switch(server);
+
 	return do_wait();
 }
 
@@ -252,7 +489,7 @@ SYSCALL_DEFINE2(umcg_wait, u32, flags,
  * @flags:         Reserved.
  * @next_tid:      The ID of the task to wake.
  *
- * Wake @next identified by @next_tid. @next must be either a UMCG core
+ * Wake task next identified by @next_tid. @next must be either a UMCG core
  * task or a UMCG worker task.
  *
  * Return:
@@ -265,7 +502,7 @@ SYSCALL_DEFINE2(umcg_wait, u32, flags,
 SYSCALL_DEFINE2(umcg_wake, u32, flags, u32, next_tid)
 {
 	struct umcg_task_data *next_utd;
-	struct task_struct *next;
+	struct task_struct *next, *next_peer;
 	int ret = -EINVAL;
 
 	if (!next_tid)
@@ -282,11 +519,29 @@ SYSCALL_DEFINE2(umcg_wake, u32, flags, u32, next_tid)
 	if (!next_utd)
 		goto out;
 
+	if (next_utd->task_type == UMCG_TT_SERVER)
+		goto out;
+
 	if (!READ_ONCE(next_utd->in_wait)) {
 		ret = -EAGAIN;
 		goto out;
 	}
 
+	next_peer = rcu_dereference(next_utd->peer);
+	if (next_peer) {
+		if (next_peer == current)
+			umcg_detach_peer();
+		else {
+			/*
+			 * Waking a worker with an assigned server is not
+			 * permitted, unless the waking is done by the assigned
+			 * server.
+			 */
+			umcg_segv(0);
+			goto out;
+		}
+	}
+
 	ret = wake_up_process(next);
 	put_task_struct(next);
 	if (ret)
@@ -348,7 +603,7 @@ SYSCALL_DEFINE4(umcg_swap, u32, wake_flags, u32, next_tid, u32, wait_flags,
 	}
 
 	next_utd = rcu_dereference(next->umcg_task_data);
-	if (!next_utd) {
+	if (!next_utd || next_utd->group != curr_utd->group) {
 		ret = -EINVAL;
 		goto out;
 	}
@@ -358,6 +613,25 @@ SYSCALL_DEFINE4(umcg_swap, u32, wake_flags, u32, next_tid, u32, wait_flags,
 		goto out;
 	}
 
+	/* Move the server from curr to next, if appropriate. */
+	if (curr_utd->task_type == UMCG_TT_WORKER) {
+		struct task_struct *server = rcu_dereference(curr_utd->peer);
+		if (server) {
+			struct umcg_task_data *server_utd =
+				rcu_dereference(server->umcg_task_data);
+
+			if (rcu_access_pointer(next_utd->peer)) {
+				ret = -EAGAIN;
+				goto out;
+			}
+			umcg_detach_peer();
+			umcg_lock_pair(server, next);
+			rcu_assign_pointer(server_utd->peer, next);
+			rcu_assign_pointer(next_utd->peer, server);
+			umcg_unlock_pair(server, next);
+		}
+	}
+
 	rcu_read_unlock();
 
 	return do_context_switch(next);
@@ -366,3 +640,475 @@ SYSCALL_DEFINE4(umcg_swap, u32, wake_flags, u32, next_tid, u32, wait_flags,
 	rcu_read_unlock();
 	return ret;
 }
+
+/**
+ * sys_umcg_create_group - create a UMCG group
+ * @api_version:           Requested API version.
+ * @flags:                 Reserved.
+ *
+ * Return:
+ * >= 0                - the group ID
+ * -EOPNOTSUPP         - @api_version is not supported
+ * -EINVAL             - @flags is not valid
+ * -ENOMEM             - not enough memory
+ */
+SYSCALL_DEFINE2(umcg_create_group, u32, api_version, u64, flags)
+{
+	int ret;
+	struct umcg_group *group;
+	struct umcg_group *list_entry;
+	struct mm_struct *mm = current->mm;
+
+	if (flags)
+		return -EINVAL;
+
+	if (__api_version(api_version))
+		return -EOPNOTSUPP;
+
+	group = kzalloc(sizeof(struct umcg_group), GFP_KERNEL);
+	if (!group)
+		return -ENOMEM;
+
+	spin_lock_init(&group->lock);
+	INIT_LIST_HEAD(&group->list);
+	INIT_LIST_HEAD(&group->waiters);
+	group->flags = flags;
+	group->api_version = api_version;
+
+	spin_lock(&mm->umcg_lock);
+
+	list_for_each_entry_rcu(list_entry, &mm->umcg_groups, list) {
+		if (list_entry->group_id >= group->group_id)
+			group->group_id = list_entry->group_id + 1;
+	}
+
+	list_add_rcu(&mm->umcg_groups, &group->list);
+
+	ret = group->group_id;
+	spin_unlock(&mm->umcg_lock);
+
+	return ret;
+}
+
+/**
+ * sys_umcg_destroy_group - destroy a UMCG group
+ * @group_id: The ID of the group to destroy.
+ *
+ * The group must be empty, i.e. have no registered servers or workers.
+ *
+ * Return:
+ * 0       - success;
+ * -ESRCH  - group not found;
+ * -EBUSY  - the group has registered workers or servers.
+ */
+SYSCALL_DEFINE1(umcg_destroy_group, u32, group_id)
+{
+	int ret = 0;
+	struct umcg_group *group = NULL;
+	struct umcg_group *list_entry;
+	struct mm_struct *mm = current->mm;
+
+	spin_lock(&mm->umcg_lock);
+	list_for_each_entry_rcu(list_entry, &mm->umcg_groups, list) {
+		if (list_entry->group_id == group_id) {
+			group = list_entry;
+			break;
+		}
+	}
+
+	if (group == NULL) {
+		ret = -ESRCH;
+		goto out;
+	}
+
+	spin_lock(&group->lock);
+
+	if (group->nr_tasks > 0) {
+		ret = -EBUSY;
+		spin_unlock(&group->lock);
+		goto out;
+	}
+
+	/* Tell group rcu readers that the group is going to be deleted. */
+	group->nr_tasks = -1;
+
+	spin_unlock(&group->lock);
+
+	list_del_rcu(&group->list);
+	kfree_rcu(group, rcu);
+
+out:
+	spin_unlock(&mm->umcg_lock);
+	return ret;
+}
+
+/**
+ * sys_umcg_poll_worker - poll an UNBLOCKED worker
+ * @flags: reserved;
+ * @ut:    the control struct umcg_task of the polled worker.
+ *
+ * The current task must be a UMCG server in POLLING state; if there are
+ * UNBLOCKED workers in the server's group, take the earliest queued,
+ * mark the worker as RUNNABLE.and return.
+ *
+ * If there are no unblocked workers, the syscall waits for one to become
+ * available.
+ *
+ * Return:
+ * 0       - Ok;
+ * -EINTR  - a signal was received;
+ * -EINVAL - one of the parameters is wrong, or a precondition was not met.
+ */
+SYSCALL_DEFINE2(umcg_poll_worker, u32, flags, struct umcg_task __user **, ut)
+{
+	struct umcg_group *group;
+	struct task_struct *worker;
+	struct task_struct *server = current;
+	struct umcg_task __user *result;
+	struct umcg_task_data *worker_utd, *server_utd;
+
+	if (flags)
+		return -EINVAL;
+
+	rcu_read_lock();
+
+	server_utd = rcu_dereference(server->umcg_task_data);
+
+	if (!server_utd || server_utd->task_type != UMCG_TT_SERVER) {
+		rcu_read_unlock();
+		return -EINVAL;
+	}
+
+	umcg_detach_peer();
+
+	group = server_utd->group;
+
+	spin_lock(&group->lock);
+
+	if (group->nr_waiting_workers == 0) {  /* Queue the server. */
+		++group->nr_waiting_pollers;
+		list_add_tail(&server_utd->list, &group->waiters);
+		set_current_state(TASK_INTERRUPTIBLE);
+		spin_unlock(&group->lock);
+		rcu_read_unlock();
+
+		freezable_schedule();
+
+		rcu_read_lock();
+		server_utd = rcu_dereference(server->umcg_task_data);
+
+		if (!list_empty(&server_utd->list)) {
+			spin_lock(&group->lock);
+			list_del_init(&server_utd->list);
+			--group->nr_waiting_pollers;
+			spin_unlock(&group->lock);
+		}
+
+		if (signal_pending(current)) {
+			rcu_read_unlock();
+			return -EINTR;
+		}
+
+		worker = rcu_dereference(server_utd->peer);
+		if (worker) {
+			worker_utd = rcu_dereference(worker->umcg_task_data);
+			result = worker_utd->umcg_task;
+		} else
+			result = NULL;
+
+		rcu_read_unlock();
+
+		if (put_user(result, ut))
+			return umcg_segv(-EFAULT);
+		return 0;
+	}
+
+	/* Pick up the first worker. */
+	worker_utd = list_first_entry(&group->waiters, struct umcg_task_data,
+					list);
+	list_del_init(&worker_utd->list);
+	worker = worker_utd->self;
+	--group->nr_waiting_workers;
+
+	umcg_lock_pair(server, worker);
+	spin_unlock(&group->lock);
+
+	if (WARN_ON(rcu_access_pointer(server_utd->peer) ||
+			rcu_access_pointer(worker_utd->peer))) {
+		/* This is unexpected. */
+		rcu_read_unlock();
+		return umcg_segv(-EINVAL);
+	}
+	rcu_assign_pointer(server_utd->peer, worker);
+	rcu_assign_pointer(worker_utd->peer, current);
+
+	umcg_unlock_pair(server, worker);
+
+	result = worker_utd->umcg_task;
+	rcu_read_unlock();
+
+	if (put_state(result, UMCG_TASK_RUNNABLE))
+		return umcg_segv(-EFAULT);
+
+	if (put_user(result, ut))
+		return umcg_segv(-EFAULT);
+
+	return 0;
+}
+
+/**
+ * sys_umcg_run_worker - "run" a RUNNABLE worker as a server
+ * @flags:       reserved;
+ * @worker_tid:  tid of the worker to run;
+ * @ut:          the control struct umcg_task of the worker that blocked
+ *               during this "run".
+ *
+ * The worker must be in RUNNABLE state. The server (=current task)
+ * wakes the worker and blocks; when the worker, or one of the workers
+ * in umcg_swap chain, blocks, the server is woken and the syscall returns
+ * with ut indicating the blocked worker.
+ *
+ * If the worker exits or unregisters itself, the syscall succeeds with
+ * ut == NULL.
+ *
+ * Return:
+ * 0       - Ok;
+ * -EINTR  - a signal was received;
+ * -EINVAL - one of the parameters is wrong, or a precondition was not met.
+ */
+SYSCALL_DEFINE3(umcg_run_worker, u32, flags, u32, worker_tid,
+		struct umcg_task __user **, ut)
+{
+	int ret = -EINVAL;
+	struct task_struct *worker;
+	struct task_struct *server = current;
+	struct umcg_task __user *result = NULL;
+	struct umcg_task_data *worker_utd;
+	struct umcg_task_data *server_utd;
+	struct umcg_task __user *server_ut;
+	struct umcg_task __user *worker_ut;
+
+	if (!ut)
+		return -EINVAL;
+
+	rcu_read_lock();
+	server_utd = rcu_dereference(server->umcg_task_data);
+
+	if (!server_utd || server_utd->task_type != UMCG_TT_SERVER)
+		goto out_rcu;
+
+	if (flags)
+		goto out_rcu;
+
+	worker = find_get_task_by_vpid(worker_tid);
+	if (!worker) {
+		ret = -ESRCH;
+		goto out_rcu;
+	}
+
+	worker_utd = rcu_dereference(worker->umcg_task_data);
+	if (!worker_utd)
+		goto out_rcu;
+
+	if (!READ_ONCE(worker_utd->in_wait)) {
+		ret = -EAGAIN;
+		goto out_rcu;
+	}
+
+	if (server_utd->group != worker_utd->group)
+		goto out_rcu;
+
+	if (rcu_access_pointer(server_utd->peer) != worker)
+		umcg_detach_peer();
+
+	if (!rcu_access_pointer(server_utd->peer)) {
+		umcg_lock_pair(server, worker);
+		WARN_ON(worker_utd->peer);
+		rcu_assign_pointer(server_utd->peer, worker);
+		rcu_assign_pointer(worker_utd->peer, server);
+		umcg_unlock_pair(server, worker);
+	}
+
+	server_ut = server_utd->umcg_task;
+	worker_ut = server_utd->umcg_task;
+
+	rcu_read_unlock();
+
+	ret = do_context_switch(worker);
+	if (ret)
+		return ret;
+
+	rcu_read_lock();
+	worker = rcu_dereference(server_utd->peer);
+	if (worker) {
+		worker_utd = rcu_dereference(worker->umcg_task_data);
+		if (worker_utd)
+			result = worker_utd->umcg_task;
+	}
+	rcu_read_unlock();
+
+	if (put_user(result, ut))
+		return -EFAULT;
+	return 0;
+
+out_rcu:
+	rcu_read_unlock();
+	return ret;
+}
+
+void umcg_on_block(void)
+{
+	struct umcg_task_data *utd = rcu_access_pointer(current->umcg_task_data);
+	struct umcg_task __user *ut;
+	struct task_struct *server;
+	u32 state;
+
+	if (utd->task_type != UMCG_TT_WORKER || utd->in_workqueue)
+		return;
+
+	ut = utd->umcg_task;
+
+	if (get_user(state, (u32 __user *)ut)) {
+		if (signal_pending(current))
+			return;
+		umcg_segv(0);
+		return;
+	}
+
+	if (state != UMCG_TASK_RUNNING)
+		return;
+
+	state = UMCG_TASK_BLOCKED;
+	if (put_user(state, (u32 __user *)ut)) {
+		umcg_segv(0);
+		return;
+	}
+
+	rcu_read_lock();
+	server = rcu_dereference(utd->peer);
+	rcu_read_unlock();
+
+	if (server)
+		WARN_ON(!try_to_wake_up(server, TASK_NORMAL, WF_CURRENT_CPU));
+}
+
+/* Return true to return to the user, false to keep waiting. */
+static bool process_unblocked_worker(void)
+{
+	struct umcg_task_data *utd;
+	struct umcg_group *group;
+
+	rcu_read_lock();
+
+	utd = rcu_dereference(current->umcg_task_data);
+	group = utd->group;
+
+	spin_lock(&group->lock);
+	if (!list_empty(&utd->list)) {
+		/* This was a spurious wakeup or an interrupt, do nothing. */
+		spin_unlock(&group->lock);
+		rcu_read_unlock();
+		do_wait();
+		return false;
+	}
+
+	if (group->nr_waiting_pollers > 0) {  /* Wake a server. */
+		struct task_struct *server;
+		struct umcg_task_data *server_utd = list_first_entry(
+				&group->waiters, struct umcg_task_data, list);
+
+		list_del_init(&server_utd->list);
+		server = server_utd->self;
+		--group->nr_waiting_pollers;
+
+		umcg_lock_pair(server, current);
+		spin_unlock(&group->lock);
+
+		if (WARN_ON(server_utd->peer || utd->peer)) {
+			umcg_segv(0);
+			return true;
+		}
+		rcu_assign_pointer(server_utd->peer, current);
+		rcu_assign_pointer(utd->peer, server);
+
+		umcg_unlock_pair(server, current);
+		rcu_read_unlock();
+
+		if (put_state(utd->umcg_task, UMCG_TASK_RUNNABLE)) {
+			umcg_segv(0);
+			return true;
+		}
+
+		do_context_switch(server);
+		return false;
+	}
+
+	/* Add to the queue. */
+	++group->nr_waiting_workers;
+	list_add_tail(&utd->list, &group->waiters);
+	spin_unlock(&group->lock);
+	rcu_read_unlock();
+
+	do_wait();
+
+	smp_rmb();
+	if (!list_empty(&utd->list)) {
+		spin_lock(&group->lock);
+		list_del_init(&utd->list);
+		--group->nr_waiting_workers;
+		spin_unlock(&group->lock);
+	}
+
+	return false;
+}
+
+void umcg_on_wake(void)
+{
+	struct umcg_task_data *utd;
+	struct umcg_task __user *ut;
+	bool should_break = false;
+
+	/* current->umcg_task_data is modified only from current. */
+	utd = rcu_access_pointer(current->umcg_task_data);
+	if (utd->task_type != UMCG_TT_WORKER || utd->in_workqueue)
+		return;
+
+	do {
+		u32 state;
+
+		if (fatal_signal_pending(current))
+			return;
+
+		if (signal_pending(current))
+			return;
+
+		ut = utd->umcg_task;
+
+		if (get_state(ut, &state)) {
+			if (signal_pending(current))
+				return;
+			goto segv;
+		}
+
+		if (state == UMCG_TASK_RUNNING && rcu_access_pointer(utd->peer))
+			return;
+
+		if (state == UMCG_TASK_BLOCKED || state == UMCG_TASK_RUNNING) {
+			state = UMCG_TASK_UNBLOCKED;
+			if (put_state(ut, state))
+				goto segv;
+		} else if (state != UMCG_TASK_UNBLOCKED) {
+			goto segv;
+		}
+
+		utd->in_workqueue = true;
+		should_break = process_unblocked_worker();
+		utd->in_workqueue = false;
+		if (should_break)
+			return;
+
+	} while (!should_break);
+
+segv:
+	umcg_segv(0);
+}
diff --git a/kernel/sched/umcg.h b/kernel/sched/umcg.h
index 6791d570f622..92012a1674ab 100644
--- a/kernel/sched/umcg.h
+++ b/kernel/sched/umcg.h
@@ -8,6 +8,34 @@
 #include <linux/sched.h>
 #include <linux/umcg.h>
 
+struct umcg_group {
+	struct list_head list;
+	u32 group_id;     /* Never changes. */
+	u32 api_version;  /* Never changes. */
+	u64 flags;        /* Never changes. */
+
+	spinlock_t lock;
+
+	/*
+	 * One of the counters below is always zero. The non-zero counter
+	 * indicates the number of elements in @waiters below.
+	 */
+	int nr_waiting_workers;
+	int nr_waiting_pollers;
+
+	/*
+	 * The list below either contains UNBLOCKED workers waiting
+	 * for the userspace to poll or run them if nr_waiting_workers > 0,
+	 *  or polling servers waiting for unblocked workers if
+	 *  nr_waiting_pollers > 0.
+	 */
+	struct list_head waiters;
+
+	int nr_tasks;  /* The total number of tasks registered. */
+
+	struct rcu_head rcu;
+};
+
 enum umcg_task_type {
 	UMCG_TT_CORE	= 1,
 	UMCG_TT_SERVER	= 2,
@@ -32,11 +60,37 @@ struct umcg_task_data {
 	 */
 	u32 api_version;
 
+	/* NULL for core API tasks. Never changes. */
+	struct umcg_group		*group;
+
+	/*
+	 * If this is a server task, points to its assigned worker, if any;
+	 * if this is a worker task, points to its assigned server, if any.
+	 *
+	 * Protected by alloc_lock of the task owning this struct.
+	 *
+	 * Always either NULL, or the server and the worker point to each other.
+	 * Locking order: first lock the server, then the worker.
+	 *
+	 * Either the worker or the server should be the current task when
+	 * this field is changed, with the exception of sys_umcg_swap.
+	 */
+	struct task_struct __rcu	*peer;
+
+	/* Used in umcg_group.waiters. */
+	struct list_head		list;
+
+	/* Used by curr in umcg_on_block/wake to prevent nesting/recursion. */
+	bool				in_workqueue;
+
 	/*
 	 * Used by wait/wake routines to handle races. Written only by current.
 	 */
 	bool				in_wait;
 };
 
+void umcg_on_block(void);
+void umcg_on_wake(void);
+
 #endif  /* CONFIG_UMCG */
 #endif  /* _KERNEL_SCHED_UMCG_H */
diff --git a/mm/init-mm.c b/mm/init-mm.c
index 153162669f80..85e4a8ecfd91 100644
--- a/mm/init-mm.c
+++ b/mm/init-mm.c
@@ -36,6 +36,10 @@ struct mm_struct init_mm = {
 	.page_table_lock =  __SPIN_LOCK_UNLOCKED(init_mm.page_table_lock),
 	.arg_lock	=  __SPIN_LOCK_UNLOCKED(init_mm.arg_lock),
 	.mmlist		= LIST_HEAD_INIT(init_mm.mmlist),
+#ifdef CONFIG_UMCG
+	.umcg_lock	= __SPIN_LOCK_UNLOCKED(init_mm.umcg_lock),
+	.umcg_groups	= LIST_HEAD_INIT(init_mm.umcg_groups),
+#endif
 	.user_ns	= &init_user_ns,
 	.cpu_bitmap	= CPU_BITS_NONE,
 	INIT_MM_CONTEXT(init_mm)
-- 
2.31.1.818.g46aad6cb9e-goog




[Index of Archives]     [Linux USB Devel]     [Video for Linux]     [Linux Audio Users]     [Yosemite News]     [Linux Kernel]     [Linux SCSI]

  Powered by Linux