> +int vfio_assign_device_set(struct vfio_device *device, void *set_id) > +{ > + struct vfio_device_set *alloc_dev_set = NULL; > + struct vfio_device_set *dev_set; > + > + if (WARN_ON(!set_id)) > + return -EINVAL; > + > + /* > + * Atomically acquire a singleton object in the xarray for this set_id > + */ > +again: > + xa_lock(&vfio_device_set_xa); > + if (alloc_dev_set) { > + dev_set = __xa_cmpxchg(&vfio_device_set_xa, > + (unsigned long)set_id, NULL, > + alloc_dev_set, GFP_KERNEL); > + if (xa_is_err(dev_set)) { > + xa_unlock(&vfio_device_set_xa); > + kfree(alloc_dev_set); > + return xa_err(dev_set); > + } > + if (!dev_set) > + dev_set = alloc_dev_set; > + } else { > + dev_set = xa_load(&vfio_device_set_xa, (unsigned long)set_id); > + } > + > + if (dev_set) { > + dev_set->device_count++; > + xa_unlock(&vfio_device_set_xa); > + device->dev_set = dev_set; > + if (dev_set != alloc_dev_set) > + kfree(alloc_dev_set); > + return 0; > + } > + xa_unlock(&vfio_device_set_xa); > + > + if (WARN_ON(alloc_dev_set)) > + return -EINVAL; > + > + alloc_dev_set = kzalloc(sizeof(*alloc_dev_set), GFP_KERNEL); > + if (!alloc_dev_set) > + return -ENOMEM; > + mutex_init(&alloc_dev_set->lock); > + alloc_dev_set->set_id = set_id; > + goto again; > +} > +EXPORT_SYMBOL_GPL(vfio_assign_device_set); This looks unessecarily complicated. We can just try to load first and then store it under the same lock, e.g.: int vfio_assign_device_set(struct vfio_device *device, void *set_id) { unsigned long idx = (unsigned long)set_id; struct vfio_device_set *set, *new; int err; if (WARN_ON(!set_id)) return -EINVAL; xa_lock(&vfio_device_set_xa); set = xa_load(&vfio_device_set_xa, idx); if (set) goto found; xa_unlock(&vfio_device_set_xa); new = kzalloc(sizeof(*new), GFP_KERNEL); if (!new) return -ENOMEM; mutex_init(&new->lock); alloc_dev_set->set_id = set_id; xa_lock(&vfio_device_set_xa); set = xa_load(&vfio_device_set_xa, idx); if (set) { kfree(new); goto found; } err = xa_err(__xa_store(&vfio_device_set_xa, idx, new, GFP_KERNEL)); xa_unlock(&vfio_device_set_xa); if (err) kfree(new); return err; found: set->device_count++; xa_unlock(&vfio_device_set_xa); device->dev_set = set; return 0; } > +static void vfio_release_device_set(struct vfio_device *device) > +{ > + struct vfio_device_set *dev_set = device->dev_set; > + > + if (!dev_set) > + return; > + > + xa_lock(&vfio_device_set_xa); > + dev_set->device_count--; > + if (!dev_set->device_count) { Nit, by I'd find if (!--dev_set->device_count) { easier to follow as it clearly documents the dec_and_test pattern.