On Sat, Nov 23, 2019 at 12:07:47PM +0100, Jakub Sitnicki wrote: [ ... ] > @@ -370,6 +378,11 @@ static inline void sk_psock_restore_proto(struct sock *sk, > sk->sk_prot = psock->sk_proto; > psock->sk_proto = NULL; > } > + > + if (psock->icsk_af_ops) { > + icsk->icsk_af_ops = psock->icsk_af_ops; > + psock->icsk_af_ops = NULL; > + } > } [ ... ] > +static struct sock *tcp_bpf_syn_recv_sock(const struct sock *sk, > + struct sk_buff *skb, > + struct request_sock *req, > + struct dst_entry *dst, > + struct request_sock *req_unhash, > + bool *own_req) > +{ > + const struct inet_connection_sock_af_ops *ops; > + void (*write_space)(struct sock *sk); > + struct sk_psock *psock; > + struct proto *proto; > + struct sock *child; > + > + rcu_read_lock(); > + psock = sk_psock(sk); > + if (likely(psock)) { > + proto = psock->sk_proto; > + write_space = psock->saved_write_space; > + ops = psock->icsk_af_ops; It is not immediately clear to me what ensure ops is not NULL here. It is likely I missed something. A short comment would be very useful here. > + } else { > + ops = inet_csk(sk)->icsk_af_ops; > + } > + rcu_read_unlock(); > + > + child = ops->syn_recv_sock(sk, skb, req, dst, req_unhash, own_req); > + > + /* Child must not inherit psock or its ops. */ > + if (child && psock) { > + rcu_assign_sk_user_data(child, NULL); > + child->sk_prot = proto; > + child->sk_write_space = write_space; > + > + /* v4-mapped sockets don't inherit parent ops. Don't restore. */ > + if (inet_csk(child)->icsk_af_ops == inet_csk(sk)->icsk_af_ops) > + inet_csk(child)->icsk_af_ops = ops; > + } > + return child; > +} > + > enum { > TCP_BPF_IPV4, > TCP_BPF_IPV6, > @@ -597,6 +642,7 @@ enum { > static struct proto *tcpv6_prot_saved __read_mostly; > static DEFINE_SPINLOCK(tcpv6_prot_lock); > static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS]; > +static struct inet_connection_sock_af_ops tcp_bpf_af_ops[TCP_BPF_NUM_PROTS]; > > static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS], > struct proto *base) > @@ -612,13 +658,23 @@ static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS], > prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage; > } > > -static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops) > +static void tcp_bpf_rebuild_af_ops(struct inet_connection_sock_af_ops *ops, > + const struct inet_connection_sock_af_ops *base) > +{ > + *ops = *base; > + ops->syn_recv_sock = tcp_bpf_syn_recv_sock; > +} > + > +static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops, > + const struct inet_connection_sock_af_ops *af_ops) > { > if (sk->sk_family == AF_INET6 && > unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) { > spin_lock_bh(&tcpv6_prot_lock); > if (likely(ops != tcpv6_prot_saved)) { > tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops); > + tcp_bpf_rebuild_af_ops(&tcp_bpf_af_ops[TCP_BPF_IPV6], > + af_ops); > smp_store_release(&tcpv6_prot_saved, ops); > } > spin_unlock_bh(&tcpv6_prot_lock); > @@ -628,6 +684,8 @@ static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops) > static int __init tcp_bpf_v4_build_proto(void) > { > tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot); > + tcp_bpf_rebuild_af_ops(&tcp_bpf_af_ops[TCP_BPF_IPV4], &ipv4_specific); > + > return 0; > } > core_initcall(tcp_bpf_v4_build_proto); > @@ -637,7 +695,8 @@ static void tcp_bpf_update_sk_prot(struct sock *sk, struct sk_psock *psock) > int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; > int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; > > - sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]); > + sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config], > + &tcp_bpf_af_ops[family]); > } > > static void tcp_bpf_reinit_sk_prot(struct sock *sk, struct sk_psock *psock) > @@ -677,6 +736,7 @@ void tcp_bpf_reinit(struct sock *sk) > > int tcp_bpf_init(struct sock *sk) > { > + struct inet_connection_sock *icsk = inet_csk(sk); > struct proto *ops = READ_ONCE(sk->sk_prot); > struct sk_psock *psock; > > @@ -689,7 +749,7 @@ int tcp_bpf_init(struct sock *sk) > rcu_read_unlock(); > return -EINVAL; > } > - tcp_bpf_check_v6_needs_rebuild(sk, ops); > + tcp_bpf_check_v6_needs_rebuild(sk, ops, icsk->icsk_af_ops); > tcp_bpf_update_sk_prot(sk, psock); > rcu_read_unlock(); > return 0; > -- > 2.20.1 >