On Thu, Aug 05, 2021 at 07:13 AM CEST, Jiang Wang wrote: [...] > --- a/net/unix/af_unix.c > +++ b/net/unix/af_unix.c > @@ -791,17 +791,35 @@ static void unix_close(struct sock *sk, long timeout) > */ > } > > -struct proto unix_proto = { > - .name = "UNIX", > +static void unix_unhash(struct sock *sk) > +{ > + /* Nothing to do here, unix socket does not need a ->unhash(). > + * This is merely for sockmap. > + */ > +} > + > +struct proto unix_dgram_proto = { > + .name = "UNIX-DGRAM", > + .owner = THIS_MODULE, > + .obj_size = sizeof(struct unix_sock), > + .close = unix_close, > +#ifdef CONFIG_BPF_SYSCALL > + .psock_update_sk_prot = unix_dgram_bpf_update_proto, > +#endif > +}; > + > +struct proto unix_stream_proto = { > + .name = "UNIX-STREAM", > .owner = THIS_MODULE, > .obj_size = sizeof(struct unix_sock), > .close = unix_close, > + .unhash = unix_unhash, > #ifdef CONFIG_BPF_SYSCALL > - .psock_update_sk_prot = unix_bpf_update_proto, > + .psock_update_sk_prot = unix_stream_bpf_update_proto, > #endif > }; > > -static struct sock *unix_create1(struct net *net, struct socket *sock, int kern) > +static struct sock *unix_create1(struct net *net, struct socket *sock, int kern, int type) > { > struct sock *sk = NULL; > struct unix_sock *u; > @@ -810,7 +828,11 @@ static struct sock *unix_create1(struct net *net, struct socket *sock, int kern) > if (atomic_long_read(&unix_nr_socks) > 2 * get_max_files()) > goto out; > > - sk = sk_alloc(net, PF_UNIX, GFP_KERNEL, &unix_proto, kern); > + if (type == SOCK_STREAM) > + sk = sk_alloc(net, PF_UNIX, GFP_KERNEL, &unix_stream_proto, kern); > + else /*dgram and seqpacket */ > + sk = sk_alloc(net, PF_UNIX, GFP_KERNEL, &unix_dgram_proto, kern); > + > if (!sk) > goto out; > > @@ -872,7 +894,7 @@ static int unix_create(struct net *net, struct socket *sock, int protocol, > return -ESOCKTNOSUPPORT; > } > > - return unix_create1(net, sock, kern) ? 0 : -ENOMEM; > + return unix_create1(net, sock, kern, sock->type) ? 0 : -ENOMEM; > } > > static int unix_release(struct socket *sock) > @@ -1286,7 +1308,7 @@ static int unix_stream_connect(struct socket *sock, struct sockaddr *uaddr, > err = -ENOMEM; > > /* create new sock for complete connection */ > - newsk = unix_create1(sock_net(sk), NULL, 0); > + newsk = unix_create1(sock_net(sk), NULL, 0, sock->type); > if (newsk == NULL) > goto out; > > @@ -2261,7 +2283,7 @@ static int unix_dgram_recvmsg(struct socket *sock, struct msghdr *msg, size_t si > struct sock *sk = sock->sk; > > #ifdef CONFIG_BPF_SYSCALL > - if (sk->sk_prot != &unix_proto) > + if (READ_ONCE(sk->sk_prot) != &unix_dgram_proto) > return sk->sk_prot->recvmsg(sk, msg, size, flags & MSG_DONTWAIT, > flags & ~MSG_DONTWAIT, NULL); Notice we have two reads from sk->sk_prot here. And the value sk->sk_prot holds might change between reads (that is when we remove the socket from sockmap). So we want to load it just once. Otherwise, it seems possible that sk->sk_prot->recvmsg will be called, when sk->sk_prot == unix_proto. Which means sk->sk_prot->recvmsg is NULL. > #endif > @@ -2580,6 +2602,20 @@ static int unix_stream_read_actor(struct sk_buff *skb, > return ret ?: chunk; > } > > +int __unix_stream_recvmsg(struct sock *sk, struct msghdr *msg, > + size_t size, int flags) > +{ > + struct unix_stream_read_state state = { > + .recv_actor = unix_stream_read_actor, > + .socket = sk->sk_socket, > + .msg = msg, > + .size = size, > + .flags = flags > + }; > + > + return unix_stream_read_generic(&state, true); > +} > + > static int unix_stream_recvmsg(struct socket *sock, struct msghdr *msg, > size_t size, int flags) > { > @@ -2591,6 +2627,12 @@ static int unix_stream_recvmsg(struct socket *sock, struct msghdr *msg, > .flags = flags > }; > > +#ifdef CONFIG_BPF_SYSCALL > + struct sock *sk = sock->sk; > + if (sk->sk_prot != &unix_stream_proto) > + return sk->sk_prot->recvmsg(sk, msg, size, flags & MSG_DONTWAIT, > + flags & ~MSG_DONTWAIT, NULL); Also needs READ_ONCE annotations. > +#endif > return unix_stream_read_generic(&state, true); > } > > @@ -2652,6 +2694,7 @@ static int unix_shutdown(struct socket *sock, int mode) > > int peer_mode = 0; > > + other->sk_prot->unhash(other); Here as well. > if (mode&RCV_SHUTDOWN) > peer_mode |= SEND_SHUTDOWN; > if (mode&SEND_SHUTDOWN) [...]