Cc: Andrea Arcangeli <aarcange@xxxxxxxxxx>
Cc: James Bottomley <James.Bottomley@xxxxxxxxxxxxxxxxxxxxx>
Cc: Christoph Hellwig <hch@xxxxxxxxxxxxx>
Cc: David Miller <davem@xxxxxxxxxxxxx>
Cc: Jerome Glisse <jglisse@xxxxxxxxxx>
Cc: Jason Gunthorpe <jgg@xxxxxxxxxxxx>
Cc: linux-mm@xxxxxxxxx
Cc: linux-arm-kernel@xxxxxxxxxxxxxxxxxxx
Cc: linux-parisc@xxxxxxxxxxxxxxx
Signed-off-by: Jason Wang <jasowang@xxxxxxxxxx>
Signed-off-by: Michael S. Tsirkin <mst@xxxxxxxxxx>
---
drivers/vhost/vhost.c | 551 +++++++++++++++++++++++++++++++++++++++++-
drivers/vhost/vhost.h | 41 ++++
2 files changed, 589 insertions(+), 3 deletions(-)
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 791562e03fe0..f98155f28f02 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -298,6 +298,182 @@ static void vhost_vq_meta_reset(struct vhost_dev *d)
__vhost_vq_meta_reset(d->vqs[i]);
}
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+static void vhost_map_unprefetch(struct vhost_map *map)
+{
+ kfree(map->pages);
+ kfree(map);
+}
+
+static void vhost_set_map_dirty(struct vhost_virtqueue *vq,
+ struct vhost_map *map, int index)
+{
+ struct vhost_uaddr *uaddr = &vq->uaddrs[index];
+ int i;
+
+ if (uaddr->write) {
+ for (i = 0; i < map->npages; i++)
+ set_page_dirty(map->pages[i]);
+ }
+}
+
+static void vhost_uninit_vq_maps(struct vhost_virtqueue *vq)
+{
+ struct vhost_map *map[VHOST_NUM_ADDRS];
+ int i;
+
+ spin_lock(&vq->mmu_lock);
+ for (i = 0; i < VHOST_NUM_ADDRS; i++) {
+ map[i] = vq->maps[i];
+ if (map[i]) {
+ vhost_set_map_dirty(vq, map[i], i);
+ vq->maps[i] = NULL;
+ }
+ }
+ spin_unlock(&vq->mmu_lock);
+
+ /* No need for synchronization since we are serialized with
+ * memory accessors (e.g vq mutex held).
+ */
+
+ for (i = 0; i < VHOST_NUM_ADDRS; i++)
+ if (map[i])
+ vhost_map_unprefetch(map[i]);
+
+}
+
+static void vhost_reset_vq_maps(struct vhost_virtqueue *vq)
+{
+ int i;
+
+ vhost_uninit_vq_maps(vq);
+ for (i = 0; i < VHOST_NUM_ADDRS; i++)
+ vq->uaddrs[i].size = 0;
+}
+
+static bool vhost_map_range_overlap(struct vhost_uaddr *uaddr,
+ unsigned long start,
+ unsigned long end)
+{
+ if (unlikely(!uaddr->size))
+ return false;
+
+ return !(end < uaddr->uaddr || start > uaddr->uaddr - 1 + uaddr->size);
+}
+
+static void inline vhost_vq_access_map_begin(struct vhost_virtqueue *vq)
+{
+ spin_lock(&vq->mmu_lock);
+}
+
+static void inline vhost_vq_access_map_end(struct vhost_virtqueue *vq)
+{
+ spin_unlock(&vq->mmu_lock);
+}
+
+static int vhost_invalidate_vq_start(struct vhost_virtqueue *vq,
+ int index,
+ unsigned long start,
+ unsigned long end,
+ bool blockable)
+{
+ struct vhost_uaddr *uaddr = &vq->uaddrs[index];
+ struct vhost_map *map;
+
+ if (!vhost_map_range_overlap(uaddr, start, end))
+ return 0;
+ else if (!blockable)
+ return -EAGAIN;
+
+ spin_lock(&vq->mmu_lock);
+ ++vq->invalidate_count;
+
+ map = vq->maps[index];
+ if (map)
+ vq->maps[index] = NULL;
+ spin_unlock(&vq->mmu_lock);
+
+ if (map) {
+ vhost_set_map_dirty(vq, map, index);
+ vhost_map_unprefetch(map);
+ }
+
+ return 0;
+}
+
+static void vhost_invalidate_vq_end(struct vhost_virtqueue *vq,
+ int index,
+ unsigned long start,
+ unsigned long end)
+{
+ if (!vhost_map_range_overlap(&vq->uaddrs[index], start, end))
+ return;
+
+ spin_lock(&vq->mmu_lock);
+ --vq->invalidate_count;
+ spin_unlock(&vq->mmu_lock);
+}
+
+static int vhost_invalidate_range_start(struct mmu_notifier *mn,
+ const struct mmu_notifier_range *range)
+{
+ struct vhost_dev *dev = container_of(mn, struct vhost_dev,
+ mmu_notifier);
+ bool blockable = mmu_notifier_range_blockable(range);
+ int i, j, ret;
+
+ for (i = 0; i < dev->nvqs; i++) {
+ struct vhost_virtqueue *vq = dev->vqs[i];
+
+ for (j = 0; j < VHOST_NUM_ADDRS; j++) {
+ ret = vhost_invalidate_vq_start(vq, j,
+ range->start,
+ range->end, blockable);
+ if (ret)
+ return ret;
+ }
+ }
+
+ return 0;
+}
+
+static void vhost_invalidate_range_end(struct mmu_notifier *mn,
+ const struct mmu_notifier_range *range)
+{
+ struct vhost_dev *dev = container_of(mn, struct vhost_dev,
+ mmu_notifier);
+ int i, j;
+
+ for (i = 0; i < dev->nvqs; i++) {
+ struct vhost_virtqueue *vq = dev->vqs[i];
+
+ for (j = 0; j < VHOST_NUM_ADDRS; j++)
+ vhost_invalidate_vq_end(vq, j,
+ range->start,
+ range->end);
+ }
+}
+
+static const struct mmu_notifier_ops vhost_mmu_notifier_ops = {
+ .invalidate_range_start = vhost_invalidate_range_start,
+ .invalidate_range_end = vhost_invalidate_range_end,
+};
+
+static void vhost_init_maps(struct vhost_dev *dev)
+{
+ struct vhost_virtqueue *vq;
+ int i, j;
+
+ dev->mmu_notifier.ops = &vhost_mmu_notifier_ops;
+
+ for (i = 0; i < dev->nvqs; ++i) {
+ vq = dev->vqs[i];
+ for (j = 0; j < VHOST_NUM_ADDRS; j++)
+ vq->maps[j] = NULL;
+ }
+}
+#endif
+
static void vhost_vq_reset(struct vhost_dev *dev,
struct vhost_virtqueue *vq)
{
@@ -326,7 +502,11 @@ static void vhost_vq_reset(struct vhost_dev *dev,
vq->busyloop_timeout = 0;
vq->umem = NULL;
vq->iotlb = NULL;
+ vq->invalidate_count = 0;
__vhost_vq_meta_reset(vq);
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ vhost_reset_vq_maps(vq);
+#endif
}
static int vhost_worker(void *data)
@@ -471,12 +651,15 @@ void vhost_dev_init(struct vhost_dev *dev,
dev->iov_limit = iov_limit;
dev->weight = weight;
dev->byte_weight = byte_weight;
+ dev->has_notifier = false;
init_llist_head(&dev->work_list);
init_waitqueue_head(&dev->wait);
INIT_LIST_HEAD(&dev->read_list);
INIT_LIST_HEAD(&dev->pending_list);
spin_lock_init(&dev->iotlb_lock);
-
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ vhost_init_maps(dev);
+#endif
for (i = 0; i < dev->nvqs; ++i) {
vq = dev->vqs[i];
@@ -485,6 +668,7 @@ void vhost_dev_init(struct vhost_dev *dev,
vq->heads = NULL;
vq->dev = dev;
mutex_init(&vq->mutex);
+ spin_lock_init(&vq->mmu_lock);
vhost_vq_reset(dev, vq);
if (vq->handle_kick)
vhost_poll_init(&vq->poll, vq->handle_kick,
@@ -564,7 +748,19 @@ long vhost_dev_set_owner(struct vhost_dev *dev)
if (err)
goto err_cgroup;
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ err = mmu_notifier_register(&dev->mmu_notifier, dev->mm);
+ if (err)
+ goto err_mmu_notifier;
+#endif
+ dev->has_notifier = true;
+
return 0;
+
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+err_mmu_notifier:
+ vhost_dev_free_iovecs(dev);
+#endif
err_cgroup:
kthread_stop(worker);
dev->worker = NULL;
@@ -655,6 +851,107 @@ static void vhost_clear_msg(struct vhost_dev *dev)
spin_unlock(&dev->iotlb_lock);
}
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+static void vhost_setup_uaddr(struct vhost_virtqueue *vq,
+ int index, unsigned long uaddr,
+ size_t size, bool write)
+{
+ struct vhost_uaddr *addr = &vq->uaddrs[index];
+
+ addr->uaddr = uaddr;
+ addr->size = size;
+ addr->write = write;
+}
+
+static void vhost_setup_vq_uaddr(struct vhost_virtqueue *vq)
+{
+ vhost_setup_uaddr(vq, VHOST_ADDR_DESC,
+ (unsigned long)vq->desc,
+ vhost_get_desc_size(vq, vq->num),
+ false);
+ vhost_setup_uaddr(vq, VHOST_ADDR_AVAIL,
+ (unsigned long)vq->avail,
+ vhost_get_avail_size(vq, vq->num),
+ false);
+ vhost_setup_uaddr(vq, VHOST_ADDR_USED,
+ (unsigned long)vq->used,
+ vhost_get_used_size(vq, vq->num),
+ true);
+}
+
+static int vhost_map_prefetch(struct vhost_virtqueue *vq,
+ int index)
+{
+ struct vhost_map *map;
+ struct vhost_uaddr *uaddr = &vq->uaddrs[index];
+ struct page **pages;
+ int npages = DIV_ROUND_UP(uaddr->size, PAGE_SIZE);
+ int npinned;
+ void *vaddr, *v;
+ int err;
+ int i;
+
+ spin_lock(&vq->mmu_lock);
+
+ err = -EFAULT;
+ if (vq->invalidate_count)
+ goto err;
+
+ err = -ENOMEM;
+ map = kmalloc(sizeof(*map), GFP_ATOMIC);
+ if (!map)
+ goto err;
+
+ pages = kmalloc_array(npages, sizeof(struct page *), GFP_ATOMIC);
+ if (!pages)
+ goto err_pages;
+
+ err = EFAULT;
+ npinned = __get_user_pages_fast(uaddr->uaddr, npages,
+ uaddr->write, pages);
+ if (npinned > 0)
+ release_pages(pages, npinned);
+ if (npinned != npages)
+ goto err_gup;
+
+ for (i = 0; i < npinned; i++)
+ if (PageHighMem(pages[i]))
+ goto err_gup;
+
+ vaddr = v = page_address(pages[0]);
+
+ /* For simplicity, fallback to userspace address if VA is not
+ * contigious.
+ */
+ for (i = 1; i < npinned; i++) {
+ v += PAGE_SIZE;
+ if (v != page_address(pages[i]))
+ goto err_gup;
+ }
+
+ map->addr = vaddr + (uaddr->uaddr & (PAGE_SIZE - 1));
+ map->npages = npages;
+ map->pages = pages;
+
+ vq->maps[index] = map;
+ /* No need for a synchronize_rcu(). This function should be
+ * called by dev->worker so we are serialized with all
+ * readers.
+ */
+ spin_unlock(&vq->mmu_lock);
+
+ return 0;
+
+err_gup:
+ kfree(pages);
+err_pages:
+ kfree(map);
+err:
+ spin_unlock(&vq->mmu_lock);
+ return err;
+}
+#endif
+
void vhost_dev_cleanup(struct vhost_dev *dev)
{
int i;
@@ -684,8 +981,20 @@ void vhost_dev_cleanup(struct vhost_dev *dev)
kthread_stop(dev->worker);
dev->worker = NULL;
}
- if (dev->mm)
+ if (dev->mm) {
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ if (dev->has_notifier) {
+ mmu_notifier_unregister(&dev->mmu_notifier,
+ dev->mm);
+ dev->has_notifier = false;
+ }
+#endif
mmput(dev->mm);
+ }
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ for (i = 0; i < dev->nvqs; i++)
+ vhost_uninit_vq_maps(dev->vqs[i]);
+#endif
dev->mm = NULL;
}
EXPORT_SYMBOL_GPL(vhost_dev_cleanup);
@@ -914,6 +1223,26 @@ static inline void __user *__vhost_get_user(struct vhost_virtqueue *vq,
static inline int vhost_put_avail_event(struct vhost_virtqueue *vq)
{
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ struct vhost_map *map;
+ struct vring_used *used;
+
+ if (!vq->iotlb) {
+ vhost_vq_access_map_begin(vq);
+
+ map = vq->maps[VHOST_ADDR_USED];
+ if (likely(map)) {
+ used = map->addr;
+ *((__virtio16 *)&used->ring[vq->num]) =
+ cpu_to_vhost16(vq, vq->avail_idx);
+ vhost_vq_access_map_end(vq);
+ return 0;
+ }
+
+ vhost_vq_access_map_end(vq);
+ }
+#endif
+
return vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx),
vhost_avail_event(vq));
}
@@ -922,6 +1251,27 @@ static inline int vhost_put_used(struct vhost_virtqueue *vq,
struct vring_used_elem *head, int idx,
int count)
{
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ struct vhost_map *map;
+ struct vring_used *used;
+ size_t size;
+
+ if (!vq->iotlb) {
+ vhost_vq_access_map_begin(vq);
+
+ map = vq->maps[VHOST_ADDR_USED];
+ if (likely(map)) {
+ used = map->addr;
+ size = count * sizeof(*head);
+ memcpy(used->ring + idx, head, size);
+ vhost_vq_access_map_end(vq);
+ return 0;
+ }
+
+ vhost_vq_access_map_end(vq);
+ }
+#endif
+
return vhost_copy_to_user(vq, vq->used->ring + idx, head,
count * sizeof(*head));
}
@@ -929,6 +1279,25 @@ static inline int vhost_put_used(struct vhost_virtqueue *vq,
static inline int vhost_put_used_flags(struct vhost_virtqueue *vq)
{
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ struct vhost_map *map;
+ struct vring_used *used;
+
+ if (!vq->iotlb) {
+ vhost_vq_access_map_begin(vq);
+
+ map = vq->maps[VHOST_ADDR_USED];
+ if (likely(map)) {
+ used = map->addr;
+ used->flags = cpu_to_vhost16(vq, vq->used_flags);
+ vhost_vq_access_map_end(vq);
+ return 0;
+ }
+
+ vhost_vq_access_map_end(vq);
+ }
+#endif
+
return vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags),
&vq->used->flags);
}
@@ -936,6 +1305,25 @@ static inline int vhost_put_used_flags(struct vhost_virtqueue *vq)
static inline int vhost_put_used_idx(struct vhost_virtqueue *vq)
{
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ struct vhost_map *map;
+ struct vring_used *used;
+
+ if (!vq->iotlb) {
+ vhost_vq_access_map_begin(vq);
+
+ map = vq->maps[VHOST_ADDR_USED];
+ if (likely(map)) {
+ used = map->addr;
+ used->idx = cpu_to_vhost16(vq, vq->last_used_idx);
+ vhost_vq_access_map_end(vq);
+ return 0;
+ }
+
+ vhost_vq_access_map_end(vq);
+ }
+#endif
+
return vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx),
&vq->used->idx);
}
@@ -981,12 +1369,50 @@ static void vhost_dev_unlock_vqs(struct vhost_dev *d)
static inline int vhost_get_avail_idx(struct vhost_virtqueue *vq,
__virtio16 *idx)
{
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ struct vhost_map *map;
+ struct vring_avail *avail;
+
+ if (!vq->iotlb) {
+ vhost_vq_access_map_begin(vq);
+
+ map = vq->maps[VHOST_ADDR_AVAIL];
+ if (likely(map)) {
+ avail = map->addr;
+ *idx = avail->idx;