On Thu, 22 Sep 2022 16:36:14 -0300 Jason Gunthorpe <jgg@xxxxxxxxxx> wrote: > On Thu, Sep 22, 2022 at 01:10:50PM -0600, Alex Williamson wrote: > > The semantics of vfio_get_group() are also rather strange, 'return a > > vfio_group for this iommu_group, but make sure it doesn't include this > > device' :-\ Thanks, > > I think of it as "return a group and also do sanity checks that the > returned group has not been corrupted" > > I don't like the name of this function but couldn't figure a better > one. It is something like "find or create a group for a device which > we know doesn't already have a group" Well, we don't really need to have this behavior, we could choose to implement the first two patches with the caller holding the group_lock. Only one of the callers needs the duplicate test, no-iommu creates its own iommu_group and therefore cannot have an existing device. I think patches 1 & 2 would look like below*, with patch 3 simply moving the change from vfio_group_get() to refcount_inc() into the equivalent place in vfio_group_find_of_alloc(). Thanks, Alex *only compile tested --- drivers/vfio/vfio_main.c | 33 ++++++++++----------------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/drivers/vfio/vfio_main.c b/drivers/vfio/vfio_main.c index f9d10dbcf3e6..aa33944cb759 100644 --- a/drivers/vfio/vfio_main.c +++ b/drivers/vfio/vfio_main.c @@ -310,10 +310,12 @@ static void vfio_container_put(struct vfio_container *container) * Group objects - create, release, get, put, search */ static struct vfio_group * -__vfio_group_get_from_iommu(struct iommu_group *iommu_group) +vfio_group_get_from_iommu(struct iommu_group *iommu_group) { struct vfio_group *group; + lockdep_assert_held(&vfio.group_lock); + list_for_each_entry(group, &vfio.group_list, vfio_next) { if (group->iommu_group == iommu_group) { vfio_group_get(group); @@ -323,17 +325,6 @@ __vfio_group_get_from_iommu(struct iommu_group *iommu_group) return NULL; } -static struct vfio_group * -vfio_group_get_from_iommu(struct iommu_group *iommu_group) -{ - struct vfio_group *group; - - mutex_lock(&vfio.group_lock); - group = __vfio_group_get_from_iommu(iommu_group); - mutex_unlock(&vfio.group_lock); - return group; -} - static void vfio_group_release(struct device *dev) { struct vfio_group *group = container_of(dev, struct vfio_group, dev); @@ -387,6 +378,8 @@ static struct vfio_group *vfio_create_group(struct iommu_group *iommu_group, struct vfio_group *ret; int err; + lockdep_assert_held(&vfio.group_lock); + group = vfio_group_alloc(iommu_group, type); if (IS_ERR(group)) return group; @@ -399,26 +392,16 @@ static struct vfio_group *vfio_create_group(struct iommu_group *iommu_group, goto err_put; } - mutex_lock(&vfio.group_lock); - - /* Did we race creating this group? */ - ret = __vfio_group_get_from_iommu(iommu_group); - if (ret) - goto err_unlock; - err = cdev_device_add(&group->cdev, &group->dev); if (err) { ret = ERR_PTR(err); - goto err_unlock; + goto err_put; } list_add(&group->vfio_next, &vfio.group_list); - mutex_unlock(&vfio.group_lock); return group; -err_unlock: - mutex_unlock(&vfio.group_lock); err_put: put_device(&group->dev); return ret; @@ -609,7 +592,9 @@ static struct vfio_group *vfio_noiommu_group_alloc(struct device *dev, if (ret) goto out_put_group; + mutex_lock(&vfio.group_lock); group = vfio_create_group(iommu_group, type); + mutex_unlock(&vfio.group_lock); if (IS_ERR(group)) { ret = PTR_ERR(group); goto out_remove_device; @@ -659,9 +644,11 @@ static struct vfio_group *vfio_group_find_or_alloc(struct device *dev) return ERR_PTR(-EINVAL); } + mutex_lock(&vfio.group_lock); group = vfio_group_get_from_iommu(iommu_group); if (!group) group = vfio_create_group(iommu_group, VFIO_IOMMU); + mutex_unlock(&vfio.group_lock); /* The vfio_group holds a reference to the iommu_group */ iommu_group_put(iommu_group); --- drivers/vfio/vfio_main.c | 58 ++++++++++++++++++++-------------------------- 1 file changed, 25 insertions(+), 33 deletions(-) diff --git a/drivers/vfio/vfio_main.c b/drivers/vfio/vfio_main.c index aa33944cb759..4692493d386a 100644 --- a/drivers/vfio/vfio_main.c +++ b/drivers/vfio/vfio_main.c @@ -310,17 +310,15 @@ static void vfio_container_put(struct vfio_container *container) * Group objects - create, release, get, put, search */ static struct vfio_group * -vfio_group_get_from_iommu(struct iommu_group *iommu_group) +vfio_group_find_from_iommu(struct iommu_group *iommu_group) { struct vfio_group *group; lockdep_assert_held(&vfio.group_lock); list_for_each_entry(group, &vfio.group_list, vfio_next) { - if (group->iommu_group == iommu_group) { - vfio_group_get(group); + if (group->iommu_group == iommu_group) return group; - } } return NULL; } @@ -449,23 +447,6 @@ static bool vfio_device_try_get_registration(struct vfio_device *device) return refcount_inc_not_zero(&device->refcount); } -static struct vfio_device *vfio_group_get_device(struct vfio_group *group, - struct device *dev) -{ - struct vfio_device *device; - - mutex_lock(&group->device_lock); - list_for_each_entry(device, &group->device_list, group_next) { - if (device->dev == dev && - vfio_device_try_get_registration(device)) { - mutex_unlock(&group->device_lock); - return device; - } - } - mutex_unlock(&group->device_lock); - return NULL; -} - /* * VFIO driver API */ @@ -609,6 +590,21 @@ static struct vfio_group *vfio_noiommu_group_alloc(struct device *dev, return ERR_PTR(ret); } +static bool vfio_group_has_device(struct vfio_group *group, struct device *dev) +{ + struct vfio_device *device; + + mutex_lock(&group->device_lock); + list_for_each_entry(device, &group->device_list, group_next) { + if (device->dev == dev) { + mutex_unlock(&group->device_lock); + return true; + } + } + mutex_unlock(&group->device_lock); + return false; +} + static struct vfio_group *vfio_group_find_or_alloc(struct device *dev) { struct iommu_group *iommu_group; @@ -645,9 +641,15 @@ static struct vfio_group *vfio_group_find_or_alloc(struct device *dev) } mutex_lock(&vfio.group_lock); - group = vfio_group_get_from_iommu(iommu_group); - if (!group) + group = vfio_group_find_from_iommu(iommu_group); + if (group) { + if (WARN_ON(vfio_group_has_device(group, dev))) + group = ERR_PTR(-EINVAL); + else + vfio_group_get(group); + } else { group = vfio_create_group(iommu_group, VFIO_IOMMU); + } mutex_unlock(&vfio.group_lock); /* The vfio_group holds a reference to the iommu_group */ @@ -658,7 +660,6 @@ static struct vfio_group *vfio_group_find_or_alloc(struct device *dev) static int __vfio_register_dev(struct vfio_device *device, struct vfio_group *group) { - struct vfio_device *existing_device; int ret; if (IS_ERR(group)) @@ -671,15 +672,6 @@ static int __vfio_register_dev(struct vfio_device *device, if (!device->dev_set) vfio_assign_device_set(device, device); - existing_device = vfio_group_get_device(group, device->dev); - if (existing_device) { - dev_WARN(device->dev, "Device already exists on group %d\n", - iommu_group_id(group->iommu_group)); - vfio_device_put_registration(existing_device); - ret = -EBUSY; - goto err_out; - } - /* Our reference on group is moved to the device */ device->group = group;