Currently msg is queued in ingress_msg of the target psock on ingress redirect, without increment rcv_nxt. The size that user can read includes the data in receive_queue and ingress_msg. So we introduce sk_msg_queue_len() helper to get the data length in ingress_msg. Note that the msg_len does not include the data length of msg from recevive_queue via SK_PASS, as they increment rcv_nxt when received. Signed-off-by: Pengcheng Yang <yangpc@xxxxxxxxxx> --- include/linux/skmsg.h | 26 ++++++++++++++++++++++++-- net/core/skmsg.c | 10 +++++++++- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h index c1637515a8a4..423a5c28c606 100644 --- a/include/linux/skmsg.h +++ b/include/linux/skmsg.h @@ -47,6 +47,7 @@ struct sk_msg { u32 apply_bytes; u32 cork_bytes; u32 flags; + bool ingress_self; struct sk_buff *skb; struct sock *sk_redir; struct sock *sk; @@ -82,6 +83,7 @@ struct sk_psock { u32 apply_bytes; u32 cork_bytes; u32 eval; + u32 msg_len; bool redir_ingress; /* undefined if sk_redir is null */ struct sk_msg *cork; struct sk_psock_progs progs; @@ -311,9 +313,11 @@ static inline void sk_psock_queue_msg(struct sk_psock *psock, struct sk_msg *msg) { spin_lock_bh(&psock->ingress_lock); - if (sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) + if (sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) { list_add_tail(&msg->list, &psock->ingress_msg); - else { + if (!msg->ingress_self) + WRITE_ONCE(psock->msg_len, psock->msg_len + msg->sg.size); + } else { sk_msg_free(psock->sk, msg); kfree(msg); } @@ -368,6 +372,24 @@ static inline void kfree_sk_msg(struct sk_msg *msg) kfree(msg); } +static inline void sk_msg_queue_consumed(struct sk_psock *psock, u32 len) +{ + WRITE_ONCE(psock->msg_len, psock->msg_len - len); +} + +static inline u32 sk_msg_queue_len(const struct sock *sk) +{ + struct sk_psock *psock; + u32 len = 0; + + rcu_read_lock(); + psock = sk_psock(sk); + if (psock) + len = READ_ONCE(psock->msg_len); + rcu_read_unlock(); + return len; +} + static inline void sk_psock_report_error(struct sk_psock *psock, int err) { struct sock *sk = psock->sk; diff --git a/net/core/skmsg.c b/net/core/skmsg.c index 6c31eefbd777..f46732a8ddc2 100644 --- a/net/core/skmsg.c +++ b/net/core/skmsg.c @@ -415,7 +415,7 @@ int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg, struct iov_iter *iter = &msg->msg_iter; int peek = flags & MSG_PEEK; struct sk_msg *msg_rx; - int i, copied = 0; + int i, copied = 0, msg_copied = 0; msg_rx = sk_psock_peek_msg(psock); while (copied != len) { @@ -441,6 +441,8 @@ int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg, } copied += copy; + if (!msg_rx->ingress_self) + msg_copied += copy; if (likely(!peek)) { sge->offset += copy; sge->length -= copy; @@ -481,6 +483,8 @@ int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg, msg_rx = sk_psock_peek_msg(psock); } out: + if (likely(!peek) && msg_copied) + sk_msg_queue_consumed(psock, msg_copied); return copied; } EXPORT_SYMBOL_GPL(sk_msg_recvmsg); @@ -602,6 +606,7 @@ static int sk_psock_skb_ingress_self(struct sk_psock *psock, struct sk_buff *skb if (unlikely(!msg)) return -EAGAIN; + msg->ingress_self = true; skb_set_owner_r(skb, sk); err = sk_psock_skb_ingress_enqueue(skb, off, len, psock, sk, msg); if (err < 0) @@ -771,9 +776,12 @@ static void __sk_psock_purge_ingress_msg(struct sk_psock *psock) list_for_each_entry_safe(msg, tmp, &psock->ingress_msg, list) { list_del(&msg->list); + if (!msg->ingress_self) + sk_msg_queue_consumed(psock, msg->sg.size); sk_msg_free(psock->sk, msg); kfree(msg); } + WARN_ON_ONCE(READ_ONCE(psock->msg_len) != 0); } static void __sk_psock_zap_ingress(struct sk_psock *psock) -- 2.38.1