Firstly, pass the wp_copy variable into hugetlb_mcopy_atomic_pte() thoughout the stack. Then, apply the UFFD_WP bit if UFFDIO_COPY_MODE_WP is with UFFDIO_COPY. Introduce huge_pte_mkuffd_wp() for it. Note that similar to how we've handled shmem, we'd better keep setting the dirty bit even if UFFDIO_COPY_MODE_WP is provided, so that the core mm will know this page contains valid data and never drop it. Signed-off-by: Peter Xu <peterx@xxxxxxxxxx> --- include/asm-generic/hugetlb.h | 5 +++++ include/linux/hugetlb.h | 6 ++++-- mm/hugetlb.c | 9 +++++++-- mm/userfaultfd.c | 12 ++++++++---- 4 files changed, 24 insertions(+), 8 deletions(-) diff --git a/include/asm-generic/hugetlb.h b/include/asm-generic/hugetlb.h index 8e1e6244a89d..548212eccbd6 100644 --- a/include/asm-generic/hugetlb.h +++ b/include/asm-generic/hugetlb.h @@ -27,6 +27,11 @@ static inline pte_t huge_pte_mkdirty(pte_t pte) return pte_mkdirty(pte); } +static inline pte_t huge_pte_mkuffd_wp(pte_t pte) +{ + return pte_mkuffd_wp(pte); +} + static inline pte_t huge_pte_modify(pte_t pte, pgprot_t newprot) { return pte_modify(pte, newprot); diff --git a/include/linux/hugetlb.h b/include/linux/hugetlb.h index ebca2ef02212..bd061f7eedcb 100644 --- a/include/linux/hugetlb.h +++ b/include/linux/hugetlb.h @@ -138,7 +138,8 @@ int hugetlb_mcopy_atomic_pte(struct mm_struct *dst_mm, pte_t *dst_pte, struct vm_area_struct *dst_vma, unsigned long dst_addr, unsigned long src_addr, - struct page **pagep); + struct page **pagep, + bool wp_copy); int hugetlb_reserve_pages(struct inode *inode, long from, long to, struct vm_area_struct *vma, vm_flags_t vm_flags); @@ -313,7 +314,8 @@ static inline int hugetlb_mcopy_atomic_pte(struct mm_struct *dst_mm, struct vm_area_struct *dst_vma, unsigned long dst_addr, unsigned long src_addr, - struct page **pagep) + struct page **pagep, + bool wp_copy) { BUG(); return 0; diff --git a/mm/hugetlb.c b/mm/hugetlb.c index dcbbba53bd10..563b8f70537f 100644 --- a/mm/hugetlb.c +++ b/mm/hugetlb.c @@ -4624,7 +4624,8 @@ int hugetlb_mcopy_atomic_pte(struct mm_struct *dst_mm, struct vm_area_struct *dst_vma, unsigned long dst_addr, unsigned long src_addr, - struct page **pagep) + struct page **pagep, + bool wp_copy) { struct address_space *mapping; pgoff_t idx; @@ -4717,8 +4718,12 @@ int hugetlb_mcopy_atomic_pte(struct mm_struct *dst_mm, } _dst_pte = make_huge_pte(dst_vma, page, dst_vma->vm_flags & VM_WRITE); - if (dst_vma->vm_flags & VM_WRITE) + if (dst_vma->vm_flags & VM_WRITE) { _dst_pte = huge_pte_mkdirty(_dst_pte); + if (wp_copy) + _dst_pte = huge_pte_mkuffd_wp( + huge_pte_wrprotect(_dst_pte)); + } _dst_pte = pte_mkyoung(_dst_pte); set_huge_pte_at(dst_mm, dst_addr, dst_pte, _dst_pte); diff --git a/mm/userfaultfd.c b/mm/userfaultfd.c index 6d4b3b7c7f9f..b00e5e6b8b8b 100644 --- a/mm/userfaultfd.c +++ b/mm/userfaultfd.c @@ -207,7 +207,8 @@ static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm, unsigned long dst_start, unsigned long src_start, unsigned long len, - bool zeropage) + bool zeropage, + bool wp_copy) { int vm_alloc_shared = dst_vma->vm_flags & VM_SHARED; int vm_shared = dst_vma->vm_flags & VM_SHARED; @@ -306,7 +307,8 @@ static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm, } err = hugetlb_mcopy_atomic_pte(dst_mm, dst_pte, dst_vma, - dst_addr, src_addr, &page); + dst_addr, src_addr, &page, + wp_copy); mutex_unlock(&hugetlb_fault_mutex_table[hash]); i_mmap_unlock_read(mapping); @@ -408,7 +410,8 @@ extern ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm, unsigned long dst_start, unsigned long src_start, unsigned long len, - bool zeropage); + bool zeropage, + bool wp_copy); #endif /* CONFIG_HUGETLB_PAGE */ static __always_inline ssize_t mfill_atomic_pte(struct mm_struct *dst_mm, @@ -527,7 +530,8 @@ static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm, */ if (is_vm_hugetlb_page(dst_vma)) return __mcopy_atomic_hugetlb(dst_mm, dst_vma, dst_start, - src_start, len, zeropage); + src_start, len, zeropage, + wp_copy); if (!vma_is_anonymous(dst_vma) && !vma_is_shmem(dst_vma)) goto out_unlock; -- 2.26.2