[PATCH v2 3/5] VSOCK: support receive mergeable rx buffer in guest

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



Guest receive mergeable rx buffer, it can merge
scatter rx buffer into a big buffer and then copy
to user space.

In addition, it also use iovec to replace buf in struct
virtio_vsock_pkt, keep tx and rx consistency. The only
difference is now tx still uses a segment of continuous
physical memory to implement.

Signed-off-by: Yiwen Jiang <jiangyiwen@xxxxxxxxxx>
---
 drivers/vhost/vsock.c                   |  31 +++++++---
 include/linux/virtio_vsock.h            |   6 +-
 net/vmw_vsock/virtio_transport.c        | 105 ++++++++++++++++++++++++++++----
 net/vmw_vsock/virtio_transport_common.c |  59 ++++++++++++++----
 4 files changed, 166 insertions(+), 35 deletions(-)

diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
index dc52b0f..c7ab0dd 100644
--- a/drivers/vhost/vsock.c
+++ b/drivers/vhost/vsock.c
@@ -179,6 +179,8 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,
 		size_t nbytes;
 		size_t len;
 		s16 headcount;
+		size_t remain_len;
+		int i;

 		spin_lock_bh(&vsock->send_pkt_list_lock);
 		if (list_empty(&vsock->send_pkt_list)) {
@@ -221,11 +223,19 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,
 			break;
 		}

-		nbytes = copy_to_iter(pkt->buf, pkt->len, &iov_iter);
-		if (nbytes != pkt->len) {
-			virtio_transport_free_pkt(pkt);
-			vq_err(vq, "Faulted on copying pkt buf\n");
-			break;
+		remain_len = pkt->len;
+		for (i = 0; i < pkt->nr_vecs; i++) {
+			int tmp_len;
+
+			tmp_len = min(remain_len, pkt->vec[i].iov_len);
+			nbytes = copy_to_iter(pkt->vec[i].iov_base, tmp_len, &iov_iter);
+			if (nbytes != tmp_len) {
+				virtio_transport_free_pkt(pkt);
+				vq_err(vq, "Faulted on copying pkt buf\n");
+				break;
+			}
+
+			remain_len -= tmp_len;
 		}

 		vhost_add_used_n(vq, vq->heads, headcount);
@@ -341,6 +351,7 @@ static void vhost_transport_send_pkt_work(struct vhost_work *work)
 	struct iov_iter iov_iter;
 	size_t nbytes;
 	size_t len;
+	void *buf;

 	if (in != 0) {
 		vq_err(vq, "Expected 0 input buffers, got %u\n", in);
@@ -375,13 +386,17 @@ static void vhost_transport_send_pkt_work(struct vhost_work *work)
 		return NULL;
 	}

-	pkt->buf = kmalloc(pkt->len, GFP_KERNEL);
-	if (!pkt->buf) {
+	buf = kmalloc(pkt->len, GFP_KERNEL);
+	if (!buf) {
 		kfree(pkt);
 		return NULL;
 	}

-	nbytes = copy_from_iter(pkt->buf, pkt->len, &iov_iter);
+	pkt->vec[0].iov_base = buf;
+	pkt->vec[0].iov_len = pkt->len;
+	pkt->nr_vecs = 1;
+
+	nbytes = copy_from_iter(buf, pkt->len, &iov_iter);
 	if (nbytes != pkt->len) {
 		vq_err(vq, "Expected %u byte payload, got %zu bytes\n",
 		       pkt->len, nbytes);
diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h
index da9e1fe..734eeed 100644
--- a/include/linux/virtio_vsock.h
+++ b/include/linux/virtio_vsock.h
@@ -13,6 +13,8 @@
 #define VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE	(1024 * 4)
 #define VIRTIO_VSOCK_MAX_BUF_SIZE		0xFFFFFFFFUL
 #define VIRTIO_VSOCK_MAX_PKT_BUF_SIZE		(1024 * 64)
+/* virtio_vsock_pkt + max_pkt_len(default MAX_PKT_BUF_SIZE) */
+#define VIRTIO_VSOCK_MAX_VEC_NUM ((VIRTIO_VSOCK_MAX_PKT_BUF_SIZE / PAGE_SIZE) + 1)

 /* Virtio-vsock feature */
 #define VIRTIO_VSOCK_F_MRG_RXBUF 0 /* Host can merge receive buffers. */
@@ -55,10 +57,12 @@ struct virtio_vsock_pkt {
 	struct list_head list;
 	/* socket refcnt not held, only use for cancellation */
 	struct vsock_sock *vsk;
-	void *buf;
+	struct kvec vec[VIRTIO_VSOCK_MAX_VEC_NUM];
+	int nr_vecs;
 	u32 len;
 	u32 off;
 	bool reply;
+	bool mergeable;
 };

 struct virtio_vsock_pkt_info {
diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c
index c4a465c..148b58a 100644
--- a/net/vmw_vsock/virtio_transport.c
+++ b/net/vmw_vsock/virtio_transport.c
@@ -155,8 +155,10 @@ static int virtio_transport_send_pkt_loopback(struct virtio_vsock *vsock,

 		sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr));
 		sgs[out_sg++] = &hdr;
-		if (pkt->buf) {
-			sg_init_one(&buf, pkt->buf, pkt->len);
+		if (pkt->len) {
+			/* Currently only support a segment of memory in tx */
+			BUG_ON(pkt->vec[0].iov_len != pkt->len);
+			sg_init_one(&buf, pkt->vec[0].iov_base, pkt->vec[0].iov_len);
 			sgs[out_sg++] = &buf;
 		}

@@ -304,23 +306,28 @@ static int fill_old_rx_buff(struct virtqueue *vq)
 	struct virtio_vsock_pkt *pkt;
 	struct scatterlist hdr, buf, *sgs[2];
 	int ret;
+	void *pkt_buf;

 	pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
 	if (!pkt)
 		return -ENOMEM;

-	pkt->buf = kmalloc(buf_len, GFP_KERNEL);
-	if (!pkt->buf) {
+	pkt_buf = kmalloc(buf_len, GFP_KERNEL);
+	if (!pkt_buf) {
 		virtio_transport_free_pkt(pkt);
 		return -ENOMEM;
 	}

+	pkt->vec[0].iov_base = pkt_buf;
+	pkt->vec[0].iov_len = buf_len;
+	pkt->nr_vecs = 1;
+
 	pkt->len = buf_len;

 	sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr));
 	sgs[0] = &hdr;

-	sg_init_one(&buf, pkt->buf, buf_len);
+	sg_init_one(&buf, pkt->vec[0].iov_base, buf_len);
 	sgs[1] = &buf;
 	ret = virtqueue_add_sgs(vq, sgs, 0, 2, pkt, GFP_KERNEL);
 	if (ret)
@@ -388,11 +395,78 @@ static bool virtio_transport_more_replies(struct virtio_vsock *vsock)
 	return val < virtqueue_get_vring_size(vq);
 }

+static struct virtio_vsock_pkt *receive_mergeable(struct virtqueue *vq,
+		struct virtio_vsock *vsock, unsigned int *total_len)
+{
+	struct virtio_vsock_pkt *pkt;
+	u16 num_buf;
+	void *buf;
+	unsigned int len;
+	size_t vsock_hlen = sizeof(struct virtio_vsock_pkt);
+
+	buf = virtqueue_get_buf(vq, &len);
+	if (!buf)
+		return NULL;
+
+	*total_len = len;
+	vsock->rx_buf_nr--;
+
+	if (unlikely(len < vsock_hlen)) {
+		put_page(virt_to_head_page(buf));
+		return NULL;
+	}
+
+	pkt = buf;
+	num_buf = le16_to_cpu(pkt->mrg_rxbuf_hdr.num_buffers);
+	if (!num_buf || num_buf > VIRTIO_VSOCK_MAX_VEC_NUM) {
+		put_page(virt_to_head_page(buf));
+		return NULL;
+	}
+
+	/* Initialize pkt residual structure */
+	memset(&pkt->work, 0, vsock_hlen - sizeof(struct virtio_vsock_hdr) -
+			sizeof(struct virtio_vsock_mrg_rxbuf_hdr));
+
+	pkt->mergeable = true;
+	pkt->len = le32_to_cpu(pkt->hdr.len);
+	if (!pkt->len)
+		return pkt;
+
+	len -= vsock_hlen;
+	if (len) {
+		pkt->vec[pkt->nr_vecs].iov_base = buf + vsock_hlen;
+		pkt->vec[pkt->nr_vecs].iov_len = len;
+		/* Shared page with pkt, so get page in advance */
+		get_page(virt_to_head_page(buf));
+		pkt->nr_vecs++;
+	}
+
+	while (--num_buf) {
+		buf = virtqueue_get_buf(vq, &len);
+		if (!buf)
+			goto err;
+
+		*total_len += len;
+		vsock->rx_buf_nr--;
+
+		pkt->vec[pkt->nr_vecs].iov_base = buf;
+		pkt->vec[pkt->nr_vecs].iov_len = len;
+		pkt->nr_vecs++;
+	}
+
+	return pkt;
+err:
+	virtio_transport_free_pkt(pkt);
+	return NULL;
+}
+
 static void virtio_transport_rx_work(struct work_struct *work)
 {
 	struct virtio_vsock *vsock =
 		container_of(work, struct virtio_vsock, rx_work);
 	struct virtqueue *vq;
+	size_t vsock_hlen = vsock->mergeable ? sizeof(struct virtio_vsock_pkt) :
+			sizeof(struct virtio_vsock_hdr);

 	vq = vsock->vqs[VSOCK_VQ_RX];

@@ -412,21 +486,26 @@ static void virtio_transport_rx_work(struct work_struct *work)
 				goto out;
 			}

-			pkt = virtqueue_get_buf(vq, &len);
-			if (!pkt) {
-				break;
-			}
+			if (likely(vsock->mergeable)) {
+				pkt = receive_mergeable(vq, vsock, &len);
+				if (!pkt)
+					break;
+			} else {
+				pkt = virtqueue_get_buf(vq, &len);
+				if (!pkt)
+					break;

-			vsock->rx_buf_nr--;
+				vsock->rx_buf_nr--;
+			}

 			/* Drop short/long packets */
-			if (unlikely(len < sizeof(pkt->hdr) ||
-				     len > sizeof(pkt->hdr) + pkt->len)) {
+			if (unlikely(len < vsock_hlen ||
+				     len > vsock_hlen + pkt->len)) {
 				virtio_transport_free_pkt(pkt);
 				continue;
 			}

-			pkt->len = len - sizeof(pkt->hdr);
+			pkt->len = len - vsock_hlen;
 			virtio_transport_deliver_tap_pkt(pkt);
 			virtio_transport_recv_pkt(pkt);
 		}
diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
index 3ae3a33..123a8b6 100644
--- a/net/vmw_vsock/virtio_transport_common.c
+++ b/net/vmw_vsock/virtio_transport_common.c
@@ -44,6 +44,7 @@ static const struct virtio_transport *virtio_transport_get_ops(void)
 {
 	struct virtio_vsock_pkt *pkt;
 	int err;
+	void *buf = NULL;

 	pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
 	if (!pkt)
@@ -62,12 +63,16 @@ static const struct virtio_transport *virtio_transport_get_ops(void)
 	pkt->vsk		= info->vsk;

 	if (info->msg && len > 0) {
-		pkt->buf = kmalloc(len, GFP_KERNEL);
-		if (!pkt->buf)
+		buf = kmalloc(len, GFP_KERNEL);
+		if (!buf)
 			goto out_pkt;
-		err = memcpy_from_msg(pkt->buf, info->msg, len);
+		err = memcpy_from_msg(buf, info->msg, len);
 		if (err)
 			goto out;
+
+		pkt->vec[0].iov_base = buf;
+		pkt->vec[0].iov_len = len;
+		pkt->nr_vecs = 1;
 	}

 	trace_virtio_transport_alloc_pkt(src_cid, src_port,
@@ -80,7 +85,7 @@ static const struct virtio_transport *virtio_transport_get_ops(void)
 	return pkt;

 out:
-	kfree(pkt->buf);
+	kfree(buf);
 out_pkt:
 	kfree(pkt);
 	return NULL;
@@ -92,6 +97,7 @@ static struct sk_buff *virtio_transport_build_skb(void *opaque)
 	struct virtio_vsock_pkt *pkt = opaque;
 	struct af_vsockmon_hdr *hdr;
 	struct sk_buff *skb;
+	int i;

 	skb = alloc_skb(sizeof(*hdr) + sizeof(pkt->hdr) + pkt->len,
 			GFP_ATOMIC);
@@ -134,7 +140,8 @@ static struct sk_buff *virtio_transport_build_skb(void *opaque)
 	skb_put_data(skb, &pkt->hdr, sizeof(pkt->hdr));

 	if (pkt->len) {
-		skb_put_data(skb, pkt->buf, pkt->len);
+		for (i = 0; i < pkt->nr_vecs; i++)
+			skb_put_data(skb, pkt->vec[i].iov_base, pkt->vec[i].iov_len);
 	}

 	return skb;
@@ -260,6 +267,9 @@ static int virtio_transport_send_credit_update(struct vsock_sock *vsk,

 	spin_lock_bh(&vvs->rx_lock);
 	while (total < len && !list_empty(&vvs->rx_queue)) {
+		size_t copy_bytes, last_vec_total = 0, vec_off;
+		int i;
+
 		pkt = list_first_entry(&vvs->rx_queue,
 				       struct virtio_vsock_pkt, list);

@@ -272,14 +282,28 @@ static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
 		 */
 		spin_unlock_bh(&vvs->rx_lock);

-		err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
-		if (err)
-			goto out;
+		for (i = 0; i < pkt->nr_vecs; i++) {
+			if (pkt->off > last_vec_total + pkt->vec[i].iov_len) {
+				last_vec_total += pkt->vec[i].iov_len;
+				continue;
+			}
+
+			vec_off = pkt->off - last_vec_total;
+			copy_bytes = min(pkt->vec[i].iov_len - vec_off, bytes);
+			err = memcpy_to_msg(msg, pkt->vec[i].iov_base + vec_off,
+					copy_bytes);
+			if (err)
+				goto out;
+
+			bytes -= copy_bytes;
+			pkt->off += copy_bytes;
+			total += copy_bytes;
+			last_vec_total += pkt->vec[i].iov_len;
+			if (!bytes)
+				break;
+		}

 		spin_lock_bh(&vvs->rx_lock);
-
-		total += bytes;
-		pkt->off += bytes;
 		if (pkt->off == pkt->len) {
 			virtio_transport_dec_rx_pkt(vvs, pkt);
 			list_del(&pkt->list);
@@ -1050,8 +1074,17 @@ void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)

 void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
 {
-	kfree(pkt->buf);
-	kfree(pkt);
+	int i;
+
+	if (pkt->mergeable) {
+		for (i = 0; i < pkt->nr_vecs; i++)
+			put_page(virt_to_head_page(pkt->vec[i].iov_base));
+		put_page(virt_to_head_page((void *)pkt));
+	} else {
+		for (i = 0; i < pkt->nr_vecs; i++)
+			kfree(pkt->vec[i].iov_base);
+		kfree(pkt);
+	}
 }
 EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);

-- 
1.8.3.1





[Index of Archives]     [KVM ARM]     [KVM ia64]     [KVM ppc]     [Virtualization Tools]     [Spice Development]     [Libvirt]     [Libvirt Users]     [Linux USB Devel]     [Linux Audio Users]     [Yosemite Questions]     [Linux Kernel]     [Linux SCSI]     [XFree86]

  Powered by Linux