On Wed, Dec 20, 2023 at 09:23:32AM +0800, Lu Baolu wrote: > /** > - * iommu_handle_iopf - IO Page Fault handler > - * @fault: fault event > - * @iopf_param: the fault parameter of the device. > + * iommu_report_device_fault() - Report fault event to device driver > + * @dev: the device > + * @evt: fault event data > * > - * Add a fault to the device workqueue, to be handled by mm. > + * Called by IOMMU drivers when a fault is detected, typically in a threaded IRQ > + * handler. When this function fails and the fault is recoverable, it is the > + * caller's responsibility to complete the fault. This patch seems OK for what it does so: Reviewed-by: Jason Gunthorpe <jgg@xxxxxxxxxx> However, this seems like a strange design, surely this function should just call ops->page_response() when it can't enqueue the fault? It is much cleaner that way, so maybe you can take this into a following patch (along with the driver fixes to accomodate. (and perhaps iommu_report_device_fault() should return void too) Also iopf_group_response() should return void (another patch!), nothing can do anything with the failure. This implies that ops->page_response() must also return void - which is consistent with what the drivers do, the failure paths are all integrity validations of the fault and should be WARN_ON'd not return codes. diff --git a/drivers/iommu/io-pgfault.c b/drivers/iommu/io-pgfault.c index 7d11b74e4048e2..2715e24fd64234 100644 --- a/drivers/iommu/io-pgfault.c +++ b/drivers/iommu/io-pgfault.c @@ -39,7 +39,7 @@ static void iopf_put_dev_fault_param(struct iommu_fault_param *fault_param) kfree_rcu(fault_param, rcu); } -void iopf_free_group(struct iopf_group *group) +static void __iopf_free_group(struct iopf_group *group) { struct iopf_fault *iopf, *next; @@ -50,6 +50,11 @@ void iopf_free_group(struct iopf_group *group) /* Pair with iommu_report_device_fault(). */ iopf_put_dev_fault_param(group->fault_param); +} + +void iopf_free_group(struct iopf_group *group) +{ + __iopf_free_group(group); kfree(group); } EXPORT_SYMBOL_GPL(iopf_free_group); @@ -97,14 +102,49 @@ static int report_partial_fault(struct iommu_fault_param *fault_param, return 0; } +static struct iopf_group *iopf_group_alloc(struct iommu_fault_param *iopf_param, + struct iopf_fault *evt, + struct iopf_group *abort_group) +{ + struct iopf_fault *iopf, *next; + struct iopf_group *group; + + group = kzalloc(sizeof(*group), GFP_KERNEL); + if (!group) { + /* + * We always need to construct the group as we need it to abort + * the request at the driver if it cfan't be handled. + */ + group = abort_group; + } + + group->fault_param = iopf_param; + group->last_fault.fault = evt->fault; + INIT_LIST_HEAD(&group->faults); + INIT_LIST_HEAD(&group->pending_node); + list_add(&group->last_fault.list, &group->faults); + + /* See if we have partial faults for this group */ + mutex_lock(&iopf_param->lock); + list_for_each_entry_safe(iopf, next, &iopf_param->partial, list) { + if (iopf->fault.prm.grpid == evt->fault.prm.grpid) + /* Insert *before* the last fault */ + list_move(&iopf->list, &group->faults); + } + list_add(&group->pending_node, &iopf_param->faults); + mutex_unlock(&iopf_param->lock); + + return group; +} + /** * iommu_report_device_fault() - Report fault event to device driver * @dev: the device * @evt: fault event data * * Called by IOMMU drivers when a fault is detected, typically in a threaded IRQ - * handler. When this function fails and the fault is recoverable, it is the - * caller's responsibility to complete the fault. + * handler. If this function fails then ops->page_response() was called to + * complete evt if required. * * This module doesn't handle PCI PASID Stop Marker; IOMMU drivers must discard * them before reporting faults. A PASID Stop Marker (LRW = 0b100) doesn't @@ -143,22 +183,24 @@ int iommu_report_device_fault(struct device *dev, struct iopf_fault *evt) { struct iommu_fault *fault = &evt->fault; struct iommu_fault_param *iopf_param; - struct iopf_fault *iopf, *next; - struct iommu_domain *domain; + struct iopf_group abort_group; struct iopf_group *group; int ret; +/* + remove this too, it is pointless. The driver should only invoke this function on page_req faults. if (fault->type != IOMMU_FAULT_PAGE_REQ) return -EOPNOTSUPP; +*/ iopf_param = iopf_get_dev_fault_param(dev); - if (!iopf_param) + if (WARN_ON(!iopf_param)) return -ENODEV; if (!(fault->prm.flags & IOMMU_FAULT_PAGE_REQUEST_LAST_PAGE)) { ret = report_partial_fault(iopf_param, fault); iopf_put_dev_fault_param(iopf_param); - + /* A request that is not the last does not need to be ack'd */ return ret; } @@ -170,56 +212,34 @@ int iommu_report_device_fault(struct device *dev, struct iopf_fault *evt) * will send a response to the hardware. We need to clean up before * leaving, otherwise partial faults will be stuck. */ - domain = get_domain_for_iopf(dev, fault); - if (!domain) { - ret = -EINVAL; - goto cleanup_partial; - } - - group = kzalloc(sizeof(*group), GFP_KERNEL); - if (!group) { + group = iopf_group_alloc(iopf_param, evt, &abort_group); + if (group == &abort_group) { ret = -ENOMEM; - goto cleanup_partial; + goto err_abort; } - group->fault_param = iopf_param; - group->last_fault.fault = *fault; - INIT_LIST_HEAD(&group->faults); - INIT_LIST_HEAD(&group->pending_node); - group->domain = domain; - list_add(&group->last_fault.list, &group->faults); - - /* See if we have partial faults for this group */ - mutex_lock(&iopf_param->lock); - list_for_each_entry_safe(iopf, next, &iopf_param->partial, list) { - if (iopf->fault.prm.grpid == fault->prm.grpid) - /* Insert *before* the last fault */ - list_move(&iopf->list, &group->faults); + group->domain = get_domain_for_iopf(dev, fault); + if (!group->domain) { + ret = -EINVAL; + goto err_abort; } - list_add(&group->pending_node, &iopf_param->faults); - mutex_unlock(&iopf_param->lock); - ret = domain->iopf_handler(group); - if (ret) { - mutex_lock(&iopf_param->lock); - list_del_init(&group->pending_node); - mutex_unlock(&iopf_param->lock); + /* + * On success iopf_handler must call iopf_group_response() and + * iopf_free_group() + */ + ret = group->domain->iopf_handler(group); + if (ret) + goto err_abort; + return 0; + +err_abort: + iopf_group_response(group, + IOMMU_PAGE_RESP_FAILURE); //?? right code? + if (group == &abort_group) + __iopf_free_group(group); + else iopf_free_group(group); - } - - return ret; - -cleanup_partial: - mutex_lock(&iopf_param->lock); - list_for_each_entry_safe(iopf, next, &iopf_param->partial, list) { - if (iopf->fault.prm.grpid == fault->prm.grpid) { - list_del(&iopf->list); - kfree(iopf); - } - } - mutex_unlock(&iopf_param->lock); - iopf_put_dev_fault_param(iopf_param); - return ret; } EXPORT_SYMBOL_GPL(iommu_report_device_fault); @@ -262,7 +282,7 @@ EXPORT_SYMBOL_GPL(iopf_queue_flush_dev); * * Return 0 on success and <0 on error. */ -int iopf_group_response(struct iopf_group *group, +void iopf_group_response(struct iopf_group *group, enum iommu_page_response_code status) { struct iommu_fault_param *fault_param = group->fault_param; @@ -400,9 +420,9 @@ EXPORT_SYMBOL_GPL(iopf_queue_add_device); */ void iopf_queue_remove_device(struct iopf_queue *queue, struct device *dev) { - struct iopf_fault *iopf, *next; + struct iopf_fault *partial_iopf; + struct iopf_fault *next; struct iopf_group *group, *temp; - struct iommu_page_response resp; struct dev_iommu *param = dev->iommu; struct iommu_fault_param *fault_param; const struct iommu_ops *ops = dev_iommu_ops(dev); @@ -416,15 +436,16 @@ void iopf_queue_remove_device(struct iopf_queue *queue, struct device *dev) goto unlock; mutex_lock(&fault_param->lock); - list_for_each_entry_safe(iopf, next, &fault_param->partial, list) - kfree(iopf); + list_for_each_entry_safe(partial_iopf, next, &fault_param->partial, list) + kfree(partial_iopf); list_for_each_entry_safe(group, temp, &fault_param->faults, pending_node) { - memset(&resp, 0, sizeof(struct iommu_page_response)); - iopf = &group->last_fault; - resp.pasid = iopf->fault.prm.pasid; - resp.grpid = iopf->fault.prm.grpid; - resp.code = IOMMU_PAGE_RESP_INVALID; + struct iopf_fault *iopf = &group->last_fault; + struct iommu_page_response resp = { + .pasid = iopf->fault.prm.pasid, + .grpid = iopf->fault.prm.grpid, + .code = IOMMU_PAGE_RESP_INVALID + }; if (iopf->fault.prm.flags & IOMMU_FAULT_PAGE_RESPONSE_NEEDS_PASID) resp.flags = IOMMU_PAGE_RESP_PASID_VALID;