The per-device fault data structure stores information about faults occurring on a device. Its lifetime spans from IOPF enablement to disablement. Multiple paths, including IOPF reporting, handling, and responding, may access it concurrently. Previously, a mutex protected the fault data from use after free. But this is not performance friendly due to the critical nature of IOPF handling paths. Refine this with a refcount-based approach. The fault data pointer is obtained within an RCU read region with a refcount. The fault data pointer is returned for usage only when the pointer is valid and a refcount is successfully obtained. The fault data is freed with kfree_rcu(), ensuring data is only freed after all RCU critical regions complete. Suggested-by: Jason Gunthorpe <jgg@xxxxxxxxxx> Signed-off-by: Lu Baolu <baolu.lu@xxxxxxxxxxxxxxx> Tested-by: Yan Zhao <yan.y.zhao@xxxxxxxxx> --- include/linux/iommu.h | 4 ++ drivers/iommu/io-pgfault.c | 81 +++++++++++++++++++++++++------------- 2 files changed, 57 insertions(+), 28 deletions(-) diff --git a/include/linux/iommu.h b/include/linux/iommu.h index 63df77cc0b61..8020bb44a64a 100644 --- a/include/linux/iommu.h +++ b/include/linux/iommu.h @@ -597,6 +597,8 @@ struct iommu_device { /** * struct iommu_fault_param - per-device IOMMU fault data * @lock: protect pending faults list + * @users: user counter to manage the lifetime of the data + * @ruc: rcu head for kfree_rcu() * @dev: the device that owns this param * @queue: IOPF queue * @queue_list: index into queue->devices @@ -606,6 +608,8 @@ struct iommu_device { */ struct iommu_fault_param { struct mutex lock; + refcount_t users; + struct rcu_head rcu; struct device *dev; struct iopf_queue *queue; diff --git a/drivers/iommu/io-pgfault.c b/drivers/iommu/io-pgfault.c index 9439eaf54928..2ace32c6d13b 100644 --- a/drivers/iommu/io-pgfault.c +++ b/drivers/iommu/io-pgfault.c @@ -26,6 +26,44 @@ void iopf_free_group(struct iopf_group *group) } EXPORT_SYMBOL_GPL(iopf_free_group); +/* + * Return the fault parameter of a device if it exists. Otherwise, return NULL. + * On a successful return, the caller takes a reference of this parameter and + * should put it after use by calling iopf_put_dev_fault_param(). + */ +static struct iommu_fault_param *iopf_get_dev_fault_param(struct device *dev) +{ + struct dev_iommu *param = dev->iommu; + struct iommu_fault_param *fault_param; + + if (!param) + return NULL; + + rcu_read_lock(); + fault_param = param->fault_param; + if (fault_param && !refcount_inc_not_zero(&fault_param->users)) + fault_param = NULL; + rcu_read_unlock(); + + return fault_param; +} + +/* Caller must hold a reference of the fault parameter. */ +static void iopf_put_dev_fault_param(struct iommu_fault_param *fault_param) +{ + struct dev_iommu *param = fault_param->dev->iommu; + + rcu_read_lock(); + if (!refcount_dec_and_test(&fault_param->users)) { + rcu_read_unlock(); + return; + } + rcu_read_unlock(); + + param->fault_param = NULL; + kfree_rcu(fault_param, rcu); +} + /** * iommu_handle_iopf - IO Page Fault handler * @fault: fault event @@ -167,15 +205,11 @@ int iommu_report_device_fault(struct device *dev, struct iopf_fault *evt) { struct iommu_fault_param *fault_param; struct iopf_fault *evt_pending = NULL; - struct dev_iommu *param = dev->iommu; int ret = 0; - mutex_lock(¶m->lock); - fault_param = param->fault_param; - if (!fault_param) { - mutex_unlock(¶m->lock); + fault_param = iopf_get_dev_fault_param(dev); + if (!fault_param) return -EINVAL; - } mutex_lock(&fault_param->lock); if (evt->fault.type == IOMMU_FAULT_PAGE_REQ && @@ -196,7 +230,7 @@ int iommu_report_device_fault(struct device *dev, struct iopf_fault *evt) } done_unlock: mutex_unlock(&fault_param->lock); - mutex_unlock(¶m->lock); + iopf_put_dev_fault_param(fault_param); return ret; } @@ -209,7 +243,6 @@ int iommu_page_response(struct device *dev, int ret = -EINVAL; struct iopf_fault *evt; struct iommu_fault_page_request *prm; - struct dev_iommu *param = dev->iommu; struct iommu_fault_param *fault_param; const struct iommu_ops *ops = dev_iommu_ops(dev); bool has_pasid = msg->flags & IOMMU_PAGE_RESP_PASID_VALID; @@ -217,12 +250,9 @@ int iommu_page_response(struct device *dev, if (!ops->page_response) return -ENODEV; - mutex_lock(¶m->lock); - fault_param = param->fault_param; - if (!fault_param) { - mutex_unlock(¶m->lock); + fault_param = iopf_get_dev_fault_param(dev); + if (!fault_param) return -EINVAL; - } /* Only send response if there is a fault report pending */ mutex_lock(&fault_param->lock); @@ -263,7 +293,8 @@ int iommu_page_response(struct device *dev, done_unlock: mutex_unlock(&fault_param->lock); - mutex_unlock(¶m->lock); + iopf_put_dev_fault_param(fault_param); + return ret; } EXPORT_SYMBOL_GPL(iommu_page_response); @@ -282,22 +313,15 @@ EXPORT_SYMBOL_GPL(iommu_page_response); */ int iopf_queue_flush_dev(struct device *dev) { - int ret = 0; - struct iommu_fault_param *iopf_param; - struct dev_iommu *param = dev->iommu; + struct iommu_fault_param *iopf_param = iopf_get_dev_fault_param(dev); - if (!param) + if (!iopf_param) return -ENODEV; - mutex_lock(¶m->lock); - iopf_param = param->fault_param; - if (iopf_param) - flush_workqueue(iopf_param->queue->wq); - else - ret = -ENODEV; - mutex_unlock(¶m->lock); + flush_workqueue(iopf_param->queue->wq); + iopf_put_dev_fault_param(iopf_param); - return ret; + return 0; } EXPORT_SYMBOL_GPL(iopf_queue_flush_dev); @@ -389,6 +413,8 @@ int iopf_queue_add_device(struct iopf_queue *queue, struct device *dev) INIT_LIST_HEAD(&fault_param->faults); INIT_LIST_HEAD(&fault_param->partial); fault_param->dev = dev; + refcount_set(&fault_param->users, 1); + init_rcu_head(&fault_param->rcu); list_add(&fault_param->queue_list, &queue->devices); fault_param->queue = queue; @@ -441,8 +467,7 @@ int iopf_queue_remove_device(struct iopf_queue *queue, struct device *dev) list_for_each_entry_safe(iopf, next, &fault_param->partial, list) kfree(iopf); - param->fault_param = NULL; - kfree(fault_param); + iopf_put_dev_fault_param(fault_param); unlock: mutex_unlock(¶m->lock); mutex_unlock(&queue->lock); -- 2.34.1