This series depends on "deliver page faults to user space" patch-set:
https://lore.kernel.org/linux-iommu/20230928042734.16134-1-baolu.lu@xxxxxxxxxxxxxxx/
Signed-off-by: Huang Jiaqing <jiaqing.huang@xxxxxxxxx>
---
drivers/iommu/io-pgfault.c | 104 ++++++++++++++++++++++++++++++++++++-
include/linux/iommu.h | 16 ++++++
2 files changed, 118 insertions(+), 2 deletions(-)
diff --git a/drivers/iommu/io-pgfault.c b/drivers/iommu/io-pgfault.c
index 1dbacc4fdf72..68e85dc6b1b6 100644
--- a/drivers/iommu/io-pgfault.c
+++ b/drivers/iommu/io-pgfault.c
@@ -7,6 +7,7 @@
#include <linux/iommu.h>
#include <linux/list.h>
+#include <linux/pci.h>
#include <linux/sched/mm.h>
#include <linux/slab.h>
#include <linux/workqueue.h>
@@ -392,6 +393,55 @@ int iopf_queue_discard_partial(struct iopf_queue *queue)
}
EXPORT_SYMBOL_GPL(iopf_queue_discard_partial);
+static int iopf_queue_pci_rbtree_insert(struct iopf_queue *queue, struct pci_dev *pdev)
+{
+ int ret;
+ struct rb_node **new, *parent = NULL;
+ struct iommu_fault_param *iopf_param = iopf_get_dev_fault_param(&pdev->dev);
+
+ if (!iopf_param)
+ return -ENODEV;
+
+ down_write(&queue->pci_dev_sem);
+ new = &(queue->pci_dev_rbtree.rb_node);
+ while (*new) {
+ struct iommu_fault_param *this = container_of(*new, struct iommu_fault_param, node);
+ struct pci_dev *this_pdev = to_pci_dev(this->dev);
+ s16 result = RB_NODE_CMP(pdev->bus->number, pdev->devfn, this_pdev->bus->number, this_pdev->devfn);
+
+ parent = *new;
+ if (result < 0)
+ new = &((*new)->rb_left);
+ else if (result > 0)
+ new = &((*new)->rb_right);
+ else {
+ ret = -EEXIST;
+ goto err_unlock;
+ }
+ }
+
+ rb_link_node(&iopf_param->node, parent, new);
+ rb_insert_color(&iopf_param->node, &queue->pci_dev_rbtree);
+
+ up_write(&queue->pci_dev_sem);
+ return 0;
+err_unlock:
+ up_write(&queue->pci_dev_sem);
+ iopf_put_dev_fault_param(iopf_param);
+ return ret;
+}
+
+/* Caller must have inserted iopf_param by calling iopf_queue_pci_rbtree_insert() */
+static void iopf_queue_pci_rbtree_remove(struct iopf_queue *queue, struct iommu_fault_param *iopf_param)
+{
+ down_write(&queue->pci_dev_sem);
+ rb_erase(&iopf_param->node, &queue->pci_dev_rbtree);
+ up_write(&queue->pci_dev_sem);
+
+ /* paired with iopf_queue_pci_rbtree_insert() */
+ iopf_put_dev_fault_param(iopf_param);
+}
+
/**
* iopf_queue_add_device - Add producer to the fault queue
* @queue: IOPF queue
@@ -434,7 +484,13 @@ int iopf_queue_add_device(struct iopf_queue *queue, struct device *dev)
mutex_unlock(¶m->lock);
mutex_unlock(&queue->lock);
- return ret;
+ if (ret)
+ return ret;
+
+ if (dev_is_pci(dev))
+ return iopf_queue_pci_rbtree_insert(queue, to_pci_dev(dev));
+
+ return 0;
}
EXPORT_SYMBOL_GPL(iopf_queue_add_device);
@@ -486,7 +542,13 @@ int iopf_queue_remove_device(struct iopf_queue *queue, struct device *dev)
mutex_unlock(¶m->lock);
mutex_unlock(&queue->lock);
- return ret;
+ if (ret)
+ return ret;
+
+ if (dev_is_pci(dev))
+ iopf_queue_pci_rbtree_remove(queue, fault_param);
+
+ return 0;
}
EXPORT_SYMBOL_GPL(iopf_queue_remove_device);
@@ -519,6 +581,9 @@ struct iopf_queue *iopf_queue_alloc(const char *name)
INIT_LIST_HEAD(&queue->devices);
mutex_init(&queue->lock);
+ queue->pci_dev_rbtree = RB_ROOT;
+ init_rwsem(&queue->pci_dev_sem);
+
return queue;
}
EXPORT_SYMBOL_GPL(iopf_queue_alloc);
@@ -544,3 +609,38 @@ void iopf_queue_free(struct iopf_queue *queue)
kfree(queue);
}
EXPORT_SYMBOL_GPL(iopf_queue_free);
+
+/**
+ * iopf_queue_find_pdev - Lookup pci device in iopf_queue rbtree
+ * @queue: IOPF queue
+ * @bus: bus number of pci device to lookup
+ * @devfn: devfn of pci device to lookup
+ *
+ * Return: the pci device on success and NULL on not found.
+ */
+struct pci_dev *iopf_queue_find_pdev(struct iopf_queue *queue, u8 bus, u8 devfn)
+{
+ struct iommu_fault_param *data = NULL;
+ struct pci_dev *pdev = NULL;
+ struct rb_node *node;
+
+ down_read(&queue->pci_dev_sem);
+
+ node = queue->pci_dev_rbtree.rb_node;
+ while (node) {
+ data = container_of(node, struct iommu_fault_param, node);
+ pdev = to_pci_dev(data->dev);
+ s16 result = RB_NODE_CMP(bus, devfn, pdev->bus->number, pdev->devfn);
+
+ if (result < 0)
+ node = node->rb_left;
+ else if (result > 0)
+ node = node->rb_right;
+ else
+ break;
+ }
+ up_read(&queue->pci_dev_sem);
+
+ return node ? pdev : NULL;
+}
+EXPORT_SYMBOL_GPL(iopf_queue_find_pdev);
diff --git a/include/linux/iommu.h b/include/linux/iommu.h
index bcec7e91dfc4..b29bbb0d1843 100644
--- a/include/linux/iommu.h
+++ b/include/linux/iommu.h
@@ -136,11 +136,15 @@ struct iopf_group {
* @wq: the fault workqueue
* @devices: devices attached to this queue
* @lock: protects the device list
+ * @pci_dev_rbtree: pci devices for looking up
+ * @pci_dev_sem: protects the rb_tree
*/
struct iopf_queue {
struct workqueue_struct *wq;
struct list_head devices;
struct mutex lock;
+ struct rb_root pci_dev_rbtree;
+ struct rw_semaphore pci_dev_sem;
};
/* iommu fault flags */
@@ -483,6 +487,8 @@ struct iommu_device {
u32 max_pasids;
};
+#define RB_NODE_CMP(bus1, devfn1, bus2, devfn2) ((s16)(PCI_DEVID(bus1, devfn1) - PCI_DEVID(bus2, devfn2)))
+
/**
* struct iommu_fault_param - per-device IOMMU fault data
* @lock: protect pending faults list
@@ -494,6 +500,7 @@ struct iommu_device {
* @partial: faults that are part of a Page Request Group for which the last
* request hasn't been submitted yet.
* @faults: holds the pending faults which needs response
+ * @node: pci device tracking node(lookup by (bus, devfn))
*/
struct iommu_fault_param {
struct mutex lock;
@@ -505,6 +512,7 @@ struct iommu_fault_param {
struct list_head partial;
struct list_head faults;
+ struct rb_node node;
};
/**
@@ -1286,6 +1294,8 @@ int iopf_queue_discard_dev_pasid(struct device *dev, ioasid_t pasid);
struct iopf_queue *iopf_queue_alloc(const char *name);
void iopf_queue_free(struct iopf_queue *queue);
int iopf_queue_discard_partial(struct iopf_queue *queue);
+struct pci_dev *iopf_queue_find_pdev(struct iopf_queue *queue,
+ u8 bus, u8 devfn);
void iopf_free_group(struct iopf_group *group);
int iommu_report_device_fault(struct device *dev, struct iopf_fault *evt);
int iommu_page_response(struct device *dev, struct iommu_page_response *msg);
@@ -1321,6 +1331,12 @@ static inline int iopf_queue_discard_partial(struct iopf_queue *queue)
return -ENODEV;
}
+static inline struct pci_dev *iopf_queue_find_pdev(struct iopf_queue *queue,
+ u8 bus, u8 devfn)
+{
+ return NULL;
+}
+
static inline void iopf_free_group(struct iopf_group *group)
{
}