Guest receive mergeable rx buffer, it can merge scatter rx buffer into a big buffer and then copy to user space. In addition, it also use iovec to replace buf in struct virtio_vsock_pkt, keep tx and rx consistency. The only difference is now tx still uses a segment of continuous physical memory to implement. Signed-off-by: Yiwen Jiang <jiangyiwen@xxxxxxxxxx> --- drivers/vhost/vsock.c | 31 +++++++--- include/linux/virtio_vsock.h | 6 +- net/vmw_vsock/virtio_transport.c | 105 ++++++++++++++++++++++++++++---- net/vmw_vsock/virtio_transport_common.c | 59 ++++++++++++++---- 4 files changed, 166 insertions(+), 35 deletions(-) diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c index dc52b0f..c7ab0dd 100644 --- a/drivers/vhost/vsock.c +++ b/drivers/vhost/vsock.c @@ -179,6 +179,8 @@ static int get_rx_bufs(struct vhost_virtqueue *vq, size_t nbytes; size_t len; s16 headcount; + size_t remain_len; + int i; spin_lock_bh(&vsock->send_pkt_list_lock); if (list_empty(&vsock->send_pkt_list)) { @@ -221,11 +223,19 @@ static int get_rx_bufs(struct vhost_virtqueue *vq, break; } - nbytes = copy_to_iter(pkt->buf, pkt->len, &iov_iter); - if (nbytes != pkt->len) { - virtio_transport_free_pkt(pkt); - vq_err(vq, "Faulted on copying pkt buf\n"); - break; + remain_len = pkt->len; + for (i = 0; i < pkt->nr_vecs; i++) { + int tmp_len; + + tmp_len = min(remain_len, pkt->vec[i].iov_len); + nbytes = copy_to_iter(pkt->vec[i].iov_base, tmp_len, &iov_iter); + if (nbytes != tmp_len) { + virtio_transport_free_pkt(pkt); + vq_err(vq, "Faulted on copying pkt buf\n"); + break; + } + + remain_len -= tmp_len; } vhost_add_used_n(vq, vq->heads, headcount); @@ -341,6 +351,7 @@ static void vhost_transport_send_pkt_work(struct vhost_work *work) struct iov_iter iov_iter; size_t nbytes; size_t len; + void *buf; if (in != 0) { vq_err(vq, "Expected 0 input buffers, got %u\n", in); @@ -375,13 +386,17 @@ static void vhost_transport_send_pkt_work(struct vhost_work *work) return NULL; } - pkt->buf = kmalloc(pkt->len, GFP_KERNEL); - if (!pkt->buf) { + buf = kmalloc(pkt->len, GFP_KERNEL); + if (!buf) { kfree(pkt); return NULL; } - nbytes = copy_from_iter(pkt->buf, pkt->len, &iov_iter); + pkt->vec[0].iov_base = buf; + pkt->vec[0].iov_len = pkt->len; + pkt->nr_vecs = 1; + + nbytes = copy_from_iter(buf, pkt->len, &iov_iter); if (nbytes != pkt->len) { vq_err(vq, "Expected %u byte payload, got %zu bytes\n", pkt->len, nbytes); diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h index da9e1fe..734eeed 100644 --- a/include/linux/virtio_vsock.h +++ b/include/linux/virtio_vsock.h @@ -13,6 +13,8 @@ #define VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE (1024 * 4) #define VIRTIO_VSOCK_MAX_BUF_SIZE 0xFFFFFFFFUL #define VIRTIO_VSOCK_MAX_PKT_BUF_SIZE (1024 * 64) +/* virtio_vsock_pkt + max_pkt_len(default MAX_PKT_BUF_SIZE) */ +#define VIRTIO_VSOCK_MAX_VEC_NUM ((VIRTIO_VSOCK_MAX_PKT_BUF_SIZE / PAGE_SIZE) + 1) /* Virtio-vsock feature */ #define VIRTIO_VSOCK_F_MRG_RXBUF 0 /* Host can merge receive buffers. */ @@ -55,10 +57,12 @@ struct virtio_vsock_pkt { struct list_head list; /* socket refcnt not held, only use for cancellation */ struct vsock_sock *vsk; - void *buf; + struct kvec vec[VIRTIO_VSOCK_MAX_VEC_NUM]; + int nr_vecs; u32 len; u32 off; bool reply; + bool mergeable; }; struct virtio_vsock_pkt_info { diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c index c4a465c..148b58a 100644 --- a/net/vmw_vsock/virtio_transport.c +++ b/net/vmw_vsock/virtio_transport.c @@ -155,8 +155,10 @@ static int virtio_transport_send_pkt_loopback(struct virtio_vsock *vsock, sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr)); sgs[out_sg++] = &hdr; - if (pkt->buf) { - sg_init_one(&buf, pkt->buf, pkt->len); + if (pkt->len) { + /* Currently only support a segment of memory in tx */ + BUG_ON(pkt->vec[0].iov_len != pkt->len); + sg_init_one(&buf, pkt->vec[0].iov_base, pkt->vec[0].iov_len); sgs[out_sg++] = &buf; } @@ -304,23 +306,28 @@ static int fill_old_rx_buff(struct virtqueue *vq) struct virtio_vsock_pkt *pkt; struct scatterlist hdr, buf, *sgs[2]; int ret; + void *pkt_buf; pkt = kzalloc(sizeof(*pkt), GFP_KERNEL); if (!pkt) return -ENOMEM; - pkt->buf = kmalloc(buf_len, GFP_KERNEL); - if (!pkt->buf) { + pkt_buf = kmalloc(buf_len, GFP_KERNEL); + if (!pkt_buf) { virtio_transport_free_pkt(pkt); return -ENOMEM; } + pkt->vec[0].iov_base = pkt_buf; + pkt->vec[0].iov_len = buf_len; + pkt->nr_vecs = 1; + pkt->len = buf_len; sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr)); sgs[0] = &hdr; - sg_init_one(&buf, pkt->buf, buf_len); + sg_init_one(&buf, pkt->vec[0].iov_base, buf_len); sgs[1] = &buf; ret = virtqueue_add_sgs(vq, sgs, 0, 2, pkt, GFP_KERNEL); if (ret) @@ -388,11 +395,78 @@ static bool virtio_transport_more_replies(struct virtio_vsock *vsock) return val < virtqueue_get_vring_size(vq); } +static struct virtio_vsock_pkt *receive_mergeable(struct virtqueue *vq, + struct virtio_vsock *vsock, unsigned int *total_len) +{ + struct virtio_vsock_pkt *pkt; + u16 num_buf; + void *buf; + unsigned int len; + size_t vsock_hlen = sizeof(struct virtio_vsock_pkt); + + buf = virtqueue_get_buf(vq, &len); + if (!buf) + return NULL; + + *total_len = len; + vsock->rx_buf_nr--; + + if (unlikely(len < vsock_hlen)) { + put_page(virt_to_head_page(buf)); + return NULL; + } + + pkt = buf; + num_buf = le16_to_cpu(pkt->mrg_rxbuf_hdr.num_buffers); + if (!num_buf || num_buf > VIRTIO_VSOCK_MAX_VEC_NUM) { + put_page(virt_to_head_page(buf)); + return NULL; + } + + /* Initialize pkt residual structure */ + memset(&pkt->work, 0, vsock_hlen - sizeof(struct virtio_vsock_hdr) - + sizeof(struct virtio_vsock_mrg_rxbuf_hdr)); + + pkt->mergeable = true; + pkt->len = le32_to_cpu(pkt->hdr.len); + if (!pkt->len) + return pkt; + + len -= vsock_hlen; + if (len) { + pkt->vec[pkt->nr_vecs].iov_base = buf + vsock_hlen; + pkt->vec[pkt->nr_vecs].iov_len = len; + /* Shared page with pkt, so get page in advance */ + get_page(virt_to_head_page(buf)); + pkt->nr_vecs++; + } + + while (--num_buf) { + buf = virtqueue_get_buf(vq, &len); + if (!buf) + goto err; + + *total_len += len; + vsock->rx_buf_nr--; + + pkt->vec[pkt->nr_vecs].iov_base = buf; + pkt->vec[pkt->nr_vecs].iov_len = len; + pkt->nr_vecs++; + } + + return pkt; +err: + virtio_transport_free_pkt(pkt); + return NULL; +} + static void virtio_transport_rx_work(struct work_struct *work) { struct virtio_vsock *vsock = container_of(work, struct virtio_vsock, rx_work); struct virtqueue *vq; + size_t vsock_hlen = vsock->mergeable ? sizeof(struct virtio_vsock_pkt) : + sizeof(struct virtio_vsock_hdr); vq = vsock->vqs[VSOCK_VQ_RX]; @@ -412,21 +486,26 @@ static void virtio_transport_rx_work(struct work_struct *work) goto out; } - pkt = virtqueue_get_buf(vq, &len); - if (!pkt) { - break; - } + if (likely(vsock->mergeable)) { + pkt = receive_mergeable(vq, vsock, &len); + if (!pkt) + break; + } else { + pkt = virtqueue_get_buf(vq, &len); + if (!pkt) + break; - vsock->rx_buf_nr--; + vsock->rx_buf_nr--; + } /* Drop short/long packets */ - if (unlikely(len < sizeof(pkt->hdr) || - len > sizeof(pkt->hdr) + pkt->len)) { + if (unlikely(len < vsock_hlen || + len > vsock_hlen + pkt->len)) { virtio_transport_free_pkt(pkt); continue; } - pkt->len = len - sizeof(pkt->hdr); + pkt->len = len - vsock_hlen; virtio_transport_deliver_tap_pkt(pkt); virtio_transport_recv_pkt(pkt); } diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c index 3ae3a33..123a8b6 100644 --- a/net/vmw_vsock/virtio_transport_common.c +++ b/net/vmw_vsock/virtio_transport_common.c @@ -44,6 +44,7 @@ static const struct virtio_transport *virtio_transport_get_ops(void) { struct virtio_vsock_pkt *pkt; int err; + void *buf = NULL; pkt = kzalloc(sizeof(*pkt), GFP_KERNEL); if (!pkt) @@ -62,12 +63,16 @@ static const struct virtio_transport *virtio_transport_get_ops(void) pkt->vsk = info->vsk; if (info->msg && len > 0) { - pkt->buf = kmalloc(len, GFP_KERNEL); - if (!pkt->buf) + buf = kmalloc(len, GFP_KERNEL); + if (!buf) goto out_pkt; - err = memcpy_from_msg(pkt->buf, info->msg, len); + err = memcpy_from_msg(buf, info->msg, len); if (err) goto out; + + pkt->vec[0].iov_base = buf; + pkt->vec[0].iov_len = len; + pkt->nr_vecs = 1; } trace_virtio_transport_alloc_pkt(src_cid, src_port, @@ -80,7 +85,7 @@ static const struct virtio_transport *virtio_transport_get_ops(void) return pkt; out: - kfree(pkt->buf); + kfree(buf); out_pkt: kfree(pkt); return NULL; @@ -92,6 +97,7 @@ static struct sk_buff *virtio_transport_build_skb(void *opaque) struct virtio_vsock_pkt *pkt = opaque; struct af_vsockmon_hdr *hdr; struct sk_buff *skb; + int i; skb = alloc_skb(sizeof(*hdr) + sizeof(pkt->hdr) + pkt->len, GFP_ATOMIC); @@ -134,7 +140,8 @@ static struct sk_buff *virtio_transport_build_skb(void *opaque) skb_put_data(skb, &pkt->hdr, sizeof(pkt->hdr)); if (pkt->len) { - skb_put_data(skb, pkt->buf, pkt->len); + for (i = 0; i < pkt->nr_vecs; i++) + skb_put_data(skb, pkt->vec[i].iov_base, pkt->vec[i].iov_len); } return skb; @@ -260,6 +267,9 @@ static int virtio_transport_send_credit_update(struct vsock_sock *vsk, spin_lock_bh(&vvs->rx_lock); while (total < len && !list_empty(&vvs->rx_queue)) { + size_t copy_bytes, last_vec_total = 0, vec_off; + int i; + pkt = list_first_entry(&vvs->rx_queue, struct virtio_vsock_pkt, list); @@ -272,14 +282,28 @@ static int virtio_transport_send_credit_update(struct vsock_sock *vsk, */ spin_unlock_bh(&vvs->rx_lock); - err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes); - if (err) - goto out; + for (i = 0; i < pkt->nr_vecs; i++) { + if (pkt->off > last_vec_total + pkt->vec[i].iov_len) { + last_vec_total += pkt->vec[i].iov_len; + continue; + } + + vec_off = pkt->off - last_vec_total; + copy_bytes = min(pkt->vec[i].iov_len - vec_off, bytes); + err = memcpy_to_msg(msg, pkt->vec[i].iov_base + vec_off, + copy_bytes); + if (err) + goto out; + + bytes -= copy_bytes; + pkt->off += copy_bytes; + total += copy_bytes; + last_vec_total += pkt->vec[i].iov_len; + if (!bytes) + break; + } spin_lock_bh(&vvs->rx_lock); - - total += bytes; - pkt->off += bytes; if (pkt->off == pkt->len) { virtio_transport_dec_rx_pkt(vvs, pkt); list_del(&pkt->list); @@ -1050,8 +1074,17 @@ void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt) void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt) { - kfree(pkt->buf); - kfree(pkt); + int i; + + if (pkt->mergeable) { + for (i = 0; i < pkt->nr_vecs; i++) + put_page(virt_to_head_page(pkt->vec[i].iov_base)); + put_page(virt_to_head_page((void *)pkt)); + } else { + for (i = 0; i < pkt->nr_vecs; i++) + kfree(pkt->vec[i].iov_base); + kfree(pkt); + } } EXPORT_SYMBOL_GPL(virtio_transport_free_pkt); -- 1.8.3.1