>-----Original Message----- >From: Liu, Yi L <yi.l.liu@xxxxxxxxx> >Subject: [PATCH v2 02/12] iommu: Introduce a replace API for device pasid > >Provide a high-level API to allow replacements of one domain with >another for specific pasid of a device. This is similar to >iommu_group_replace_domain() and it is expected to be used only by >IOMMUFD. > >Signed-off-by: Lu Baolu <baolu.lu@xxxxxxxxxxxxxxx> >Signed-off-by: Yi Liu <yi.l.liu@xxxxxxxxx> >--- > drivers/iommu/iommu-priv.h | 3 ++ > drivers/iommu/iommu.c | 92 >+++++++++++++++++++++++++++++++++++--- > 2 files changed, 89 insertions(+), 6 deletions(-) > >diff --git a/drivers/iommu/iommu-priv.h b/drivers/iommu/iommu-priv.h >index 5f731d994803..0949c02cee93 100644 >--- a/drivers/iommu/iommu-priv.h >+++ b/drivers/iommu/iommu-priv.h >@@ -20,6 +20,9 @@ static inline const struct iommu_ops >*dev_iommu_ops(struct device *dev) > int iommu_group_replace_domain(struct iommu_group *group, > struct iommu_domain *new_domain); > >+int iommu_replace_device_pasid(struct iommu_domain *domain, >+ struct device *dev, ioasid_t pasid); >+ > int iommu_device_register_bus(struct iommu_device *iommu, > const struct iommu_ops *ops, > const struct bus_type *bus, >diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c >index 701b02a118db..343683e646e0 100644 >--- a/drivers/iommu/iommu.c >+++ b/drivers/iommu/iommu.c >@@ -3315,14 +3315,15 @@ bool >iommu_group_dma_owner_claimed(struct iommu_group *group) > EXPORT_SYMBOL_GPL(iommu_group_dma_owner_claimed); > > static int __iommu_set_group_pasid(struct iommu_domain *domain, >- struct iommu_group *group, ioasid_t pasid) >+ struct iommu_group *group, ioasid_t pasid, >+ struct iommu_domain *old) > { > struct group_device *device, *last_gdev; > int ret; > > for_each_group_device(group, device) { > ret = domain->ops->set_dev_pasid(domain, device->dev, >- pasid, NULL); >+ pasid, old); > if (ret) > goto err_revert; > } >@@ -3332,11 +3333,34 @@ static int __iommu_set_group_pasid(struct >iommu_domain *domain, > err_revert: > last_gdev = device; > for_each_group_device(group, device) { >- const struct iommu_ops *ops = dev_iommu_ops(device- >>dev); >+ /* >+ * If no old domain, just undo all the devices/pasid that >+ * have attached to the new domain. >+ */ >+ if (!old) { >+ const struct iommu_ops *ops = >+ dev_iommu_ops(device->dev); >+ >+ if (device == last_gdev) Maybe this check can be moved to beginning of the for loop, >+ break; >+ ops = dev_iommu_ops(device->dev); >+ ops->remove_dev_pasid(device->dev, pasid, domain); >+ continue; >+ } > >- if (device == last_gdev) >+ /* >+ * Rollback the devices/pasid that have attached to the new >+ * domain. And it is a driver bug to fail attaching with a >+ * previously good domain. >+ */ >+ if (device == last_gdev) { then this check can be removed. >+ WARN_ON(old->ops->set_dev_pasid(old, device- >>dev, >+ pasid, NULL)); Is this call necessary? last_gdev is the first device failed. Thanks Zhenzhong > break; >- ops->remove_dev_pasid(device->dev, pasid, domain); >+ } >+ >+ WARN_ON(old->ops->set_dev_pasid(old, device->dev, >+ pasid, domain)); > } > return ret; > } >@@ -3395,7 +3419,7 @@ int iommu_attach_device_pasid(struct >iommu_domain *domain, > goto out_unlock; > } > >- ret = __iommu_set_group_pasid(domain, group, pasid); >+ ret = __iommu_set_group_pasid(domain, group, pasid, NULL); > if (ret) > xa_erase(&group->pasid_array, pasid); > out_unlock: >@@ -3404,6 +3428,62 @@ int iommu_attach_device_pasid(struct >iommu_domain *domain, > } > EXPORT_SYMBOL_GPL(iommu_attach_device_pasid); > >+/** >+ * iommu_replace_device_pasid - replace the domain that a pasid is >attached to >+ * @domain: new IOMMU domain to replace with >+ * @dev: the physical device >+ * @pasid: pasid that will be attached to the new domain >+ * >+ * This API allows the pasid to switch domains. Return 0 on success, or an >+ * error. The pasid will roll back to use the old domain if failure. The >+ * caller could call iommu_detach_device_pasid() before free the old >domain >+ * in order to avoid use-after-free case. >+ */ >+int iommu_replace_device_pasid(struct iommu_domain *domain, >+ struct device *dev, ioasid_t pasid) >+{ >+ /* Caller must be a probed driver on dev */ >+ struct iommu_group *group = dev->iommu_group; >+ void *curr; >+ int ret; >+ >+ if (!domain) >+ return -EINVAL; >+ >+ if (!domain->ops->set_dev_pasid) >+ return -EOPNOTSUPP; >+ >+ if (!group) >+ return -ENODEV; >+ >+ if (!dev_has_iommu(dev) || dev_iommu_ops(dev) != domain- >>owner) >+ return -EINVAL; >+ >+ mutex_lock(&group->mutex); >+ curr = xa_store(&group->pasid_array, pasid, domain, GFP_KERNEL); >+ if (!curr) { >+ xa_erase(&group->pasid_array, pasid); >+ ret = -EINVAL; >+ goto out_unlock; >+ } >+ >+ ret = xa_err(curr); >+ if (ret) >+ goto out_unlock; >+ >+ if (curr == domain) >+ goto out_unlock; >+ >+ ret = __iommu_set_group_pasid(domain, group, pasid, curr); >+ if (ret) >+ WARN_ON(xa_err(xa_store(&group->pasid_array, pasid, >+ curr, GFP_KERNEL))); >+out_unlock: >+ mutex_unlock(&group->mutex); >+ return ret; >+} >+EXPORT_SYMBOL_NS_GPL(iommu_replace_device_pasid, >IOMMUFD_INTERNAL); >+ > /* > * iommu_detach_device_pasid() - Detach the domain from pasid of device > * @domain: the iommu domain. >-- >2.34.1