This avoids passing struct kvm * and struct iommufd_ctx * in multiple functions. vfio_device_open() becomes to be a locked helper. Signed-off-by: Yi Liu <yi.l.liu@xxxxxxxxx> --- drivers/vfio/group.c | 34 +++++++++++++++++++++++++--------- drivers/vfio/vfio.h | 10 +++++----- drivers/vfio/vfio_main.c | 40 ++++++++++++++++++++++++---------------- 3 files changed, 54 insertions(+), 30 deletions(-) diff --git a/drivers/vfio/group.c b/drivers/vfio/group.c index d83cf069d290..7200304663e5 100644 --- a/drivers/vfio/group.c +++ b/drivers/vfio/group.c @@ -154,33 +154,49 @@ static int vfio_group_ioctl_set_container(struct vfio_group *group, return ret; } -static int vfio_device_group_open(struct vfio_device *device) +static int vfio_device_group_open(struct vfio_device_file *df) { + struct vfio_device *device = df->device; int ret; mutex_lock(&device->group->group_lock); if (!vfio_group_has_iommu(device->group)) { ret = -EINVAL; - goto out_unlock; + goto err_unlock_group; } + mutex_lock(&device->dev_set->lock); /* * Here we pass the KVM pointer with the group under the lock. If the * device driver will use it, it must obtain a reference and release it * during close_device. */ - ret = vfio_device_open(device, device->group->iommufd, - device->group->kvm); + df->kvm = device->group->kvm; + df->iommufd = device->group->iommufd; + + ret = vfio_device_open(df); + if (ret) + goto err_unlock_device; + mutex_unlock(&device->dev_set->lock); -out_unlock: + mutex_unlock(&device->group->group_lock); + return 0; + +err_unlock_device: + df->kvm = NULL; + df->iommufd = NULL; + mutex_unlock(&device->dev_set->lock); +err_unlock_group: mutex_unlock(&device->group->group_lock); return ret; } -void vfio_device_group_close(struct vfio_device *device) +void vfio_device_group_close(struct vfio_device_file *df) { + struct vfio_device *device = df->device; + mutex_lock(&device->group->group_lock); - vfio_device_close(device, device->group->iommufd); + vfio_device_close(df); mutex_unlock(&device->group->group_lock); } @@ -196,7 +212,7 @@ static struct file *vfio_device_open_file(struct vfio_device *device) goto err_out; } - ret = vfio_device_group_open(device); + ret = vfio_device_group_open(df); if (ret) goto err_free; @@ -228,7 +244,7 @@ static struct file *vfio_device_open_file(struct vfio_device *device) return filep; err_close_device: - vfio_device_group_close(device); + vfio_device_group_close(df); err_free: kfree(df); err_out: diff --git a/drivers/vfio/vfio.h b/drivers/vfio/vfio.h index 53af6e3ea214..3d8ba165146c 100644 --- a/drivers/vfio/vfio.h +++ b/drivers/vfio/vfio.h @@ -19,14 +19,14 @@ struct vfio_container; struct vfio_device_file { struct vfio_device *device; struct kvm *kvm; + struct iommufd_ctx *iommufd; }; void vfio_device_put_registration(struct vfio_device *device); bool vfio_device_try_get_registration(struct vfio_device *device); -int vfio_device_open(struct vfio_device *device, - struct iommufd_ctx *iommufd, struct kvm *kvm); -void vfio_device_close(struct vfio_device *device, - struct iommufd_ctx *iommufd); +int vfio_device_open(struct vfio_device_file *df); +void vfio_device_close(struct vfio_device_file *device); + struct vfio_device_file * vfio_allocate_device_file(struct vfio_device *device); @@ -90,7 +90,7 @@ void vfio_device_group_register(struct vfio_device *device); void vfio_device_group_unregister(struct vfio_device *device); int vfio_device_group_use_iommu(struct vfio_device *device); void vfio_device_group_unuse_iommu(struct vfio_device *device); -void vfio_device_group_close(struct vfio_device *device); +void vfio_device_group_close(struct vfio_device_file *df); struct vfio_group *vfio_group_from_file(struct file *file); bool vfio_group_enforced_coherent(struct vfio_group *group); void vfio_group_set_kvm(struct vfio_group *group, struct kvm *kvm); diff --git a/drivers/vfio/vfio_main.c b/drivers/vfio/vfio_main.c index dc08d5dd62cc..3df71bd9cd1e 100644 --- a/drivers/vfio/vfio_main.c +++ b/drivers/vfio/vfio_main.c @@ -358,9 +358,11 @@ vfio_allocate_device_file(struct vfio_device *device) return df; } -static int vfio_device_first_open(struct vfio_device *device, - struct iommufd_ctx *iommufd, struct kvm *kvm) +static int vfio_device_first_open(struct vfio_device_file *df) { + struct vfio_device *device = df->device; + struct iommufd_ctx *iommufd = df->iommufd; + struct kvm *kvm = df->kvm; int ret; lockdep_assert_held(&device->dev_set->lock); @@ -394,9 +396,11 @@ static int vfio_device_first_open(struct vfio_device *device, return ret; } -static void vfio_device_last_close(struct vfio_device *device, - struct iommufd_ctx *iommufd) +static void vfio_device_last_close(struct vfio_device_file *df) { + struct vfio_device *device = df->device; + struct iommufd_ctx *iommufd = df->iommufd; + lockdep_assert_held(&device->dev_set->lock); if (device->ops->close_device) @@ -409,30 +413,34 @@ static void vfio_device_last_close(struct vfio_device *device, module_put(device->dev->driver->owner); } -int vfio_device_open(struct vfio_device *device, - struct iommufd_ctx *iommufd, struct kvm *kvm) +int vfio_device_open(struct vfio_device_file *df) { - int ret = 0; + struct vfio_device *device = df->device; + + lockdep_assert_held(&device->dev_set->lock); - mutex_lock(&device->dev_set->lock); device->open_count++; if (device->open_count == 1) { - ret = vfio_device_first_open(device, iommufd, kvm); - if (ret) + int ret; + + ret = vfio_device_first_open(df); + if (ret) { device->open_count--; + return ret; + } } - mutex_unlock(&device->dev_set->lock); - return ret; + return 0; } -void vfio_device_close(struct vfio_device *device, - struct iommufd_ctx *iommufd) +void vfio_device_close(struct vfio_device_file *df) { + struct vfio_device *device = df->device; + mutex_lock(&device->dev_set->lock); vfio_assert_device_open(device); if (device->open_count == 1) - vfio_device_last_close(device, iommufd); + vfio_device_last_close(df); device->open_count--; mutex_unlock(&device->dev_set->lock); } @@ -478,7 +486,7 @@ static int vfio_device_fops_release(struct inode *inode, struct file *filep) struct vfio_device_file *df = filep->private_data; struct vfio_device *device = df->device; - vfio_device_group_close(device); + vfio_device_group_close(df); vfio_device_put_registration(device); -- 2.34.1