In case that user space failed to read or respond the pending faults. As the per-fault iommufd data will be possibly accessed in two different contexts: user reading/responding and the timer expiring, add a reference counter for each iommufd fault data and free the data only after all the reference counters are released. The page fault response timeout value is device-specific and indicates how long the bus/device will wait for a response to a page fault request. The timeout value is added to the per-device fault cookie. Ideally, it should be calculated according to the platform configuration (PCI, ACPI, device tree, etc.). This defines a default value of 1 second in case that no platform opt-in is available. This default value is roughly estimated and subject to be changed according to real use cases. Signed-off-by: Lu Baolu <baolu.lu@xxxxxxxxxxxxxxx> --- drivers/iommu/iommufd/iommufd_private.h | 8 +++ drivers/iommu/iommufd/device.c | 3 + drivers/iommu/iommufd/hw_pagetable.c | 80 +++++++++++++++++++++++-- 3 files changed, 87 insertions(+), 4 deletions(-) diff --git a/drivers/iommu/iommufd/iommufd_private.h b/drivers/iommu/iommufd/iommufd_private.h index 0985e83a611f..f5b8a53044c4 100644 --- a/drivers/iommu/iommufd/iommufd_private.h +++ b/drivers/iommu/iommufd/iommufd_private.h @@ -249,9 +249,12 @@ struct hw_pgtable_fault { struct iommufd_fault { struct device *dev; ioasid_t pasid; + struct iommufd_hw_pagetable *hwpt; struct iommu_hwpt_pgfault fault; /* List head at hw_pgtable_fault:deliver or response */ struct list_head item; + struct timer_list timer; + refcount_t users; }; /* @@ -336,6 +339,11 @@ struct iommufd_device { struct iommufd_fault_cookie { struct iommufd_device *idev; + /* + * The maximum number of milliseconds that a device will wait for a + * response to a page fault request. + */ + unsigned long timeout; }; static inline struct iommufd_device * diff --git a/drivers/iommu/iommufd/device.c b/drivers/iommu/iommufd/device.c index 3408f1fc3e9f..6ad46638f4e1 100644 --- a/drivers/iommu/iommufd/device.c +++ b/drivers/iommu/iommufd/device.c @@ -374,6 +374,8 @@ static int iommufd_group_setup_msi(struct iommufd_group *igroup, return 0; } +#define IOMMUFD_DEFAULT_IOPF_TIMEOUT 1000 + static int iommufd_device_set_fault_cookie(struct iommufd_hw_pagetable *hwpt, struct iommufd_device *idev, ioasid_t pasid) @@ -387,6 +389,7 @@ static int iommufd_device_set_fault_cookie(struct iommufd_hw_pagetable *hwpt, if (!fcookie) return -ENOMEM; fcookie->idev = idev; + fcookie->timeout = IOMMUFD_DEFAULT_IOPF_TIMEOUT; curr = iommu_set_device_fault_cookie(idev->dev, pasid, fcookie); if (IS_ERR(curr)) { diff --git a/drivers/iommu/iommufd/hw_pagetable.c b/drivers/iommu/iommufd/hw_pagetable.c index c1f3ebdce796..8c441fd72e1f 100644 --- a/drivers/iommu/iommufd/hw_pagetable.c +++ b/drivers/iommu/iommufd/hw_pagetable.c @@ -6,6 +6,7 @@ #include <linux/eventfd.h> #include <linux/file.h> #include <linux/anon_inodes.h> +#include <linux/timer.h> #include <uapi/linux/iommufd.h> #include "../iommu-priv.h" @@ -396,6 +397,60 @@ static void iommufd_compose_fault_message(struct iommu_fault *fault, hwpt_fault->private_data[1] = fault->prm.private_data[1]; } +static void drain_iopf_fault(struct iommufd_fault *ifault) +{ + struct iommu_page_response resp = { + .version = IOMMU_PAGE_RESP_VERSION_1, + .pasid = ifault->fault.pasid, + .grpid = ifault->fault.grpid, + .code = IOMMU_PAGE_RESP_FAILURE, + }; + + if (!(ifault->fault.flags & IOMMU_FAULT_PAGE_REQUEST_LAST_PAGE)) + return; + + if ((ifault->fault.flags & IOMMU_FAULT_PAGE_REQUEST_PASID_VALID) && + (ifault->fault.flags & IOMMU_FAULT_PAGE_RESPONSE_NEEDS_PASID)) + resp.flags = IOMMU_PAGE_RESP_PASID_VALID; + + iommu_page_response(ifault->dev, &resp); +} + +static void iommufd_put_fault(struct iommufd_fault *ifault) +{ + if (!ifault) + return; + + if (refcount_dec_and_test(&ifault->users)) + kfree(ifault); +} + +static int iommufd_fault_timer_teardown(struct iommufd_fault *ifault) +{ + int rc; + + rc = timer_delete(&ifault->timer); + if (rc) + iommufd_put_fault(ifault); + + return rc; +} + +static void iopf_timer_func(struct timer_list *t) +{ + struct iommufd_fault *ifault = from_timer(ifault, t, timer); + struct hw_pgtable_fault *fault = ifault->hwpt->fault; + + mutex_lock(&fault->mutex); + if (!list_empty(&ifault->item)) { + list_del_init(&ifault->item); + drain_iopf_fault(ifault); + } + mutex_unlock(&fault->mutex); + + iommufd_put_fault(ifault); +} + static enum iommu_page_response_code iommufd_hw_pagetable_iopf_handler(struct iommu_fault *fault, struct device *dev, void *data) @@ -416,6 +471,10 @@ iommufd_hw_pagetable_iopf_handler(struct iommu_fault *fault, iommufd_compose_fault_message(fault, &ifault->fault, cookie->idev->obj.id); ifault->dev = dev; ifault->pasid = fault->prm.pasid; + ifault->hwpt = hwpt; + refcount_set(&ifault->users, 2); + timer_setup(&ifault->timer, iopf_timer_func, 0); + mod_timer(&ifault->timer, jiffies + msecs_to_jiffies(cookie->timeout)); mutex_lock(&hwpt->fault->mutex); list_add_tail(&ifault->item, &hwpt->fault->deliver); @@ -443,10 +502,12 @@ static ssize_t hwpt_fault_fops_read(struct file *filep, char __user *buf, break; done += fault_size; list_del_init(&ifault->item); - if (ifault->fault.flags & IOMMU_FAULT_PAGE_REQUEST_LAST_PAGE) + if (ifault->fault.flags & IOMMU_FAULT_PAGE_REQUEST_LAST_PAGE) { list_add_tail(&ifault->item, &fault->response); - else - kfree(ifault); + } else { + iommufd_fault_timer_teardown(ifault); + iommufd_put_fault(ifault); + } } mutex_unlock(&fault->mutex); @@ -526,6 +587,7 @@ int iommufd_hwpt_page_response(struct iommufd_ucmd *ucmd) { struct iommu_hwpt_page_response *cmd = ucmd->cmd; struct iommu_page_response resp = {}; + struct iommufd_fault *ifault = NULL; struct iommufd_fault *curr, *next; struct iommufd_hw_pagetable *hwpt; struct iommufd_device *idev; @@ -547,6 +609,7 @@ int iommufd_hwpt_page_response(struct iommufd_ucmd *ucmd) if (curr->dev != idev->dev || curr->fault.grpid != cmd->grpid) continue; + ifault = curr; if ((cmd->flags & IOMMU_PGFAULT_FLAGS_PASID_VALID) && cmd->pasid != curr->fault.pasid) break; @@ -555,6 +618,15 @@ int iommufd_hwpt_page_response(struct iommufd_ucmd *ucmd) !(cmd->flags & IOMMU_PGFAULT_FLAGS_PASID_VALID)) break; + /* + * The timer has expired if it was not pending. Leave the + * response to the timer function. + */ + if (!iommufd_fault_timer_teardown(curr)) { + rc = -ETIMEDOUT; + break; + } + resp.version = IOMMU_PAGE_RESP_VERSION_1; resp.pasid = cmd->pasid; resp.grpid = cmd->grpid; @@ -564,11 +636,11 @@ int iommufd_hwpt_page_response(struct iommufd_ucmd *ucmd) rc = iommu_page_response(idev->dev, &resp); list_del_init(&curr->item); - kfree(curr); break; } mutex_unlock(&hwpt->fault->mutex); + iommufd_put_fault(ifault); iommufd_put_object(&idev->obj); out_put_hwpt: iommufd_put_object(&hwpt->obj); -- 2.34.1