From: Cong Wang <cong.wang@xxxxxxxxxxxxx> Yucong noticed we can't poll() sockets in sockmap even when they are the destination sockets of redirections. This is because we never poll any psock queues in ->poll(). We can not overwrite ->poll() as it is in struct proto_ops, not in struct proto. So introduce sk_msg_poll() to poll psock ingress_msg queue and let sockets which support sockmap invoke it directly. Reported-by: Yucong Sun <sunyucong@xxxxxxxxx> Cc: John Fastabend <john.fastabend@xxxxxxxxx> Cc: Daniel Borkmann <daniel@xxxxxxxxxxxxx> Cc: Jakub Sitnicki <jakub@xxxxxxxxxxxxxx> Cc: Lorenz Bauer <lmb@xxxxxxxxxxxxxx> Signed-off-by: Cong Wang <cong.wang@xxxxxxxxxxxxx> --- include/linux/skmsg.h | 6 ++++++ net/core/skmsg.c | 15 +++++++++++++++ net/ipv4/tcp.c | 2 ++ net/ipv4/udp.c | 2 ++ net/unix/af_unix.c | 5 +++++ 5 files changed, 30 insertions(+) diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h index d47097f2c8c0..163b0cc1703a 100644 --- a/include/linux/skmsg.h +++ b/include/linux/skmsg.h @@ -128,6 +128,7 @@ int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from, struct sk_msg *msg, u32 bytes); int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg, int len, int flags); +__poll_t sk_msg_poll(struct sock *sk); static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes) { @@ -562,5 +563,10 @@ static inline void skb_bpf_redirect_clear(struct sk_buff *skb) { skb->_sk_redir = 0; } +#else +static inline __poll_t sk_msg_poll(struct sock *sk) +{ + return 0; +} #endif /* CONFIG_NET_SOCK_MSG */ #endif /* _LINUX_SKMSG_H */ diff --git a/net/core/skmsg.c b/net/core/skmsg.c index 2d6249b28928..8e6d7ea43eca 100644 --- a/net/core/skmsg.c +++ b/net/core/skmsg.c @@ -474,6 +474,21 @@ int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg, } EXPORT_SYMBOL_GPL(sk_msg_recvmsg); +__poll_t sk_msg_poll(struct sock *sk) +{ + struct sk_psock *psock; + __poll_t mask = 0; + + psock = sk_psock_get_checked(sk); + if (IS_ERR_OR_NULL(psock)) + return 0; + if (!sk_psock_queue_empty(psock)) + mask |= EPOLLIN | EPOLLRDNORM; + sk_psock_put(sk, psock); + return mask; +} +EXPORT_SYMBOL_GPL(sk_msg_poll); + static struct sk_msg *sk_psock_create_ingress_msg(struct sock *sk, struct sk_buff *skb) { diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c index e8b48df73c85..2eb1a87ba056 100644 --- a/net/ipv4/tcp.c +++ b/net/ipv4/tcp.c @@ -280,6 +280,7 @@ #include <linux/uaccess.h> #include <asm/ioctls.h> #include <net/busy_poll.h> +#include <linux/skmsg.h> /* Track pending CMSGs. */ enum { @@ -563,6 +564,7 @@ __poll_t tcp_poll(struct file *file, struct socket *sock, poll_table *wait) if (tcp_stream_is_readable(sk, target)) mask |= EPOLLIN | EPOLLRDNORM; + mask |= sk_msg_poll(sk); if (!(sk->sk_shutdown & SEND_SHUTDOWN)) { if (__sk_stream_is_writeable(sk, 1)) { diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index 8851c9463b4b..fbc989d27388 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -97,6 +97,7 @@ #include <linux/skbuff.h> #include <linux/proc_fs.h> #include <linux/seq_file.h> +#include <linux/skmsg.h> #include <net/net_namespace.h> #include <net/icmp.h> #include <net/inet_hashtables.h> @@ -2866,6 +2867,7 @@ __poll_t udp_poll(struct file *file, struct socket *sock, poll_table *wait) !(sk->sk_shutdown & RCV_SHUTDOWN) && first_packet_length(sk) == -1) mask &= ~(EPOLLIN | EPOLLRDNORM); + mask |= sk_msg_poll(sk); return mask; } diff --git a/net/unix/af_unix.c b/net/unix/af_unix.c index 92345c9bb60c..5d705541d082 100644 --- a/net/unix/af_unix.c +++ b/net/unix/af_unix.c @@ -114,6 +114,7 @@ #include <linux/freezer.h> #include <linux/file.h> #include <linux/btf_ids.h> +#include <linux/skmsg.h> #include "scm.h" @@ -3015,6 +3016,8 @@ static __poll_t unix_poll(struct file *file, struct socket *sock, poll_table *wa if (!skb_queue_empty_lockless(&sk->sk_receive_queue)) mask |= EPOLLIN | EPOLLRDNORM; + mask |= sk_msg_poll(sk); + /* Connection-based need to check for termination and startup */ if ((sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET) && sk->sk_state == TCP_CLOSE) @@ -3054,6 +3057,8 @@ static __poll_t unix_dgram_poll(struct file *file, struct socket *sock, if (!skb_queue_empty_lockless(&sk->sk_receive_queue)) mask |= EPOLLIN | EPOLLRDNORM; + mask |= sk_msg_poll(sk); + /* Connection-based need to check for termination and startup */ if (sk->sk_type == SOCK_SEQPACKET) { if (sk->sk_state == TCP_CLOSE) -- 2.30.2