The fault handler will need to find an mm given its PASID. This is the reason we have an IDR for storing address spaces, so hook it up. Signed-off-by: Jean-Philippe Brucker <jean-philippe.brucker@xxxxxxx> --- drivers/iommu/iommu-sva.c | 26 ++++++++++++++++++++++++++ include/linux/iommu.h | 7 +++++++ 2 files changed, 33 insertions(+) diff --git a/drivers/iommu/iommu-sva.c b/drivers/iommu/iommu-sva.c index 5ff8967cb213..ee86f00ee1b9 100644 --- a/drivers/iommu/iommu-sva.c +++ b/drivers/iommu/iommu-sva.c @@ -636,6 +636,32 @@ void iommu_sva_unbind_device_all(struct device *dev) } EXPORT_SYMBOL_GPL(iommu_sva_unbind_device_all); +/** + * iommu_sva_find() - Find mm associated to the given PASID + * @pasid: Process Address Space ID assigned to the mm + * + * Returns the mm corresponding to this PASID, or NULL if not found. A reference + * to the mm is taken, and must be released with mmput(). + */ +struct mm_struct *iommu_sva_find(int pasid) +{ + struct io_mm *io_mm; + struct mm_struct *mm = NULL; + + spin_lock(&iommu_sva_lock); + io_mm = idr_find(&iommu_pasid_idr, pasid); + if (io_mm && io_mm_get_locked(io_mm)) { + if (mmget_not_zero(io_mm->mm)) + mm = io_mm->mm; + + io_mm_put_locked(io_mm); + } + spin_unlock(&iommu_sva_lock); + + return mm; +} +EXPORT_SYMBOL_GPL(iommu_sva_find); + /** * iommu_sva_init_device() - Initialize Shared Virtual Addressing for a device * @dev: the device diff --git a/include/linux/iommu.h b/include/linux/iommu.h index 429f3dc37a35..a457650b80de 100644 --- a/include/linux/iommu.h +++ b/include/linux/iommu.h @@ -987,6 +987,8 @@ extern int __iommu_sva_bind_device(struct device *dev, struct mm_struct *mm, void *drvdata); extern int __iommu_sva_unbind_device(struct device *dev, int pasid); extern void iommu_sva_unbind_device_all(struct device *dev); +extern struct mm_struct *iommu_sva_find(int pasid); + #else /* CONFIG_IOMMU_SVA */ static inline int iommu_sva_init_device(struct device *dev, unsigned long features, @@ -1016,6 +1018,11 @@ static inline int __iommu_sva_unbind_device(struct device *dev, int pasid) static inline void iommu_sva_unbind_device_all(struct device *dev) { } + +static inline struct mm_struct *iommu_sva_find(int pasid) +{ + return NULL; +} #endif /* CONFIG_IOMMU_SVA */ #endif /* __LINUX_IOMMU_H */ -- 2.18.0