From: Cong Wang <cong.wang@xxxxxxxxxxxxx> Although we have sk_psock_get(), it assumes the psock retrieved from sk_user_data is for sockmap, this is not sufficient if we call it outside of sockmap, for example, reuseport_array. Fortunately sock_map_psock_get_checked() is more strict and checks for sock_map_close before using psock. So we can refactor it and rename it to sk_psock_get_checked(), which can be safely called outside of sockmap. Cc: John Fastabend <john.fastabend@xxxxxxxxx> Cc: Daniel Borkmann <daniel@xxxxxxxxxxxxx> Cc: Jakub Sitnicki <jakub@xxxxxxxxxxxxxx> Cc: Lorenz Bauer <lmb@xxxxxxxxxxxxxx> Signed-off-by: Cong Wang <cong.wang@xxxxxxxxxxxxx> --- include/linux/skmsg.h | 20 ++++++++++++++++++++ net/core/sock_map.c | 22 +--------------------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h index 14ab0c0bc924..8f577739fc36 100644 --- a/include/linux/skmsg.h +++ b/include/linux/skmsg.h @@ -452,6 +452,26 @@ static inline struct sk_psock *sk_psock_get(struct sock *sk) return psock; } +static inline struct sk_psock *sk_psock_get_checked(struct sock *sk) +{ + struct sk_psock *psock; + + rcu_read_lock(); + psock = sk_psock(sk); + if (psock) { +#if defined(CONFIG_BPF_SYSCALL) + if (sk->sk_prot->close != sock_map_close) { + rcu_read_unlock(); + return ERR_PTR(-EBUSY); + } +#endif + if (!refcount_inc_not_zero(&psock->refcnt)) + psock = ERR_PTR(-EBUSY); + } + rcu_read_unlock(); + return psock; +} + void sk_psock_drop(struct sock *sk, struct sk_psock *psock); static inline void sk_psock_put(struct sock *sk, struct sk_psock *psock) diff --git a/net/core/sock_map.c b/net/core/sock_map.c index e252b8ec2b85..6612bb0b95b5 100644 --- a/net/core/sock_map.c +++ b/net/core/sock_map.c @@ -191,26 +191,6 @@ static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock) return sk->sk_prot->psock_update_sk_prot(sk, psock, false); } -static struct sk_psock *sock_map_psock_get_checked(struct sock *sk) -{ - struct sk_psock *psock; - - rcu_read_lock(); - psock = sk_psock(sk); - if (psock) { - if (sk->sk_prot->close != sock_map_close) { - psock = ERR_PTR(-EBUSY); - goto out; - } - - if (!refcount_inc_not_zero(&psock->refcnt)) - psock = ERR_PTR(-EBUSY); - } -out: - rcu_read_unlock(); - return psock; -} - static int sock_map_link(struct bpf_map *map, struct sock *sk) { struct sk_psock_progs *progs = sock_map_progs(map); @@ -255,7 +235,7 @@ static int sock_map_link(struct bpf_map *map, struct sock *sk) } } - psock = sock_map_psock_get_checked(sk); + psock = sk_psock_get_checked(sk); if (IS_ERR(psock)) { ret = PTR_ERR(psock); goto out_progs; -- 2.30.2