Currently, when attaching a domain to a device or its PASID, domain is stored within the iommu group. It could be retrieved for use during the window between attachment and detachment. With new features introduced, there's a need to store more information than just a domain pointer. This information essentially represents the association between a domain and a device. For example, the SVA code already has a custom struct iommu_sva which represents a bond between sva domain and a PASID of a device. Looking forward, the IOMMUFD needs a place to store the iommufd_device pointer in the core, so that the device object ID could be quickly retrieved in the critical fault handling path. Introduce domain attachment handle that explicitly represents the attachment relationship between a domain and a device or its PASID. Caller-specific data fields can be added later to store additional information beyond a domain pointer, depending on its specific use case. Co-developed-by: Jason Gunthorpe <jgg@xxxxxxxxxx> Signed-off-by: Jason Gunthorpe <jgg@xxxxxxxxxx> Signed-off-by: Lu Baolu <baolu.lu@xxxxxxxxxxxxxxx> --- drivers/iommu/iommu-priv.h | 6 ++ drivers/iommu/iommu.c | 156 ++++++++++++++++++++++++++++++++----- 2 files changed, 142 insertions(+), 20 deletions(-) diff --git a/drivers/iommu/iommu-priv.h b/drivers/iommu/iommu-priv.h index 5f731d994803..da1addaa1a31 100644 --- a/drivers/iommu/iommu-priv.h +++ b/drivers/iommu/iommu-priv.h @@ -28,4 +28,10 @@ void iommu_device_unregister_bus(struct iommu_device *iommu, const struct bus_type *bus, struct notifier_block *nb); +struct iommu_attach_handle { + struct iommu_domain *domain; +}; + +struct iommu_attach_handle * +iommu_attach_handle_get(struct iommu_group *group, ioasid_t pasid, unsigned int type); #endif /* __LINUX_IOMMU_PRIV_H */ diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c index a95a483def2d..0cdd58e69e20 100644 --- a/drivers/iommu/iommu.c +++ b/drivers/iommu/iommu.c @@ -2039,6 +2039,89 @@ void iommu_domain_free(struct iommu_domain *domain) } EXPORT_SYMBOL_GPL(iommu_domain_free); +/* Add an attach handle to the group's pasid array. */ +static struct iommu_attach_handle * +iommu_attach_handle_set(struct iommu_domain *domain, + struct iommu_group *group, ioasid_t pasid) +{ + struct iommu_attach_handle *handle; + void *curr; + + handle = kzalloc(sizeof(*handle), GFP_KERNEL); + if (!handle) + return ERR_PTR(-ENOMEM); + + handle->domain = domain; + curr = xa_cmpxchg(&group->pasid_array, pasid, NULL, handle, GFP_KERNEL); + if (curr) { + kfree(handle); + return xa_err(curr) ? curr : ERR_PTR(-EBUSY); + } + + return handle; +} + +/* Remove the attach handle stored in group's pasid array. */ +static void iommu_attach_handle_remove(struct iommu_group *group, ioasid_t pasid) +{ + struct iommu_attach_handle *handle; + + handle = xa_erase(&group->pasid_array, pasid); + kfree(handle); +} + +static struct iommu_attach_handle * +iommu_attach_handle_replace(struct iommu_domain *domain, + struct iommu_group *group, ioasid_t pasid) +{ + struct iommu_attach_handle *handle, *curr; + + handle = kzalloc(sizeof(*handle), GFP_KERNEL); + if (!handle) + return ERR_PTR(-ENOMEM); + + handle->domain = domain; + curr = xa_store(&group->pasid_array, pasid, handle, GFP_KERNEL); + if (xa_err(curr)) { + kfree(handle); + return curr; + } + kfree(curr); + + return handle; +} + +/* + * iommu_attach_handle_get - Return the attach handle + * @group: the iommu group that domain was attached to + * @pasid: the pasid within the group + * @type: matched domain type, 0 for any match + * + * Return handle or ERR_PTR(-ENOENT) on none, ERR_PTR(-EBUSY) on mismatch. + * + * Return the attach handle to the caller. The life cycle of an iommu attach + * handle is from the time when the domain is attached to the time when the + * domain is detached. Callers are required to synchronize the call of + * iommu_attach_handle_get() with domain attachment and detachment. The attach + * handle can only be used during its life cycle. + */ +struct iommu_attach_handle * +iommu_attach_handle_get(struct iommu_group *group, ioasid_t pasid, unsigned int type) +{ + struct iommu_attach_handle *handle; + + xa_lock(&group->pasid_array); + handle = xa_load(&group->pasid_array, pasid); + if (!handle) + handle = ERR_PTR(-ENOENT); + else if (type && handle->domain->type != type) + handle = ERR_PTR(-EBUSY); + xa_unlock(&group->pasid_array); + + return handle; +} +EXPORT_SYMBOL_NS_GPL(iommu_attach_handle_get, IOMMUFD_INTERNAL); + /* * Put the group's domain back to the appropriate core-owned domain - either the * standard kernel-mode DMA configuration or an all-DMA-blocked domain. @@ -2187,12 +2270,25 @@ static int __iommu_attach_group(struct iommu_domain *domain, */ int iommu_attach_group(struct iommu_domain *domain, struct iommu_group *group) { + struct iommu_attach_handle *handle; int ret; mutex_lock(&group->mutex); + handle = iommu_attach_handle_set(domain, group, IOMMU_NO_PASID); + if (IS_ERR(handle)) { + ret = PTR_ERR(handle); + goto out_unlock; + } ret = __iommu_attach_group(domain, group); + if (ret) + goto out_remove_handle; mutex_unlock(&group->mutex); + return 0; +out_remove_handle: + iommu_attach_handle_remove(group, IOMMU_NO_PASID); +out_unlock: + mutex_unlock(&group->mutex); return ret; } EXPORT_SYMBOL_GPL(iommu_attach_group); @@ -2211,13 +2307,33 @@ EXPORT_SYMBOL_GPL(iommu_attach_group); int iommu_group_replace_domain(struct iommu_group *group, struct iommu_domain *new_domain) { + struct iommu_domain *old_domain = group->domain; + struct iommu_attach_handle *handle; int ret; if (!new_domain) return -EINVAL; + if (new_domain == old_domain) + return 0; + mutex_lock(&group->mutex); ret = __iommu_group_set_domain(group, new_domain); + if (ret) + goto out_unlock; + + handle = iommu_attach_handle_replace(new_domain, group, IOMMU_NO_PASID); + if (IS_ERR(handle)) { + ret = PTR_ERR(handle); + goto out_old_domain; + } + mutex_unlock(&group->mutex); + + return 0; + +out_old_domain: + __iommu_group_set_domain(group, old_domain); +out_unlock: mutex_unlock(&group->mutex); return ret; } @@ -2352,6 +2468,7 @@ void iommu_detach_group(struct iommu_domain *domain, struct iommu_group *group) { mutex_lock(&group->mutex); __iommu_group_set_core_domain(group); + iommu_attach_handle_remove(group, IOMMU_NO_PASID); mutex_unlock(&group->mutex); } EXPORT_SYMBOL_GPL(iommu_detach_group); @@ -3354,8 +3471,8 @@ int iommu_attach_device_pasid(struct iommu_domain *domain, { /* Caller must be a probed driver on dev */ struct iommu_group *group = dev->iommu_group; + struct iommu_attach_handle *handle; struct group_device *device; - void *curr; int ret; if (!domain->ops->set_dev_pasid) @@ -3376,17 +3493,22 @@ int iommu_attach_device_pasid(struct iommu_domain *domain, } } - curr = xa_cmpxchg(&group->pasid_array, pasid, NULL, domain, GFP_KERNEL); - if (curr) { - ret = xa_err(curr) ? : -EBUSY; + handle = iommu_attach_handle_set(domain, group, pasid); + if (IS_ERR(handle)) { + ret = PTR_ERR(handle); goto out_unlock; } ret = __iommu_set_group_pasid(domain, group, pasid); - if (ret) { - __iommu_remove_group_pasid(group, pasid); - xa_erase(&group->pasid_array, pasid); - } + if (ret) + goto out_put_handle; + mutex_unlock(&group->mutex); + + return 0; + +out_put_handle: + __iommu_remove_group_pasid(group, pasid); + iommu_attach_handle_remove(group, pasid); out_unlock: mutex_unlock(&group->mutex); return ret; @@ -3410,7 +3532,7 @@ void iommu_detach_device_pasid(struct iommu_domain *domain, struct device *dev, mutex_lock(&group->mutex); __iommu_remove_group_pasid(group, pasid); - WARN_ON(xa_erase(&group->pasid_array, pasid) != domain); + iommu_attach_handle_remove(group, pasid); mutex_unlock(&group->mutex); } EXPORT_SYMBOL_GPL(iommu_detach_device_pasid); @@ -3433,20 +3555,14 @@ struct iommu_domain *iommu_get_domain_for_dev_pasid(struct device *dev, ioasid_t pasid, unsigned int type) { - /* Caller must be a probed driver on dev */ struct iommu_group *group = dev->iommu_group; - struct iommu_domain *domain; + struct iommu_attach_handle *handle; - if (!group) - return NULL; + handle = iommu_attach_handle_get(group, pasid, type); + if (IS_ERR(handle)) + return (void *)handle; - xa_lock(&group->pasid_array); - domain = xa_load(&group->pasid_array, pasid); - if (type && domain && domain->type != type) - domain = ERR_PTR(-EBUSY); - xa_unlock(&group->pasid_array); - - return domain; + return handle->domain; } EXPORT_SYMBOL_GPL(iommu_get_domain_for_dev_pasid); -- 2.34.1