Like skb->_sk_redir, we bundle the sock redirect pointer and the ingress bit to manage them together. Suggested-by: Jakub Sitnicki <jakub@xxxxxxxxxxxxxx> Link: https://lore.kernel.org/bpf/87cz97cnz8.fsf@xxxxxxxxxxxxxx Signed-off-by: Pengcheng Yang <yangpc@xxxxxxxxxx> --- include/linux/skmsg.h | 30 ++++++++++++++++++++++++++++-- net/core/skmsg.c | 18 ++++++++++-------- net/ipv4/tcp_bpf.c | 13 +++++++------ net/tls/tls_sw.c | 11 ++++++----- 4 files changed, 51 insertions(+), 21 deletions(-) diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h index c1637515a8a4..ae021f511f46 100644 --- a/include/linux/skmsg.h +++ b/include/linux/skmsg.h @@ -78,11 +78,10 @@ struct sk_psock_work_state { struct sk_psock { struct sock *sk; - struct sock *sk_redir; + unsigned long _sk_redir; u32 apply_bytes; u32 cork_bytes; u32 eval; - bool redir_ingress; /* undefined if sk_redir is null */ struct sk_msg *cork; struct sk_psock_progs progs; #if IS_ENABLED(CONFIG_BPF_STREAM_PARSER) @@ -283,6 +282,33 @@ static inline struct sk_psock *sk_psock(const struct sock *sk) SK_USER_DATA_PSOCK); } +static inline bool sk_psock_ingress(const struct sk_psock *psock) +{ + unsigned long sk_redir = psock->_sk_redir; + + return sk_redir & BPF_F_INGRESS; +} + +static inline void sk_psock_set_redir(struct sk_psock *psock, struct sock *sk_redir, + bool ingress) +{ + psock->_sk_redir = (unsigned long)sk_redir; + if (ingress) + psock->_sk_redir |= BPF_F_INGRESS; +} + +static inline struct sock *sk_psock_get_redir(struct sk_psock *psock) +{ + unsigned long sk_redir = psock->_sk_redir; + + return (struct sock *)(sk_redir & ~(BPF_F_INGRESS)); +} + +static inline void sk_psock_clear_redir(struct sk_psock *psock) +{ + psock->_sk_redir = 0; +} + static inline void sk_psock_set_state(struct sk_psock *psock, enum sk_psock_state_bits bit) { diff --git a/net/core/skmsg.c b/net/core/skmsg.c index 6c31eefbd777..d994621f1f95 100644 --- a/net/core/skmsg.c +++ b/net/core/skmsg.c @@ -811,6 +811,7 @@ static void sk_psock_destroy(struct work_struct *work) { struct sk_psock *psock = container_of(to_rcu_work(work), struct sk_psock, rwork); + struct sock *sk_redir = sk_psock_get_redir(psock); /* No sk_callback_lock since already detached. */ sk_psock_done_strp(psock); @@ -824,8 +825,8 @@ static void sk_psock_destroy(struct work_struct *work) sk_psock_link_destroy(psock); sk_psock_cork_free(psock); - if (psock->sk_redir) - sock_put(psock->sk_redir); + if (sk_redir) + sock_put(sk_redir); sock_put(psock->sk); kfree(psock); } @@ -865,6 +866,7 @@ int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock, struct sk_msg *msg) { struct bpf_prog *prog; + struct sock *sk_redir; int ret; rcu_read_lock(); @@ -880,17 +882,17 @@ int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock, ret = sk_psock_map_verd(ret, msg->sk_redir); psock->apply_bytes = msg->apply_bytes; if (ret == __SK_REDIRECT) { - if (psock->sk_redir) { - sock_put(psock->sk_redir); - psock->sk_redir = NULL; + sk_redir = sk_psock_get_redir(psock); + if (sk_redir) { + sock_put(sk_redir); + sk_psock_clear_redir(psock); } if (!msg->sk_redir) { ret = __SK_DROP; goto out; } - psock->redir_ingress = sk_msg_to_ingress(msg); - psock->sk_redir = msg->sk_redir; - sock_hold(psock->sk_redir); + sk_psock_set_redir(psock, msg->sk_redir, sk_msg_to_ingress(msg)); + sock_hold(msg->sk_redir); } out: rcu_read_unlock(); diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c index 53b0d62fd2c2..b3c847dc87dc 100644 --- a/net/ipv4/tcp_bpf.c +++ b/net/ipv4/tcp_bpf.c @@ -427,14 +427,14 @@ static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock, sk_msg_apply_bytes(psock, tosend); break; case __SK_REDIRECT: - redir_ingress = psock->redir_ingress; - sk_redir = psock->sk_redir; + redir_ingress = sk_psock_ingress(psock); + sk_redir = sk_psock_get_redir(psock); sk_msg_apply_bytes(psock, tosend); if (!psock->apply_bytes) { /* Clean up before releasing the sock lock. */ eval = psock->eval; psock->eval = __SK_NONE; - psock->sk_redir = NULL; + sk_psock_clear_redir(psock); } if (psock->cork) { cork = true; @@ -476,9 +476,10 @@ static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock, if (likely(!ret)) { if (!psock->apply_bytes) { psock->eval = __SK_NONE; - if (psock->sk_redir) { - sock_put(psock->sk_redir); - psock->sk_redir = NULL; + sk_redir = sk_psock_get_redir(psock); + if (sk_redir) { + sock_put(sk_redir); + sk_psock_clear_redir(psock); } } if (msg && diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index e9d1e83a859d..c91cd07c1285 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -854,8 +854,8 @@ static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk, } break; case __SK_REDIRECT: - redir_ingress = psock->redir_ingress; - sk_redir = psock->sk_redir; + redir_ingress = sk_psock_ingress(psock); + sk_redir = sk_psock_get_redir(psock); memcpy(&msg_redir, msg, sizeof(*msg)); if (msg->apply_bytes < send) msg->apply_bytes = 0; @@ -898,9 +898,10 @@ static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk, } if (reset_eval) { psock->eval = __SK_NONE; - if (psock->sk_redir) { - sock_put(psock->sk_redir); - psock->sk_redir = NULL; + sk_redir = sk_psock_get_redir(psock); + if (sk_redir) { + sock_put(sk_redir); + sk_psock_clear_redir(psock); } } if (rec) -- 2.38.1