Try STARTTLS with the RPC server peer as soon as a transport connection is established. Signed-off-by: Chuck Lever <chuck.lever@xxxxxxxxxx> --- include/linux/sunrpc/clnt.h | 1 - include/linux/sunrpc/sched.h | 1 + net/sunrpc/clnt.c | 59 +++++++++++++++++++++++++++++++++++++++--- 3 files changed, 56 insertions(+), 5 deletions(-) diff --git a/include/linux/sunrpc/clnt.h b/include/linux/sunrpc/clnt.h index 15fd84e4c321..e10a19d136ca 100644 --- a/include/linux/sunrpc/clnt.h +++ b/include/linux/sunrpc/clnt.h @@ -209,7 +209,6 @@ int rpc_call_sync(struct rpc_clnt *clnt, unsigned int flags); struct rpc_task *rpc_call_null(struct rpc_clnt *clnt, struct rpc_cred *cred, int flags); -void rpc_starttls_async(struct rpc_task *task); int rpc_restart_call_prepare(struct rpc_task *); int rpc_restart_call(struct rpc_task *); void rpc_setbufsize(struct rpc_clnt *, unsigned int, unsigned int); diff --git a/include/linux/sunrpc/sched.h b/include/linux/sunrpc/sched.h index f8c09638fa69..0d1ae89a2339 100644 --- a/include/linux/sunrpc/sched.h +++ b/include/linux/sunrpc/sched.h @@ -139,6 +139,7 @@ struct rpc_task_setup { #define RPC_IS_ASYNC(t) ((t)->tk_flags & RPC_TASK_ASYNC) #define RPC_IS_SWAPPER(t) ((t)->tk_flags & RPC_TASK_SWAPPER) #define RPC_IS_CORK(t) ((t)->tk_flags & RPC_TASK_CORK) +#define RPC_IS_TLSPROBE(t) ((t)->tk_flags & RPC_TASK_TLSCRED) #define RPC_IS_SOFT(t) ((t)->tk_flags & (RPC_TASK_SOFT|RPC_TASK_TIMEOUT)) #define RPC_IS_SOFTCONN(t) ((t)->tk_flags & RPC_TASK_SOFTCONN) #define RPC_WAS_SENT(t) ((t)->tk_flags & RPC_TASK_SENT) diff --git a/net/sunrpc/clnt.c b/net/sunrpc/clnt.c index e9a6622dba68..0506971410f7 100644 --- a/net/sunrpc/clnt.c +++ b/net/sunrpc/clnt.c @@ -70,6 +70,8 @@ static void call_refresh(struct rpc_task *task); static void call_refreshresult(struct rpc_task *task); static void call_connect(struct rpc_task *task); static void call_connect_status(struct rpc_task *task); +static void call_start_tls(struct rpc_task *task); +static void call_tls_status(struct rpc_task *task); static int rpc_encode_header(struct rpc_task *task, struct xdr_stream *xdr); @@ -77,6 +79,7 @@ static int rpc_decode_header(struct rpc_task *task, struct xdr_stream *xdr); static int rpc_ping(struct rpc_clnt *clnt); static int rpc_starttls_sync(struct rpc_clnt *clnt); +static void rpc_starttls_async(struct rpc_task *task); static void rpc_check_timeout(struct rpc_task *task); static void rpc_register_client(struct rpc_clnt *clnt) @@ -2163,7 +2166,7 @@ call_connect_status(struct rpc_task *task) rpc_call_rpcerror(task, status); return; out_next: - task->tk_action = call_transmit; + task->tk_action = call_start_tls; return; out_retry: /* Check for timeouts before looping back to call_bind */ @@ -2171,6 +2174,53 @@ call_connect_status(struct rpc_task *task) rpc_check_timeout(task); } +static void +call_start_tls(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; + struct rpc_clnt *clnt = task->tk_client; + + task->tk_action = call_transmit; + if (RPC_IS_TLSPROBE(task)) + return; + + switch (clnt->cl_xprtsec_policy) { + case RPC_XPRTSEC_TLS: + case RPC_XPRTSEC_MTLS: + if (xprt->ops->tls_handshake_async) { + task->tk_action = call_tls_status; + rpc_starttls_async(task); + } + break; + default: + break; + } +} + +static void +call_tls_status(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; + struct rpc_clnt *clnt = task->tk_client; + + task->tk_action = call_transmit; + if (!task->tk_status) + return; + + xprt_force_disconnect(xprt); + + switch (clnt->cl_xprtsec_policy) { + case RPC_XPRTSEC_TLS: + case RPC_XPRTSEC_MTLS: + rpc_delay(task, 5*HZ /* arbitrary */); + break; + default: + task->tk_action = call_bind; + } + + rpc_check_timeout(task); +} + /* * 5. Transmit the RPC request, and wait for reply */ @@ -2355,7 +2405,7 @@ call_status(struct rpc_task *task) struct rpc_clnt *clnt = task->tk_client; int status; - if (!task->tk_msg.rpc_proc->p_proc) + if (!task->tk_msg.rpc_proc->p_proc && !RPC_IS_TLSPROBE(task)) trace_xprt_ping(task->tk_xprt, task->tk_status); status = task->tk_status; @@ -2663,6 +2713,8 @@ rpc_decode_header(struct rpc_task *task, struct xdr_stream *xdr) out_msg_denied: error = -EACCES; + if (RPC_IS_TLSPROBE(task)) + goto out_err; p = xdr_inline_decode(xdr, sizeof(*p)); if (!p) goto out_unparsable; @@ -2865,7 +2917,7 @@ static const struct rpc_call_ops rpc_ops_probe_tls = { * @task: an RPC task waiting for a TLS session * */ -void rpc_starttls_async(struct rpc_task *task) +static void rpc_starttls_async(struct rpc_task *task) { struct rpc_xprt *xprt = xprt_get(task->tk_xprt); @@ -2885,7 +2937,6 @@ void rpc_starttls_async(struct rpc_task *task) RPC_TASK_TLSCRED | RPC_TASK_SWAPPER | RPC_TASK_CORK, &rpc_ops_probe_tls, xprt)); } -EXPORT_SYMBOL_GPL(rpc_starttls_async); struct rpc_cb_add_xprt_calldata { struct rpc_xprt_switch *xps;