On Fri, Dec 25, 2020 at 01:25:28AM -0800, Nadav Amit wrote: > The scenario that happens in selftests/vm/userfaultfd is as follows: > > cpu0 cpu1 cpu2 > ---- ---- ---- > [ Writable PTE > cached in TLB ] > userfaultfd_writeprotect() > [ write-*unprotect* ] > mwriteprotect_range() > mmap_read_lock() > change_protection() > > change_protection_range() > ... > change_pte_range() > [ *clear* “write”-bit ] > [ defer TLB flushes ] > [ page-fault ] > ... > wp_page_copy() > cow_user_page() > [ copy page ] > [ write to old > page ] > ... > set_pte_at_notify() Yuck! Isn't this all rather similar to the problem that resulted in the tlb_flush_pending mess? I still think that's all fundamentally buggered, the much saner solution (IMO) would've been to make things wait for the pending flush, instead of doing a local flush and fudging things like we do now. Then the above could be fixed by having wp_page_copy() wait for the pending invalidate (although a more fine-grained pending state would be awesome). The below probably doesn't compile and will probably cause massive header fail at the very least, but does show the general. diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h index 07d9acb5b19c..0210547ac424 100644 --- a/include/linux/mm_types.h +++ b/include/linux/mm_types.h @@ -649,7 +649,8 @@ static inline void dec_tlb_flush_pending(struct mm_struct *mm) * * Therefore we must rely on tlb_flush_*() to guarantee order. */ - atomic_dec(&mm->tlb_flush_pending); + if (atomic_dec_and_test(&mm->tlb_flush_pending)) + wake_up_var(&mm->tlb_flush_pending); } static inline bool mm_tlb_flush_pending(struct mm_struct *mm) @@ -677,6 +678,12 @@ static inline bool mm_tlb_flush_nested(struct mm_struct *mm) return atomic_read(&mm->tlb_flush_pending) > 1; } +static inline void wait_tlb_flush_pending(struct mm_struct *mm) +{ + wait_var_event(&mm->tlb_flush_pending, + atomic_read(&mm->tlb_flush_pending) == 0); +} + struct vm_fault; /** diff --git a/mm/memory.c b/mm/memory.c index feff48e1465a..3c36bca2972a 100644 --- a/mm/memory.c +++ b/mm/memory.c @@ -3087,6 +3087,8 @@ static vm_fault_t do_wp_page(struct vm_fault *vmf) { struct vm_area_struct *vma = vmf->vma; + wait_tlb_flush_pending(vma->vm_mm); + if (userfaultfd_pte_wp(vma, *vmf->pte)) { pte_unmap_unlock(vmf->pte, vmf->ptl); return handle_userfault(vmf, VM_UFFD_WP);