req->result = mask;
req->io_task_work.func = func;
@@ -5265,7 +5264,10 @@ static bool io_poll_rewait(struct io_kiocb
*req, struct io_poll_iocb *poll)
spin_lock(&ctx->completion_lock);
if (!req->result && !READ_ONCE(poll->canceled)) {
- add_wait_queue(poll->head, &poll->wait);
+ if (req->opcode == IORING_OP_POLL_ADD)
+ WRITE_ONCE(poll->active, true);
+ else
+ add_wait_queue(poll->head, &poll->wait);
return true;
}
@@ -5331,6 +5333,26 @@ static bool __io_poll_complete(struct
io_kiocb *req, __poll_t mask)
return !(flags & IORING_CQE_F_MORE);
}
+static bool __io_poll_remove_one(struct io_kiocb *req,
+ struct io_poll_iocb *poll, bool do_cancel)
+ __must_hold(&req->ctx->completion_lock)
+{
+ bool do_complete = false;
+
+ if (!poll->head)
+ return false;
+ spin_lock_irq(&poll->head->lock);
+ if (do_cancel)
+ WRITE_ONCE(poll->canceled, true);
+ if (!list_empty(&poll->wait.entry)) {
+ list_del_init(&poll->wait.entry);
+ do_complete = true;
+ }
+ spin_unlock_irq(&poll->head->lock);
+ hash_del(&req->hash_node);
+ return do_complete;
+}
+
static void io_poll_task_func(struct io_kiocb *req, bool *locked)
{
struct io_ring_ctx *ctx = req->ctx;
@@ -5348,11 +5370,12 @@ static void io_poll_task_func(struct io_kiocb
*req, bool *locked)
done = __io_poll_complete(req, req->result);
if (done) {
io_poll_remove_double(req);
+ __io_poll_remove_one(req, io_poll_get_single(req), true);
hash_del(&req->hash_node);
req->poll.done = true;
} else {
req->result = 0;
- add_wait_queue(req->poll.head, &req->poll.wait);
+ WRITE_ONCE(req->poll.active, true);
}
io_commit_cqring(ctx);
spin_unlock(&ctx->completion_lock);
@@ -5407,6 +5430,7 @@ static void io_init_poll_iocb(struct
io_poll_iocb *poll, __poll_t events,
poll->head = NULL;
poll->done = false;
poll->canceled = false;
+ poll->active = true;
#define IO_POLL_UNMASK (EPOLLERR|EPOLLHUP|EPOLLNVAL|EPOLLRDHUP)
/* mask in events that we always want/need */
poll->events = events | IO_POLL_UNMASK;
@@ -5513,6 +5537,7 @@ static int io_async_wake(struct
wait_queue_entry *wait, unsigned mode, int sync,
if (mask && !(mask & poll->events))
return 0;
+ list_del_init(&poll->wait.entry);
return __io_async_wake(req, poll, mask, io_async_task_func);
}
@@ -5623,26 +5648,6 @@ static int io_arm_poll_handler(struct
io_kiocb *req)
return IO_APOLL_OK;
}
-static bool __io_poll_remove_one(struct io_kiocb *req,
- struct io_poll_iocb *poll, bool do_cancel)
- __must_hold(&req->ctx->completion_lock)
-{
- bool do_complete = false;
-
- if (!poll->head)
- return false;
- spin_lock_irq(&poll->head->lock);
- if (do_cancel)
- WRITE_ONCE(poll->canceled, true);
- if (!list_empty(&poll->wait.entry)) {
- list_del_init(&poll->wait.entry);
- do_complete = true;
- }
- spin_unlock_irq(&poll->head->lock);
- hash_del(&req->hash_node);
- return do_complete;
-}
-
static bool io_poll_remove_one(struct io_kiocb *req)
__must_hold(&req->ctx->completion_lock)
{
@@ -5779,6 +5784,10 @@ static int io_poll_wake(struct
wait_queue_entry *wait, unsigned mode, int sync,
if (mask && !(mask & poll->events))
return 0;
+ if (!READ_ONCE(poll->active))
+ return 0;
+ WRITE_ONCE(poll->active, false);
+
return __io_async_wake(req, poll, mask, io_poll_task_func);
}