If a domain is attaching to a group which includes the mediated devices, it should attach to the mdev parent of each mdev. This adds a helper for attaching domain to group, no matter a PCI physical device or mediated devices which are derived from a PCI physical device. Cc: Ashok Raj <ashok.raj@xxxxxxxxx> Cc: Jacob Pan <jacob.jun.pan@xxxxxxxxxxxxxxx> Cc: Kevin Tian <kevin.tian@xxxxxxxxx> Cc: Liu Yi L <yi.l.liu@xxxxxxxxx> Signed-off-by: Lu Baolu <baolu.lu@xxxxxxxxxxxxxxx> --- drivers/vfio/vfio_iommu_type1.c | 77 ++++++++++++++++++++++++++++++--- 1 file changed, 70 insertions(+), 7 deletions(-) diff --git a/drivers/vfio/vfio_iommu_type1.c b/drivers/vfio/vfio_iommu_type1.c index d9fd3188615d..89e2e6123223 100644 --- a/drivers/vfio/vfio_iommu_type1.c +++ b/drivers/vfio/vfio_iommu_type1.c @@ -91,6 +91,9 @@ struct vfio_dma { struct vfio_group { struct iommu_group *iommu_group; struct list_head next; + bool attach_parent; /* An mdev group with domain + * attached to parent + */ }; /* @@ -1327,6 +1330,66 @@ static bool vfio_iommu_has_sw_msi(struct iommu_group *group, phys_addr_t *base) return ret; } +static int vfio_mdev_set_aux_domain(struct device *dev, + struct iommu_domain *domain) +{ + int (*fn)(struct device *dev, void *domain); + int ret; + + fn = symbol_get(mdev_set_domain); + if (fn) { + ret = fn(dev, domain); + symbol_put(mdev_set_domain); + + return ret; + } + + return -EINVAL; +} + +static int vfio_attach_aux_domain(struct device *dev, void *data) +{ + struct iommu_domain *domain = data; + int ret; + + ret = vfio_mdev_set_aux_domain(dev, domain); + if (ret) + return ret; + + return iommu_attach_device(domain, dev->parent); +} + +static int vfio_detach_aux_domain(struct device *dev, void *data) +{ + struct iommu_domain *domain = data; + + vfio_mdev_set_aux_domain(dev, NULL); + iommu_detach_device(domain, dev->parent); + + return 0; +} + +static int vfio_iommu_attach_group(struct vfio_domain *domain, + struct vfio_group *group) +{ + if (group->attach_parent) + return iommu_group_for_each_dev(group->iommu_group, + domain->domain, + vfio_attach_aux_domain); + else + return iommu_attach_group(domain->domain, group->iommu_group); +} + +static void vfio_iommu_detach_group(struct vfio_domain *domain, + struct vfio_group *group) +{ + if (group->attach_parent) + iommu_group_for_each_dev(group->iommu_group, domain->domain, + vfio_detach_aux_domain); + else + iommu_detach_group(domain->domain, group->iommu_group); +} + static int vfio_iommu_type1_attach_group(void *iommu_data, struct iommu_group *iommu_group) { @@ -1402,7 +1465,7 @@ static int vfio_iommu_type1_attach_group(void *iommu_data, goto out_domain; } - ret = iommu_attach_group(domain->domain, iommu_group); + ret = vfio_iommu_attach_group(domain, group); if (ret) goto out_domain; @@ -1434,8 +1497,8 @@ static int vfio_iommu_type1_attach_group(void *iommu_data, list_for_each_entry(d, &iommu->domain_list, next) { if (d->domain->ops == domain->domain->ops && d->prot == domain->prot) { - iommu_detach_group(domain->domain, iommu_group); - if (!iommu_attach_group(d->domain, iommu_group)) { + vfio_iommu_detach_group(domain, group); + if (!vfio_iommu_attach_group(d, group)) { list_add(&group->next, &d->group_list); iommu_domain_free(domain->domain); kfree(domain); @@ -1443,7 +1506,7 @@ static int vfio_iommu_type1_attach_group(void *iommu_data, return 0; } - ret = iommu_attach_group(domain->domain, iommu_group); + ret = vfio_iommu_attach_group(domain, group); if (ret) goto out_domain; } @@ -1469,7 +1532,7 @@ static int vfio_iommu_type1_attach_group(void *iommu_data, return 0; out_detach: - iommu_detach_group(domain->domain, iommu_group); + vfio_iommu_detach_group(domain, group); out_domain: iommu_domain_free(domain->domain); out_free: @@ -1560,7 +1623,7 @@ static void vfio_iommu_type1_detach_group(void *iommu_data, if (!group) continue; - iommu_detach_group(domain->domain, iommu_group); + vfio_iommu_detach_group(domain, group); list_del(&group->next); kfree(group); /* @@ -1625,7 +1688,7 @@ static void vfio_release_domain(struct vfio_domain *domain, bool external) list_for_each_entry_safe(group, group_tmp, &domain->group_list, next) { if (!external) - iommu_detach_group(domain->domain, group->iommu_group); + vfio_iommu_detach_group(domain, group); list_del(&group->next); kfree(group); } -- 2.17.1