On Tue, May 07, 2024 at 02:22:12PM +0800, Yan Zhao wrote: > diff --git a/drivers/iommu/iommufd/hw_pagetable.c b/drivers/iommu/iommufd/hw_pagetable.c > index 33d142f8057d..e3099d732c5c 100644 > --- a/drivers/iommu/iommufd/hw_pagetable.c > +++ b/drivers/iommu/iommufd/hw_pagetable.c > @@ -14,12 +14,18 @@ void iommufd_hwpt_paging_destroy(struct iommufd_object *obj) > container_of(obj, struct iommufd_hwpt_paging, common.obj); > > if (!list_empty(&hwpt_paging->hwpt_item)) { > + struct io_pagetable *iopt = &hwpt_paging->ioas->iopt; > mutex_lock(&hwpt_paging->ioas->mutex); > list_del(&hwpt_paging->hwpt_item); > mutex_unlock(&hwpt_paging->ioas->mutex); > > - iopt_table_remove_domain(&hwpt_paging->ioas->iopt, > - hwpt_paging->common.domain); > + iopt_table_remove_domain(iopt, hwpt_paging->common.domain); > + > + if (!hwpt_paging->enforce_cache_coherency) { > + down_write(&iopt->domains_rwsem); > + iopt->noncoherent_domain_cnt--; > + up_write(&iopt->domains_rwsem); I think it would be nicer to put this in iopt_table_remove_domain() since we already have the lock there anyhow. It would be OK to pass int he hwpt. Same remark for the incr side > @@ -176,6 +182,12 @@ iommufd_hwpt_paging_alloc(struct iommufd_ctx *ictx, struct iommufd_ioas *ioas, > goto out_abort; > } > > + if (!hwpt_paging->enforce_cache_coherency) { > + down_write(&ioas->iopt.domains_rwsem); > + ioas->iopt.noncoherent_domain_cnt++; > + up_write(&ioas->iopt.domains_rwsem); > + } > + > rc = iopt_table_add_domain(&ioas->iopt, hwpt->domain); iopt_table_add_domain also already gets the required locks too > if (rc) > goto out_detach; > @@ -183,6 +195,9 @@ iommufd_hwpt_paging_alloc(struct iommufd_ctx *ictx, struct iommufd_ioas *ioas, > return hwpt_paging; > > out_detach: > + down_write(&ioas->iopt.domains_rwsem); > + ioas->iopt.noncoherent_domain_cnt--; > + up_write(&ioas->iopt.domains_rwsem); And then you don't need this error unwind > diff --git a/drivers/iommu/iommufd/io_pagetable.h b/drivers/iommu/iommufd/io_pagetable.h > index 0ec3509b7e33..557da8fb83d9 100644 > --- a/drivers/iommu/iommufd/io_pagetable.h > +++ b/drivers/iommu/iommufd/io_pagetable.h > @@ -198,6 +198,11 @@ struct iopt_pages { > void __user *uptr; > bool writable:1; > u8 account_mode; > + /* > + * CPU cache flush is required before mapping the pages to or after > + * unmapping it from a noncoherent domain > + */ > + bool cache_flush_required:1; Move this up a line so it packs with the other bool bitfield. > static void batch_clear(struct pfn_batch *batch) > { > batch->total_pfns = 0; > @@ -637,10 +648,18 @@ static void batch_unpin(struct pfn_batch *batch, struct iopt_pages *pages, > while (npages) { > size_t to_unpin = min_t(size_t, npages, > batch->npfns[cur] - first_page_off); > + unsigned long pfn = batch->pfns[cur] + first_page_off; > + > + /* > + * Lazily flushing CPU caches when a page is about to be > + * unpinned if the page was mapped into a noncoherent domain > + */ > + if (pages->cache_flush_required) > + arch_clean_nonsnoop_dma(pfn << PAGE_SHIFT, > + to_unpin << PAGE_SHIFT); > > unpin_user_page_range_dirty_lock( > - pfn_to_page(batch->pfns[cur] + first_page_off), > - to_unpin, pages->writable); > + pfn_to_page(pfn), to_unpin, pages->writable); > iopt_pages_sub_npinned(pages, to_unpin); > cur++; > first_page_off = 0; Make sense > @@ -1358,10 +1377,17 @@ int iopt_area_fill_domain(struct iopt_area *area, struct iommu_domain *domain) > { > unsigned long done_end_index; > struct pfn_reader pfns; > + bool cache_flush_required; > int rc; > > lockdep_assert_held(&area->pages->mutex); > > + cache_flush_required = area->iopt->noncoherent_domain_cnt && > + !area->pages->cache_flush_required; > + > + if (cache_flush_required) > + area->pages->cache_flush_required = true; > + > rc = pfn_reader_first(&pfns, area->pages, iopt_area_index(area), > iopt_area_last_index(area)); > if (rc) > @@ -1369,6 +1395,9 @@ int iopt_area_fill_domain(struct iopt_area *area, struct iommu_domain *domain) > > while (!pfn_reader_done(&pfns)) { > done_end_index = pfns.batch_start_index; > + if (cache_flush_required) > + iopt_cache_flush_pfn_batch(&pfns.batch); > + This is a bit unfortunate, it means we are going to flush for every domain, even though it is not required. I don't see any easy way out of that :( Jason