diff --git a/drivers/vfio/group.c b/drivers/vfio/group.c
index d8ef098c1f74a6..3a69839c65ff75 100644
--- a/drivers/vfio/group.c
+++ b/drivers/vfio/group.c
@@ -476,8 +476,8 @@ void vfio_device_remove_group(struct vfio_device *device)
put_device(&group->dev);
}
-struct vfio_group *vfio_noiommu_group_alloc(struct device *dev,
- enum vfio_group_type type)
+static struct vfio_group *vfio_noiommu_group_alloc(struct device *dev,
+ enum vfio_group_type type)
{
struct iommu_group *iommu_group;
struct vfio_group *group;
@@ -526,7 +526,7 @@ static bool vfio_group_has_device(struct vfio_group *group, struct device *dev)
return false;
}
-struct vfio_group *vfio_group_find_or_alloc(struct device *dev)
+static struct vfio_group *vfio_group_find_or_alloc(struct device *dev)
{
struct iommu_group *iommu_group;
struct vfio_group *group;
@@ -577,6 +577,22 @@ struct vfio_group *vfio_group_find_or_alloc(struct device *dev)
return group;
}
+int vfio_device_set_group(struct vfio_device *device, enum vfio_group_type type)
+{
+ struct vfio_group *group;
+
+ if (type == VFIO_IOMMU)
+ group = vfio_group_find_or_alloc(device->dev);
+ else
+ group = vfio_noiommu_group_alloc(device->dev, type);
+ if (IS_ERR(group))
+ return PTR_ERR(group);
+
+ /* Our reference on group is moved to the device */
+ device->group = group;
+ return 0;
+}
+
void vfio_device_group_register(struct vfio_device *device)
{
mutex_lock(&device->group->device_lock);
@@ -632,8 +648,10 @@ void vfio_device_group_unuse_iommu(struct vfio_device *device)
mutex_unlock(&device->group->group_lock);
}
-struct kvm *vfio_group_get_kvm(struct vfio_group *group)
+struct kvm *vfio_device_get_group_kvm(struct vfio_device *device)
{
+ struct vfio_group *group = device->group;
+
mutex_lock(&group->group_lock);
if (!group->kvm) {
mutex_unlock(&group->group_lock);
@@ -643,24 +661,8 @@ struct kvm *vfio_group_get_kvm(struct vfio_group *group)
return group->kvm;
}
-void vfio_group_put_kvm(struct vfio_group *group)
-{
- mutex_unlock(&group->group_lock);
-}
-
-void vfio_device_group_finalize_open(struct vfio_device *device)
+void vfio_device_put_group_kvm(struct vfio_device *device)
{
- mutex_lock(&device->group->group_lock);
- if (device->group->container)
- vfio_device_container_register(device);
- mutex_unlock(&device->group->group_lock);
-}
-
-void vfio_device_group_abort_open(struct vfio_device *device)
-{
- mutex_lock(&device->group->group_lock);
- if (device->group->container)
- vfio_device_container_unregister(device);
mutex_unlock(&device->group->group_lock);
}
@@ -779,9 +781,9 @@ bool vfio_file_has_dev(struct file *file, struct vfio_device *device)
}
EXPORT_SYMBOL_GPL(vfio_file_has_dev);
-bool vfio_group_has_container(struct vfio_group *group)
+bool vfio_device_has_container(struct vfio_device *device)
{
- return group->container;
+ return device->group->container;
}
static char *vfio_devnode(struct device *dev, umode_t *mode)
diff --git a/drivers/vfio/vfio.h b/drivers/vfio/vfio.h
index 670c9c5a55f1fc..e69bfcefee400e 100644
--- a/drivers/vfio/vfio.h
+++ b/drivers/vfio/vfio.h
@@ -70,19 +70,16 @@ struct vfio_group {
struct iommufd_ctx *iommufd;
};
+int vfio_device_set_group(struct vfio_device *device,
+ enum vfio_group_type type);
void vfio_device_remove_group(struct vfio_device *device);
-struct vfio_group *vfio_noiommu_group_alloc(struct device *dev,
- enum vfio_group_type type);
-struct vfio_group *vfio_group_find_or_alloc(struct device *dev);
void vfio_device_group_register(struct vfio_device *device);
void vfio_device_group_unregister(struct vfio_device *device);
int vfio_device_group_use_iommu(struct vfio_device *device);
void vfio_device_group_unuse_iommu(struct vfio_device *device);
-struct kvm *vfio_group_get_kvm(struct vfio_group *group);
-void vfio_group_put_kvm(struct vfio_group *group);
-void vfio_device_group_finalize_open(struct vfio_device *device);
-void vfio_device_group_abort_open(struct vfio_device *device);
-bool vfio_group_has_container(struct vfio_group *group);
+struct kvm *vfio_device_get_group_kvm(struct vfio_device *device);
+void vfio_device_put_group_kvm(struct vfio_device *device);
+bool vfio_device_has_container(struct vfio_device *device);
int __init vfio_group_init(void);
void vfio_group_cleanup(void);
@@ -142,12 +139,12 @@ int vfio_container_attach_group(struct vfio_container *container,
void vfio_group_detach_container(struct vfio_group *group);
void vfio_device_container_register(struct vfio_device *device);
void vfio_device_container_unregister(struct vfio_device *device);
-int vfio_group_container_pin_pages(struct vfio_group *group,
+int vfio_device_container_pin_pages(struct vfio_device *device,
dma_addr_t iova, int npage,
int prot, struct page **pages);
-void vfio_group_container_unpin_pages(struct vfio_group *group,
+void vfio_device_container_unpin_pages(struct vfio_device *device,
dma_addr_t iova, int npage);
-int vfio_group_container_dma_rw(struct vfio_group *group,
+int vfio_device_container_dma_rw(struct vfio_device *device,
dma_addr_t iova, void *data,
size_t len, bool write);
@@ -187,21 +184,21 @@ static inline void vfio_device_container_unregister(struct vfio_device *device)
{
}
-static inline int vfio_group_container_pin_pages(struct vfio_group *group,
- dma_addr_t iova, int npage,
- int prot, struct page **pages)
+static inline int vfio_device_container_pin_pages(struct vfio_device *device,
+ dma_addr_t iova, int npage,
+ int prot, struct page **pages)
{
return -EOPNOTSUPP;
}
-static inline void vfio_group_container_unpin_pages(struct vfio_group *group,
- dma_addr_t iova, int npage)
+static inline void vfio_device_container_unpin_pages(struct vfio_device *device,
+ dma_addr_t iova, int npage)
{
}
-static inline int vfio_group_container_dma_rw(struct vfio_group *group,
- dma_addr_t iova, void *data,
- size_t len, bool write)
+static inline int vfio_device_container_dma_rw(struct vfio_device *device,
+ dma_addr_t iova, void *data,
+ size_t len, bool write)
{
return -EOPNOTSUPP;
}
diff --git a/drivers/vfio/vfio_main.c b/drivers/vfio/vfio_main.c
index a7b966b4f3fc86..3108e92a5cb20b 100644
--- a/drivers/vfio/vfio_main.c
+++ b/drivers/vfio/vfio_main.c
@@ -260,17 +260,10 @@ void vfio_free_device(struct vfio_device *device)
EXPORT_SYMBOL_GPL(vfio_free_device);
static int __vfio_register_dev(struct vfio_device *device,
- struct vfio_group *group)
+ enum vfio_group_type type)
{
int ret;
- /*
- * In all cases group is the output of one of the group allocation
- * functions and we have group->drivers incremented for us.
- */
- if (IS_ERR(group))
- return PTR_ERR(group);
-
if (WARN_ON(device->ops->bind_iommufd &&
(!device->ops->unbind_iommufd ||
!device->ops->attach_ioas)))
@@ -283,16 +276,19 @@ static int __vfio_register_dev(struct vfio_device *device,
if (!device->dev_set)
vfio_assign_device_set(device, device);
- /* Our reference on group is moved to the device */
- device->group = group;
-
ret = dev_set_name(&device->device, "vfio%d", device->index);
if (ret)
- goto err_out;
+ return ret;
- ret = device_add(&device->device);
+ ret = vfio_device_set_group(device, type);
if (ret)
- goto err_out;
+ return ret;
+
+ ret = device_add(&device->device);
+ if (ret) {
+ vfio_device_remove_group(device);
+ return ret;
+ }
/* Refcounting can't start until the driver calls register */
refcount_set(&device->refcount, 1);
@@ -300,15 +296,12 @@ static int __vfio_register_dev(struct vfio_device *device,
vfio_device_group_register(device);
return 0;
-err_out:
- vfio_device_remove_group(device);
- return ret;
}
int vfio_register_group_dev(struct vfio_device *device)
{
- return __vfio_register_dev(device,
- vfio_group_find_or_alloc(device->dev));
+ return __vfio_register_dev(device, VFIO_IOMMU);
+
}
EXPORT_SYMBOL_GPL(vfio_register_group_dev);
@@ -318,8 +311,7 @@ EXPORT_SYMBOL_GPL(vfio_register_group_dev);
*/
int vfio_register_emulated_iommu_dev(struct vfio_device *device)
{
- return __vfio_register_dev(device,
- vfio_noiommu_group_alloc(device->dev, VFIO_EMULATED_IOMMU));
+ return __vfio_register_dev(device, VFIO_EMULATED_IOMMU);
}
EXPORT_SYMBOL_GPL(vfio_register_emulated_iommu_dev);
@@ -386,7 +378,7 @@ static int vfio_device_first_open(struct vfio_device *device)
if (ret)
goto err_module_put;
- kvm = vfio_group_get_kvm(device->group);
+ kvm = vfio_device_get_group_kvm(device);
if (!kvm) {
ret = -EINVAL;
goto err_unuse_iommu;
@@ -398,12 +390,12 @@ static int vfio_device_first_open(struct vfio_device *device)
if (ret)
goto err_container;
}
- vfio_group_put_kvm(device->group);
+ vfio_device_put_group_kvm(device);
return 0;
err_container:
device->kvm = NULL;
- vfio_group_put_kvm(device->group);
+ vfio_device_put_group_kvm(device);
err_unuse_iommu:
vfio_device_group_unuse_iommu(device);
err_module_put:
@@ -1199,8 +1191,8 @@ int vfio_pin_pages(struct vfio_device *device, dma_addr_t iova,
/* group->container cannot change while a vfio device is open */
if (!pages || !npage || WARN_ON(!vfio_assert_device_open(device)))
return -EINVAL;
- if (vfio_group_has_container(device->group))
- return vfio_group_container_pin_pages(device->group, iova,
+ if (vfio_device_has_container(device))
+ return vfio_device_container_pin_pages(device, iova,
npage, prot, pages);
if (device->iommufd_access) {
int ret;
@@ -1237,8 +1229,8 @@ void vfio_unpin_pages(struct vfio_device *device, dma_addr_t iova, int npage)
if (WARN_ON(!vfio_assert_device_open(device)))
return;
- if (vfio_group_has_container(device->group)) {
- vfio_group_container_unpin_pages(device->group, iova,
+ if (vfio_device_has_container(device)) {
+ vfio_device_container_unpin_pages(device, iova,
npage);
return;
}
@@ -1276,9 +1268,9 @@ int vfio_dma_rw(struct vfio_device *device, dma_addr_t iova, void *data,
if (!data || len <= 0 || !vfio_assert_device_open(device))
return -EINVAL;
- if (vfio_group_has_container(device->group))
- return vfio_group_container_dma_rw(device->group, iova,
- data, len, write);
+ if (vfio_device_has_container(device))
+ return vfio_device_container_dma_rw(device, iova, data, len,
+ write);
if (device->iommufd_access) {
unsigned int flags = 0;