We already got one bug with ->poll_refs overflows, let's add overflow checks for it in a similar way as we do for request refs. For that reserve the sign bit so underflows don't set IO_POLL_CANCEL_FLAG and making us able to catch them. Signed-off-by: Pavel Begunkov <asml.silence@xxxxxxxxx> --- fs/io_uring.c | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/fs/io_uring.c b/fs/io_uring.c index 245610494c3e..594ed8bc4585 100644 --- a/fs/io_uring.c +++ b/fs/io_uring.c @@ -5803,8 +5803,13 @@ struct io_poll_table { int error; }; -#define IO_POLL_CANCEL_FLAG BIT(31) -#define IO_POLL_REF_MASK GENMASK(30, 0) +/* keep the sign bit unused to improve overflow detection */ +#define IO_POLL_CANCEL_FLAG BIT(30) +#define IO_POLL_REF_MASK GENMASK(29, 0) + +/* 2^16 is choosen arbitrary, would be funky to have more than that */ +#define io_poll_ref_check_overflow(refs) ((unsigned int)refs >= 65536u) +#define io_poll_ref_check_underflow(refs) ((int)refs < 0) /* * If refs part of ->poll_refs (see IO_POLL_REF_MASK) is 0, it's free. We can @@ -5814,7 +5819,18 @@ struct io_poll_table { */ static inline bool io_poll_get_ownership(struct io_kiocb *req) { - return !(atomic_fetch_inc(&req->poll_refs) & IO_POLL_REF_MASK); + int ret = atomic_fetch_inc(&req->poll_refs) & IO_POLL_REF_MASK; + + WARN_ON_ONCE(io_poll_ref_check_overflow(ret)); + return !ret; +} + +static inline int io_poll_put_ownership(struct io_kiocb *req, int nr) +{ + int ret = atomic_sub_return(nr, &req->poll_refs); + + WARN_ON_ONCE(io_poll_ref_check_underflow(ret)); + return ret; } static void io_poll_mark_cancelled(struct io_kiocb *req) @@ -5956,7 +5972,7 @@ static int io_poll_check_events(struct io_kiocb *req) * Release all references, retry if someone tried to restart * task_work while we were executing it. */ - } while (atomic_sub_return(v & IO_POLL_REF_MASK, &req->poll_refs)); + } while (io_poll_put_ownership(req, v & IO_POLL_REF_MASK)); return 1; } @@ -6157,7 +6173,6 @@ static int __io_arm_poll_handler(struct io_kiocb *req, struct io_poll_table *ipt, __poll_t mask) { struct io_ring_ctx *ctx = req->ctx; - int v; INIT_HLIST_NODE(&req->hash_node); io_init_poll_iocb(poll, mask, io_poll_wake); @@ -6204,8 +6219,7 @@ static int __io_arm_poll_handler(struct io_kiocb *req, * Release ownership. If someone tried to queue a tw while it was * locked, kick it off for them. */ - v = atomic_dec_return(&req->poll_refs); - if (unlikely(v & IO_POLL_REF_MASK)) + if (unlikely(io_poll_put_ownership(req, 1) & IO_POLL_REF_MASK)) __io_poll_execute(req, 0, poll->events); return 0; } -- 2.35.1