Move the driver init and destruction code into two logically paired functions. There is a subtle ordering dependency in how the group's domains are freed, the current code does the kobject_put() on the group which will hopefully trigger the free of the domains before the module_put() that protects the domain->ops. Reorganize this to be explicit and documented. The domains are cleaned up by iommu_deinit_device() if it is the last device to be deinit'd from the group. This must be done in a specific order - after ops->release_device() and before the module_put(). Make it very clear and obvious by putting the order directly in one function. Leave WARN_ON's in case the refcounting gets messed up somehow. This also moves the module_put() and dev_iommu_free() under the group->mutex to keep the code simple. Building paired functions like this helps ensure that error cleanup flows in __iommu_probe_device() are correct because they share the same code that handles the normal flow. These details become relavent as following patches add more error unwind into __iommu_probe_device(), and ultimately a following series adds fine-grained locking to __iommu_probe_device(). Reviewed-by: Kevin Tian <kevin.tian@xxxxxxxxx> Reviewed-by: Lu Baolu <baolu.lu@xxxxxxxxxxxxxxx> Signed-off-by: Jason Gunthorpe <jgg@xxxxxxxxxx> --- drivers/iommu/iommu.c | 191 +++++++++++++++++++++++++----------------- 1 file changed, 112 insertions(+), 79 deletions(-) diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c index 456c2d2934896c..7e8f5edcff2145 100644 --- a/drivers/iommu/iommu.c +++ b/drivers/iommu/iommu.c @@ -332,10 +332,99 @@ static u32 dev_iommu_get_max_pasids(struct device *dev) return min_t(u32, max_pasids, dev->iommu->iommu_dev->max_pasids); } +/* + * Init the dev->iommu and dev->iommu_group in the struct device and get the + * driver probed + */ +static int iommu_init_device(struct device *dev, const struct iommu_ops *ops) +{ + struct iommu_device *iommu_dev; + struct iommu_group *group; + int ret; + + if (!dev_iommu_get(dev)) + return -ENOMEM; + + if (!try_module_get(ops->owner)) { + ret = -EINVAL; + goto err_free; + } + + iommu_dev = ops->probe_device(dev); + if (IS_ERR(iommu_dev)) { + ret = PTR_ERR(iommu_dev); + goto err_module_put; + } + + group = ops->device_group(dev); + if (WARN_ON_ONCE(group == NULL)) + group = ERR_PTR(-EINVAL); + if (IS_ERR(group)) { + ret = PTR_ERR(group); + goto err_release; + } + dev->iommu_group = group; + + dev->iommu->iommu_dev = iommu_dev; + dev->iommu->max_pasids = dev_iommu_get_max_pasids(dev); + if (ops->is_attach_deferred) + dev->iommu->attach_deferred = ops->is_attach_deferred(dev); + return 0; + +err_release: + if (ops->release_device) + ops->release_device(dev); +err_module_put: + module_put(ops->owner); +err_free: + dev_iommu_free(dev); + return ret; +} + +static void iommu_deinit_device(struct device *dev) +{ + struct iommu_group *group = dev->iommu_group; + const struct iommu_ops *ops = dev_iommu_ops(dev); + + lockdep_assert_held(&group->mutex); + + /* + * release_device() must stop using any attached domain on the device. + * If there are still other devices in the group they are not effected + * by this callback. + * + * The IOMMU driver must set the device to either an identity or + * blocking translation and stop using any domain pointer, as it is + * going to be freed. + */ + if (ops->release_device) + ops->release_device(dev); + + /* + * If this is the last driver to use the group then we must free the + * domains before we do the module_put(). + */ + if (list_empty(&group->devices)) { + if (group->default_domain) { + iommu_domain_free(group->default_domain); + group->default_domain = NULL; + } + if (group->blocking_domain) { + iommu_domain_free(group->blocking_domain); + group->blocking_domain = NULL; + } + group->domain = NULL; + } + + /* Caller must put iommu_group */ + dev->iommu_group = NULL; + module_put(ops->owner); + dev_iommu_free(dev); +} + static int __iommu_probe_device(struct device *dev, struct list_head *group_list) { const struct iommu_ops *ops = dev->bus->iommu_ops; - struct iommu_device *iommu_dev; struct iommu_group *group; static DEFINE_MUTEX(iommu_probe_device_lock); int ret; @@ -357,62 +446,30 @@ static int __iommu_probe_device(struct device *dev, struct list_head *group_list goto out_unlock; } - if (!dev_iommu_get(dev)) { - ret = -ENOMEM; + ret = iommu_init_device(dev, ops); + if (ret) goto out_unlock; - } - - if (!try_module_get(ops->owner)) { - ret = -EINVAL; - goto err_free; - } - - iommu_dev = ops->probe_device(dev); - if (IS_ERR(iommu_dev)) { - ret = PTR_ERR(iommu_dev); - goto out_module_put; - } - - dev->iommu->iommu_dev = iommu_dev; - dev->iommu->max_pasids = dev_iommu_get_max_pasids(dev); - if (ops->is_attach_deferred) - dev->iommu->attach_deferred = ops->is_attach_deferred(dev); - - group = ops->device_group(dev); - if (WARN_ON_ONCE(group == NULL)) - group = ERR_PTR(-EINVAL); - if (IS_ERR(group)) { - ret = PTR_ERR(group); - goto out_release; - } + group = dev->iommu_group; ret = iommu_group_add_device(group, dev); + mutex_lock(&group->mutex); if (ret) goto err_put_group; - mutex_lock(&group->mutex); if (group_list && !group->default_domain && list_empty(&group->entry)) list_add_tail(&group->entry, group_list); mutex_unlock(&group->mutex); iommu_group_put(group); mutex_unlock(&iommu_probe_device_lock); - iommu_device_link(iommu_dev, dev); + iommu_device_link(dev->iommu->iommu_dev, dev); return 0; err_put_group: + iommu_deinit_device(dev); + mutex_unlock(&group->mutex); iommu_group_put(group); -out_release: - if (ops->release_device) - ops->release_device(dev); - -out_module_put: - module_put(ops->owner); - -err_free: - dev_iommu_free(dev); - out_unlock: mutex_unlock(&iommu_probe_device_lock); @@ -491,63 +548,45 @@ static void __iommu_group_free_device(struct iommu_group *group, kfree(grp_dev->name); kfree(grp_dev); - dev->iommu_group = NULL; } -/* - * Remove the iommu_group from the struct device. The attached group must be put - * by the caller after releaseing the group->mutex. - */ +/* Remove the iommu_group from the struct device. */ static void __iommu_group_remove_device(struct device *dev) { struct iommu_group *group = dev->iommu_group; struct group_device *device; - lockdep_assert_held(&group->mutex); + mutex_lock(&group->mutex); for_each_group_device(group, device) { if (device->dev != dev) continue; list_del(&device->list); __iommu_group_free_device(group, device); - /* Caller must put iommu_group */ - return; + if (dev->iommu && dev->iommu->iommu_dev) + iommu_deinit_device(dev); + else + dev->iommu_group = NULL; + goto out; } WARN(true, "Corrupted iommu_group device_list"); +out: + mutex_unlock(&group->mutex); + + /* Pairs with the get in iommu_group_add_device() */ + iommu_group_put(group); } static void iommu_release_device(struct device *dev) { struct iommu_group *group = dev->iommu_group; - const struct iommu_ops *ops; if (!dev->iommu || !group) return; iommu_device_unlink(dev->iommu->iommu_dev, dev); - mutex_lock(&group->mutex); __iommu_group_remove_device(dev); - - /* - * release_device() must stop using any attached domain on the device. - * If there are still other devices in the group they are not effected - * by this callback. - * - * The IOMMU driver must set the device to either an identity or - * blocking translation and stop using any domain pointer, as it is - * going to be freed. - */ - ops = dev_iommu_ops(dev); - if (ops->release_device) - ops->release_device(dev); - mutex_unlock(&group->mutex); - - /* Pairs with the get in iommu_group_add_device() */ - iommu_group_put(group); - - module_put(ops->owner); - dev_iommu_free(dev); } static int __init iommu_set_def_domain_type(char *str) @@ -808,10 +847,9 @@ static void iommu_group_release(struct kobject *kobj) ida_free(&iommu_group_ida, group->id); - if (group->default_domain) - iommu_domain_free(group->default_domain); - if (group->blocking_domain) - iommu_domain_free(group->blocking_domain); + /* Domains are free'd by iommu_deinit_device() */ + WARN_ON(group->default_domain); + WARN_ON(group->blocking_domain); kfree(group->name); kfree(group); @@ -1109,12 +1147,7 @@ void iommu_group_remove_device(struct device *dev) dev_info(dev, "Removing from iommu group %d\n", group->id); - mutex_lock(&group->mutex); __iommu_group_remove_device(dev); - mutex_unlock(&group->mutex); - - /* Pairs with the get in iommu_group_add_device() */ - iommu_group_put(group); } EXPORT_SYMBOL_GPL(iommu_group_remove_device); -- 2.40.1