The first task to pin any pages becomes the dma owner, and becomes the only task allowed to pin. This prevents an application from exceeding the initial task's RLIMIT_MEMLOCK by fork'ing and pinning in children. Signed-off-by: Steve Sistare <steven.sistare@xxxxxxxxxx> --- drivers/vfio/vfio_iommu_type1.c | 64 +++++++++++++++++++++++++++++++++-------- 1 file changed, 52 insertions(+), 12 deletions(-) diff --git a/drivers/vfio/vfio_iommu_type1.c b/drivers/vfio/vfio_iommu_type1.c index 02c6ea3..4429794 100644 --- a/drivers/vfio/vfio_iommu_type1.c +++ b/drivers/vfio/vfio_iommu_type1.c @@ -75,6 +75,7 @@ struct vfio_iommu { bool nesting; bool dirty_page_tracking; struct list_head emulated_iommu_groups; + struct task_struct *task; }; struct vfio_domain { @@ -93,9 +94,9 @@ struct vfio_dma { int prot; /* IOMMU_READ/WRITE */ bool iommu_mapped; bool lock_cap; /* capable(CAP_IPC_LOCK) */ - struct task_struct *task; struct rb_root pfn_list; /* Ex-user pinned pfn list */ unsigned long *bitmap; + struct vfio_iommu *iommu; /* back pointer */ }; struct vfio_batch { @@ -408,19 +409,29 @@ static int vfio_iova_put_vfio_pfn(struct vfio_dma *dma, struct vfio_pfn *vpfn) static int vfio_lock_acct(struct vfio_dma *dma, long npage, bool async) { + struct task_struct *task = dma->iommu->task; + bool kthread = !current->mm; struct mm_struct *mm; int ret; if (!npage) return 0; - mm = async ? get_task_mm(dma->task) : dma->task->mm; + /* This is enforced at higher levels, so if it bites, it is a bug. */ + + if (!kthread && current->group_leader != task) { + WARN_ONCE(1, "%s: caller is pid %d, owner is pid %d\n", + __func__, current->group_leader->pid, task->pid); + return -EPERM; + } + + mm = async ? get_task_mm(task) : task->mm; if (!mm) return -ESRCH; /* process exited */ ret = mmap_write_lock_killable(mm); if (!ret) { - ret = __account_locked_vm(mm, abs(npage), npage > 0, dma->task, + ret = __account_locked_vm(mm, abs(npage), npage > 0, task, dma->lock_cap); mmap_write_unlock(mm); } @@ -609,6 +620,9 @@ static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr, if (!mm) return -ENODEV; + if (dma->iommu->task != current->group_leader) + return -EPERM; + if (batch->size) { /* Leftover pages in batch from an earlier call. */ *pfn_base = page_to_pfn(batch->pages[batch->offset]); @@ -730,11 +744,12 @@ static long vfio_unpin_pages_remote(struct vfio_dma *dma, dma_addr_t iova, static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr, unsigned long *pfn_base, bool do_accounting) { + struct task_struct *task = dma->iommu->task; struct page *pages[1]; struct mm_struct *mm; int ret; - mm = get_task_mm(dma->task); + mm = get_task_mm(task); if (!mm) return -ENODEV; @@ -751,8 +766,8 @@ static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr, if (ret == -ENOMEM) pr_warn("%s: Task %s (%d) RLIMIT_MEMLOCK " "(%ld) exceeded\n", __func__, - dma->task->comm, task_pid_nr(dma->task), - task_rlimit(dma->task, RLIMIT_MEMLOCK)); + task->comm, task_pid_nr(task), + task_rlimit(task, RLIMIT_MEMLOCK)); } } @@ -784,6 +799,7 @@ static int vfio_iommu_type1_pin_pages(void *iommu_data, int npage, int prot, struct page **pages) { + bool kthread = !current->mm; struct vfio_iommu *iommu = iommu_data; struct vfio_iommu_group *group; int i, j, ret; @@ -807,6 +823,11 @@ static int vfio_iommu_type1_pin_pages(void *iommu_data, goto pin_done; } + if (!kthread && iommu->task != current->group_leader) { + ret = -EPERM; + goto pin_done; + } + /* * If iommu capable domain exist in the container then all pages are * already pinned and accounted. Accounting should be done if there is no @@ -1097,7 +1118,6 @@ static void vfio_remove_dma(struct vfio_iommu *iommu, struct vfio_dma *dma) WARN_ON(!RB_EMPTY_ROOT(&dma->pfn_list)); vfio_unmap_unpin(iommu, dma, true); vfio_unlink_dma(iommu, dma); - put_task_struct(dma->task); vfio_dma_bitmap_free(dma); kfree(dma); iommu->dma_avail++; @@ -1247,6 +1267,16 @@ static void vfio_notify_dma_unmap(struct vfio_iommu *iommu, mutex_lock(&iommu->lock); } +static void vfio_iommu_set_task(struct vfio_iommu *iommu, + struct task_struct *task) +{ + if (iommu->task) + put_task_struct(iommu->task); + if (task) + iommu->task = get_task_struct(task); + iommu->task = task; +} + static int vfio_dma_do_unmap(struct vfio_iommu *iommu, struct vfio_iommu_type1_dma_unmap *unmap, struct vfio_bitmap *bitmap) @@ -1362,6 +1392,9 @@ static int vfio_dma_do_unmap(struct vfio_iommu *iommu, } unlock: + if (RB_EMPTY_ROOT(&iommu->dma_list)) + vfio_iommu_set_task(iommu, NULL); + mutex_unlock(&iommu->lock); /* Report how much was unmapped */ @@ -1537,6 +1570,7 @@ static int vfio_dma_do_map(struct vfio_iommu *iommu, } iommu->dma_avail--; + dma->iommu = iommu; dma->iova = iova; dma->vaddr = vaddr; dma->prot = prot; @@ -1566,8 +1600,8 @@ static int vfio_dma_do_map(struct vfio_iommu *iommu, * externally mapped. Therefore track CAP_IPC_LOCK in vfio_dma at the * time of calling MAP_DMA. */ - get_task_struct(current->group_leader); - dma->task = current->group_leader; + if (!iommu->task) + vfio_iommu_set_task(iommu, current->group_leader); dma->lock_cap = capable(CAP_IPC_LOCK); dma->pfn_list = RB_ROOT; @@ -2528,6 +2562,8 @@ static void vfio_iommu_type1_release(void *iommu_data) vfio_iommu_iova_free(&iommu->iova_list); + vfio_iommu_set_task(iommu, NULL); + kfree(iommu); } @@ -2963,6 +2999,7 @@ static int vfio_iommu_type1_dma_rw_chunk(struct vfio_iommu *iommu, struct vfio_dma *dma; bool kthread = current->mm == NULL; size_t offset; + int ret = -EFAULT; *copied = 0; @@ -2974,15 +3011,18 @@ static int vfio_iommu_type1_dma_rw_chunk(struct vfio_iommu *iommu, !(dma->prot & IOMMU_READ)) return -EPERM; - mm = get_task_mm(dma->task); + mm = get_task_mm(iommu->task); if (!mm) return -EPERM; if (kthread) kthread_use_mm(mm); - else if (current->mm != mm) + else if (current->mm != mm) { + /* Must use matching mm for vaddr translation. */ + ret = -EPERM; goto out; + } offset = user_iova - dma->iova; @@ -3011,7 +3051,7 @@ static int vfio_iommu_type1_dma_rw_chunk(struct vfio_iommu *iommu, kthread_unuse_mm(mm); out: mmput(mm); - return *copied ? 0 : -EFAULT; + return *copied ? 0 : ret; } static int vfio_iommu_type1_dma_rw(void *iommu_data, dma_addr_t user_iova, -- 1.8.3.1