On Thu, Nov 24, 2022 at 04:26:59AM -0800, Yi Liu wrote: > + kvm = vfio_device_get_group_kvm(device); > + if (!kvm) { > + ret = -EINVAL; > + goto err_unuse_iommu; > + } A null kvm is not an error. And looking at this along with following cdev patch, I think this organization is cleaner. Make it so the caller of the vfio_device_open does most of the group/device differences. We already have different call chains. keep the iommfd code in vfio_main.c's functions. diff --git a/drivers/vfio/group.c b/drivers/vfio/group.c index 3a69839c65ff75..9b511055150cec 100644 --- a/drivers/vfio/group.c +++ b/drivers/vfio/group.c @@ -609,61 +609,32 @@ void vfio_device_group_unregister(struct vfio_device *device) int vfio_device_group_use_iommu(struct vfio_device *device) { + struct vfio_group *group = device->group; int ret = 0; - /* - * Here we pass the KVM pointer with the group under the lock. If the - * device driver will use it, it must obtain a reference and release it - * during close_device. - */ - mutex_lock(&device->group->group_lock); - if (!vfio_group_has_iommu(device->group)) { - ret = -EINVAL; - goto out_unlock; - } + lockdep_assert_held(&group->group_lock); - if (device->group->container) { - ret = vfio_group_use_container(device->group); - if (ret) - goto out_unlock; - vfio_device_container_register(device); - } else if (device->group->iommufd) { - ret = vfio_iommufd_bind(device, device->group->iommufd); - } + if (WARN_ON(!group->container)) + return -EINVAL; -out_unlock: - mutex_unlock(&device->group->group_lock); - return ret; + ret = vfio_group_use_container(group); + if (ret) + return ret; + vfio_device_container_register(device); + return 0; } void vfio_device_group_unuse_iommu(struct vfio_device *device) -{ - mutex_lock(&device->group->group_lock); - if (device->group->container) { - vfio_device_container_unregister(device); - vfio_group_unuse_container(device->group); - } else if (device->group->iommufd) { - vfio_iommufd_unbind(device); - } - mutex_unlock(&device->group->group_lock); -} - -struct kvm *vfio_device_get_group_kvm(struct vfio_device *device) { struct vfio_group *group = device->group; - mutex_lock(&group->group_lock); - if (!group->kvm) { - mutex_unlock(&group->group_lock); - return NULL; - } - /* group_lock is released in the vfio_group_put_kvm() */ - return group->kvm; -} + lockdep_assert_held(&group->group_lock); -void vfio_device_put_group_kvm(struct vfio_device *device) -{ - mutex_unlock(&device->group->group_lock); + if (WARN_ON(!group->container)) + return; + + vfio_device_container_unregister(device); + vfio_group_unuse_container(group); } /** diff --git a/drivers/vfio/vfio_main.c b/drivers/vfio/vfio_main.c index 3108e92a5cb20b..f9386a34d584e2 100644 --- a/drivers/vfio/vfio_main.c +++ b/drivers/vfio/vfio_main.c @@ -364,9 +364,9 @@ static bool vfio_assert_device_open(struct vfio_device *device) return !WARN_ON_ONCE(!READ_ONCE(device->open_count)); } -static int vfio_device_first_open(struct vfio_device *device) +static int vfio_device_first_open(struct vfio_device *device, + struct iommufd_ctx *iommufd, struct kvm *kvm) { - struct kvm *kvm; int ret; lockdep_assert_held(&device->dev_set->lock); @@ -374,54 +374,56 @@ static int vfio_device_first_open(struct vfio_device *device) if (!try_module_get(device->dev->driver->owner)) return -ENODEV; - ret = vfio_device_group_use_iommu(device); + if (iommufd) + ret = vfio_iommufd_bind(device, iommufd); + else + ret = vfio_device_group_use_iommu(device); if (ret) goto err_module_put; - kvm = vfio_device_get_group_kvm(device); - if (!kvm) { - ret = -EINVAL; - goto err_unuse_iommu; - } - device->kvm = kvm; if (device->ops->open_device) { ret = device->ops->open_device(device); if (ret) - goto err_container; + goto err_unuse_iommu; } - vfio_device_put_group_kvm(device); return 0; -err_container: - device->kvm = NULL; - vfio_device_put_group_kvm(device); err_unuse_iommu: - vfio_device_group_unuse_iommu(device); + if (iommufd) + vfio_iommufd_unbind(device); + else + vfio_device_group_unuse_iommu(device); err_module_put: module_put(device->dev->driver->owner); + device->kvm = NULL; return ret; } -static void vfio_device_last_close(struct vfio_device *device) +static void vfio_device_last_close(struct vfio_device *device, + struct iommufd_ctx *iommufd) { lockdep_assert_held(&device->dev_set->lock); if (device->ops->close_device) device->ops->close_device(device); device->kvm = NULL; - vfio_device_group_unuse_iommu(device); + if (iommufd) + vfio_iommufd_unbind(device); + else + vfio_device_group_unuse_iommu(device); module_put(device->dev->driver->owner); } -static int vfio_device_open(struct vfio_device *device) +static int vfio_device_open(struct vfio_device *device, + struct iommufd_ctx *iommufd, struct kvm *kvm) { int ret = 0; mutex_lock(&device->dev_set->lock); device->open_count++; if (device->open_count == 1) { - ret = vfio_device_first_open(device); + ret = vfio_device_first_open(device, iommufd, kvm); if (ret) device->open_count--; } @@ -430,22 +432,53 @@ static int vfio_device_open(struct vfio_device *device) return ret; } -static void vfio_device_close(struct vfio_device *device) +static void vfio_device_close(struct vfio_device *device, + struct iommufd_ctx *iommufd) { mutex_lock(&device->dev_set->lock); vfio_assert_device_open(device); if (device->open_count == 1) - vfio_device_last_close(device); + vfio_device_last_close(device, iommufd); device->open_count--; mutex_unlock(&device->dev_set->lock); } +static int vfio_device_group_open(struct vfio_device *device) +{ + int ret; + + mutex_lock(&device->group->group_lock); + if (!vfio_group_has_iommu(device->group)) { + ret = -EINVAL; + goto out_unlock; + } + + /* + * Here we pass the KVM pointer with the group under the lock. If the + * device driver will use it, it must obtain a reference and release it + * during close_device. + */ + ret = vfio_device_open(device, device->group->iommufd, + device->group->kvm); + +out_unlock: + mutex_unlock(&device->group->group_lock); + return ret; +} + +void vfio_device_close_group(struct vfio_device *device) +{ + mutex_lock(&device->group->group_lock); + vfio_device_close(device, device->group->iommufd); + mutex_unlock(&device->group->group_lock); +} + struct file *vfio_device_open_file(struct vfio_device *device) { struct file *filep; int ret; - ret = vfio_device_open(device); + ret = vfio_device_group_open(device); if (ret) goto err_out; @@ -474,7 +507,7 @@ struct file *vfio_device_open_file(struct vfio_device *device) return filep; err_close_device: - vfio_device_close(device); + vfio_device_group_close(device), device->group->iommufd; err_out: return ERR_PTR(ret); } @@ -519,7 +552,7 @@ static int vfio_device_fops_release(struct inode *inode, struct file *filep) { struct vfio_device *device = filep->private_data; - vfio_device_close(device); + vfio_device_close_group(device); vfio_device_put_registration(device);