On Thu, Dec 07, 2023 at 02:43:08PM +0800, Lu Baolu wrote: > +/* > + * Return the fault parameter of a device if it exists. Otherwise, return NULL. > + * On a successful return, the caller takes a reference of this parameter and > + * should put it after use by calling iopf_put_dev_fault_param(). > + */ > +static struct iommu_fault_param *iopf_get_dev_fault_param(struct device *dev) > +{ > + struct dev_iommu *param = dev->iommu; > + struct iommu_fault_param *fault_param; > + > + if (!param) > + return NULL; Is it actually possible to call this function on a device that does not have an iommu driver probed? I'd be surprised by that, maybe this should be WARN_ONE > + > + rcu_read_lock(); > + fault_param = param->fault_param; The RCU stuff is not right, like this: diff --git a/drivers/iommu/io-pgfault.c b/drivers/iommu/io-pgfault.c index 2ace32c6d13bf3..0258f79c8ddf98 100644 --- a/drivers/iommu/io-pgfault.c +++ b/drivers/iommu/io-pgfault.c @@ -40,7 +40,7 @@ static struct iommu_fault_param *iopf_get_dev_fault_param(struct device *dev) return NULL; rcu_read_lock(); - fault_param = param->fault_param; + fault_param = rcu_dereference(param->fault_param); if (fault_param && !refcount_inc_not_zero(&fault_param->users)) fault_param = NULL; rcu_read_unlock(); @@ -51,17 +51,8 @@ static struct iommu_fault_param *iopf_get_dev_fault_param(struct device *dev) /* Caller must hold a reference of the fault parameter. */ static void iopf_put_dev_fault_param(struct iommu_fault_param *fault_param) { - struct dev_iommu *param = fault_param->dev->iommu; - - rcu_read_lock(); - if (!refcount_dec_and_test(&fault_param->users)) { - rcu_read_unlock(); - return; - } - rcu_read_unlock(); - - param->fault_param = NULL; - kfree_rcu(fault_param, rcu); + if (refcount_dec_and_test(&fault_param->users)) + kfree_rcu(fault_param, rcu); } /** @@ -174,7 +165,7 @@ static int iommu_handle_iopf(struct iommu_fault *fault, } mutex_unlock(&iopf_param->lock); - ret = domain->iopf_handler(group); + ret = domain->iopf_handler(iopf_param, group); mutex_lock(&iopf_param->lock); if (ret) iopf_free_group(group); @@ -398,7 +389,8 @@ int iopf_queue_add_device(struct iopf_queue *queue, struct device *dev) mutex_lock(&queue->lock); mutex_lock(¶m->lock); - if (param->fault_param) { + if (rcu_dereference_check(param->fault_param, + lockdep_is_held(¶m->lock))) { ret = -EBUSY; goto done_unlock; } @@ -418,7 +410,7 @@ int iopf_queue_add_device(struct iopf_queue *queue, struct device *dev) list_add(&fault_param->queue_list, &queue->devices); fault_param->queue = queue; - param->fault_param = fault_param; + rcu_assign_pointer(param->fault_param, fault_param); done_unlock: mutex_unlock(¶m->lock); @@ -442,10 +434,12 @@ int iopf_queue_remove_device(struct iopf_queue *queue, struct device *dev) int ret = 0; struct iopf_fault *iopf, *next; struct dev_iommu *param = dev->iommu; - struct iommu_fault_param *fault_param = param->fault_param; + struct iommu_fault_param *fault_param; mutex_lock(&queue->lock); mutex_lock(¶m->lock); + fault_param = rcu_dereference_check(param->fault_param, + lockdep_is_held(¶m->lock)); if (!fault_param) { ret = -ENODEV; goto unlock; @@ -467,7 +461,10 @@ int iopf_queue_remove_device(struct iopf_queue *queue, struct device *dev) list_for_each_entry_safe(iopf, next, &fault_param->partial, list) kfree(iopf); - iopf_put_dev_fault_param(fault_param); + /* dec the ref owned by iopf_queue_add_device() */ + rcu_assign_pointer(param->fault_param, NULL); + if (refcount_dec_and_test(&fault_param->users)) + kfree_rcu(fault_param, rcu); unlock: mutex_unlock(¶m->lock); mutex_unlock(&queue->lock); diff --git a/drivers/iommu/iommu-sva.c b/drivers/iommu/iommu-sva.c index 325d1810e133a1..63c1a233a7e91f 100644 --- a/drivers/iommu/iommu-sva.c +++ b/drivers/iommu/iommu-sva.c @@ -232,10 +232,9 @@ static void iommu_sva_handle_iopf(struct work_struct *work) iopf_free_group(group); } -static int iommu_sva_iopf_handler(struct iopf_group *group) +static int iommu_sva_iopf_handler(struct iommu_fault_param *fault_param, + struct iopf_group *group) { - struct iommu_fault_param *fault_param = group->dev->iommu->fault_param; - INIT_WORK(&group->work, iommu_sva_handle_iopf); if (!queue_work(fault_param->queue->wq, &group->work)) return -EBUSY; diff --git a/include/linux/iommu.h b/include/linux/iommu.h index 8020bb44a64ab1..e16fa9811d5023 100644 --- a/include/linux/iommu.h +++ b/include/linux/iommu.h @@ -41,6 +41,7 @@ struct iommu_dirty_ops; struct notifier_block; struct iommu_sva; struct iommu_dma_cookie; +struct iommu_fault_param; #define IOMMU_FAULT_PERM_READ (1 << 0) /* read */ #define IOMMU_FAULT_PERM_WRITE (1 << 1) /* write */ @@ -210,7 +211,8 @@ struct iommu_domain { unsigned long pgsize_bitmap; /* Bitmap of page sizes in use */ struct iommu_domain_geometry geometry; struct iommu_dma_cookie *iova_cookie; - int (*iopf_handler)(struct iopf_group *group); + int (*iopf_handler)(struct iommu_fault_param *fault_param, + struct iopf_group *group); void *fault_data; union { struct { @@ -637,7 +639,7 @@ struct iommu_fault_param { */ struct dev_iommu { struct mutex lock; - struct iommu_fault_param *fault_param; + struct iommu_fault_param __rcu *fault_param; struct iommu_fwspec *fwspec; struct iommu_device *iommu_dev; void *priv;