[RFC PATCH 5/9] ntsync: Introduce NTSYNC_IOC_WAIT_ANY.

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



This corresponds to part of the functionality of the NT syscall
NtWaitForMultipleObjects(). Specifically, it implements the behaviour where
the third argument (wait_any) is TRUE, and it does not handle alertable waits.
Those features have been split out into separate patches to ease review.

Signed-off-by: Elizabeth Figura <zfigura@xxxxxxxxxxxxxxx>
---
 drivers/misc/ntsync.c       | 229 ++++++++++++++++++++++++++++++++++++
 include/uapi/linux/ntsync.h |  13 ++
 2 files changed, 242 insertions(+)

diff --git a/drivers/misc/ntsync.c b/drivers/misc/ntsync.c
index d1c91c2a4f1a..2e8d3c2d51a4 100644
--- a/drivers/misc/ntsync.c
+++ b/drivers/misc/ntsync.c
@@ -23,6 +23,8 @@ struct ntsync_obj {
 	struct kref refcount;
 	spinlock_t lock;
 
+	struct list_head any_waiters;
+
 	enum ntsync_type type;
 
 	/* The following fields are protected by the object lock. */
@@ -34,6 +36,28 @@ struct ntsync_obj {
 	} u;
 };
 
+struct ntsync_q_entry {
+	struct list_head node;
+	struct ntsync_q *q;
+	struct ntsync_obj *obj;
+	__u32 index;
+};
+
+struct ntsync_q {
+	struct task_struct *task;
+	__u32 owner;
+
+	/*
+	 * Protected via atomic_cmpxchg(). Only the thread that wins the
+	 * compare-and-swap may actually change object states and wake this
+	 * task.
+	 */
+	atomic_t signaled;
+
+	__u32 count;
+	struct ntsync_q_entry entries[];
+};
+
 struct ntsync_device {
 	struct xarray objects;
 };
@@ -109,6 +133,26 @@ static void init_obj(struct ntsync_obj *obj)
 {
 	kref_init(&obj->refcount);
 	spin_lock_init(&obj->lock);
+	INIT_LIST_HEAD(&obj->any_waiters);
+}
+
+static void try_wake_any_sem(struct ntsync_obj *sem)
+{
+	struct ntsync_q_entry *entry;
+
+	lockdep_assert_held(&sem->lock);
+
+	list_for_each_entry(entry, &sem->any_waiters, node) {
+		struct ntsync_q *q = entry->q;
+
+		if (!sem->u.sem.count)
+			break;
+
+		if (atomic_cmpxchg(&q->signaled, -1, entry->index) == -1) {
+			sem->u.sem.count--;
+			wake_up_process(q->task);
+		}
+	}
 }
 
 static int ntsync_create_sem(struct ntsync_device *dev, void __user *argp)
@@ -194,6 +238,8 @@ static int ntsync_put_sem(struct ntsync_device *dev, void __user *argp)
 
 	prev_count = sem->u.sem.count;
 	ret = put_sem_state(sem, args.count);
+	if (!ret)
+		try_wake_any_sem(sem);
 
 	spin_unlock(&sem->lock);
 
@@ -205,6 +251,187 @@ static int ntsync_put_sem(struct ntsync_device *dev, void __user *argp)
 	return ret;
 }
 
+static int ntsync_schedule(const struct ntsync_q *q, ktime_t *timeout)
+{
+	int ret = 0;
+
+	do {
+		if (signal_pending(current)) {
+			ret = -ERESTARTSYS;
+			break;
+		}
+
+		set_current_state(TASK_INTERRUPTIBLE);
+		if (atomic_read(&q->signaled) != -1) {
+			ret = 0;
+			break;
+		}
+		ret = schedule_hrtimeout(timeout, HRTIMER_MODE_ABS);
+	} while (ret < 0);
+	__set_current_state(TASK_RUNNING);
+
+	return ret;
+}
+
+/*
+ * Allocate and initialize the ntsync_q structure, but do not queue us yet.
+ * Also, calculate the relative timeout.
+ */
+static int setup_wait(struct ntsync_device *dev,
+		      const struct ntsync_wait_args *args,
+		      ktime_t *ret_timeout, struct ntsync_q **ret_q)
+{
+	const __u32 count = args->count;
+	struct ntsync_q *q;
+	ktime_t timeout = 0;
+	__u32 *ids;
+	__u32 i, j;
+
+	if (!args->owner || args->pad)
+		return -EINVAL;
+
+	if (args->count > NTSYNC_MAX_WAIT_COUNT)
+		return -EINVAL;
+
+	if (args->timeout) {
+		struct timespec64 to;
+
+		if (get_timespec64(&to, u64_to_user_ptr(args->timeout)))
+			return -EFAULT;
+		if (!timespec64_valid(&to))
+			return -EINVAL;
+
+		timeout = timespec64_to_ns(&to);
+	}
+
+	ids = kmalloc_array(count, sizeof(*ids), GFP_KERNEL);
+	if (!ids)
+		return -ENOMEM;
+	if (copy_from_user(ids, u64_to_user_ptr(args->objs),
+			   array_size(count, sizeof(*ids)))) {
+		kfree(ids);
+		return -EFAULT;
+	}
+
+	q = kmalloc(struct_size(q, entries, count), GFP_KERNEL);
+	if (!q) {
+		kfree(ids);
+		return -ENOMEM;
+	}
+	q->task = current;
+	q->owner = args->owner;
+	atomic_set(&q->signaled, -1);
+	q->count = count;
+
+	for (i = 0; i < count; i++) {
+		struct ntsync_q_entry *entry = &q->entries[i];
+		struct ntsync_obj *obj = get_obj(dev, ids[i]);
+
+		if (!obj)
+			goto err;
+
+		entry->obj = obj;
+		entry->q = q;
+		entry->index = i;
+	}
+
+	kfree(ids);
+
+	*ret_q = q;
+	*ret_timeout = timeout;
+	return 0;
+
+err:
+	for (j = 0; j < i; j++)
+		put_obj(q->entries[j].obj);
+	kfree(ids);
+	kfree(q);
+	return -EINVAL;
+}
+
+static void try_wake_any_obj(struct ntsync_obj *obj)
+{
+	switch (obj->type) {
+	case NTSYNC_TYPE_SEM:
+		try_wake_any_sem(obj);
+		break;
+	}
+}
+
+static int ntsync_wait_any(struct ntsync_device *dev, void __user *argp)
+{
+	struct ntsync_wait_args args;
+	struct ntsync_q *q;
+	ktime_t timeout;
+	int signaled;
+	__u32 i;
+	int ret;
+
+	if (copy_from_user(&args, argp, sizeof(args)))
+		return -EFAULT;
+
+	ret = setup_wait(dev, &args, &timeout, &q);
+	if (ret < 0)
+		return ret;
+
+	/* queue ourselves */
+
+	for (i = 0; i < args.count; i++) {
+		struct ntsync_q_entry *entry = &q->entries[i];
+		struct ntsync_obj *obj = entry->obj;
+
+		spin_lock(&obj->lock);
+		list_add_tail(&entry->node, &obj->any_waiters);
+		spin_unlock(&obj->lock);
+	}
+
+	/* check if we are already signaled */
+
+	for (i = 0; i < args.count; i++) {
+		struct ntsync_obj *obj = q->entries[i].obj;
+
+		if (atomic_read(&q->signaled) != -1)
+			break;
+
+		spin_lock(&obj->lock);
+		try_wake_any_obj(obj);
+		spin_unlock(&obj->lock);
+	}
+
+	/* sleep */
+
+	ret = ntsync_schedule(q, args.timeout ? &timeout : NULL);
+
+	/* and finally, unqueue */
+
+	for (i = 0; i < args.count; i++) {
+		struct ntsync_q_entry *entry = &q->entries[i];
+		struct ntsync_obj *obj = entry->obj;
+
+		spin_lock(&obj->lock);
+		list_del(&entry->node);
+		spin_unlock(&obj->lock);
+
+		put_obj(obj);
+	}
+
+	signaled = atomic_read(&q->signaled);
+	if (signaled != -1) {
+		struct ntsync_wait_args __user *user_args = argp;
+
+		/* even if we caught a signal, we need to communicate success */
+		ret = 0;
+
+		if (put_user(signaled, &user_args->index))
+			ret = -EFAULT;
+	} else if (!ret) {
+		ret = -ETIMEDOUT;
+	}
+
+	kfree(q);
+	return ret;
+}
+
 static long ntsync_char_ioctl(struct file *file, unsigned int cmd,
 			      unsigned long parm)
 {
@@ -218,6 +445,8 @@ static long ntsync_char_ioctl(struct file *file, unsigned int cmd,
 		return ntsync_delete(dev, argp);
 	case NTSYNC_IOC_PUT_SEM:
 		return ntsync_put_sem(dev, argp);
+	case NTSYNC_IOC_WAIT_ANY:
+		return ntsync_wait_any(dev, argp);
 	default:
 		return -ENOIOCTLCMD;
 	}
diff --git a/include/uapi/linux/ntsync.h b/include/uapi/linux/ntsync.h
index 8c610d65f8ef..10f07da7864e 100644
--- a/include/uapi/linux/ntsync.h
+++ b/include/uapi/linux/ntsync.h
@@ -16,6 +16,17 @@ struct ntsync_sem_args {
 	__u32 max;
 };
 
+struct ntsync_wait_args {
+	__u64 timeout;
+	__u64 objs;
+	__u32 count;
+	__u32 owner;
+	__u32 index;
+	__u32 pad;
+};
+
+#define NTSYNC_MAX_WAIT_COUNT 64
+
 #define NTSYNC_IOC_BASE 0xf7
 
 #define NTSYNC_IOC_CREATE_SEM		_IOWR(NTSYNC_IOC_BASE, 0, \
@@ -23,5 +34,7 @@ struct ntsync_sem_args {
 #define NTSYNC_IOC_DELETE		_IOW (NTSYNC_IOC_BASE, 1, __u32)
 #define NTSYNC_IOC_PUT_SEM		_IOWR(NTSYNC_IOC_BASE, 2, \
 					      struct ntsync_sem_args)
+#define NTSYNC_IOC_WAIT_ANY		_IOWR(NTSYNC_IOC_BASE, 3, \
+					      struct ntsync_wait_args)
 
 #endif
-- 
2.43.0





[Index of Archives]     [Linux USB Devel]     [Video for Linux]     [Linux Audio Users]     [Yosemite News]     [Linux Kernel]     [Linux SCSI]

  Powered by Linux