On Tue, Oct 29, 2024 at 04:09:41PM -0300, Jason Gunthorpe wrote: > On Fri, Oct 25, 2024 at 04:50:34PM -0700, Nicolin Chen wrote: > > @@ -497,17 +497,35 @@ int iommufd_hwpt_invalidate(struct iommufd_ucmd *ucmd) > > goto out; > > } > > > > - hwpt = iommufd_get_hwpt_nested(ucmd, cmd->hwpt_id); > > - if (IS_ERR(hwpt)) { > > - rc = PTR_ERR(hwpt); > > + pt_obj = iommufd_get_object(ucmd->ictx, cmd->hwpt_id, IOMMUFD_OBJ_ANY); > > + if (IS_ERR(pt_obj)) { > > + rc = PTR_ERR(pt_obj); > > goto out; > > } > > + if (pt_obj->type == IOMMUFD_OBJ_HWPT_NESTED) { > > + struct iommufd_hw_pagetable *hwpt = > > + container_of(pt_obj, struct iommufd_hw_pagetable, obj); > > + > > + rc = hwpt->domain->ops->cache_invalidate_user(hwpt->domain, > > + &data_array); > > + } else if (pt_obj->type == IOMMUFD_OBJ_VIOMMU) { > > + struct iommufd_viommu *viommu = > > + container_of(pt_obj, struct iommufd_viommu, obj); > > + > > + if (!viommu->ops || !viommu->ops->cache_invalidate) { > > + rc = -EOPNOTSUPP; > > + goto out_put_pt; > > + } > > + rc = viommu->ops->cache_invalidate(viommu, &data_array); > > + } else { > > + rc = -EINVAL; > > + goto out_put_pt; > > + } > > Given the test in iommufd_viommu_alloc_hwpt_nested() is: > > if (WARN_ON_ONCE(hwpt->domain->type != IOMMU_DOMAIN_NESTED || > (!viommu->ops->cache_invalidate && > !hwpt->domain->ops->cache_invalidate_user))) > { > > We will crash if the user passes a viommu allocated domain as > IOMMUFD_OBJ_HWPT_NESTED since the above doesn't check it. Ah, that was missed. > I suggest we put the required if (ops..) -EOPNOTSUPP above and remove > the ops->cache_invalidate checks from both WARN_ONs. Ack. I will add hwpt->domain->ops check: --------------------------------------------------------------------- if (pt_obj->type == IOMMUFD_OBJ_HWPT_NESTED) { struct iommufd_hw_pagetable *hwpt = container_of(pt_obj, struct iommufd_hw_pagetable, obj); if (!hwpt->domain->ops || !hwpt->domain->ops->cache_invalidate_user) { rc = -EOPNOTSUPP; goto out_put_pt; } rc = hwpt->domain->ops->cache_invalidate_user(hwpt->domain, &data_array); } else if (pt_obj->type == IOMMUFD_OBJ_VIOMMU) { struct iommufd_viommu *viommu = container_of(pt_obj, struct iommufd_viommu, obj); if (!viommu->ops || !viommu->ops->cache_invalidate) { rc = -EOPNOTSUPP; goto out_put_pt; } rc = viommu->ops->cache_invalidate(viommu, &data_array); } else { --------------------------------------------------------------------- Thanks Nicolin