+
+ size -= sz;
+ pfn++;
+ addr += sz;
+ offset = 0;
+ }
+ } else if (dir == DMA_TO_DEVICE) {
+ memcpy(addr, phys_to_virt(orig), size);
+ } else {
+ memcpy(phys_to_virt(orig), addr, size);
+ }
+}
+
+static struct page *
+vduse_domain_get_mapping_page(struct vduse_iova_domain *domain, u64 iova)
+{
+ u64 start = iova & PAGE_MASK;
+ u64 last = start + PAGE_SIZE - 1;
+ struct vhost_iotlb_map *map;
+ struct page *page = NULL;
+
+ spin_lock(&domain->iotlb_lock);
+ map = vhost_iotlb_itree_first(domain->iotlb, start, last);
+ if (!map)
+ goto out;
+
+ page = pfn_to_page((map->addr + iova - map->start) >> PAGE_SHIFT);
+ get_page(page);
+out:
+ spin_unlock(&domain->iotlb_lock);
+
+ return page;
+}
+
+static struct page *
+vduse_domain_alloc_bounce_page(struct vduse_iova_domain *domain, u64 iova)
+{
+ u64 start = iova & PAGE_MASK;
+ u64 last = start + PAGE_SIZE - 1;
+ struct vhost_iotlb_map *map;
+ struct page *page = NULL, *new_page = alloc_page(GFP_KERNEL);
+
+ if (!new_page)
+ return NULL;
+
+ spin_lock(&domain->iotlb_lock);
+ if (!vhost_iotlb_itree_first(domain->iotlb, start, last)) {
+ __free_page(new_page);
+ goto out;
+ }
+ page = vduse_domain_get_bounce_page(domain, iova);
+ if (page) {
+ get_page(page);
+ __free_page(new_page);
+ goto out;
+ }
+ vduse_domain_set_bounce_page(domain, iova, new_page);
+ get_page(new_page);
+ page = new_page;
+
+ for (map = vhost_iotlb_itree_first(domain->iotlb, start, last); map;
+ map = vhost_iotlb_itree_next(map, start, last)) {
+ unsigned int src_offset = 0, dst_offset = 0;
+ phys_addr_t src;
+ void *dst;
+ size_t sz;
+
+ if (perm_to_dir(map->perm) == DMA_FROM_DEVICE)
+ continue;
+
+ if (start > map->start)
+ src_offset = start - map->start;
+ else
+ dst_offset = map->start - start;
+
+ src = map->addr + src_offset;
+ dst = page_address(page) + dst_offset;
+ sz = min_t(size_t, map->size - src_offset,
+ PAGE_SIZE - dst_offset);
+ do_bounce(src, dst, sz, DMA_TO_DEVICE);
+ }
+out:
+ spin_unlock(&domain->iotlb_lock);
+
+ return page;
+}
+
+static void
+vduse_domain_free_bounce_pages(struct vduse_iova_domain *domain,
+ u64 iova, size_t size)
+{
+ struct page *page;
+
+ spin_lock(&domain->iotlb_lock);
+ if (WARN_ON(vhost_iotlb_itree_first(domain->iotlb, iova,
+ iova + size - 1)))
+ goto out;
+
+ while (size > 0) {
+ page = vduse_domain_get_bounce_page(domain, iova);
+ if (page) {
+ vduse_domain_set_bounce_page(domain, iova, NULL);
+ __free_page(page);
+ }
+ size -= PAGE_SIZE;
+ iova += PAGE_SIZE;
+ }
+out:
+ spin_unlock(&domain->iotlb_lock);
+}
+
+static void vduse_domain_bounce(struct vduse_iova_domain *domain,
+ dma_addr_t iova, phys_addr_t orig,
+ size_t size, enum dma_data_direction dir)
+{
+ unsigned int offset = offset_in_page(iova);
+
+ while (size) {
+ struct page *p = vduse_domain_get_bounce_page(domain, iova);
+ size_t sz = min_t(size_t, PAGE_SIZE - offset, size);
+
+ WARN_ON(!p && dir == DMA_FROM_DEVICE);
+
+ if (p)
+ do_bounce(orig, page_address(p) + offset, sz, dir);
+
+ size -= sz;
+ orig += sz;
+ iova += sz;
+ offset = 0;
+ }
+}
+
+static dma_addr_t vduse_domain_alloc_iova(struct iova_domain *iovad,
+ unsigned long size, unsigned long limit)
+{
+ unsigned long shift = iova_shift(iovad);
+ unsigned long iova_len = iova_align(iovad, size) >> shift;
+ unsigned long iova_pfn;
+
+ if (iova_len < (1 << (IOVA_RANGE_CACHE_MAX_SIZE - 1)))
+ iova_len = roundup_pow_of_two(iova_len);
+ iova_pfn = alloc_iova_fast(iovad, iova_len, limit >> shift, true);
+
+ return iova_pfn << shift;
+}
+
+static void vduse_domain_free_iova(struct iova_domain *iovad,
+ dma_addr_t iova, size_t size)
+{
+ unsigned long shift = iova_shift(iovad);
+ unsigned long iova_len = iova_align(iovad, size) >> shift;
+
+ free_iova_fast(iovad, iova >> shift, iova_len);
+}
+
+dma_addr_t vduse_domain_map_page(struct vduse_iova_domain *domain,
+ struct page *page, unsigned long offset,
+ size_t size, enum dma_data_direction dir,
+ unsigned long attrs)
+{
+ struct iova_domain *iovad = &domain->stream_iovad;
+ unsigned long limit = domain->bounce_size - 1;
+ phys_addr_t pa = page_to_phys(page) + offset;
+ dma_addr_t iova = vduse_domain_alloc_iova(iovad, size, limit);
+ int ret;
+
+ if (!iova)
+ return DMA_MAPPING_ERROR;
+
+ spin_lock(&domain->iotlb_lock);
+ ret = vhost_iotlb_add_range(domain->iotlb, (u64)iova,
+ (u64)iova + size - 1,
+ pa, dir_to_perm(dir));
+ spin_unlock(&domain->iotlb_lock);
+ if (ret) {
+ vduse_domain_free_iova(iovad, iova, size);
+ return DMA_MAPPING_ERROR;
+ }
+ if (dir == DMA_TO_DEVICE || dir == DMA_BIDIRECTIONAL)
+ vduse_domain_bounce(domain, iova, pa, size, DMA_TO_DEVICE);
+
+ return iova;
+}
+
+void vduse_domain_unmap_page(struct vduse_iova_domain *domain,
+ dma_addr_t dma_addr, size_t size,
+ enum dma_data_direction dir, unsigned long attrs)
+{
+ struct iova_domain *iovad = &domain->stream_iovad;
+ struct vhost_iotlb_map *map;
+ phys_addr_t pa;
+
+ spin_lock(&domain->iotlb_lock);
+ map = vhost_iotlb_itree_first(domain->iotlb, (u64)dma_addr,
+ (u64)dma_addr + size - 1);
+ if (WARN_ON(!map)) {
+ spin_unlock(&domain->iotlb_lock);
+ return;
+ }
+ pa = map->addr;
+ vhost_iotlb_map_free(domain->iotlb, map);
+ spin_unlock(&domain->iotlb_lock);
+
+ if (dir == DMA_FROM_DEVICE || dir == DMA_BIDIRECTIONAL)
+ vduse_domain_bounce(domain, dma_addr, pa,
+ size, DMA_FROM_DEVICE);
+
+ vduse_domain_free_iova(iovad, dma_addr, size);
+}
+
+void *vduse_domain_alloc_coherent(struct vduse_iova_domain *domain,
+ size_t size, dma_addr_t *dma_addr,
+ gfp_t flag, unsigned long attrs)
+{
+ struct iova_domain *iovad = &domain->consistent_iovad;
+ unsigned long limit = domain->iova_limit;
+ dma_addr_t iova = vduse_domain_alloc_iova(iovad, size, limit);
+ void *orig = alloc_pages_exact(size, flag);
+ int ret;
+
+ if (!iova || !orig)
+ goto err;
+
+ spin_lock(&domain->iotlb_lock);
+ ret = vhost_iotlb_add_range(domain->iotlb, (u64)iova,
+ (u64)iova + size - 1,
+ virt_to_phys(orig), VHOST_MAP_RW);
+ spin_unlock(&domain->iotlb_lock);
+ if (ret)
+ goto err;
+
+ *dma_addr = iova;
+
+ return orig;
+err:
+ *dma_addr = DMA_MAPPING_ERROR;
+ if (orig)
+ free_pages_exact(orig, size);
+ if (iova)
+ vduse_domain_free_iova(iovad, iova, size);
+
+ return NULL;
+}
+
+void vduse_domain_free_coherent(struct vduse_iova_domain *domain, size_t size,
+ void *vaddr, dma_addr_t dma_addr,
+ unsigned long attrs)
+{
+ struct iova_domain *iovad = &domain->consistent_iovad;
+ struct vhost_iotlb_map *map;
+ phys_addr_t pa;
+
+ spin_lock(&domain->iotlb_lock);
+ map = vhost_iotlb_itree_first(domain->iotlb, (u64)dma_addr,
+ (u64)dma_addr + size - 1);
+ if (WARN_ON(!map)) {
+ spin_unlock(&domain->iotlb_lock);
+ return;
+ }
+ pa = map->addr;
+ vhost_iotlb_map_free(domain->iotlb, map);
+ spin_unlock(&domain->iotlb_lock);
+
+ vduse_domain_free_iova(iovad, dma_addr, size);
+ free_pages_exact(phys_to_virt(pa), size);
+}
+
+static vm_fault_t vduse_domain_mmap_fault(struct vm_fault *vmf)
+{
+ struct vduse_iova_domain *domain = vmf->vma->vm_private_data;
+ unsigned long iova = vmf->pgoff << PAGE_SHIFT;
+ struct page *page;
+
+ if (!domain)
+ return VM_FAULT_SIGBUS;
+
+ if (iova < domain->bounce_size)
+ page = vduse_domain_alloc_bounce_page(domain, iova);
+ else
+ page = vduse_domain_get_mapping_page(domain, iova);
+
+ if (!page)
+ return VM_FAULT_SIGBUS;
+
+ vmf->page = page;
+
+ return 0;
+}
+
+static const struct vm_operations_struct vduse_domain_mmap_ops = {
+ .fault = vduse_domain_mmap_fault,
+};
+
+static int vduse_domain_mmap(struct file *file, struct vm_area_struct *vma)
+{
+ struct vduse_iova_domain *domain = file->private_data;
+
+ vma->vm_flags |= VM_DONTDUMP | VM_DONTEXPAND;
+ vma->vm_private_data = domain;
+ vma->vm_ops = &vduse_domain_mmap_ops;
+
+ return 0;
+}
+
+static int vduse_domain_release(struct inode *inode, struct file *file)
+{
+ struct vduse_iova_domain *domain = file->private_data;
+
+ vduse_domain_free_bounce_pages(domain, 0, domain->bounce_size);
+ put_iova_domain(&domain->stream_iovad);
+ put_iova_domain(&domain->consistent_iovad);
+ vhost_iotlb_free(domain->iotlb);
+ vfree(domain->bounce_pages);
+ kfree(domain);
+
+ return 0;
+}
+
+static const struct file_operations vduse_domain_fops = {
+ .mmap = vduse_domain_mmap,
+ .release = vduse_domain_release,
+};
+
+void vduse_domain_destroy(struct vduse_iova_domain *domain)
+{
+ fput(domain->file);
+}
+
+struct vduse_iova_domain *
+vduse_domain_create(unsigned long iova_limit, size_t bounce_size)
+{
+ struct vduse_iova_domain *domain;
+ struct file *file;
+ unsigned long bounce_pfns = PAGE_ALIGN(bounce_size) >> PAGE_SHIFT;
+
+ if (iova_limit <= bounce_size)
+ return NULL;
+
+ domain = kzalloc(sizeof(*domain), GFP_KERNEL);
+ if (!domain)
+ return NULL;
+
+ domain->iotlb = vhost_iotlb_alloc(0, 0);
+ if (!domain->iotlb)
+ goto err_iotlb;
+
+ domain->iova_limit = iova_limit;
+ domain->bounce_size = PAGE_ALIGN(bounce_size);
+ domain->bounce_pages = vzalloc(bounce_pfns * sizeof(struct page *));
+ if (!domain->bounce_pages)
+ goto err_page;
+
+ file = anon_inode_getfile("[vduse-domain]", &vduse_domain_fops,
+ domain, O_RDWR);
+ if (IS_ERR(file))
+ goto err_file;
+
+ domain->file = file;
+ spin_lock_init(&domain->iotlb_lock);
+ init_iova_domain(&domain->stream_iovad,
+ IOVA_ALLOC_SIZE, IOVA_START_PFN);
+ init_iova_domain(&domain->consistent_iovad,
+ PAGE_SIZE, bounce_pfns);
+
+ return domain;
+err_file:
+ vfree(domain->bounce_pages);
+err_page:
+ vhost_iotlb_free(domain->iotlb);
+err_iotlb:
+ kfree(domain);
+ return NULL;
+}
+
+int vduse_domain_init(void)
+{
+ return iova_cache_get();
+}
+
+void vduse_domain_exit(void)
+{
+ iova_cache_put();
+}
diff --git a/drivers/vdpa/vdpa_user/iova_domain.h b/drivers/vdpa/vdpa_user/iova_domain.h
new file mode 100644
index 000000000000..9c85d8346626
--- /dev/null
+++ b/drivers/vdpa/vdpa_user/iova_domain.h
@@ -0,0 +1,61 @@
+/* SPDX-License-Identifier: GPL-2.0-only */
+/*
+ * MMU-based IOMMU implementation
+ *
+ * Copyright (C) 2020 Bytedance Inc. and/or its affiliates. All rights reserved.
+ *
+ * Author: Xie Yongji <xieyongji@xxxxxxxxxxxxx>
+ *
+ */
+
+#ifndef _VDUSE_IOVA_DOMAIN_H
+#define _VDUSE_IOVA_DOMAIN_H
+
+#include <linux/iova.h>
+#include <linux/dma-mapping.h>
+#include <linux/vhost_iotlb.h>
+
+struct vduse_iova_domain {
+ struct iova_domain stream_iovad;
+ struct iova_domain consistent_iovad;
+ struct page **bounce_pages;
+ size_t bounce_size;
+ unsigned long iova_limit;
+ struct vhost_iotlb *iotlb;