[PATCH 3/3] vhost: device IOTLB API

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



This patch tries to implement an device IOTLB for vhost. This could be
used with for co-operation with userspace(qemu) implementation of DMA
remapping.

The idea is simple, cache the translation in a software device IOTLB
(which was implemented as interval tree) in vhost and use vhost_net
file descriptor for reporting IOTLB miss and IOTLB
update/invalidation. When vhost meets an IOTLB miss, the fault
address, size and access could be read from the file. After userspace
finishes the translation, it write the translated address to the
vhost_net file to update the device IOTLB.

When device IOTLB (VHOST_F_DEVICE_IOTLB) is enabled all vq address
set by ioctl were treated as iova instead of virtual address and the
accessing could only be done through IOTLB instead of direct
userspace memory access. Before each rounds or vq processing, all vq
metadata were prefetched in device IOTLB to make sure no translation
fault happens during vq processing.

In most cases, virtqueue were mapped contiguous even in virtual
address. So the IOTLB translation for virtqueue itself maybe a little
bit slower. We can add fast path on top of this patch.

Signed-off-by: Jason Wang <jasowang@xxxxxxxxxx>
---
 drivers/vhost/net.c        |  54 +++-
 drivers/vhost/vhost.c      | 627 ++++++++++++++++++++++++++++++++++++++++++---
 drivers/vhost/vhost.h      |  35 ++-
 include/uapi/linux/vhost.h |  28 ++
 4 files changed, 696 insertions(+), 48 deletions(-)

diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index 7641543..7ceea39 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -61,7 +61,8 @@ MODULE_PARM_DESC(experimental_zcopytx, "Enable Zero Copy TX;"
 enum {
 	VHOST_NET_FEATURES = VHOST_FEATURES |
 			 (1ULL << VHOST_NET_F_VIRTIO_NET_HDR) |
-			 (1ULL << VIRTIO_NET_F_MRG_RXBUF)
+			 (1ULL << VIRTIO_NET_F_MRG_RXBUF) |
+			 (1ULL << VHOST_F_DEVICE_IOTLB)
 };
 
 enum {
@@ -334,7 +335,8 @@ static int vhost_net_tx_get_vq_desc(struct vhost_net *net,
 {
 	unsigned long uninitialized_var(endtime);
 	int r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
-				    out_num, in_num, NULL, NULL);
+				  out_num, in_num, NULL, NULL,
+				  VHOST_ACCESS_RO);
 
 	if (r == vq->num && vq->busyloop_timeout) {
 		preempt_disable();
@@ -344,7 +346,8 @@ static int vhost_net_tx_get_vq_desc(struct vhost_net *net,
 			cpu_relax_lowlatency();
 		preempt_enable();
 		r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
-					out_num, in_num, NULL, NULL);
+				      out_num, in_num, NULL, NULL,
+				      VHOST_ACCESS_RO);
 	}
 
 	return r;
@@ -377,6 +380,9 @@ static void handle_tx(struct vhost_net *net)
 	if (!sock)
 		goto out;
 
+	if (!vq_iotlb_prefetch(vq))
+		goto out;
+
 	vhost_disable_notify(&net->dev, vq);
 
 	hdr_size = nvq->vhost_hlen;
@@ -564,7 +570,7 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,
 		}
 		r = vhost_get_vq_desc(vq, vq->iov + seg,
 				      ARRAY_SIZE(vq->iov) - seg, &out,
-				      &in, log, log_num);
+				      &in, log, log_num, VHOST_ACCESS_WO);
 		if (unlikely(r < 0))
 			goto err;
 
@@ -638,6 +644,10 @@ static void handle_rx(struct vhost_net *net)
 	sock = vq->private_data;
 	if (!sock)
 		goto out;
+
+	if (!vq_iotlb_prefetch(vq))
+		goto out;
+
 	vhost_disable_notify(&net->dev, vq);
 	vhost_net_disable_vq(net, vq);
 
@@ -1087,6 +1097,11 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features)
 		mutex_unlock(&n->dev.mutex);
 		return -EFAULT;
 	}
+	if ((features & (1ULL << VHOST_F_DEVICE_IOTLB))) {
+		if (vhost_init_device_iotlb(&n->dev, true))
+			return -EFAULT;
+	}
+
 	for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
 		mutex_lock(&n->vqs[i].vq.mutex);
 		n->vqs[i].vq.acked_features = features;
@@ -1169,9 +1184,40 @@ static long vhost_net_compat_ioctl(struct file *f, unsigned int ioctl,
 }
 #endif
 
+static ssize_t vhost_net_chr_read_iter(struct kiocb *iocb, struct iov_iter *to)
+{
+	struct file *file = iocb->ki_filp;
+	struct vhost_net *n = file->private_data;
+	struct vhost_dev *dev = &n->dev;
+	int noblock = file->f_flags & O_NONBLOCK;
+
+	return vhost_chr_read_iter(dev, to, noblock);
+}
+
+static ssize_t vhost_net_chr_write_iter(struct kiocb *iocb,
+					struct iov_iter *from)
+{
+	struct file *file = iocb->ki_filp;
+	struct vhost_net *n = file->private_data;
+	struct vhost_dev *dev = &n->dev;
+
+	return vhost_chr_write_iter(dev, from);
+}
+
+static unsigned int vhost_net_chr_poll(struct file *file, poll_table *wait)
+{
+	struct vhost_net *n = file->private_data;
+	struct vhost_dev *dev = &n->dev;
+
+	return vhost_chr_poll(file, dev, wait);
+}
+
 static const struct file_operations vhost_net_fops = {
 	.owner          = THIS_MODULE,
 	.release        = vhost_net_release,
+	.read_iter      = vhost_net_chr_read_iter,
+	.write_iter     = vhost_net_chr_write_iter,
+	.poll           = vhost_net_chr_poll,
 	.unlocked_ioctl = vhost_net_ioctl,
 #ifdef CONFIG_COMPAT
 	.compat_ioctl   = vhost_net_compat_ioctl,
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 166e779..46569fb 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -35,6 +35,10 @@ static ushort max_mem_regions = 64;
 module_param(max_mem_regions, ushort, 0444);
 MODULE_PARM_DESC(max_mem_regions,
 	"Maximum number of memory regions in memory map. (default: 64)");
+static int max_iotlb_entries = 2048;
+module_param(max_iotlb_entries, int, 0444);
+MODULE_PARM_DESC(max_iotlb_entries,
+	"Maximum number of iotlb entries. (default: 2048)");
 
 enum {
 	VHOST_MEMORY_F_LOG = 0x1,
@@ -309,6 +313,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,
 	vhost_disable_cross_endian(vq);
 	vq->busyloop_timeout = 0;
 	vq->umem = NULL;
+	vq->iotlb = NULL;
 }
 
 static int vhost_worker(void *data)
@@ -413,9 +418,14 @@ void vhost_dev_init(struct vhost_dev *dev,
 	dev->log_ctx = NULL;
 	dev->log_file = NULL;
 	dev->umem = NULL;
+	dev->iotlb = NULL;
 	dev->mm = NULL;
 	spin_lock_init(&dev->work_lock);
 	INIT_LIST_HEAD(&dev->work_list);
+	init_waitqueue_head(&dev->wait);
+	INIT_LIST_HEAD(&dev->read_list);
+	INIT_LIST_HEAD(&dev->pending_list);
+	spin_lock_init(&dev->iotlb_lock);
 	dev->worker = NULL;
 
 	for (i = 0; i < dev->nvqs; ++i) {
@@ -563,6 +573,15 @@ void vhost_dev_stop(struct vhost_dev *dev)
 }
 EXPORT_SYMBOL_GPL(vhost_dev_stop);
 
+static void vhost_umem_free(struct vhost_umem *umem,
+			    struct vhost_umem_node *node)
+{
+	vhost_umem_interval_tree_remove(node, &umem->umem_tree);
+	list_del(&node->link);
+	kfree(node);
+	umem->numem--;
+}
+
 static void vhost_umem_clean(struct vhost_umem *umem)
 {
 	struct vhost_umem_node *node, *tmp;
@@ -570,14 +589,31 @@ static void vhost_umem_clean(struct vhost_umem *umem)
 	if (!umem)
 		return;
 
-	list_for_each_entry_safe(node, tmp, &umem->umem_list, link) {
-		vhost_umem_interval_tree_remove(node, &umem->umem_tree);
-		list_del(&node->link);
-		kvfree(node);
-	}
+	list_for_each_entry_safe(node, tmp, &umem->umem_list, link)
+		vhost_umem_free(umem, node);
+
 	kvfree(umem);
 }
 
+static void vhost_clear_msg(struct vhost_dev *dev)
+{
+	struct vhost_msg_node *node, *n;
+
+	spin_lock(&dev->iotlb_lock);
+
+	list_for_each_entry_safe(node, n, &dev->read_list, node) {
+		list_del(&node->node);
+		kfree(node);
+	}
+
+	list_for_each_entry_safe(node, n, &dev->pending_list, node) {
+		list_del(&node->node);
+		kfree(node);
+	}
+
+	spin_unlock(&dev->iotlb_lock);
+}
+
 /* Caller should have device mutex if and only if locked is set */
 void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
 {
@@ -606,6 +642,10 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
 	/* No one will access memory at this point */
 	vhost_umem_clean(dev->umem);
 	dev->umem = NULL;
+	vhost_umem_clean(dev->iotlb);
+	dev->iotlb = NULL;
+	vhost_clear_msg(dev);
+	wake_up_interruptible_poll(&dev->wait, POLLIN | POLLRDNORM);
 	WARN_ON(!list_empty(&dev->work_list));
 	if (dev->worker) {
 		kthread_stop(dev->worker);
@@ -681,28 +721,382 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_umem *umem,
 	return 1;
 }
 
-#define vhost_put_user(vq, x, ptr)  __put_user(x, ptr)
+static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
+			  struct iovec iov[], int iov_size, int access);
 
 static int vhost_copy_to_user(struct vhost_virtqueue *vq, void *to,
 			      const void *from, unsigned size)
 {
-	return __copy_to_user(to, from, size);
-}
+	int ret;
 
-#define vhost_get_user(vq, x, ptr) __get_user(x, ptr)
+	if (!vq->iotlb)
+		return __copy_to_user(to, from, size);
+	else {
+		/* This function should be called after iotlb
+		 * prefetch, which means we're sure that all vq
+		 * could be access through iotlb. So -EAGAIN should
+		 * not happen in this case.
+		 */
+		/* TODO: more fast path */
+		struct iov_iter t;
+		ret = translate_desc(vq, (u64)to, size, vq->iotlb_iov,
+				     ARRAY_SIZE(vq->iotlb_iov),
+				     VHOST_ACCESS_WO);
+		if (ret < 0)
+			goto out;
+		iov_iter_init(&t, WRITE, vq->iotlb_iov, ret, size);
+		ret = copy_to_iter(from, size, &t);
+		if (ret == size)
+			ret = 0;
+	}
+out:
+	return ret;
+}
 
 static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to,
 				void *from, unsigned size)
 {
-	return __copy_from_user(to, from, size);
+	int ret;
+
+	if (!vq->iotlb)
+		return __copy_from_user(to, from, size);
+	else {
+		/* This function should be called after iotlb
+		 * prefetch, which means we're sure that vq
+		 * could be access through iotlb. So -EAGAIN should
+		 * not happen in this case.
+		 */
+		/* TODO: more fast path */
+		struct iov_iter f;
+		ret = translate_desc(vq, (u64)from, size, vq->iotlb_iov,
+				     ARRAY_SIZE(vq->iotlb_iov),
+				     VHOST_ACCESS_RO);
+		if (ret < 0) {
+			vq_err(vq, "IOTLB translation failure: uaddr "
+			       "0x%llx size 0x%llx\n",
+			       (unsigned long long) from,
+			       (unsigned long long) size);
+			goto out;
+		}
+		iov_iter_init(&f, READ, vq->iotlb_iov, ret, size);
+		ret = copy_from_iter(to, size, &f);
+		if (ret == size)
+			ret = 0;
+	}
+
+out:
+	return ret;
+}
+
+static void __user *__vhost_get_user(struct vhost_virtqueue *vq,
+				     void *addr, unsigned size)
+{
+	int ret;
+
+	/* This function should be called after iotlb
+	 * prefetch, which means we're sure that vq
+	 * could be access through iotlb. So -EAGAIN should
+	 * not happen in this case.
+	 */
+	/* TODO: more fast path */
+	ret = translate_desc(vq, (u64)addr, size, vq->iotlb_iov,
+			     ARRAY_SIZE(vq->iotlb_iov),
+			     VHOST_ACCESS_RO);
+	if (ret < 0) {
+		vq_err(vq, "IOTLB translation failure: uaddr "
+			"0x%llx size 0x%llx\n",
+			(unsigned long long) addr,
+			(unsigned long long) size);
+		return NULL;
+	}
+
+	if (ret != 1 || vq->iotlb_iov[0].iov_len != size) {
+		vq_err(vq, "Non atomic userspace memory access: uaddr "
+			"0x%llx size 0x%llx\n",
+			(unsigned long long) addr,
+			(unsigned long long) size);
+		return NULL;
+	}
+
+	return vq->iotlb_iov[0].iov_base;
+}
+
+#define vhost_put_user(vq, x, ptr) \
+({ \
+	int ret = -EFAULT; \
+	if (!vq->iotlb) { \
+		ret = __put_user(x, ptr); \
+	} else { \
+		__typeof__(ptr) to = \
+			(__typeof__(ptr)) __vhost_get_user(vq, ptr, sizeof(*ptr)); \
+		if (to != NULL) \
+			ret = __put_user(x, to); \
+		else \
+			ret = -EFAULT;	\
+	} \
+	ret; \
+})
+
+#define vhost_get_user(vq, x, ptr) \
+({ \
+	int ret; \
+	if (!vq->iotlb) { \
+		ret = __get_user(x, ptr); \
+	} else { \
+		__typeof__(ptr) from = \
+			(__typeof__(ptr)) __vhost_get_user(vq, ptr, sizeof(*ptr)); \
+		if (from != NULL) \
+			ret = __get_user(x, from); \
+		else \
+			ret = -EFAULT; \
+	} \
+	ret; \
+})
+
+static void vhost_dev_lock_vqs(struct vhost_dev *d)
+{
+	int i = 0;
+	for (i = 0; i < d->nvqs; ++i)
+		mutex_lock(&d->vqs[i]->mutex);
+}
+
+static void vhost_dev_unlock_vqs(struct vhost_dev *d)
+{
+	int i = 0;
+	for (i = 0; i < d->nvqs; ++i)
+		mutex_unlock(&d->vqs[i]->mutex);
+}
+
+static int vhost_new_umem_range(struct vhost_umem *umem,
+				u64 start, u64 size, u64 end,
+				u64 userspace_addr, int perm)
+{
+	struct vhost_umem_node *tmp, *node = kmalloc(sizeof(*node), GFP_ATOMIC);
+
+	if (!node)
+		return -ENOMEM;
+
+	if (umem->numem == max_iotlb_entries) {
+		tmp = list_first_entry(&umem->umem_list, typeof(*tmp), link);
+		vhost_umem_free(umem, tmp);
+	}
+
+	node->start = start;
+	node->size = size;
+	node->last = end;
+	node->userspace_addr = userspace_addr;
+	node->perm = perm;
+	INIT_LIST_HEAD(&node->link);
+	list_add_tail(&node->link, &umem->umem_list);
+	vhost_umem_interval_tree_insert(node, &umem->umem_tree);
+	umem->numem++;
+
+	return 0;
+}
+
+static void vhost_del_umem_range(struct vhost_umem *umem,
+				 u64 start, u64 end)
+{
+	struct vhost_umem_node *node;
+
+	while ((node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
+							   start, end)))
+		vhost_umem_free(umem, node);
+}
+
+static void vhost_iotlb_notify_vq(struct vhost_dev *d,
+				  struct vhost_iotlb_msg *msg)
+{
+	struct vhost_msg_node *node, *n;
+
+	spin_lock(&d->iotlb_lock);
+
+	list_for_each_entry_safe(node, n, &d->pending_list, node) {
+		struct vhost_iotlb_msg *vq_msg = &node->msg.iotlb;
+		if (msg->iova <= vq_msg->iova &&
+		    msg->iova + msg->size - 1 > vq_msg->iova &&
+		    vq_msg->type == VHOST_IOTLB_MISS) {
+			vhost_poll_queue(&node->vq->poll);
+			list_del(&node->node);
+			kfree(node);
+		}
+	}
+
+	spin_unlock(&d->iotlb_lock);
+}
+
+static int umem_access_ok(u64 uaddr, u64 size, int access)
+{
+	if ((access & VHOST_ACCESS_RO) &&
+	    !access_ok(VERIFY_READ, uaddr, size))
+		return -EFAULT;
+	if ((access & VHOST_ACCESS_WO) &&
+	    !access_ok(VERIFY_WRITE, uaddr, size))
+		return -EFAULT;
+	return 0;
+}
+
+int vhost_process_iotlb_msg(struct vhost_dev *dev,
+			    struct vhost_iotlb_msg *msg)
+{
+	int ret = 0;
+
+	vhost_dev_lock_vqs(dev);
+	switch (msg->type) {
+	case VHOST_IOTLB_UPDATE:
+		if (!dev->iotlb) {
+			ret = -EFAULT;
+			break;
+		}
+		if (umem_access_ok(msg->uaddr, msg->size, msg->perm)) {
+			ret = -EFAULT;
+			break;
+		}
+		if (vhost_new_umem_range(dev->iotlb, msg->iova, msg->size,
+					 msg->iova + msg->size - 1,
+					 msg->uaddr, msg->perm)) {
+			ret = -ENOMEM;
+			break;
+		}
+		vhost_iotlb_notify_vq(dev, msg);
+		break;
+	case VHOST_IOTLB_INVALIDATE:
+		vhost_del_umem_range(dev->iotlb, msg->iova,
+				     msg->iova + msg->size - 1);
+		break;
+	default:
+		ret = -EINVAL;
+		break;
+	}
+
+	vhost_dev_unlock_vqs(dev);
+	return ret;
+}
+ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
+			     struct iov_iter *from)
+{
+	struct vhost_msg_node node;
+	unsigned size = sizeof(struct vhost_msg);
+	size_t ret;
+	int err;
+
+	if (iov_iter_count(from) < size)
+		return 0;
+	ret = copy_from_iter(&node.msg, size, from);
+	if (ret != size)
+		goto done;
+
+	switch (node.msg.type) {
+	case VHOST_IOTLB_MSG:
+		err = vhost_process_iotlb_msg(dev, &node.msg.iotlb);
+		if (err)
+			ret = err;
+		break;
+	default:
+		ret = -EINVAL;
+		break;
+	}
+
+done:
+	return ret;
+}
+EXPORT_SYMBOL(vhost_chr_write_iter);
+
+unsigned int vhost_chr_poll(struct file *file, struct vhost_dev *dev,
+			    poll_table *wait)
+{
+	unsigned int mask = 0;
+
+	poll_wait(file, &dev->wait, wait);
+
+	if (!list_empty(&dev->read_list))
+		mask |= POLLIN | POLLRDNORM;
+
+	return mask;
+}
+EXPORT_SYMBOL(vhost_chr_poll);
+
+ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to,
+			    int noblock)
+{
+	DEFINE_WAIT(wait);
+	struct vhost_msg_node *node;
+	ssize_t ret = 0;
+	unsigned size = sizeof(struct vhost_msg);
+
+	if (iov_iter_count(to) < size)
+		return 0;
+
+	while (1) {
+		if (!noblock)
+			prepare_to_wait(&dev->wait, &wait,
+					TASK_INTERRUPTIBLE);
+
+		node = vhost_dequeue_msg(dev, &dev->read_list);
+		if (node)
+			break;
+		if (noblock) {
+			ret = -EAGAIN;
+			break;
+		}
+		if (signal_pending(current)) {
+			ret = -ERESTARTSYS;
+			break;
+		}
+		if (!dev->iotlb) {
+			ret = -EBADFD;
+			break;
+		}
+
+		schedule();
+	}
+
+	if (!noblock)
+		finish_wait(&dev->wait, &wait);
+
+	if (node) {
+		ret = copy_to_iter(&node->msg, size, to);
+
+		if (ret != size || node->msg.type != VHOST_IOTLB_MISS) {
+			kfree(node);
+			return ret;
+		}
+
+		vhost_enqueue_msg(dev, &dev->pending_list, node);
+	}
+
+	return ret;
+}
+EXPORT_SYMBOL_GPL(vhost_chr_read_iter);
+
+static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access)
+{
+	struct vhost_dev *dev = vq->dev;
+	struct vhost_msg_node *node;
+	struct vhost_iotlb_msg *msg;
+
+	node = vhost_new_msg(vq, VHOST_IOTLB_MISS);
+	if (!node)
+		return -ENOMEM;
+
+	msg = &node->msg.iotlb;
+	msg->type = VHOST_IOTLB_MISS;
+	msg->iova = iova;
+	msg->perm = access;
+
+	vhost_enqueue_msg(dev, &dev->read_list, node);
+
+	return 0;
 }
 
 static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
 			struct vring_desc __user *desc,
 			struct vring_avail __user *avail,
 			struct vring_used __user *used)
+
 {
 	size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
+
 	return access_ok(VERIFY_READ, desc, num * sizeof *desc) &&
 	       access_ok(VERIFY_READ, avail,
 			 sizeof *avail + num * sizeof *avail->ring + s) &&
@@ -710,6 +1104,54 @@ static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
 			sizeof *used + num * sizeof *used->ring + s);
 }
 
+static int iotlb_access_ok(struct vhost_virtqueue *vq,
+			   int access, u64 addr, u64 len)
+{
+	const struct vhost_umem_node *node;
+	struct vhost_umem *umem = vq->iotlb;
+	u64 s = 0, size;
+
+	while (len > s) {
+		node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
+							   addr,
+							   addr + len - 1);
+		if (node == NULL || node->start > addr) {
+			vhost_iotlb_miss(vq, addr, access);
+			return false;
+		} else if (!(node->perm & access)) {
+			/* Report the possible access violation by
+			 * request another translation from userspace.
+			 */
+			return false;
+		}
+
+		size = node->size - addr + node->start;
+		s += size;
+		addr += size;
+	}
+
+	return true;
+}
+
+int vq_iotlb_prefetch(struct vhost_virtqueue *vq)
+{
+	size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
+	unsigned int num = vq->num;
+
+	if (!vq->iotlb)
+		return 1;
+
+	return iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)vq->desc,
+			       num * sizeof *vq->desc) &&
+	       iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)vq->avail,
+			       sizeof *vq->avail +
+			       num * sizeof *vq->avail->ring + s) &&
+	       iotlb_access_ok(vq, VHOST_ACCESS_WO, (u64)vq->used,
+			       sizeof *vq->used +
+			       num * sizeof *vq->used->ring + s);
+}
+EXPORT_SYMBOL_GPL(vq_iotlb_prefetch);
+
 /* Can we log writes? */
 /* Caller should have device mutex but not vq mutex */
 int vhost_log_access_ok(struct vhost_dev *dev)
@@ -736,16 +1178,35 @@ static int vq_log_access_ok(struct vhost_virtqueue *vq,
 /* Caller should have vq mutex and device mutex */
 int vhost_vq_access_ok(struct vhost_virtqueue *vq)
 {
+	if (vq->iotlb) {
+		/* When device IOTLB was used, the access validation
+		 * will be validated during prefetching.
+		 */
+		return 1;
+	}
 	return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used) &&
 		vq_log_access_ok(vq, vq->log_base);
 }
 EXPORT_SYMBOL_GPL(vhost_vq_access_ok);
 
+static struct vhost_umem *vhost_umem_alloc(void)
+{
+	struct vhost_umem *umem = vhost_kvzalloc(sizeof(*umem));
+
+	if (!umem)
+		return NULL;
+
+	umem->umem_tree = RB_ROOT;
+	umem->numem = 0;
+	INIT_LIST_HEAD(&umem->umem_list);
+
+	return umem;
+}
+
 static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
 {
 	struct vhost_memory mem, *newmem;
 	struct vhost_memory_region *region;
-	struct vhost_umem_node *node;
 	struct vhost_umem *newumem, *oldumem;
 	unsigned long size = offsetof(struct vhost_memory, regions);
 	int i;
@@ -767,28 +1228,23 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
 		return -EFAULT;
 	}
 
-	newumem = vhost_kvzalloc(sizeof(*newumem));
+	newumem = vhost_umem_alloc();
 	if (!newumem) {
 		kvfree(newmem);
 		return -ENOMEM;
 	}
 
-	newumem->umem_tree = RB_ROOT;
-	INIT_LIST_HEAD(&newumem->umem_list);
-
 	for (region = newmem->regions;
 	     region < newmem->regions + mem.nregions;
 	     region++) {
-		node = vhost_kvzalloc(sizeof(*node));
-		if (!node)
+		if (vhost_new_umem_range(newumem,
+					 region->guest_phys_addr,
+					 region->memory_size,
+					 region->guest_phys_addr +
+					 region->memory_size - 1,
+					 region->userspace_addr,
+					 VHOST_ACCESS_RW))
 			goto err;
-		node->start = region->guest_phys_addr;
-		node->size = region->memory_size;
-		node->last = node->start + node->size - 1;
-		node->userspace_addr = region->userspace_addr;
-		INIT_LIST_HEAD(&node->link);
-		list_add_tail(&node->link, &newumem->umem_list);
-		vhost_umem_interval_tree_insert(node, &newumem->umem_tree);
 	}
 
 	if (!memory_access_ok(d, newumem, 0))
@@ -1032,6 +1488,30 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp)
 }
 EXPORT_SYMBOL_GPL(vhost_vring_ioctl);
 
+int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled)
+{
+	struct vhost_umem *niotlb, *oiotlb;
+	int i;
+
+	niotlb = vhost_umem_alloc();
+	if (!niotlb)
+		return -ENOMEM;
+
+	oiotlb = d->iotlb;
+	d->iotlb = niotlb;
+
+	for (i = 0; i < d->nvqs; ++i) {
+		mutex_lock(&d->vqs[i]->mutex);
+		d->vqs[i]->iotlb = niotlb;
+		mutex_unlock(&d->vqs[i]->mutex);
+	}
+
+	vhost_umem_clean(oiotlb);
+
+	return 0;
+}
+EXPORT_SYMBOL_GPL(vhost_init_device_iotlb);
+
 /* Caller must have device mutex */
 long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
 {
@@ -1246,15 +1726,20 @@ int vhost_vq_init_access(struct vhost_virtqueue *vq)
 	if (r)
 		goto err;
 	vq->signalled_used_valid = false;
-	if (!access_ok(VERIFY_READ, &vq->used->idx, sizeof vq->used->idx)) {
+	if (!vq->iotlb &&
+	    !access_ok(VERIFY_READ, &vq->used->idx, sizeof vq->used->idx)) {
 		r = -EFAULT;
 		goto err;
 	}
 	r = vhost_get_user(vq, last_used_idx, &vq->used->idx);
-	if (r)
+	if (r) {
+		vq_err(vq, "Can't access used idx at 0x%llx\n",
+			(unsigned long long) &vq->used->idx);
 		goto err;
+	}
 	vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx);
 	return 0;
+
 err:
 	vq->is_le = is_le;
 	return r;
@@ -1262,10 +1747,11 @@ err:
 EXPORT_SYMBOL_GPL(vhost_vq_init_access);
 
 static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
-			  struct iovec iov[], int iov_size)
+			  struct iovec iov[], int iov_size, int access)
 {
 	const struct vhost_umem_node *node;
-	struct vhost_umem *umem = vq->umem;
+	struct vhost_dev *dev = vq->dev;
+	struct vhost_umem *umem = dev->iotlb ? dev->iotlb : dev->umem;
 	struct iovec *_iov;
 	u64 s = 0;
 	int ret = 0;
@@ -1276,12 +1762,21 @@ static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
 			ret = -ENOBUFS;
 			break;
 		}
+
 		node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
 							addr, addr + len - 1);
 		if (node == NULL || node->start > addr) {
-			ret = -EFAULT;
+			if (umem != dev->iotlb) {
+				ret = -EFAULT;
+				break;
+			}
+			ret = -EAGAIN;
+			break;
+		} else if (!(node->perm & access)) {
+			ret = -EPERM;
 			break;
 		}
+
 		_iov = iov + ret;
 		size = node->size - addr + node->start;
 		_iov->iov_len = min((u64)len - s, size);
@@ -1292,6 +1787,8 @@ static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
 		++ret;
 	}
 
+	if (ret == -EAGAIN)
+		vhost_iotlb_miss(vq, addr, access);
 	return ret;
 }
 
@@ -1320,7 +1817,7 @@ static int get_indirect(struct vhost_virtqueue *vq,
 			struct iovec iov[], unsigned int iov_size,
 			unsigned int *out_num, unsigned int *in_num,
 			struct vhost_log *log, unsigned int *log_num,
-			struct vring_desc *indirect)
+			struct vring_desc *indirect, int access)
 {
 	struct vring_desc desc;
 	unsigned int i = 0, count, found = 0;
@@ -1338,9 +1835,10 @@ static int get_indirect(struct vhost_virtqueue *vq,
 	}
 
 	ret = translate_desc(vq, vhost64_to_cpu(vq, indirect->addr), len, vq->indirect,
-			     UIO_MAXIOV);
+			     UIO_MAXIOV, VHOST_ACCESS_RO);
 	if (unlikely(ret < 0)) {
-		vq_err(vq, "Translation failure %d in indirect.\n", ret);
+		if (ret != -EAGAIN)
+			vq_err(vq, "Translation failure %d in indirect.\n", ret);
 		return ret;
 	}
 	iov_iter_init(&from, READ, vq->indirect, ret, len);
@@ -1380,10 +1878,11 @@ static int get_indirect(struct vhost_virtqueue *vq,
 
 		ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
 				     vhost32_to_cpu(vq, desc.len), iov + iov_count,
-				     iov_size - iov_count);
+				     iov_size - iov_count, access);
 		if (unlikely(ret < 0)) {
-			vq_err(vq, "Translation failure %d indirect idx %d\n",
-			       ret, i);
+			if (ret != -EAGAIN)
+				vq_err(vq, "Translation failure %d indirect idx %d\n",
+					ret, i);
 			return ret;
 		}
 		/* If this is an input descriptor, increment that count. */
@@ -1419,7 +1918,8 @@ static int get_indirect(struct vhost_virtqueue *vq,
 int vhost_get_vq_desc(struct vhost_virtqueue *vq,
 		      struct iovec iov[], unsigned int iov_size,
 		      unsigned int *out_num, unsigned int *in_num,
-		      struct vhost_log *log, unsigned int *log_num)
+		      struct vhost_log *log, unsigned int *log_num,
+		      int access)
 {
 	struct vring_desc desc;
 	unsigned int i, head, found = 0;
@@ -1498,10 +1998,11 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
 		if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT)) {
 			ret = get_indirect(vq, iov, iov_size,
 					   out_num, in_num,
-					   log, log_num, &desc);
+					   log, log_num, &desc, access);
 			if (unlikely(ret < 0)) {
-				vq_err(vq, "Failure detected "
-				       "in indirect descriptor at idx %d\n", i);
+				if (ret != -EAGAIN)
+					vq_err(vq, "Failure detected "
+						"in indirect descriptor at idx %d\n", i);
 				return ret;
 			}
 			continue;
@@ -1509,10 +2010,11 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
 
 		ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
 				     vhost32_to_cpu(vq, desc.len), iov + iov_count,
-				     iov_size - iov_count);
+				     iov_size - iov_count, access);
 		if (unlikely(ret < 0)) {
-			vq_err(vq, "Translation failure %d descriptor idx %d\n",
-			       ret, i);
+			if (ret != -EAGAIN)
+				vq_err(vq, "Translation failure %d descriptor idx %d\n",
+					ret, i);
 			return ret;
 		}
 		if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) {
@@ -1781,6 +2283,47 @@ void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
 }
 EXPORT_SYMBOL_GPL(vhost_disable_notify);
 
+/* Create a new message. */
+struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type)
+{
+	struct vhost_msg_node *node = kmalloc(sizeof *node, GFP_KERNEL);
+	if (!node)
+		return NULL;
+	node->vq = vq;
+	node->msg.type = type;
+	return node;
+}
+EXPORT_SYMBOL_GPL(vhost_new_msg);
+
+void vhost_enqueue_msg(struct vhost_dev *dev, struct list_head *head,
+		       struct vhost_msg_node *node)
+{
+	spin_lock(&dev->iotlb_lock);
+	list_add_tail(&node->node, head);
+	spin_unlock(&dev->iotlb_lock);
+
+	wake_up_interruptible_poll(&dev->wait, POLLIN | POLLRDNORM);
+}
+EXPORT_SYMBOL_GPL(vhost_enqueue_msg);
+
+struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev,
+					 struct list_head *head)
+{
+	struct vhost_msg_node *node = NULL;
+
+	spin_lock(&dev->iotlb_lock);
+	if (!list_empty(head)) {
+		node = list_first_entry(head, struct vhost_msg_node,
+					node);
+		list_del(&node->node);
+	}
+	spin_unlock(&dev->iotlb_lock);
+
+	return node;
+}
+EXPORT_SYMBOL_GPL(vhost_dequeue_msg);
+
+
 static int __init vhost_init(void)
 {
 	return 0;
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index b93b6a0..85c1d78 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -63,13 +63,15 @@ struct vhost_umem_node {
 	__u64 last;
 	__u64 size;
 	__u64 userspace_addr;
-	__u64 flags_padding;
+	__u32 perm;
+	__u32 flags_padding;
 	__u64 __subtree_last;
 };
 
 struct vhost_umem {
 	struct rb_root umem_tree;
 	struct list_head umem_list;
+	int numem;
 };
 
 /* The virtqueue structure describes a queue attached to a device. */
@@ -117,10 +119,12 @@ struct vhost_virtqueue {
 	u64 log_addr;
 
 	struct iovec iov[UIO_MAXIOV];
+	struct iovec iotlb_iov[64];
 	struct iovec *indirect;
 	struct vring_used_elem *heads;
 	/* Protected by virtqueue mutex. */
 	struct vhost_umem *umem;
+	struct vhost_umem *iotlb;
 	void *private_data;
 	u64 acked_features;
 	/* Log write descriptors */
@@ -137,6 +141,12 @@ struct vhost_virtqueue {
 	u32 busyloop_timeout;
 };
 
+struct vhost_msg_node {
+  struct vhost_msg msg;
+  struct vhost_virtqueue *vq;
+  struct list_head node;
+};
+
 struct vhost_dev {
 	struct mm_struct *mm;
 	struct mutex mutex;
@@ -148,6 +158,11 @@ struct vhost_dev {
 	struct list_head work_list;
 	struct task_struct *worker;
 	struct vhost_umem *umem;
+	struct vhost_umem *iotlb;
+	spinlock_t iotlb_lock;
+	struct list_head read_list;
+	struct list_head pending_list;
+	wait_queue_head_t wait;
 };
 
 void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, int nvqs);
@@ -166,7 +181,8 @@ int vhost_log_access_ok(struct vhost_dev *);
 int vhost_get_vq_desc(struct vhost_virtqueue *,
 		      struct iovec iov[], unsigned int iov_count,
 		      unsigned int *out_num, unsigned int *in_num,
-		      struct vhost_log *log, unsigned int *log_num);
+		      struct vhost_log *log, unsigned int *log_num,
+		      int access);
 void vhost_discard_vq_desc(struct vhost_virtqueue *, int n);
 
 int vhost_vq_init_access(struct vhost_virtqueue *);
@@ -184,6 +200,21 @@ bool vhost_enable_notify(struct vhost_dev *, struct vhost_virtqueue *);
 
 int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
 		    unsigned int log_num, u64 len);
+int vq_iotlb_prefetch(struct vhost_virtqueue *vq);
+
+struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type);
+void vhost_enqueue_msg(struct vhost_dev *dev,
+		       struct list_head *head,
+		       struct vhost_msg_node *node);
+struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev,
+					 struct list_head *head);
+unsigned int vhost_chr_poll(struct file *file, struct vhost_dev *dev,
+			    poll_table *wait);
+ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to,
+			    int noblock);
+ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
+			     struct iov_iter *from);
+int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled);
 
 #define vq_err(vq, fmt, ...) do {                                  \
 		pr_debug(pr_fmt(fmt), ##__VA_ARGS__);       \
diff --git a/include/uapi/linux/vhost.h b/include/uapi/linux/vhost.h
index 61a8777..8cb0a65 100644
--- a/include/uapi/linux/vhost.h
+++ b/include/uapi/linux/vhost.h
@@ -47,6 +47,32 @@ struct vhost_vring_addr {
 	__u64 log_guest_addr;
 };
 
+/* no alignment requirement */
+struct vhost_iotlb_msg {
+	__u64 iova;
+	__u64 size;
+	__u64 uaddr;
+#define VHOST_ACCESS_RO      0x1
+#define VHOST_ACCESS_WO      0x2
+#define VHOST_ACCESS_RW      0x3
+	__u8 perm;
+#define VHOST_IOTLB_MISS           1
+#define VHOST_IOTLB_UPDATE         2
+#define VHOST_IOTLB_INVALIDATE     3
+#define VHOST_IOTLB_ACCESS_FAIL    4
+	__u8 type;
+};
+
+#define VHOST_IOTLB_MSG 0x1
+
+struct vhost_msg {
+	int type;
+	union {
+		struct vhost_iotlb_msg iotlb;
+		__u8 padding[64];
+	};
+};
+
 struct vhost_memory_region {
 	__u64 guest_phys_addr;
 	__u64 memory_size; /* bytes */
@@ -146,6 +172,8 @@ struct vhost_memory {
 #define VHOST_F_LOG_ALL 26
 /* vhost-net should add virtio_net_hdr for RX, and strip for TX packets. */
 #define VHOST_NET_F_VIRTIO_NET_HDR 27
+/* Vhost have device IOTLB */
+#define VHOST_F_DEVICE_IOTLB 63
 
 /* VHOST_SCSI specific definitions */
 
-- 
1.8.3.1

_______________________________________________
Virtualization mailing list
Virtualization@xxxxxxxxxxxxxxxxxxxxxxxxxx
https://lists.linuxfoundation.org/mailman/listinfo/virtualization



[Index of Archives]     [KVM Development]     [Libvirt Development]     [Libvirt Users]     [CentOS Virtualization]     [Netdev]     [Ethernet Bridging]     [Linux Wireless]     [Kernel Newbies]     [Security]     [Linux for Hams]     [Netfilter]     [Bugtraq]     [Yosemite Forum]     [MIPS Linux]     [ARM Linux]     [Linux RAID]     [Linux Admin]     [Samba]

  Powered by Linux