On Tue, Nov 14, 2023 at 12:42 PM Pengcheng Yang <yangpc@xxxxxxxxxx> wrote: > > Currently we cannot get the data length in ingress_msg, > we introduce sk_msg_queue_len() to do this. > > Signed-off-by: Pengcheng Yang <yangpc@xxxxxxxxxx> > --- > include/linux/skmsg.h | 24 ++++++++++++++++++++++-- > net/core/skmsg.c | 4 ++++ > 2 files changed, 26 insertions(+), 2 deletions(-) > > diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h > index c1637515a8a4..3023a573859d 100644 > --- a/include/linux/skmsg.h > +++ b/include/linux/skmsg.h > @@ -82,6 +82,7 @@ struct sk_psock { > u32 apply_bytes; > u32 cork_bytes; > u32 eval; > + u32 msg_len; > bool redir_ingress; /* undefined if sk_redir is null */ > struct sk_msg *cork; > struct sk_psock_progs progs; > @@ -131,6 +132,11 @@ int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg, > int len, int flags); > bool sk_msg_is_readable(struct sock *sk); > > +static inline void sk_msg_queue_consumed(struct sk_psock *psock, u32 len) > +{ > + psock->msg_len -= len; > +} > + > static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes) > { > WARN_ON(i == msg->sg.end && bytes); > @@ -311,9 +317,10 @@ static inline void sk_psock_queue_msg(struct sk_psock *psock, > struct sk_msg *msg) > { > spin_lock_bh(&psock->ingress_lock); > - if (sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) > + if (sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) { > list_add_tail(&msg->list, &psock->ingress_msg); > - else { > + psock->msg_len += msg->sg.size; > + } else { > sk_msg_free(psock->sk, msg); > kfree(msg); > } > @@ -368,6 +375,19 @@ static inline void kfree_sk_msg(struct sk_msg *msg) > kfree(msg); > } > > +static inline u32 sk_msg_queue_len(struct sock *sk) const struct sock *sk; > +{ > + struct sk_psock *psock; > + u32 len = 0; > + > + rcu_read_lock(); > + psock = sk_psock(sk); > + if (psock) > + len = psock->msg_len; This is racy against writers. You must use READ_ONCE() here, and WRITE_ONCE() on write sides. > + rcu_read_unlock(); > + return len; > +} > + > static inline void sk_psock_report_error(struct sk_psock *psock, int err) > { > struct sock *sk = psock->sk; > diff --git a/net/core/skmsg.c b/net/core/skmsg.c > index 6c31eefbd777..b3de17e99b67 100644 > --- a/net/core/skmsg.c > +++ b/net/core/skmsg.c > @@ -481,6 +481,8 @@ int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg, > msg_rx = sk_psock_peek_msg(psock); > } > out: > + if (likely(!peek) && copied > 0) > + sk_msg_queue_consumed(psock, copied); > return copied; > } > EXPORT_SYMBOL_GPL(sk_msg_recvmsg); > @@ -771,9 +773,11 @@ static void __sk_psock_purge_ingress_msg(struct sk_psock *psock) > > list_for_each_entry_safe(msg, tmp, &psock->ingress_msg, list) { > list_del(&msg->list); > + sk_msg_queue_consumed(psock, msg->sg.size); > sk_msg_free(psock->sk, msg); > kfree(msg); > } > + WARN_ON_ONCE(psock->msg_len != 0); > } > > static void __sk_psock_zap_ingress(struct sk_psock *psock) > -- > 2.38.1 >