It is possible (via shutdown()) for TCP socks to go through TCP_CLOSE state via tcp_disconnect() without calling into close callback. This would allow a kTLS enabled socket to exist outside of ESTABLISHED state which is not supported. Solve this the same way we solved the sock{map|hash} case by adding an unhash hook to remove tear down the TLS state. Tested with bpf and net selftests plus ran syzkaller reproducers for below listed issues. Fixes: d91c3e17f75f2 ("net/tls: Only attach to sockets in ESTABLISHED state") Reported-by: Eric Dumazet <edumazet@xxxxxxxxxx> Reported-by: syzbot+4207c7f3a443366d8aa2@xxxxxxxxxxxxxxxxxxxxxxxxx Reported-by: syzbot+06537213db7ba2745c4a@xxxxxxxxxxxxxxxxxxxxxxxxx Signed-off-by: John Fastabend <john.fastabend@xxxxxxxxx> --- include/net/tls.h | 2 ++ net/tls/tls_main.c | 50 +++++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/include/net/tls.h b/include/net/tls.h index 6fe1f5c96f4a..935d65606bb3 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -264,6 +264,8 @@ struct tls_context { bool in_tcp_sendpages; bool pending_open_record_frags; + struct proto *sk_proto; + int (*push_pending_record)(struct sock *sk, int flags); void (*sk_write_space)(struct sock *sk); diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 51cb19e24dd9..e1750634a53a 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -251,11 +251,16 @@ static void tls_write_space(struct sock *sk) ctx->sk_write_space(sk); } -static void tls_ctx_free(struct tls_context *ctx) +static void tls_ctx_free(struct sock *sk, struct tls_context *ctx) { + struct inet_connection_sock *icsk = inet_csk(sk); + if (!ctx) return; + sk->sk_prot = ctx->sk_proto; + icsk->icsk_ulp_data = NULL; + memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send)); memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv)); kfree(ctx); @@ -287,23 +292,49 @@ static void tls_sk_proto_cleanup(struct sock *sk, #endif } +static void tls_sk_proto_unhash(struct sock *sk) +{ + struct tls_context *ctx = tls_get_ctx(sk); + void (*sk_proto_unhash)(struct sock *sk); + long timeo = sock_sndtimeo(sk, 0); + + if (unlikely(!ctx)) { + if (sk->sk_prot->unhash) + sk->sk_prot->unhash(sk); + return; + } + + sk->sk_prot = ctx->sk_proto; + sk_proto_unhash = ctx->unhash; + tls_sk_proto_cleanup(sk, ctx, timeo); + if (ctx->rx_conf == TLS_SW) + tls_sw_release_strp_rx(ctx); + tls_ctx_free(sk, ctx); + if (sk_proto_unhash) + sk_proto_unhash(sk); +} + static void tls_sk_proto_close(struct sock *sk, long timeout) { struct tls_context *ctx = tls_get_ctx(sk); long timeo = sock_sndtimeo(sk, 0); void (*sk_proto_close)(struct sock *sk, long timeout); - bool free_ctx = false; + + if (unlikely(!ctx)) { + if (sk->sk_prot->close) + sk->sk_prot->close(sk, timeout); + return; + } lock_sock(sk); + sk->sk_prot = ctx->sk_proto; sk_proto_close = ctx->sk_proto_close; if (ctx->tx_conf == TLS_HW_RECORD && ctx->rx_conf == TLS_HW_RECORD) goto skip_tx_cleanup; - if (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE) { - free_ctx = true; + if (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE) goto skip_tx_cleanup; - } tls_sk_proto_cleanup(sk, ctx, timeo); @@ -311,11 +342,12 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) release_sock(sk); if (ctx->rx_conf == TLS_SW) tls_sw_release_strp_rx(ctx); - sk_proto_close(sk, timeout); if (ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW && ctx->tx_conf != TLS_HW_RECORD && ctx->rx_conf != TLS_HW_RECORD) - tls_ctx_free(ctx); + tls_ctx_free(sk, ctx); + if (sk_proto_close) + sk_proto_close(sk, timeout); } static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval, @@ -733,16 +765,19 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg; prot[TLS_SW][TLS_BASE].sendpage = tls_sw_sendpage; + prot[TLS_SW][TLS_BASE].unhash = tls_sk_proto_unhash; prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE]; prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg; prot[TLS_BASE][TLS_SW].stream_memory_read = tls_sw_stream_read; prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close; + prot[TLS_BASE][TLS_SW].unhash = tls_sk_proto_unhash; prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE]; prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg; prot[TLS_SW][TLS_SW].stream_memory_read = tls_sw_stream_read; prot[TLS_SW][TLS_SW].close = tls_sk_proto_close; + prot[TLS_SW][TLS_SW].unhash = tls_sk_proto_unhash; #ifdef CONFIG_TLS_DEVICE prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; @@ -793,6 +828,7 @@ static int tls_init(struct sock *sk) tls_build_proto(sk); ctx->tx_conf = TLS_BASE; ctx->rx_conf = TLS_BASE; + ctx->sk_proto = sk->sk_prot; update_sk_prot(sk, ctx); out: return rc;