diff --git a/mm/gup.c b/mm/gup.c
index f3fc1f08d90c..40aa1c937212 100644
--- a/mm/gup.c
+++ b/mm/gup.c
@@ -2380,8 +2380,28 @@ static void __maybe_unused undo_dev_pagemap(int *nr, int nr_start,
}
#ifdef CONFIG_ARCH_HAS_PTE_SPECIAL
-static int gup_pte_range(pmd_t pmd, unsigned long addr, unsigned long end,
- unsigned int flags, struct page **pages, int *nr)
+/*
+ * Fast-gup relies on pte change detection to avoid concurrent pgtable
+ * operations.
+ *
+ * To pin the page, fast-gup needs to do below in order:
+ * (1) pin the page (by prefetching pte), then (2) check pte not changed.
+ *
+ * For the rest of pgtable operations where pgtable updates can be racy
+ * with fast-gup, we need to do (1) clear pte, then (2) check whether page
+ * is pinned.
+ *
+ * Above will work for all pte-level operations, including THP split.
+ *
+ * For THP collapse, it's a bit more complicated because fast-gup may be
+ * walking a pgtable page that is being freed (pte is still valid but pmd
+ * can be cleared already). To avoid race in such condition, we need to
+ * also check pmd here to make sure pmd doesn't change (corresponds to
+ * pmdp_collapse_flush() in the THP collapse code path).
+ */
+static int gup_pte_range(pmd_t pmd, pmd_t *pmdp, unsigned long addr,
+ unsigned long end, unsigned int flags,
+ struct page **pages, int *nr)
{
struct dev_pagemap *pgmap = NULL;
int nr_start = *nr, ret = 0;
@@ -2423,7 +2443,8 @@ static int gup_pte_range(pmd_t pmd, unsigned long addr, unsigned long end,
goto pte_unmap;
}
- if (unlikely(pte_val(pte) != pte_val(*ptep))) {
+ if (unlikely(pmd_val(pmd) != pmd_val(*pmdp)) ||
+ unlikely(pte_val(pte) != pte_val(*ptep))) {
gup_put_folio(folio, 1, flags);
goto pte_unmap;
}
@@ -2470,8 +2491,9 @@ static int gup_pte_range(pmd_t pmd, unsigned long addr, unsigned long end,
* get_user_pages_fast_only implementation that can pin pages. Thus it's still
* useful to have gup_huge_pmd even if we can't operate on ptes.
*/
-static int gup_pte_range(pmd_t pmd, unsigned long addr, unsigned long end,
- unsigned int flags, struct page **pages, int *nr)
+static int gup_pte_range(pmd_t pmd, pmd_t *pmdp, unsigned long addr,
+ unsigned long end, unsigned int flags,
+ struct page **pages, int *nr)
{
return 0;
}
@@ -2791,7 +2813,7 @@ static int gup_pmd_range(pud_t *pudp, pud_t pud, unsigned long addr, unsigned lo
if (!gup_huge_pd(__hugepd(pmd_val(pmd)), addr,
PMD_SHIFT, next, flags, pages, nr))
return 0;
- } else if (!gup_pte_range(pmd, addr, next, flags, pages, nr))
+ } else if (!gup_pte_range(pmd, pmdp, addr, next, flags, pages, nr))
return 0;
} while (pmdp++, addr = next, addr != end);
diff --git a/mm/khugepaged.c b/mm/khugepaged.c
index 2d74cf01f694..518b49095db3 100644
--- a/mm/khugepaged.c
+++ b/mm/khugepaged.c
@@ -1049,10 +1049,12 @@ static int collapse_huge_page(struct mm_struct *mm, unsigned long address,
pmd_ptl = pmd_lock(mm, pmd); /* probably unnecessary */
/*
- * After this gup_fast can't run anymore. This also removes
- * any huge TLB entry from the CPU so we won't allow
- * huge and small TLB entries for the same virtual address
- * to avoid the risk of CPU bugs in that area.
+ * This removes any huge TLB entry from the CPU so we won't allow
+ * huge and small TLB entries for the same virtual address to
+ * avoid the risk of CPU bugs in that area.
+ *
+ * Parallel fast GUP is fine since fast GUP will back off when
+ * it detects PMD is changed.
*/
_pmd = pmdp_collapse_flush(vma, address, pmd);
spin_unlock(pmd_ptl);