Currently, when a hvsock socket is closed, the socket is shutdown and immediately a RST is sent. There is no wait for the FIN packet to arrive from the other end. This can lead to data loss since the connection is terminated abruptly. This can manifest easily in cases of a fast guest hvsock writer and a much slower host hvsock reader. Essentially hvsock is not following the proper STREAM(TCP) closing handshake mechanism. The fix involves adding support for the delayed close of hvsock, which is in-line with other socket providers such as virtio. While closing, the socket waits for a constant timeout, for the FIN packet to arrive from the other end. On timeout, it will terminate the connection (i.e a RST). Signed-off-by: Sunil Muthuswamy <sunilmut@xxxxxxxxxxxxx> --- net/vmw_vsock/hyperv_transport.c | 122 ++++++++++++++++++++++++++++----------- 1 file changed, 87 insertions(+), 35 deletions(-) diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c index a827547..62b986d 100644 --- a/net/vmw_vsock/hyperv_transport.c +++ b/net/vmw_vsock/hyperv_transport.c @@ -35,6 +35,9 @@ /* The MTU is 16KB per the host side's design */ #define HVS_MTU_SIZE (1024 * 16) +/* How long to wait for graceful shutdown of a connection */ +#define HVS_CLOSE_TIMEOUT (8 * HZ) + struct vmpipe_proto_header { u32 pkt_type; u32 data_size; @@ -305,19 +308,33 @@ static void hvs_channel_cb(void *ctx) sk->sk_write_space(sk); } -static void hvs_close_connection(struct vmbus_channel *chan) +static void hvs_do_close_lock_held(struct vsock_sock *vsk, + bool cancel_timeout) { - struct sock *sk = get_per_channel_state(chan); - struct vsock_sock *vsk = vsock_sk(sk); - - lock_sock(sk); + struct sock *sk = sk_vsock(vsk); - sk->sk_state = TCP_CLOSE; sock_set_flag(sk, SOCK_DONE); - vsk->peer_shutdown |= SEND_SHUTDOWN | RCV_SHUTDOWN; - + vsk->peer_shutdown = SHUTDOWN_MASK; + if (vsock_stream_has_data(vsk) <= 0) + sk->sk_state = TCP_CLOSING; sk->sk_state_change(sk); + if (vsk->close_work_scheduled && + (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) { + vsk->close_work_scheduled = false; + vsock_remove_sock(vsk); + + /* Release the reference taken while scheduling the timeout */ + sock_put(sk); + } +} + +/* Equivalent of a RST */ +static void hvs_close_connection(struct vmbus_channel *chan) +{ + struct sock *sk = get_per_channel_state(chan); + lock_sock(sk); + hvs_do_close_lock_held(vsock_sk(sk), true); release_sock(sk); } @@ -452,50 +469,80 @@ static int hvs_connect(struct vsock_sock *vsk) return vmbus_send_tl_connect_request(&h->vm_srv_id, &h->host_srv_id); } +static inline void hvs_shutdown_lock_held(struct hvsock *hvs, int mode) +{ + struct vmpipe_proto_header hdr; + + if (hvs->fin_sent || !hvs->chan) + return; + + /* It can't fail: see hvs_channel_writable_bytes(). */ + (void)hvs_send_data(hvs->chan, (struct hvs_send_buf *)&hdr, 0); + hvs->fin_sent = true; +} + static int hvs_shutdown(struct vsock_sock *vsk, int mode) { struct sock *sk = sk_vsock(vsk); - struct vmpipe_proto_header hdr; - struct hvs_send_buf *send_buf; - struct hvsock *hvs; if (!(mode & SEND_SHUTDOWN)) return 0; lock_sock(sk); + hvs_shutdown_lock_held(vsk->trans, mode); + release_sock(sk); + return 0; +} - hvs = vsk->trans; - if (hvs->fin_sent) - goto out; - - send_buf = (struct hvs_send_buf *)&hdr; +static void hvs_close_timeout(struct work_struct *work) +{ + struct vsock_sock *vsk = + container_of(work, struct vsock_sock, close_work.work); + struct sock *sk = sk_vsock(vsk); - /* It can't fail: see hvs_channel_writable_bytes(). */ - (void)hvs_send_data(hvs->chan, send_buf, 0); + sock_hold(sk); + lock_sock(sk); + if (!sock_flag(sk, SOCK_DONE)) + hvs_do_close_lock_held(vsk, false); - hvs->fin_sent = true; -out: + vsk->close_work_scheduled = false; release_sock(sk); - return 0; + sock_put(sk); } -static void hvs_release(struct vsock_sock *vsk) +/* Returns true, if it is safe to remove socket; false otherwise */ +static bool hvs_close_lock_held(struct vsock_sock *vsk) { struct sock *sk = sk_vsock(vsk); - struct hvsock *hvs = vsk->trans; - struct vmbus_channel *chan; - lock_sock(sk); + if (!(sk->sk_state == TCP_ESTABLISHED || + sk->sk_state == TCP_CLOSING)) + return true; - sk->sk_state = TCP_CLOSING; - vsock_remove_sock(vsk); + if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK) + hvs_shutdown_lock_held(vsk->trans, SHUTDOWN_MASK); - release_sock(sk); + if (sock_flag(sk, SOCK_DONE)) + return true; - chan = hvs->chan; - if (chan) - hvs_shutdown(vsk, RCV_SHUTDOWN | SEND_SHUTDOWN); + /* This reference will be dropped by the delayed close routine */ + sock_hold(sk); + INIT_DELAYED_WORK(&vsk->close_work, hvs_close_timeout); + vsk->close_work_scheduled = true; + schedule_delayed_work(&vsk->close_work, HVS_CLOSE_TIMEOUT); + return false; +} +static void hvs_release(struct vsock_sock *vsk) +{ + struct sock *sk = sk_vsock(vsk); + bool remove_sock; + + lock_sock(sk); + remove_sock = hvs_close_lock_held(vsk); + release_sock(sk); + if (remove_sock) + vsock_remove_sock(vsk); } static void hvs_destruct(struct vsock_sock *vsk) @@ -532,10 +579,11 @@ static bool hvs_dgram_allow(u32 cid, u32 port) return false; } -static int hvs_update_recv_data(struct hvsock *hvs) +static int hvs_update_recv_data(struct vsock_sock *vsk) { struct hvs_recv_buf *recv_buf; u32 payload_len; + struct hvsock *hvs = vsk->trans; recv_buf = (struct hvs_recv_buf *)(hvs->recv_desc + 1); payload_len = recv_buf->hdr.data_size; @@ -543,8 +591,12 @@ static int hvs_update_recv_data(struct hvsock *hvs) if (payload_len > HVS_MTU_SIZE) return -EIO; - if (payload_len == 0) + /* Peer shutdown */ + if (payload_len == 0) { + struct sock *sk = sk_vsock(vsk); hvs->vsk->peer_shutdown |= SEND_SHUTDOWN; + sk->sk_state_change(sk); + } hvs->recv_data_len = payload_len; hvs->recv_data_off = 0; @@ -566,7 +618,7 @@ static ssize_t hvs_stream_dequeue(struct vsock_sock *vsk, struct msghdr *msg, if (need_refill) { hvs->recv_desc = hv_pkt_iter_first(hvs->chan); - ret = hvs_update_recv_data(hvs); + ret = hvs_update_recv_data(vsk); if (ret) return ret; } @@ -581,7 +633,7 @@ static ssize_t hvs_stream_dequeue(struct vsock_sock *vsk, struct msghdr *msg, if (hvs->recv_data_len == 0) { hvs->recv_desc = hv_pkt_iter_next(hvs->chan, hvs->recv_desc); if (hvs->recv_desc) { - ret = hvs_update_recv_data(hvs); + ret = hvs_update_recv_data(vsk); if (ret) return ret; } -- 2.7.4