On Tue, 9 Mar 2021 12:26:07 -0700 Alex Williamson <alex.williamson@xxxxxxxxxx> wrote: > On Tue, 9 Mar 2021 13:47:39 -0500 > Peter Xu <peterx@xxxxxxxxxx> wrote: > > > On Tue, Mar 09, 2021 at 12:40:04PM -0400, Jason Gunthorpe wrote: > > > On Tue, Mar 09, 2021 at 08:29:51AM -0700, Alex Williamson wrote: > > > > On Tue, 9 Mar 2021 08:46:09 -0400 > > > > Jason Gunthorpe <jgg@xxxxxxxxxx> wrote: > > > > > > > > > On Tue, Mar 09, 2021 at 03:49:09AM +0000, Zengtao (B) wrote: > > > > > > Hi guys: > > > > > > > > > > > > Thanks for the helpful comments, after rethinking the issue, I have proposed > > > > > > the following change: > > > > > > 1. follow_pte instead of follow_pfn. > > > > > > > > > > Still no on follow_pfn, you don't need it once you use vmf_insert_pfn > > > > > > > > vmf_insert_pfn() only solves the BUG_ON, follow_pte() is being used > > > > here to determine whether the translation is already present to avoid > > > > both duplicate work in inserting the translation and allocating a > > > > duplicate vma tracking structure. > > > > > > Oh.. Doing something stateful in fault is not nice at all > > > > > > I would rather see __vfio_pci_add_vma() search the vma_list for dups > > > than call follow_pfn/pte.. > > > > It seems to me that searching vma list is still the simplest way to fix the > > problem for the current code base. I see io_remap_pfn_range() is also used in > > the new series - maybe that'll need to be moved to where PCI_COMMAND_MEMORY got > > turned on/off in the new series (I just noticed remap_pfn_range modifies vma > > flags..), as you suggested in the other email. > > > In the new series, I think the fault handler becomes (untested): > > static vm_fault_t vfio_pci_mmap_fault(struct vm_fault *vmf) > { > struct vm_area_struct *vma = vmf->vma; > struct vfio_pci_device *vdev = vma->vm_private_data; > unsigned long base_pfn, pgoff; > vm_fault_t ret = VM_FAULT_SIGBUS; > > if (vfio_pci_bar_vma_to_pfn(vma, &base_pfn)) > return ret; > > pgoff = (vmf->address - vma->vm_start) >> PAGE_SHIFT; > > down_read(&vdev->memory_lock); > > if (__vfio_pci_memory_enabled(vdev)) > ret = vmf_insert_pfn(vma, vmf->address, pgoff + base_pfn); > > up_read(&vdev->memory_lock); > > return ret; > } And I think this is what we end up with for the current code base: diff --git a/drivers/vfio/pci/vfio_pci.c b/drivers/vfio/pci/vfio_pci.c index 65e7e6b44578..2f247ab18c66 100644 --- a/drivers/vfio/pci/vfio_pci.c +++ b/drivers/vfio/pci/vfio_pci.c @@ -1568,19 +1568,24 @@ void vfio_pci_memory_unlock_and_restore(struct vfio_pci_device *vdev, u16 cmd) } /* Caller holds vma_lock */ -static int __vfio_pci_add_vma(struct vfio_pci_device *vdev, - struct vm_area_struct *vma) +struct vfio_pci_mmap_vma *__vfio_pci_add_vma(struct vfio_pci_device *vdev, + struct vm_area_struct *vma) { struct vfio_pci_mmap_vma *mmap_vma; + list_for_each_entry(mmap_vma, &vdev->vma_list, vma_next) { + if (mmap_vma->vma == vma) + return ERR_PTR(-EEXIST); + } + mmap_vma = kmalloc(sizeof(*mmap_vma), GFP_KERNEL); if (!mmap_vma) - return -ENOMEM; + return ERR_PTR(-ENOMEM); mmap_vma->vma = vma; list_add(&mmap_vma->vma_next, &vdev->vma_list); - return 0; + return mmap_vma; } /* @@ -1612,30 +1617,39 @@ static vm_fault_t vfio_pci_mmap_fault(struct vm_fault *vmf) { struct vm_area_struct *vma = vmf->vma; struct vfio_pci_device *vdev = vma->vm_private_data; - vm_fault_t ret = VM_FAULT_NOPAGE; + struct vfio_pci_mmap_vma *mmap_vma; + unsigned long vaddr, pfn; + vm_fault_t ret; mutex_lock(&vdev->vma_lock); down_read(&vdev->memory_lock); if (!__vfio_pci_memory_enabled(vdev)) { ret = VM_FAULT_SIGBUS; - mutex_unlock(&vdev->vma_lock); goto up_out; } - if (__vfio_pci_add_vma(vdev, vma)) { - ret = VM_FAULT_OOM; - mutex_unlock(&vdev->vma_lock); + mmap_vma = __vfio_pci_add_vma(vdev, vma); + if (IS_ERR(mmap_vma)) { + /* A concurrent fault might have already inserted the page */ + ret = (PTR_ERR(mmap_vma) == -EEXIST) ? VM_FAULT_NOPAGE : + VM_FAULT_OOM; goto up_out; } - mutex_unlock(&vdev->vma_lock); - - if (io_remap_pfn_range(vma, vma->vm_start, vma->vm_pgoff, - vma->vm_end - vma->vm_start, vma->vm_page_prot)) - ret = VM_FAULT_SIGBUS; - + for (vaddr = vma->vm_start, pfn = vma->vm_pgoff; + vaddr < vma->vm_end; vaddr += PAGE_SIZE, pfn++) { + ret = vmf_insert_pfn(vma, vaddr, pfn); + if (ret != VM_FAULT_NOPAGE) { + zap_vma_ptes(vma, vma->vm_start, + vma->vm_end - vma->vm_start); + list_del(&mmap_vma->vma_next); + kfree(mmap_vma); + break; + } + } up_out: + mutex_unlock(&vdev->vma_lock); up_read(&vdev->memory_lock); return ret; }