On 11/21/23 10:42 AM, Kuniyuki Iwashima wrote:
diff --git a/include/net/inet6_hashtables.h b/include/net/inet6_hashtables.h
index 533a7337865a..9a67f47a5e64 100644
--- a/include/net/inet6_hashtables.h
+++ b/include/net/inet6_hashtables.h
@@ -116,9 +116,23 @@ struct sock *inet6_steal_sock(struct net *net, struct sk_buff *skb, int doff,
if (!sk)
return NULL;
- if (!prefetched || !sk_fullsock(sk))
+ if (!prefetched)
return sk;
+ if (sk->sk_state == TCP_NEW_SYN_RECV) {
+#if IS_ENABLED(CONFIG_SYN_COOKIE)
+ if (inet_reqsk(sk)->syncookie) {
+ *refcounted = false;
+ skb->sk = sk;
+ skb->destructor = sock_pfree;
Instead of re-init the skb->sk and skb->destructor, can skb_steal_sock() avoid
resetting them to NULL in the first place and skb_steal_sock() returns the
rsk_listener instead? btw, can inet_reqsk(sk)->rsk_listener be set to NULL after
this point?
Beside, it is essentially assigning the incoming request to a listening sk. Does
it need to call the inet6_lookup_reuseport() a few lines below to avoid skipping
the bpf reuseport selection that was fixed in commit 9c02bec95954 ("bpf, net:
Support SO_REUSEPORT sockets with bpf_sk_assign")?
+ return inet_reqsk(sk)->rsk_listener;
+ }
+#endif
+ return sk;
+ } else if (sk->sk_state == TCP_TIME_WAIT) {
+ return sk;
+ }
+
if (sk->sk_protocol == IPPROTO_TCP) {
if (sk->sk_state != TCP_LISTEN)
return sk;
diff --git a/include/net/inet_hashtables.h b/include/net/inet_hashtables.h
index 3ecfeadbfa06..36609656a047 100644
--- a/include/net/inet_hashtables.h
+++ b/include/net/inet_hashtables.h
@@ -462,9 +462,23 @@ struct sock *inet_steal_sock(struct net *net, struct sk_buff *skb, int doff,
if (!sk)
return NULL;
- if (!prefetched || !sk_fullsock(sk))
+ if (!prefetched)
return sk;
+ if (sk->sk_state == TCP_NEW_SYN_RECV) {
+#if IS_ENABLED(CONFIG_SYN_COOKIE)
+ if (inet_reqsk(sk)->syncookie) {
+ *refcounted = false;
+ skb->sk = sk;
+ skb->destructor = sock_pfree;
+ return inet_reqsk(sk)->rsk_listener;
+ }
+#endif
+ return sk;
+ } else if (sk->sk_state == TCP_TIME_WAIT) {
+ return sk;
+ }
+
if (sk->sk_protocol == IPPROTO_TCP) {
if (sk->sk_state != TCP_LISTEN)