A latter patch refactors bpf_setsockopt(SOL_SOCKET) with the sock_setsockopt() to avoid code duplication and code drift between the two duplicates. The current sock_setsockopt() takes sock ptr as the argument. The very first thing of this function is to get back the sk ptr by 'sk = sock->sk'. bpf_setsockopt() could be called when the sk does not have a userspace owner. Meaning sk->sk_socket is NULL. For example, when a passive tcp connection has just been established. Thus, it cannot use the sock_setsockopt(sk->sk_socket) or else it will pass a NULL sock ptr. All existing callers have both sock->sk and sk->sk_socket pointer. Thus, this patch changes the sock_setsockopt() to take a sk ptr instead of the sock ptr. The bpf_setsockopt() only allows optnames that do not require a sock ptr. Signed-off-by: Martin KaFai Lau <kafai@xxxxxx> --- drivers/nvme/host/tcp.c | 2 +- fs/ksmbd/transport_tcp.c | 2 +- include/net/sock.h | 2 +- net/core/sock.c | 4 ++-- net/mptcp/sockopt.c | 12 ++++++------ net/socket.c | 2 +- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/drivers/nvme/host/tcp.c b/drivers/nvme/host/tcp.c index 7a9e6ffa2342..60e14cc39e49 100644 --- a/drivers/nvme/host/tcp.c +++ b/drivers/nvme/host/tcp.c @@ -1555,7 +1555,7 @@ static int nvme_tcp_alloc_queue(struct nvme_ctrl *nctrl, char *iface = nctrl->opts->host_iface; sockptr_t optval = KERNEL_SOCKPTR(iface); - ret = sock_setsockopt(queue->sock, SOL_SOCKET, SO_BINDTODEVICE, + ret = sock_setsockopt(queue->sock->sk, SOL_SOCKET, SO_BINDTODEVICE, optval, strlen(iface)); if (ret) { dev_err(nctrl->device, diff --git a/fs/ksmbd/transport_tcp.c b/fs/ksmbd/transport_tcp.c index 143bba4e4db8..982eed2dd575 100644 --- a/fs/ksmbd/transport_tcp.c +++ b/fs/ksmbd/transport_tcp.c @@ -420,7 +420,7 @@ static int create_socket(struct interface *iface) ksmbd_tcp_nodelay(ksmbd_socket); ksmbd_tcp_reuseaddr(ksmbd_socket); - ret = sock_setsockopt(ksmbd_socket, + ret = sock_setsockopt(ksmbd_socket->sk, SOL_SOCKET, SO_BINDTODEVICE, KERNEL_SOCKPTR(iface->name), diff --git a/include/net/sock.h b/include/net/sock.h index f7ad1a7705e9..9e2539dcc293 100644 --- a/include/net/sock.h +++ b/include/net/sock.h @@ -1795,7 +1795,7 @@ void sock_pfree(struct sk_buff *skb); #define sock_edemux sock_efree #endif -int sock_setsockopt(struct socket *sock, int level, int op, +int sock_setsockopt(struct sock *sk, int level, int op, sockptr_t optval, unsigned int optlen); int sock_getsockopt(struct socket *sock, int level, int op, diff --git a/net/core/sock.c b/net/core/sock.c index 4cb957d934a2..18bb4f269cf1 100644 --- a/net/core/sock.c +++ b/net/core/sock.c @@ -1041,12 +1041,12 @@ static int sock_reserve_memory(struct sock *sk, int bytes) * at the socket level. Everything here is generic. */ -int sock_setsockopt(struct socket *sock, int level, int optname, +int sock_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval, unsigned int optlen) { struct so_timestamping timestamping; + struct socket *sock = sk->sk_socket; struct sock_txtime sk_txtime; - struct sock *sk = sock->sk; int val; int valbool; struct linger ling; diff --git a/net/mptcp/sockopt.c b/net/mptcp/sockopt.c index 423d3826ca1e..5684499b4d39 100644 --- a/net/mptcp/sockopt.c +++ b/net/mptcp/sockopt.c @@ -124,7 +124,7 @@ static int mptcp_sol_socket_intval(struct mptcp_sock *msk, int optname, int val) struct sock *sk = (struct sock *)msk; int ret; - ret = sock_setsockopt(sk->sk_socket, SOL_SOCKET, optname, + ret = sock_setsockopt(sk, SOL_SOCKET, optname, optval, sizeof(val)); if (ret) return ret; @@ -149,7 +149,7 @@ static int mptcp_setsockopt_sol_socket_tstamp(struct mptcp_sock *msk, int optnam struct sock *sk = (struct sock *)msk; int ret; - ret = sock_setsockopt(sk->sk_socket, SOL_SOCKET, optname, + ret = sock_setsockopt(sk, SOL_SOCKET, optname, optval, sizeof(val)); if (ret) return ret; @@ -225,7 +225,7 @@ static int mptcp_setsockopt_sol_socket_timestamping(struct mptcp_sock *msk, return -EINVAL; } - ret = sock_setsockopt(sk->sk_socket, SOL_SOCKET, optname, + ret = sock_setsockopt(sk, SOL_SOCKET, optname, KERNEL_SOCKPTR(×tamping), sizeof(timestamping)); if (ret) @@ -262,7 +262,7 @@ static int mptcp_setsockopt_sol_socket_linger(struct mptcp_sock *msk, sockptr_t return -EFAULT; kopt = KERNEL_SOCKPTR(&ling); - ret = sock_setsockopt(sk->sk_socket, SOL_SOCKET, SO_LINGER, kopt, sizeof(ling)); + ret = sock_setsockopt(sk, SOL_SOCKET, SO_LINGER, kopt, sizeof(ling)); if (ret) return ret; @@ -306,7 +306,7 @@ static int mptcp_setsockopt_sol_socket(struct mptcp_sock *msk, int optname, return -EINVAL; } - ret = sock_setsockopt(ssock, SOL_SOCKET, optname, optval, optlen); + ret = sock_setsockopt(ssock->sk, SOL_SOCKET, optname, optval, optlen); if (ret == 0) { if (optname == SO_REUSEPORT) sk->sk_reuseport = ssock->sk->sk_reuseport; @@ -349,7 +349,7 @@ static int mptcp_setsockopt_sol_socket(struct mptcp_sock *msk, int optname, case SO_PREFER_BUSY_POLL: case SO_BUSY_POLL_BUDGET: /* No need to copy: only relevant for msk */ - return sock_setsockopt(sk->sk_socket, SOL_SOCKET, optname, optval, optlen); + return sock_setsockopt(sk, SOL_SOCKET, optname, optval, optlen); case SO_NO_CHECK: case SO_DONTROUTE: case SO_BROADCAST: diff --git a/net/socket.c b/net/socket.c index b6bd4cf44d3f..c6911d613ae2 100644 --- a/net/socket.c +++ b/net/socket.c @@ -2245,7 +2245,7 @@ int __sys_setsockopt(int fd, int level, int optname, char __user *user_optval, if (kernel_optval) optval = KERNEL_SOCKPTR(kernel_optval); if (level == SOL_SOCKET && !sock_use_custom_sol_socket(sock)) - err = sock_setsockopt(sock, level, optname, optval, optlen); + err = sock_setsockopt(sock->sk, level, optname, optval, optlen); else if (unlikely(!sock->ops->setsockopt)) err = -EOPNOTSUPP; else -- 2.30.2