On Sun, Aug 11, 2024 at 07:21:53PM -0700, Cong Wang wrote: > From: Cong Wang <cong.wang@xxxxxxxxxxxxx> > > After a vsock socket has been added to a BPF sockmap, its prot->recvmsg > has been replaced with vsock_bpf_recvmsg(). Thus the following > recursiion could happen: > > vsock_bpf_recvmsg() > -> __vsock_recvmsg() > -> vsock_connectible_recvmsg() > -> prot->recvmsg() > -> vsock_bpf_recvmsg() again > > We need to fix it by calling the original ->recvmsg() without any BPF > sockmap logic in __vsock_recvmsg(). > > Fixes: 634f1a7110b4 ("vsock: support sockmap") > Reported-by: syzbot+bdb4bd87b5e22058e2a4@xxxxxxxxxxxxxxxxxxxxxxxxx > Tested-by: syzbot+bdb4bd87b5e22058e2a4@xxxxxxxxxxxxxxxxxxxxxxxxx > Cc: Bobby Eshleman <bobby.eshleman@xxxxxxxxxxxxx> > Cc: Michael S. Tsirkin <mst@xxxxxxxxxx> > Cc: Stefano Garzarella <sgarzare@xxxxxxxxxx> > Signed-off-by: Cong Wang <cong.wang@xxxxxxxxxxxxx> Acked-by: Michael S. Tsirkin <mst@xxxxxxxxxx> > --- > include/net/af_vsock.h | 4 ++++ > net/vmw_vsock/af_vsock.c | 50 +++++++++++++++++++++++---------------- > net/vmw_vsock/vsock_bpf.c | 4 ++-- > 3 files changed, 35 insertions(+), 23 deletions(-) > > diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h > index 535701efc1e5..24d970f7a4fa 100644 > --- a/include/net/af_vsock.h > +++ b/include/net/af_vsock.h > @@ -230,8 +230,12 @@ struct vsock_tap { > int vsock_add_tap(struct vsock_tap *vt); > int vsock_remove_tap(struct vsock_tap *vt); > void vsock_deliver_tap(struct sk_buff *build_skb(void *opaque), void *opaque); > +int __vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, > + int flags); > int vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, > int flags); > +int __vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg, > + size_t len, int flags); > int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg, > size_t len, int flags); > > diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c > index 4b040285aa78..0ff9b2dd86ba 100644 > --- a/net/vmw_vsock/af_vsock.c > +++ b/net/vmw_vsock/af_vsock.c > @@ -1270,25 +1270,28 @@ static int vsock_dgram_connect(struct socket *sock, > return err; > } > > +int __vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg, > + size_t len, int flags) > +{ > + struct sock *sk = sock->sk; > + struct vsock_sock *vsk = vsock_sk(sk); > + > + return vsk->transport->dgram_dequeue(vsk, msg, len, flags); > +} > + > int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg, > size_t len, int flags) > { > #ifdef CONFIG_BPF_SYSCALL > + struct sock *sk = sock->sk; > const struct proto *prot; > -#endif > - struct vsock_sock *vsk; > - struct sock *sk; > > - sk = sock->sk; > - vsk = vsock_sk(sk); > - > -#ifdef CONFIG_BPF_SYSCALL > prot = READ_ONCE(sk->sk_prot); > if (prot != &vsock_proto) > return prot->recvmsg(sk, msg, len, flags, NULL); > #endif > > - return vsk->transport->dgram_dequeue(vsk, msg, len, flags); > + return __vsock_dgram_recvmsg(sock, msg, len, flags); > } > EXPORT_SYMBOL_GPL(vsock_dgram_recvmsg); > > @@ -2174,15 +2177,12 @@ static int __vsock_seqpacket_recvmsg(struct sock *sk, struct msghdr *msg, > } > > int > -vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, > - int flags) > +__vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, > + int flags) > { > struct sock *sk; > struct vsock_sock *vsk; > const struct vsock_transport *transport; > -#ifdef CONFIG_BPF_SYSCALL > - const struct proto *prot; > -#endif > int err; > > sk = sock->sk; > @@ -2233,14 +2233,6 @@ vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, > goto out; > } > > -#ifdef CONFIG_BPF_SYSCALL > - prot = READ_ONCE(sk->sk_prot); > - if (prot != &vsock_proto) { > - release_sock(sk); > - return prot->recvmsg(sk, msg, len, flags, NULL); > - } > -#endif > - > if (sk->sk_type == SOCK_STREAM) > err = __vsock_stream_recvmsg(sk, msg, len, flags); > else > @@ -2250,6 +2242,22 @@ vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, > release_sock(sk); > return err; > } > + > +int > +vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, > + int flags) > +{ > +#ifdef CONFIG_BPF_SYSCALL > + struct sock *sk = sock->sk; > + const struct proto *prot; > + > + prot = READ_ONCE(sk->sk_prot); > + if (prot != &vsock_proto) > + return prot->recvmsg(sk, msg, len, flags, NULL); > +#endif > + > + return __vsock_connectible_recvmsg(sock, msg, len, flags); > +} > EXPORT_SYMBOL_GPL(vsock_connectible_recvmsg); > > static int vsock_set_rcvlowat(struct sock *sk, int val) > diff --git a/net/vmw_vsock/vsock_bpf.c b/net/vmw_vsock/vsock_bpf.c > index a3c97546ab84..c42c5cc18f32 100644 > --- a/net/vmw_vsock/vsock_bpf.c > +++ b/net/vmw_vsock/vsock_bpf.c > @@ -64,9 +64,9 @@ static int __vsock_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int > int err; > > if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET) > - err = vsock_connectible_recvmsg(sock, msg, len, flags); > + err = __vsock_connectible_recvmsg(sock, msg, len, flags); > else if (sk->sk_type == SOCK_DGRAM) > - err = vsock_dgram_recvmsg(sock, msg, len, flags); > + err = __vsock_dgram_recvmsg(sock, msg, len, flags); > else > err = -EPROTOTYPE; > > -- > 2.34.1