Currently we cannot get the data length in ingress_msg, we introduce sk_msg_queue_len() to do this. Signed-off-by: Pengcheng Yang <yangpc@xxxxxxxxxx> --- include/linux/skmsg.h | 24 ++++++++++++++++++++++-- net/core/skmsg.c | 4 ++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h index c1637515a8a4..3023a573859d 100644 --- a/include/linux/skmsg.h +++ b/include/linux/skmsg.h @@ -82,6 +82,7 @@ struct sk_psock { u32 apply_bytes; u32 cork_bytes; u32 eval; + u32 msg_len; bool redir_ingress; /* undefined if sk_redir is null */ struct sk_msg *cork; struct sk_psock_progs progs; @@ -131,6 +132,11 @@ int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg, int len, int flags); bool sk_msg_is_readable(struct sock *sk); +static inline void sk_msg_queue_consumed(struct sk_psock *psock, u32 len) +{ + psock->msg_len -= len; +} + static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes) { WARN_ON(i == msg->sg.end && bytes); @@ -311,9 +317,10 @@ static inline void sk_psock_queue_msg(struct sk_psock *psock, struct sk_msg *msg) { spin_lock_bh(&psock->ingress_lock); - if (sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) + if (sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) { list_add_tail(&msg->list, &psock->ingress_msg); - else { + psock->msg_len += msg->sg.size; + } else { sk_msg_free(psock->sk, msg); kfree(msg); } @@ -368,6 +375,19 @@ static inline void kfree_sk_msg(struct sk_msg *msg) kfree(msg); } +static inline u32 sk_msg_queue_len(struct sock *sk) +{ + struct sk_psock *psock; + u32 len = 0; + + rcu_read_lock(); + psock = sk_psock(sk); + if (psock) + len = psock->msg_len; + rcu_read_unlock(); + return len; +} + static inline void sk_psock_report_error(struct sk_psock *psock, int err) { struct sock *sk = psock->sk; diff --git a/net/core/skmsg.c b/net/core/skmsg.c index 6c31eefbd777..b3de17e99b67 100644 --- a/net/core/skmsg.c +++ b/net/core/skmsg.c @@ -481,6 +481,8 @@ int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg, msg_rx = sk_psock_peek_msg(psock); } out: + if (likely(!peek) && copied > 0) + sk_msg_queue_consumed(psock, copied); return copied; } EXPORT_SYMBOL_GPL(sk_msg_recvmsg); @@ -771,9 +773,11 @@ static void __sk_psock_purge_ingress_msg(struct sk_psock *psock) list_for_each_entry_safe(msg, tmp, &psock->ingress_msg, list) { list_del(&msg->list); + sk_msg_queue_consumed(psock, msg->sg.size); sk_msg_free(psock->sk, msg); kfree(msg); } + WARN_ON_ONCE(psock->msg_len != 0); } static void __sk_psock_zap_ingress(struct sk_psock *psock) -- 2.38.1