__vfio_register_dev() has a bit of code to sanity check if an (existing) group is not corrupted by having two copies of the same struct device in it. This should be impossible. It then has some complicated error unwind to uncreate the group. Instead check if the existing group is sane at the same time we locate it. If a bug is found then there is no error unwind, just simply fail allocation. Signed-off-by: Jason Gunthorpe <jgg@xxxxxxxxxx> --- drivers/vfio/vfio_main.c | 79 ++++++++++++++++++---------------------- 1 file changed, 36 insertions(+), 43 deletions(-) diff --git a/drivers/vfio/vfio_main.c b/drivers/vfio/vfio_main.c index 4ab13808b536e1..ba8b6bed12c7e7 100644 --- a/drivers/vfio/vfio_main.c +++ b/drivers/vfio/vfio_main.c @@ -306,15 +306,15 @@ static void vfio_container_put(struct vfio_container *container) * Group objects - create, release, get, put, search */ static struct vfio_group * -__vfio_group_get_from_iommu(struct iommu_group *iommu_group) +vfio_group_find_from_iommu(struct iommu_group *iommu_group) { struct vfio_group *group; + lockdep_assert_held(&vfio.group_lock); + list_for_each_entry(group, &vfio.group_list, vfio_next) { - if (group->iommu_group == iommu_group) { - vfio_group_get(group); + if (group->iommu_group == iommu_group) return group; - } } return NULL; } @@ -365,11 +365,27 @@ static struct vfio_group *vfio_group_alloc(struct iommu_group *iommu_group, return group; } +static bool vfio_group_has_device(struct vfio_group *group, struct device *dev) +{ + struct vfio_device *device; + + mutex_lock(&group->device_lock); + list_for_each_entry(device, &group->device_list, group_next) { + if (device->dev == dev) { + mutex_unlock(&group->device_lock); + return true; + } + } + mutex_unlock(&group->device_lock); + return false; +} + /* * Return a struct vfio_group * for the given iommu_group. If no vfio_group * already exists then create a new one. */ -static struct vfio_group *vfio_get_group(struct iommu_group *iommu_group, +static struct vfio_group *vfio_get_group(struct device *dev, + struct iommu_group *iommu_group, enum vfio_group_type type) { struct vfio_group *group; @@ -378,13 +394,20 @@ static struct vfio_group *vfio_get_group(struct iommu_group *iommu_group, mutex_lock(&vfio.group_lock); - ret = __vfio_group_get_from_iommu(iommu_group); - if (ret) - goto err_unlock; + ret = vfio_group_find_from_iommu(iommu_group); + if (ret) { + if (WARN_ON(vfio_group_has_device(ret, dev))) { + ret = ERR_PTR(-EINVAL); + goto out_unlock; + } + /* Found an existing group */ + vfio_group_get(ret); + goto out_unlock; + } group = ret = vfio_group_alloc(iommu_group, type); if (IS_ERR(ret)) - goto err_unlock; + goto out_unlock; err = dev_set_name(&group->dev, "%s%d", group->type == VFIO_NO_IOMMU ? "noiommu-" : "", @@ -397,7 +420,7 @@ static struct vfio_group *vfio_get_group(struct iommu_group *iommu_group, err = cdev_device_add(&group->cdev, &group->dev); if (err) { ret = ERR_PTR(err); - goto err_unlock; + goto out_unlock; } list_add(&group->vfio_next, &vfio.group_list); @@ -407,7 +430,7 @@ static struct vfio_group *vfio_get_group(struct iommu_group *iommu_group, err_put: put_device(&group->dev); -err_unlock: +out_unlock: mutex_unlock(&vfio.group_lock); return ret; } @@ -454,22 +477,6 @@ static bool vfio_device_try_get(struct vfio_device *device) return refcount_inc_not_zero(&device->refcount); } -static struct vfio_device *vfio_group_get_device(struct vfio_group *group, - struct device *dev) -{ - struct vfio_device *device; - - mutex_lock(&group->device_lock); - list_for_each_entry(device, &group->device_list, group_next) { - if (device->dev == dev && vfio_device_try_get(device)) { - mutex_unlock(&group->device_lock); - return device; - } - } - mutex_unlock(&group->device_lock); - return NULL; -} - /* * VFIO driver API */ @@ -506,7 +513,7 @@ static struct vfio_group *vfio_noiommu_group_alloc(struct device *dev, if (ret) goto out_put_group; - group = vfio_get_group(iommu_group, type); + group = vfio_get_group(dev, iommu_group, type); if (IS_ERR(group)) { ret = PTR_ERR(group); goto out_remove_device; @@ -556,7 +563,7 @@ static struct vfio_group *vfio_group_find_or_alloc(struct device *dev) return ERR_PTR(-EINVAL); } - group = vfio_get_group(iommu_group, VFIO_IOMMU); + group = vfio_get_group(dev, iommu_group, VFIO_IOMMU); /* The vfio_group holds a reference to the iommu_group */ iommu_group_put(iommu_group); @@ -566,8 +573,6 @@ static struct vfio_group *vfio_group_find_or_alloc(struct device *dev) static int __vfio_register_dev(struct vfio_device *device, struct vfio_group *group) { - struct vfio_device *existing_device; - if (IS_ERR(group)) return PTR_ERR(group); @@ -578,18 +583,6 @@ static int __vfio_register_dev(struct vfio_device *device, if (!device->dev_set) vfio_assign_device_set(device, device); - existing_device = vfio_group_get_device(group, device->dev); - if (existing_device) { - dev_WARN(device->dev, "Device already exists on group %d\n", - iommu_group_id(group->iommu_group)); - vfio_device_put(existing_device); - if (group->type == VFIO_NO_IOMMU || - group->type == VFIO_EMULATED_IOMMU) - iommu_group_remove_device(device->dev); - vfio_group_put(group); - return -EBUSY; - } - /* Our reference on group is moved to the device */ device->group = group; -- 2.37.3