Conversion is straightforward, mmap_sem is used within the the same function context most of the time, and we already have vmf updated. No changes in semantics. Signed-off-by: Davidlohr Bueso <dbueso@xxxxxxx> --- include/linux/mm.h | 8 +++--- mm/filemap.c | 8 +++--- mm/frame_vector.c | 4 +-- mm/gup.c | 21 +++++++-------- mm/hmm.c | 3 ++- mm/khugepaged.c | 54 +++++++++++++++++++++------------------ mm/ksm.c | 42 +++++++++++++++++------------- mm/madvise.c | 36 ++++++++++++++------------ mm/memcontrol.c | 10 +++++--- mm/memory.c | 10 +++++--- mm/mempolicy.c | 25 ++++++++++-------- mm/migrate.c | 10 +++++--- mm/mincore.c | 6 +++-- mm/mlock.c | 20 +++++++++------ mm/mmap.c | 69 ++++++++++++++++++++++++++++---------------------- mm/mmu_notifier.c | 9 ++++--- mm/mprotect.c | 15 ++++++----- mm/mremap.c | 9 ++++--- mm/msync.c | 9 ++++--- mm/nommu.c | 25 ++++++++++-------- mm/oom_kill.c | 5 ++-- mm/process_vm_access.c | 4 +-- mm/shmem.c | 2 +- mm/swapfile.c | 5 ++-- mm/userfaultfd.c | 21 ++++++++------- mm/util.c | 10 +++++--- 26 files changed, 252 insertions(+), 188 deletions(-) diff --git a/include/linux/mm.h b/include/linux/mm.h index 044e428b1905..8bf3e2542047 100644 --- a/include/linux/mm.h +++ b/include/linux/mm.h @@ -1459,6 +1459,7 @@ void unmap_vmas(struct mmu_gather *tlb, struct vm_area_struct *start_vma, * right now." 1 means "skip the current vma." * @mm: mm_struct representing the target process of page table walk * @vma: vma currently walked (NULL if walking outside vmas) + * @mmrange: mm address space range locking * @private: private data for callbacks' usage * * (see the comment on walk_page_range() for more details) @@ -2358,8 +2359,8 @@ static inline int check_data_rlimit(unsigned long rlim, return 0; } -extern int mm_take_all_locks(struct mm_struct *mm); -extern void mm_drop_all_locks(struct mm_struct *mm); +extern int mm_take_all_locks(struct mm_struct *mm, struct range_lock *mmrange); +extern void mm_drop_all_locks(struct mm_struct *mm, struct range_lock *mmrange); extern void set_mm_exe_file(struct mm_struct *mm, struct file *new_exe_file); extern struct file *get_mm_exe_file(struct mm_struct *mm); @@ -2389,7 +2390,8 @@ extern unsigned long do_mmap(struct file *file, unsigned long addr, vm_flags_t vm_flags, unsigned long pgoff, unsigned long *populate, struct list_head *uf); extern int __do_munmap(struct mm_struct *, unsigned long, size_t, - struct list_head *uf, bool downgrade); + struct list_head *uf, bool downgrade, + struct range_lock *); extern int do_munmap(struct mm_struct *, unsigned long, size_t, struct list_head *uf); diff --git a/mm/filemap.c b/mm/filemap.c index 959022841bab..71f0d8a18f40 100644 --- a/mm/filemap.c +++ b/mm/filemap.c @@ -1388,7 +1388,7 @@ int __lock_page_or_retry(struct page *page, struct mm_struct *mm, if (flags & FAULT_FLAG_RETRY_NOWAIT) return 0; - up_read(&mm->mmap_sem); + mm_read_unlock(mm, mmrange); if (flags & FAULT_FLAG_KILLABLE) wait_on_page_locked_killable(page); else @@ -1400,7 +1400,7 @@ int __lock_page_or_retry(struct page *page, struct mm_struct *mm, ret = __lock_page_killable(page); if (ret) { - up_read(&mm->mmap_sem); + mm_read_unlock(mm, mmrange); return 0; } } else @@ -2317,7 +2317,7 @@ static struct file *maybe_unlock_mmap_for_io(struct vm_fault *vmf, if ((flags & (FAULT_FLAG_ALLOW_RETRY | FAULT_FLAG_RETRY_NOWAIT)) == FAULT_FLAG_ALLOW_RETRY) { fpin = get_file(vmf->vma->vm_file); - up_read(&vmf->vma->vm_mm->mmap_sem); + mm_read_unlock(vmf->vma->vm_mm, vmf->lockrange); } return fpin; } @@ -2357,7 +2357,7 @@ static int lock_page_maybe_drop_mmap(struct vm_fault *vmf, struct page *page, * mmap_sem here and return 0 if we don't have a fpin. */ if (*fpin == NULL) - up_read(&vmf->vma->vm_mm->mmap_sem); + mm_read_unlock(vmf->vma->vm_mm, vmf->lockrange); return 0; } } else diff --git a/mm/frame_vector.c b/mm/frame_vector.c index 4e1a577cbb79..ef33d21b3f39 100644 --- a/mm/frame_vector.c +++ b/mm/frame_vector.c @@ -47,7 +47,7 @@ int get_vaddr_frames(unsigned long start, unsigned int nr_frames, if (WARN_ON_ONCE(nr_frames > vec->nr_allocated)) nr_frames = vec->nr_allocated; - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); locked = 1; vma = find_vma_intersection(mm, start, start + 1); if (!vma) { @@ -102,7 +102,7 @@ int get_vaddr_frames(unsigned long start, unsigned int nr_frames, } while (vma && vma->vm_flags & (VM_IO | VM_PFNMAP)); out: if (locked) - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); if (!ret) ret = -EFAULT; if (ret > 0) diff --git a/mm/gup.c b/mm/gup.c index cf8fa037ce27..70b546a01682 100644 --- a/mm/gup.c +++ b/mm/gup.c @@ -990,7 +990,7 @@ int fixup_user_fault(struct task_struct *tsk, struct mm_struct *mm, } if (ret & VM_FAULT_RETRY) { - down_read(&mm->mmap_sem); + mm_read_lock(mm, mmrange); if (!(fault_flags & FAULT_FLAG_TRIED)) { *unlocked = true; fault_flags &= ~FAULT_FLAG_ALLOW_RETRY; @@ -1077,7 +1077,7 @@ static __always_inline long __get_user_pages_locked(struct task_struct *tsk, */ *locked = 1; lock_dropped = true; - down_read(&mm->mmap_sem); + mm_read_lock(mm, mmrange); ret = __get_user_pages(tsk, mm, start, 1, flags | FOLL_TRIED, pages, NULL, NULL, NULL); if (ret != 1) { @@ -1098,7 +1098,7 @@ static __always_inline long __get_user_pages_locked(struct task_struct *tsk, * We must let the caller know we temporarily dropped the lock * and so the critical section protected by it was lost. */ - up_read(&mm->mmap_sem); + mm_read_unlock(mm, mmrange); *locked = 0; } return pages_done; @@ -1176,11 +1176,11 @@ long get_user_pages_unlocked(unsigned long start, unsigned long nr_pages, if (WARN_ON_ONCE(gup_flags & FOLL_LONGTERM)) return -EINVAL; - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); ret = __get_user_pages_locked(current, mm, start, nr_pages, pages, NULL, &locked, gup_flags | FOLL_TOUCH, &mmrange); if (locked) - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); return ret; } EXPORT_SYMBOL(get_user_pages_unlocked); @@ -1543,7 +1543,7 @@ long populate_vma_page_range(struct vm_area_struct *vma, VM_BUG_ON(end & ~PAGE_MASK); VM_BUG_ON_VMA(start < vma->vm_start, vma); VM_BUG_ON_VMA(end > vma->vm_end, vma); - VM_BUG_ON_MM(!rwsem_is_locked(&mm->mmap_sem), mm); + VM_BUG_ON_MM(!mm_is_locked(mm, mmrange), mm); gup_flags = FOLL_TOUCH | FOLL_POPULATE | FOLL_MLOCK; if (vma->vm_flags & VM_LOCKONFAULT) @@ -1596,7 +1596,7 @@ int __mm_populate(unsigned long start, unsigned long len, int ignore_errors) */ if (!locked) { locked = 1; - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); vma = find_vma(mm, nstart); } else if (nstart >= vma->vm_end) vma = vma->vm_next; @@ -1628,7 +1628,7 @@ int __mm_populate(unsigned long start, unsigned long len, int ignore_errors) ret = 0; } if (locked) - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); return ret; /* 0 or negative error code */ } @@ -2189,17 +2189,18 @@ static int __gup_longterm_unlocked(unsigned long start, int nr_pages, unsigned int gup_flags, struct page **pages) { int ret; + DEFINE_RANGE_LOCK_FULL(mmrange); /* * FIXME: FOLL_LONGTERM does not work with * get_user_pages_unlocked() (see comments in that function) */ if (gup_flags & FOLL_LONGTERM) { - down_read(¤t->mm->mmap_sem); + mm_read_lock(current->mm, &mmrange); ret = __gup_longterm_locked(current, current->mm, start, nr_pages, pages, NULL, gup_flags); - up_read(¤t->mm->mmap_sem); + mm_read_unlock(current->mm, &mmrange); } else { ret = get_user_pages_unlocked(start, nr_pages, pages, gup_flags); diff --git a/mm/hmm.c b/mm/hmm.c index 723109ac6bdc..a79a07f7ccc1 100644 --- a/mm/hmm.c +++ b/mm/hmm.c @@ -1118,7 +1118,8 @@ long hmm_range_fault(struct hmm_range *range, bool block) do { /* If range is no longer valid force retry. */ if (!range->valid) { - up_read(&hmm->mm->mmap_sem); + /*** BROKEN mmrange, we don't care about hmm (for now) */ + mm_read_unlock(hmm->mm, NULL); return -EAGAIN; } diff --git a/mm/khugepaged.c b/mm/khugepaged.c index 3eefcb8f797d..13d8e29f4674 100644 --- a/mm/khugepaged.c +++ b/mm/khugepaged.c @@ -488,6 +488,8 @@ void __khugepaged_exit(struct mm_struct *mm) free_mm_slot(mm_slot); mmdrop(mm); } else if (mm_slot) { + DEFINE_RANGE_LOCK_FULL(mmrange); + /* * This is required to serialize against * khugepaged_test_exit() (which is guaranteed to run @@ -496,8 +498,8 @@ void __khugepaged_exit(struct mm_struct *mm) * khugepaged has finished working on the pagetables * under the mmap_sem. */ - down_write(&mm->mmap_sem); - up_write(&mm->mmap_sem); + mm_write_lock(mm, &mmrange); + mm_write_unlock(mm, &mmrange); } } @@ -908,7 +910,7 @@ static bool __collapse_huge_page_swapin(struct mm_struct *mm, /* do_swap_page returns VM_FAULT_RETRY with released mmap_sem */ if (ret & VM_FAULT_RETRY) { - down_read(&mm->mmap_sem); + mm_read_lock(mm, mmrange); if (hugepage_vma_revalidate(mm, address, &vmf.vma)) { /* vma is no longer available, don't continue to swapin */ trace_mm_collapse_huge_page_swapin(mm, swapped_in, referenced, 0); @@ -961,7 +963,7 @@ static void collapse_huge_page(struct mm_struct *mm, * sync compaction, and we do not need to hold the mmap_sem during * that. We will recheck the vma after taking it again in write mode. */ - up_read(&mm->mmap_sem); + mm_read_unlock(mm, mmrange); new_page = khugepaged_alloc_page(hpage, gfp, node); if (!new_page) { result = SCAN_ALLOC_HUGE_PAGE_FAIL; @@ -973,11 +975,11 @@ static void collapse_huge_page(struct mm_struct *mm, goto out_nolock; } - down_read(&mm->mmap_sem); + mm_read_lock(mm, mmrange); result = hugepage_vma_revalidate(mm, address, &vma); if (result) { mem_cgroup_cancel_charge(new_page, memcg, true); - up_read(&mm->mmap_sem); + mm_read_unlock(mm, mmrange); goto out_nolock; } @@ -985,7 +987,7 @@ static void collapse_huge_page(struct mm_struct *mm, if (!pmd) { result = SCAN_PMD_NULL; mem_cgroup_cancel_charge(new_page, memcg, true); - up_read(&mm->mmap_sem); + mm_read_unlock(mm, mmrange); goto out_nolock; } @@ -997,17 +999,17 @@ static void collapse_huge_page(struct mm_struct *mm, if (!__collapse_huge_page_swapin(mm, vma, address, pmd, referenced, mmrange)) { mem_cgroup_cancel_charge(new_page, memcg, true); - up_read(&mm->mmap_sem); + mm_read_unlock(mm, mmrange); goto out_nolock; } - up_read(&mm->mmap_sem); + mm_read_unlock(mm, mmrange); /* * Prevent all access to pagetables with the exception of * gup_fast later handled by the ptep_clear_flush and the VM * handled by the anon_vma lock + PG_lock. */ - down_write(&mm->mmap_sem); + mm_write_lock(mm, mmrange); result = hugepage_vma_revalidate(mm, address, &vma); if (result) goto out; @@ -1091,7 +1093,7 @@ static void collapse_huge_page(struct mm_struct *mm, khugepaged_pages_collapsed++; result = SCAN_SUCCEED; out_up_write: - up_write(&mm->mmap_sem); + mm_write_unlock(mm, mmrange); out_nolock: trace_mm_collapse_huge_page(mm, isolated, result); return; @@ -1250,7 +1252,8 @@ static void collect_mm_slot(struct mm_slot *mm_slot) } #if defined(CONFIG_SHMEM) && defined(CONFIG_TRANSPARENT_HUGE_PAGECACHE) -static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff) +static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff, + struct range_lock *mmrange) { struct vm_area_struct *vma; unsigned long addr; @@ -1275,12 +1278,12 @@ static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff) * re-fault. Not ideal, but it's more important to not disturb * the system too much. */ - if (down_write_trylock(&vma->vm_mm->mmap_sem)) { + if (mm_write_trylock(vma->vm_mm, mmrange)) { spinlock_t *ptl = pmd_lock(vma->vm_mm, pmd); /* assume page table is clear */ _pmd = pmdp_collapse_flush(vma, addr, pmd); spin_unlock(ptl); - up_write(&vma->vm_mm->mmap_sem); + mm_write_unlock(vma->vm_mm, mmrange); mm_dec_nr_ptes(vma->vm_mm); pte_free(vma->vm_mm, pmd_pgtable(_pmd)); } @@ -1307,8 +1310,9 @@ static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff) * + unlock and free huge page; */ static void collapse_shmem(struct mm_struct *mm, - struct address_space *mapping, pgoff_t start, - struct page **hpage, int node) + struct address_space *mapping, pgoff_t start, + struct page **hpage, int node, + struct range_lock *mmrange) { gfp_t gfp; struct page *new_page; @@ -1515,7 +1519,7 @@ static void collapse_shmem(struct mm_struct *mm, /* * Remove pte page tables, so we can re-fault the page as huge. */ - retract_page_tables(mapping, start); + retract_page_tables(mapping, start, mmrange); *hpage = NULL; khugepaged_pages_collapsed++; @@ -1566,8 +1570,9 @@ static void collapse_shmem(struct mm_struct *mm, } static void khugepaged_scan_shmem(struct mm_struct *mm, - struct address_space *mapping, - pgoff_t start, struct page **hpage) + struct address_space *mapping, + pgoff_t start, struct page **hpage, + struct range_lock *mmrange) { struct page *page = NULL; XA_STATE(xas, &mapping->i_pages, start); @@ -1633,7 +1638,8 @@ static void khugepaged_scan_shmem(struct mm_struct *mm, result = SCAN_EXCEED_NONE_PTE; } else { node = khugepaged_find_target_node(); - collapse_shmem(mm, mapping, start, hpage, node); + collapse_shmem(mm, mapping, start, hpage, + node, mmrange); } } @@ -1678,7 +1684,7 @@ static unsigned int khugepaged_scan_mm_slot(unsigned int pages, * the next mm on the list. */ vma = NULL; - if (unlikely(!down_read_trylock(&mm->mmap_sem))) + if (unlikely(!mm_read_trylock(mm, &mmrange))) goto breakouterloop_mmap_sem; if (likely(!khugepaged_test_exit(mm))) vma = find_vma(mm, khugepaged_scan.address); @@ -1723,10 +1729,10 @@ static unsigned int khugepaged_scan_mm_slot(unsigned int pages, if (!shmem_huge_enabled(vma)) goto skip; file = get_file(vma->vm_file); - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); ret = 1; khugepaged_scan_shmem(mm, file->f_mapping, - pgoff, hpage); + pgoff, hpage, &mmrange); fput(file); } else { ret = khugepaged_scan_pmd(mm, vma, @@ -1744,7 +1750,7 @@ static unsigned int khugepaged_scan_mm_slot(unsigned int pages, } } breakouterloop: - up_read(&mm->mmap_sem); /* exit_mmap will destroy ptes after this */ + mm_read_unlock(mm, &mmrange); /* exit_mmap will destroy ptes after this */ breakouterloop_mmap_sem: spin_lock(&khugepaged_mm_lock); diff --git a/mm/ksm.c b/mm/ksm.c index ccc9737311eb..7f9826ea7dba 100644 --- a/mm/ksm.c +++ b/mm/ksm.c @@ -537,6 +537,7 @@ static void break_cow(struct rmap_item *rmap_item) struct mm_struct *mm = rmap_item->mm; unsigned long addr = rmap_item->address; struct vm_area_struct *vma; + DEFINE_RANGE_LOCK_FULL(mmrange); /* * It is not an accident that whenever we want to break COW @@ -544,11 +545,11 @@ static void break_cow(struct rmap_item *rmap_item) */ put_anon_vma(rmap_item->anon_vma); - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); vma = find_mergeable_vma(mm, addr); if (vma) break_ksm(vma, addr); - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); } static struct page *get_mergeable_page(struct rmap_item *rmap_item) @@ -557,8 +558,9 @@ static struct page *get_mergeable_page(struct rmap_item *rmap_item) unsigned long addr = rmap_item->address; struct vm_area_struct *vma; struct page *page; + DEFINE_RANGE_LOCK_FULL(mmrange); - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); vma = find_mergeable_vma(mm, addr); if (!vma) goto out; @@ -574,7 +576,7 @@ static struct page *get_mergeable_page(struct rmap_item *rmap_item) out: page = NULL; } - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); return page; } @@ -969,6 +971,7 @@ static int unmerge_and_remove_all_rmap_items(void) struct mm_struct *mm; struct vm_area_struct *vma; int err = 0; + DEFINE_RANGE_LOCK_FULL(mmrange); spin_lock(&ksm_mmlist_lock); ksm_scan.mm_slot = list_entry(ksm_mm_head.mm_list.next, @@ -978,7 +981,7 @@ static int unmerge_and_remove_all_rmap_items(void) for (mm_slot = ksm_scan.mm_slot; mm_slot != &ksm_mm_head; mm_slot = ksm_scan.mm_slot) { mm = mm_slot->mm; - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); for (vma = mm->mmap; vma; vma = vma->vm_next) { if (ksm_test_exit(mm)) break; @@ -991,7 +994,7 @@ static int unmerge_and_remove_all_rmap_items(void) } remove_trailing_rmap_items(mm_slot, &mm_slot->rmap_list); - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); spin_lock(&ksm_mmlist_lock); ksm_scan.mm_slot = list_entry(mm_slot->mm_list.next, @@ -1014,7 +1017,7 @@ static int unmerge_and_remove_all_rmap_items(void) return 0; error: - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); spin_lock(&ksm_mmlist_lock); ksm_scan.mm_slot = &ksm_mm_head; spin_unlock(&ksm_mmlist_lock); @@ -1299,8 +1302,9 @@ static int try_to_merge_with_ksm_page(struct rmap_item *rmap_item, struct mm_struct *mm = rmap_item->mm; struct vm_area_struct *vma; int err = -EFAULT; + DEFINE_RANGE_LOCK_FULL(mmrange); - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); vma = find_mergeable_vma(mm, rmap_item->address); if (!vma) goto out; @@ -1316,7 +1320,7 @@ static int try_to_merge_with_ksm_page(struct rmap_item *rmap_item, rmap_item->anon_vma = vma->anon_vma; get_anon_vma(vma->anon_vma); out: - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); return err; } @@ -2129,12 +2133,13 @@ static void cmp_and_merge_page(struct page *page, struct rmap_item *rmap_item) */ if (ksm_use_zero_pages && (checksum == zero_checksum)) { struct vm_area_struct *vma; + DEFINE_RANGE_LOCK_FULL(mmrange); - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); vma = find_mergeable_vma(mm, rmap_item->address); err = try_to_merge_one_page(vma, page, ZERO_PAGE(rmap_item->address)); - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); /* * In case of failure, the page was not really empty, so we * need to continue. Otherwise we're done. @@ -2240,6 +2245,7 @@ static struct rmap_item *scan_get_next_rmap_item(struct page **page) struct vm_area_struct *vma; struct rmap_item *rmap_item; int nid; + DEFINE_RANGE_LOCK_FULL(mmrange); if (list_empty(&ksm_mm_head.mm_list)) return NULL; @@ -2297,7 +2303,7 @@ static struct rmap_item *scan_get_next_rmap_item(struct page **page) } mm = slot->mm; - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); if (ksm_test_exit(mm)) vma = NULL; else @@ -2331,7 +2337,7 @@ static struct rmap_item *scan_get_next_rmap_item(struct page **page) ksm_scan.address += PAGE_SIZE; } else put_page(*page); - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); return rmap_item; } put_page(*page); @@ -2369,10 +2375,10 @@ static struct rmap_item *scan_get_next_rmap_item(struct page **page) free_mm_slot(slot); clear_bit(MMF_VM_MERGEABLE, &mm->flags); - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); mmdrop(mm); } else { - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); /* * up_read(&mm->mmap_sem) first because after * spin_unlock(&ksm_mmlist_lock) run, the "mm" may @@ -2571,8 +2577,10 @@ void __ksm_exit(struct mm_struct *mm) clear_bit(MMF_VM_MERGEABLE, &mm->flags); mmdrop(mm); } else if (mm_slot) { - down_write(&mm->mmap_sem); - up_write(&mm->mmap_sem); + DEFINE_RANGE_LOCK_FULL(mmrange); + + mm_write_lock(mm, &mmrange); + mm_write_unlock(mm, &mmrange); } } diff --git a/mm/madvise.c b/mm/madvise.c index 628022e674a7..78a3f86d9c52 100644 --- a/mm/madvise.c +++ b/mm/madvise.c @@ -516,16 +516,16 @@ static long madvise_dontneed_single_vma(struct vm_area_struct *vma, static long madvise_dontneed_free(struct vm_area_struct *vma, struct vm_area_struct **prev, unsigned long start, unsigned long end, - int behavior) + int behavior, struct range_lock *mmrange) { *prev = vma; if (!can_madv_dontneed_vma(vma)) return -EINVAL; - if (!userfaultfd_remove(vma, start, end)) { + if (!userfaultfd_remove(vma, start, end, mmrange)) { *prev = NULL; /* mmap_sem has been dropped, prev is stale */ - down_read(¤t->mm->mmap_sem); + mm_read_lock(current->mm, mmrange); vma = find_vma(current->mm, start); if (!vma) return -ENOMEM; @@ -574,8 +574,9 @@ static long madvise_dontneed_free(struct vm_area_struct *vma, * This is effectively punching a hole into the middle of a file. */ static long madvise_remove(struct vm_area_struct *vma, - struct vm_area_struct **prev, - unsigned long start, unsigned long end) + struct vm_area_struct **prev, + unsigned long start, unsigned long end, + struct range_lock *mmrange) { loff_t offset; int error; @@ -605,15 +606,15 @@ static long madvise_remove(struct vm_area_struct *vma, * mmap_sem. */ get_file(f); - if (userfaultfd_remove(vma, start, end)) { + if (userfaultfd_remove(vma, start, end, mmrange)) { /* mmap_sem was not released by userfaultfd_remove() */ - up_read(¤t->mm->mmap_sem); + mm_read_unlock(current->mm, mmrange); } error = vfs_fallocate(f, FALLOC_FL_PUNCH_HOLE | FALLOC_FL_KEEP_SIZE, offset, end - start); fput(f); - down_read(¤t->mm->mmap_sem); + mm_read_lock(current->mm, mmrange); return error; } @@ -688,16 +689,18 @@ static int madvise_inject_error(int behavior, static long madvise_vma(struct vm_area_struct *vma, struct vm_area_struct **prev, - unsigned long start, unsigned long end, int behavior) + unsigned long start, unsigned long end, int behavior, + struct range_lock *mmrange) { switch (behavior) { case MADV_REMOVE: - return madvise_remove(vma, prev, start, end); + return madvise_remove(vma, prev, start, end, mmrange); case MADV_WILLNEED: return madvise_willneed(vma, prev, start, end); case MADV_FREE: case MADV_DONTNEED: - return madvise_dontneed_free(vma, prev, start, end, behavior); + return madvise_dontneed_free(vma, prev, start, end, + behavior, mmrange); default: return madvise_behavior(vma, prev, start, end, behavior); } @@ -809,6 +812,7 @@ SYSCALL_DEFINE3(madvise, unsigned long, start, size_t, len_in, int, behavior) int write; size_t len; struct blk_plug plug; + DEFINE_RANGE_LOCK_FULL(mmrange); if (!madvise_behavior_valid(behavior)) return error; @@ -836,10 +840,10 @@ SYSCALL_DEFINE3(madvise, unsigned long, start, size_t, len_in, int, behavior) write = madvise_need_mmap_write(behavior); if (write) { - if (down_write_killable(¤t->mm->mmap_sem)) + if (mm_write_lock_killable(current->mm, &mmrange)) return -EINTR; } else { - down_read(¤t->mm->mmap_sem); + mm_read_lock(current->mm, &mmrange); } /* @@ -872,7 +876,7 @@ SYSCALL_DEFINE3(madvise, unsigned long, start, size_t, len_in, int, behavior) tmp = end; /* Here vma->vm_start <= start < tmp <= (end|vma->vm_end). */ - error = madvise_vma(vma, &prev, start, tmp, behavior); + error = madvise_vma(vma, &prev, start, tmp, behavior, &mmrange); if (error) goto out; start = tmp; @@ -889,9 +893,9 @@ SYSCALL_DEFINE3(madvise, unsigned long, start, size_t, len_in, int, behavior) out: blk_finish_plug(&plug); if (write) - up_write(¤t->mm->mmap_sem); + mm_write_unlock(current->mm, &mmrange); else - up_read(¤t->mm->mmap_sem); + mm_read_unlock(current->mm, &mmrange); return error; } diff --git a/mm/memcontrol.c b/mm/memcontrol.c index 2535e54e7989..c822cea99570 100644 --- a/mm/memcontrol.c +++ b/mm/memcontrol.c @@ -5139,10 +5139,11 @@ static unsigned long mem_cgroup_count_precharge(struct mm_struct *mm) .pmd_entry = mem_cgroup_count_precharge_pte_range, .mm = mm, }; - down_read(&mm->mmap_sem); + DEFINE_RANGE_LOCK_FULL(mmrange); + mm_read_lock(mm, &mmrange); walk_page_range(0, mm->highest_vm_end, &mem_cgroup_count_precharge_walk); - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); precharge = mc.precharge; mc.precharge = 0; @@ -5412,6 +5413,7 @@ static int mem_cgroup_move_charge_pte_range(pmd_t *pmd, static void mem_cgroup_move_charge(void) { + DEFINE_RANGE_LOCK_FULL(mmrange); struct mm_walk mem_cgroup_move_charge_walk = { .pmd_entry = mem_cgroup_move_charge_pte_range, .mm = mc.mm, @@ -5426,7 +5428,7 @@ static void mem_cgroup_move_charge(void) atomic_inc(&mc.from->moving_account); synchronize_rcu(); retry: - if (unlikely(!down_read_trylock(&mc.mm->mmap_sem))) { + if (unlikely(!mm_read_trylock(mc.mm, &mmrange))) { /* * Someone who are holding the mmap_sem might be waiting in * waitq. So we cancel all extra charges, wake up all waiters, @@ -5444,7 +5446,7 @@ static void mem_cgroup_move_charge(void) */ walk_page_range(0, mc.mm->highest_vm_end, &mem_cgroup_move_charge_walk); - up_read(&mc.mm->mmap_sem); + mm_read_unlock(mc.mm, &mmrange); atomic_dec(&mc.from->moving_account); } diff --git a/mm/memory.c b/mm/memory.c index 73971f859035..8a5f52978893 100644 --- a/mm/memory.c +++ b/mm/memory.c @@ -4347,8 +4347,9 @@ int __access_remote_vm(struct task_struct *tsk, struct mm_struct *mm, struct vm_area_struct *vma; void *old_buf = buf; int write = gup_flags & FOLL_WRITE; + DEFINE_RANGE_LOCK_FULL(mmrange); - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); /* ignore errors, just check how much was successfully transferred */ while (len) { int bytes, ret, offset; @@ -4397,7 +4398,7 @@ int __access_remote_vm(struct task_struct *tsk, struct mm_struct *mm, buf += bytes; addr += bytes; } - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); return buf - old_buf; } @@ -4450,11 +4451,12 @@ void print_vma_addr(char *prefix, unsigned long ip) { struct mm_struct *mm = current->mm; struct vm_area_struct *vma; + DEFINE_RANGE_LOCK_FULL(mmrange); /* * we might be running from an atomic context so we cannot sleep */ - if (!down_read_trylock(&mm->mmap_sem)) + if (!mm_read_trylock(mm, &mmrange)) return; vma = find_vma(mm, ip); @@ -4473,7 +4475,7 @@ void print_vma_addr(char *prefix, unsigned long ip) free_page((unsigned long)buf); } } - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); } #if defined(CONFIG_PROVE_LOCKING) || defined(CONFIG_DEBUG_ATOMIC_SLEEP) diff --git a/mm/mempolicy.c b/mm/mempolicy.c index 975793cc1d71..8bf8861e0c73 100644 --- a/mm/mempolicy.c +++ b/mm/mempolicy.c @@ -378,11 +378,12 @@ void mpol_rebind_task(struct task_struct *tsk, const nodemask_t *new) void mpol_rebind_mm(struct mm_struct *mm, nodemask_t *new) { struct vm_area_struct *vma; + DEFINE_RANGE_LOCK_FULL(mmrange); - down_write(&mm->mmap_sem); + mm_write_lock(mm, &mmrange); for (vma = mm->mmap; vma; vma = vma->vm_next) mpol_rebind_policy(vma->vm_policy, new); - up_write(&mm->mmap_sem); + mm_write_unlock(mm, &mmrange); } static const struct mempolicy_operations mpol_ops[MPOL_MAX] = { @@ -837,7 +838,7 @@ static int lookup_node(struct mm_struct *mm, unsigned long addr, put_page(p); } if (locked) - up_read(&mm->mmap_sem); + mm_read_unlock(mm, mmrange); return err; } @@ -871,10 +872,10 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask, * vma/shared policy at addr is NULL. We * want to return MPOL_DEFAULT in this case. */ - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); vma = find_vma_intersection(mm, addr, addr+1); if (!vma) { - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); return -EFAULT; } if (vma->vm_ops && vma->vm_ops->get_policy) @@ -933,7 +934,7 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask, out: mpol_cond_put(pol); if (vma) - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); if (pol_refcount) mpol_put(pol_refcount); return err; @@ -1026,12 +1027,13 @@ int do_migrate_pages(struct mm_struct *mm, const nodemask_t *from, int busy = 0; int err; nodemask_t tmp; + DEFINE_RANGE_LOCK_FULL(mmrange); err = migrate_prep(); if (err) return err; - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); /* * Find a 'source' bit set in 'tmp' whose corresponding 'dest' @@ -1112,7 +1114,7 @@ int do_migrate_pages(struct mm_struct *mm, const nodemask_t *from, if (err < 0) break; } - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); if (err < 0) return err; return busy; @@ -1186,6 +1188,7 @@ static long do_mbind(unsigned long start, unsigned long len, unsigned long end; int err; LIST_HEAD(pagelist); + DEFINE_RANGE_LOCK_FULL(mmrange); if (flags & ~(unsigned long)MPOL_MF_VALID) return -EINVAL; @@ -1233,12 +1236,12 @@ static long do_mbind(unsigned long start, unsigned long len, { NODEMASK_SCRATCH(scratch); if (scratch) { - down_write(&mm->mmap_sem); + mm_write_lock(mm, &mmrange); task_lock(current); err = mpol_set_nodemask(new, nmask, scratch); task_unlock(current); if (err) - up_write(&mm->mmap_sem); + mm_write_unlock(mm, &mmrange); } else err = -ENOMEM; NODEMASK_SCRATCH_FREE(scratch); @@ -1267,7 +1270,7 @@ static long do_mbind(unsigned long start, unsigned long len, } else putback_movable_pages(&pagelist); - up_write(&mm->mmap_sem); + mm_write_unlock(mm, &mmrange); mpol_out: mpol_put(new); return err; diff --git a/mm/migrate.c b/mm/migrate.c index f2ecc2855a12..3a268b316e4e 100644 --- a/mm/migrate.c +++ b/mm/migrate.c @@ -1531,8 +1531,9 @@ static int add_page_for_migration(struct mm_struct *mm, unsigned long addr, struct page *page; unsigned int follflags; int err; + DEFINE_RANGE_LOCK_FULL(mmrange); - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); err = -EFAULT; vma = find_vma(mm, addr); if (!vma || addr < vma->vm_start || !vma_migratable(vma)) @@ -1585,7 +1586,7 @@ static int add_page_for_migration(struct mm_struct *mm, unsigned long addr, */ put_page(page); out: - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); return err; } @@ -1686,8 +1687,9 @@ static void do_pages_stat_array(struct mm_struct *mm, unsigned long nr_pages, const void __user **pages, int *status) { unsigned long i; + DEFINE_RANGE_LOCK_FULL(mmrange); - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); for (i = 0; i < nr_pages; i++) { unsigned long addr = (unsigned long)(*pages); @@ -1714,7 +1716,7 @@ static void do_pages_stat_array(struct mm_struct *mm, unsigned long nr_pages, status++; } - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); } /* diff --git a/mm/mincore.c b/mm/mincore.c index c3f058bd0faf..c1d3a9cd2ba3 100644 --- a/mm/mincore.c +++ b/mm/mincore.c @@ -270,13 +270,15 @@ SYSCALL_DEFINE3(mincore, unsigned long, start, size_t, len, retval = 0; while (pages) { + DEFINE_RANGE_LOCK_FULL(mmrange); + /* * Do at most PAGE_SIZE entries per iteration, due to * the temporary buffer size. */ - down_read(¤t->mm->mmap_sem); + mm_read_lock(current->mm, &mmrange); retval = do_mincore(start, min(pages, PAGE_SIZE), tmp); - up_read(¤t->mm->mmap_sem); + mm_read_unlock(current->mm, &mmrange); if (retval <= 0) break; diff --git a/mm/mlock.c b/mm/mlock.c index e492a155c51a..c5b5dbd92a3a 100644 --- a/mm/mlock.c +++ b/mm/mlock.c @@ -670,6 +670,7 @@ static int count_mm_mlocked_page_nr(struct mm_struct *mm, static __must_check int do_mlock(unsigned long start, size_t len, vm_flags_t flags) { + DEFINE_RANGE_LOCK_FULL(mmrange); unsigned long locked; unsigned long lock_limit; int error = -ENOMEM; @@ -684,7 +685,7 @@ static __must_check int do_mlock(unsigned long start, size_t len, vm_flags_t fla lock_limit >>= PAGE_SHIFT; locked = len >> PAGE_SHIFT; - if (down_write_killable(¤t->mm->mmap_sem)) + if (mm_write_lock_killable(current->mm, &mmrange)) return -EINTR; locked += atomic64_read(¤t->mm->locked_vm); @@ -703,7 +704,7 @@ static __must_check int do_mlock(unsigned long start, size_t len, vm_flags_t fla if ((locked <= lock_limit) || capable(CAP_IPC_LOCK)) error = apply_vma_lock_flags(start, len, flags); - up_write(¤t->mm->mmap_sem); + mm_write_unlock(current->mm, &mmrange); if (error) return error; @@ -733,15 +734,16 @@ SYSCALL_DEFINE3(mlock2, unsigned long, start, size_t, len, int, flags) SYSCALL_DEFINE2(munlock, unsigned long, start, size_t, len) { + DEFINE_RANGE_LOCK_FULL(mmrange); int ret; len = PAGE_ALIGN(len + (offset_in_page(start))); start &= PAGE_MASK; - if (down_write_killable(¤t->mm->mmap_sem)) + if (mm_write_lock_killable(current->mm, &mmrange)) return -EINTR; ret = apply_vma_lock_flags(start, len, 0); - up_write(¤t->mm->mmap_sem); + mm_write_unlock(current->mm, &mmrange); return ret; } @@ -794,6 +796,7 @@ static int apply_mlockall_flags(int flags) SYSCALL_DEFINE1(mlockall, int, flags) { + DEFINE_RANGE_LOCK_FULL(mmrange); unsigned long lock_limit; int ret; @@ -806,14 +809,14 @@ SYSCALL_DEFINE1(mlockall, int, flags) lock_limit = rlimit(RLIMIT_MEMLOCK); lock_limit >>= PAGE_SHIFT; - if (down_write_killable(¤t->mm->mmap_sem)) + if (mm_write_lock_killable(current->mm, &mmrange)) return -EINTR; ret = -ENOMEM; if (!(flags & MCL_CURRENT) || (current->mm->total_vm <= lock_limit) || capable(CAP_IPC_LOCK)) ret = apply_mlockall_flags(flags); - up_write(¤t->mm->mmap_sem); + mm_write_unlock(current->mm, &mmrange); if (!ret && (flags & MCL_CURRENT)) mm_populate(0, TASK_SIZE); @@ -822,12 +825,13 @@ SYSCALL_DEFINE1(mlockall, int, flags) SYSCALL_DEFINE0(munlockall) { + DEFINE_RANGE_LOCK_FULL(mmrange); int ret; - if (down_write_killable(¤t->mm->mmap_sem)) + if (mm_write_lock_killable(current->mm, &mmrange)) return -EINTR; ret = apply_mlockall_flags(0); - up_write(¤t->mm->mmap_sem); + mm_write_unlock(current->mm, &mmrange); return ret; } diff --git a/mm/mmap.c b/mm/mmap.c index a03ded49f9eb..2eecdeb5fcd6 100644 --- a/mm/mmap.c +++ b/mm/mmap.c @@ -198,9 +198,10 @@ SYSCALL_DEFINE1(brk, unsigned long, brk) unsigned long min_brk; bool populate; bool downgraded = false; + DEFINE_RANGE_LOCK_FULL(mmrange); LIST_HEAD(uf); - if (down_write_killable(&mm->mmap_sem)) + if (mm_write_lock_killable(mm, &mmrange)) return -EINTR; origbrk = mm->brk; @@ -251,7 +252,7 @@ SYSCALL_DEFINE1(brk, unsigned long, brk) * mm->brk will be restored from origbrk. */ mm->brk = brk; - ret = __do_munmap(mm, newbrk, oldbrk-newbrk, &uf, true); + ret = __do_munmap(mm, newbrk, oldbrk-newbrk, &uf, true, &mmrange); if (ret < 0) { mm->brk = origbrk; goto out; @@ -274,9 +275,9 @@ SYSCALL_DEFINE1(brk, unsigned long, brk) success: populate = newbrk > oldbrk && (mm->def_flags & VM_LOCKED) != 0; if (downgraded) - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); else - up_write(&mm->mmap_sem); + mm_write_unlock(mm, &mmrange); userfaultfd_unmap_complete(mm, &uf); if (populate) mm_populate(oldbrk, newbrk - oldbrk); @@ -284,7 +285,7 @@ SYSCALL_DEFINE1(brk, unsigned long, brk) out: retval = origbrk; - up_write(&mm->mmap_sem); + mm_write_unlock(mm, &mmrange); return retval; } @@ -2726,7 +2727,8 @@ int split_vma(struct mm_struct *mm, struct vm_area_struct *vma, * Jeremy Fitzhardinge <jeremy@xxxxxxxx> */ int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len, - struct list_head *uf, bool downgrade) + struct list_head *uf, bool downgrade, + struct range_lock *mmrange) { unsigned long end; struct vm_area_struct *vma, *prev, *last; @@ -2824,7 +2826,7 @@ int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len, detach_vmas_to_be_unmapped(mm, vma, prev, end); if (downgrade) - downgrade_write(&mm->mmap_sem); + mm_downgrade_write(mm, mmrange); unmap_region(mm, vma, prev, start, end); @@ -2837,7 +2839,7 @@ int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len, int do_munmap(struct mm_struct *mm, unsigned long start, size_t len, struct list_head *uf) { - return __do_munmap(mm, start, len, uf, false); + return __do_munmap(mm, start, len, uf, false, NULL); } static int __vm_munmap(unsigned long start, size_t len, bool downgrade) @@ -2845,21 +2847,22 @@ static int __vm_munmap(unsigned long start, size_t len, bool downgrade) int ret; struct mm_struct *mm = current->mm; LIST_HEAD(uf); + DEFINE_RANGE_LOCK_FULL(mmrange); - if (down_write_killable(&mm->mmap_sem)) + if (mm_write_lock_killable(mm, &mmrange)) return -EINTR; - ret = __do_munmap(mm, start, len, &uf, downgrade); + ret = __do_munmap(mm, start, len, &uf, downgrade, &mmrange); /* * Returning 1 indicates mmap_sem is downgraded. * But 1 is not legal return value of vm_munmap() and munmap(), reset * it to 0 before return. */ if (ret == 1) { - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); ret = 0; } else - up_write(&mm->mmap_sem); + mm_write_unlock(mm, &mmrange); userfaultfd_unmap_complete(mm, &uf); return ret; @@ -2884,6 +2887,7 @@ SYSCALL_DEFINE2(munmap, unsigned long, addr, size_t, len) SYSCALL_DEFINE5(remap_file_pages, unsigned long, start, unsigned long, size, unsigned long, prot, unsigned long, pgoff, unsigned long, flags) { + DEFINE_RANGE_LOCK_FULL(mmrange); struct mm_struct *mm = current->mm; struct vm_area_struct *vma; @@ -2906,7 +2910,7 @@ SYSCALL_DEFINE5(remap_file_pages, unsigned long, start, unsigned long, size, if (pgoff + (size >> PAGE_SHIFT) < pgoff) return ret; - if (down_write_killable(&mm->mmap_sem)) + if (mm_write_lock_killable(mm, &mmrange)) return -EINTR; vma = find_vma(mm, start); @@ -2969,7 +2973,7 @@ SYSCALL_DEFINE5(remap_file_pages, unsigned long, start, unsigned long, size, prot, flags, pgoff, &populate, NULL); fput(file); out: - up_write(&mm->mmap_sem); + mm_write_unlock(mm, &mmrange); if (populate) mm_populate(ret, populate); if (!IS_ERR_VALUE(ret)) @@ -3056,6 +3060,7 @@ static int do_brk_flags(unsigned long addr, unsigned long len, unsigned long fla int vm_brk_flags(unsigned long addr, unsigned long request, unsigned long flags) { + DEFINE_RANGE_LOCK_FULL(mmrange); struct mm_struct *mm = current->mm; unsigned long len; int ret; @@ -3068,12 +3073,12 @@ int vm_brk_flags(unsigned long addr, unsigned long request, unsigned long flags) if (!len) return 0; - if (down_write_killable(&mm->mmap_sem)) + if (mm_write_lock_killable(mm, &mmrange)) return -EINTR; ret = do_brk_flags(addr, len, flags, &uf); populate = ((mm->def_flags & VM_LOCKED) != 0); - up_write(&mm->mmap_sem); + mm_write_unlock(mm, &mmrange); userfaultfd_unmap_complete(mm, &uf); if (populate && !ret) mm_populate(addr, len); @@ -3098,6 +3103,8 @@ void exit_mmap(struct mm_struct *mm) mmu_notifier_release(mm); if (unlikely(mm_is_oom_victim(mm))) { + DEFINE_RANGE_LOCK_FULL(mmrange); + /* * Manually reap the mm to free as much memory as possible. * Then, as the oom reaper does, set MMF_OOM_SKIP to disregard @@ -3117,8 +3124,8 @@ void exit_mmap(struct mm_struct *mm) (void)__oom_reap_task_mm(mm); set_bit(MMF_OOM_SKIP, &mm->flags); - down_write(&mm->mmap_sem); - up_write(&mm->mmap_sem); + mm_write_lock(mm, &mmrange); + mm_write_unlock(mm, &mmrange); } if (atomic64_read(&mm->locked_vm)) { @@ -3459,14 +3466,15 @@ int install_special_mapping(struct mm_struct *mm, static DEFINE_MUTEX(mm_all_locks_mutex); -static void vm_lock_anon_vma(struct mm_struct *mm, struct anon_vma *anon_vma) +static void vm_lock_anon_vma(struct mm_struct *mm, struct anon_vma *anon_vma, + struct range_lock *mmrange) { if (!test_bit(0, (unsigned long *) &anon_vma->root->rb_root.rb_root.rb_node)) { /* * The LSB of head.next can't change from under us * because we hold the mm_all_locks_mutex. */ - down_write(&mm->mmap_sem); + mm_write_lock(mm, mmrange); /* * We can safely modify head.next after taking the * anon_vma->root->rwsem. If some other vma in this mm shares @@ -3482,7 +3490,8 @@ static void vm_lock_anon_vma(struct mm_struct *mm, struct anon_vma *anon_vma) } } -static void vm_lock_mapping(struct mm_struct *mm, struct address_space *mapping) +static void vm_lock_mapping(struct mm_struct *mm, struct address_space *mapping, + struct range_lock *mmrange) { if (!test_bit(AS_MM_ALL_LOCKS, &mapping->flags)) { /* @@ -3496,7 +3505,7 @@ static void vm_lock_mapping(struct mm_struct *mm, struct address_space *mapping) */ if (test_and_set_bit(AS_MM_ALL_LOCKS, &mapping->flags)) BUG(); - down_write(&mm->mmap_sem); + mm_write_lock(mm, mmrange); } } @@ -3537,12 +3546,12 @@ static void vm_lock_mapping(struct mm_struct *mm, struct address_space *mapping) * * mm_take_all_locks() can fail if it's interrupted by signals. */ -int mm_take_all_locks(struct mm_struct *mm) +int mm_take_all_locks(struct mm_struct *mm, struct range_lock *mmrange) { struct vm_area_struct *vma; struct anon_vma_chain *avc; - BUG_ON(down_read_trylock(&mm->mmap_sem)); + BUG_ON(mm_read_trylock(mm, mmrange)); mutex_lock(&mm_all_locks_mutex); @@ -3551,7 +3560,7 @@ int mm_take_all_locks(struct mm_struct *mm) goto out_unlock; if (vma->vm_file && vma->vm_file->f_mapping && is_vm_hugetlb_page(vma)) - vm_lock_mapping(mm, vma->vm_file->f_mapping); + vm_lock_mapping(mm, vma->vm_file->f_mapping, mmrange); } for (vma = mm->mmap; vma; vma = vma->vm_next) { @@ -3559,7 +3568,7 @@ int mm_take_all_locks(struct mm_struct *mm) goto out_unlock; if (vma->vm_file && vma->vm_file->f_mapping && !is_vm_hugetlb_page(vma)) - vm_lock_mapping(mm, vma->vm_file->f_mapping); + vm_lock_mapping(mm, vma->vm_file->f_mapping, mmrange); } for (vma = mm->mmap; vma; vma = vma->vm_next) { @@ -3567,13 +3576,13 @@ int mm_take_all_locks(struct mm_struct *mm) goto out_unlock; if (vma->anon_vma) list_for_each_entry(avc, &vma->anon_vma_chain, same_vma) - vm_lock_anon_vma(mm, avc->anon_vma); + vm_lock_anon_vma(mm, avc->anon_vma, mmrange); } return 0; out_unlock: - mm_drop_all_locks(mm); + mm_drop_all_locks(mm, mmrange); return -EINTR; } @@ -3617,12 +3626,12 @@ static void vm_unlock_mapping(struct address_space *mapping) * The mmap_sem cannot be released by the caller until * mm_drop_all_locks() returns. */ -void mm_drop_all_locks(struct mm_struct *mm) +void mm_drop_all_locks(struct mm_struct *mm, struct range_lock *mmrange) { struct vm_area_struct *vma; struct anon_vma_chain *avc; - BUG_ON(down_read_trylock(&mm->mmap_sem)); + BUG_ON(mm_read_trylock(mm, mmrange)); BUG_ON(!mutex_is_locked(&mm_all_locks_mutex)); for (vma = mm->mmap; vma; vma = vma->vm_next) { diff --git a/mm/mmu_notifier.c b/mm/mmu_notifier.c index ee36068077b6..028eaed031e1 100644 --- a/mm/mmu_notifier.c +++ b/mm/mmu_notifier.c @@ -244,6 +244,7 @@ static int do_mmu_notifier_register(struct mmu_notifier *mn, { struct mmu_notifier_mm *mmu_notifier_mm; int ret; + DEFINE_RANGE_LOCK_FULL(mmrange); BUG_ON(atomic_read(&mm->mm_users) <= 0); @@ -253,8 +254,8 @@ static int do_mmu_notifier_register(struct mmu_notifier *mn, goto out; if (take_mmap_sem) - down_write(&mm->mmap_sem); - ret = mm_take_all_locks(mm); + mm_write_lock(mm, &mmrange); + ret = mm_take_all_locks(mm, &mmrange); if (unlikely(ret)) goto out_clean; @@ -279,10 +280,10 @@ static int do_mmu_notifier_register(struct mmu_notifier *mn, hlist_add_head(&mn->hlist, &mm->mmu_notifier_mm->list); spin_unlock(&mm->mmu_notifier_mm->lock); - mm_drop_all_locks(mm); + mm_drop_all_locks(mm, &mmrange); out_clean: if (take_mmap_sem) - up_write(&mm->mmap_sem); + mm_write_unlock(mm, &mmrange); kfree(mmu_notifier_mm); out: BUG_ON(atomic_read(&mm->mm_users) <= 0); diff --git a/mm/mprotect.c b/mm/mprotect.c index 36c517c6a5b1..443b033f240c 100644 --- a/mm/mprotect.c +++ b/mm/mprotect.c @@ -458,6 +458,7 @@ mprotect_fixup(struct vm_area_struct *vma, struct vm_area_struct **pprev, static int do_mprotect_pkey(unsigned long start, size_t len, unsigned long prot, int pkey) { + DEFINE_RANGE_LOCK_FULL(mmrange); unsigned long nstart, end, tmp, reqprot; struct vm_area_struct *vma, *prev; int error = -EINVAL; @@ -482,7 +483,7 @@ static int do_mprotect_pkey(unsigned long start, size_t len, reqprot = prot; - if (down_write_killable(¤t->mm->mmap_sem)) + if (mm_write_lock_killable(current->mm, &mmrange)) return -EINTR; /* @@ -572,7 +573,7 @@ static int do_mprotect_pkey(unsigned long start, size_t len, prot = reqprot; } out: - up_write(¤t->mm->mmap_sem); + mm_write_unlock(current->mm, &mmrange); return error; } @@ -594,6 +595,7 @@ SYSCALL_DEFINE2(pkey_alloc, unsigned long, flags, unsigned long, init_val) { int pkey; int ret; + DEFINE_RANGE_LOCK_FULL(mmrange); /* No flags supported yet. */ if (flags) @@ -602,7 +604,7 @@ SYSCALL_DEFINE2(pkey_alloc, unsigned long, flags, unsigned long, init_val) if (init_val & ~PKEY_ACCESS_MASK) return -EINVAL; - down_write(¤t->mm->mmap_sem); + mm_write_lock(current->mm, &mmrange); pkey = mm_pkey_alloc(current->mm); ret = -ENOSPC; @@ -616,17 +618,18 @@ SYSCALL_DEFINE2(pkey_alloc, unsigned long, flags, unsigned long, init_val) } ret = pkey; out: - up_write(¤t->mm->mmap_sem); + mm_write_unlock(current->mm, &mmrange); return ret; } SYSCALL_DEFINE1(pkey_free, int, pkey) { int ret; + DEFINE_RANGE_LOCK_FULL(mmrange); - down_write(¤t->mm->mmap_sem); + mm_write_lock(current->mm, &mmrange); ret = mm_pkey_free(current->mm, pkey); - up_write(¤t->mm->mmap_sem); + mm_write_unlock(current->mm, &mmrange); /* * We could provie warnings or errors if any VMA still diff --git a/mm/mremap.c b/mm/mremap.c index 37b5b2ad91be..9009210aea97 100644 --- a/mm/mremap.c +++ b/mm/mremap.c @@ -603,6 +603,7 @@ SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len, bool locked = false; bool downgraded = false; struct vm_userfaultfd_ctx uf = NULL_VM_UFFD_CTX; + DEFINE_RANGE_LOCK_FULL(mmrange); LIST_HEAD(uf_unmap_early); LIST_HEAD(uf_unmap); @@ -626,7 +627,7 @@ SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len, if (!new_len) return ret; - if (down_write_killable(¤t->mm->mmap_sem)) + if (mm_write_lock_killable(current->mm, &mmrange)) return -EINTR; if (flags & MREMAP_FIXED) { @@ -645,7 +646,7 @@ SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len, int retval; retval = __do_munmap(mm, addr+new_len, old_len - new_len, - &uf_unmap, true); + &uf_unmap, true, &mmrange); if (retval < 0 && old_len != new_len) { ret = retval; goto out; @@ -717,9 +718,9 @@ SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len, locked = 0; } if (downgraded) - up_read(¤t->mm->mmap_sem); + mm_read_unlock(current->mm, &mmrange); else - up_write(¤t->mm->mmap_sem); + mm_write_unlock(current->mm, &mmrange); if (locked && new_len > old_len) mm_populate(new_addr + old_len, new_len - old_len); userfaultfd_unmap_complete(mm, &uf_unmap_early); diff --git a/mm/msync.c b/mm/msync.c index ef30a429623a..2524b4708e78 100644 --- a/mm/msync.c +++ b/mm/msync.c @@ -36,6 +36,7 @@ SYSCALL_DEFINE3(msync, unsigned long, start, size_t, len, int, flags) struct vm_area_struct *vma; int unmapped_error = 0; int error = -EINVAL; + DEFINE_RANGE_LOCK_FULL(mmrange); if (flags & ~(MS_ASYNC | MS_INVALIDATE | MS_SYNC)) goto out; @@ -55,7 +56,7 @@ SYSCALL_DEFINE3(msync, unsigned long, start, size_t, len, int, flags) * If the interval [start,end) covers some unmapped address ranges, * just ignore them, but return -ENOMEM at the end. */ - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); vma = find_vma(mm, start); for (;;) { struct file *file; @@ -86,12 +87,12 @@ SYSCALL_DEFINE3(msync, unsigned long, start, size_t, len, int, flags) if ((flags & MS_SYNC) && file && (vma->vm_flags & VM_SHARED)) { get_file(file); - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); error = vfs_fsync_range(file, fstart, fend, 1); fput(file); if (error || start >= end) goto out; - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); vma = find_vma(mm, start); } else { if (start >= end) { @@ -102,7 +103,7 @@ SYSCALL_DEFINE3(msync, unsigned long, start, size_t, len, int, flags) } } out_unlock: - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); out: return error ? : unmapped_error; } diff --git a/mm/nommu.c b/mm/nommu.c index b492fd1fcf9f..b454b0004fd2 100644 --- a/mm/nommu.c +++ b/mm/nommu.c @@ -183,10 +183,11 @@ static long __get_user_pages_unlocked(struct task_struct *tsk, unsigned int gup_flags) { long ret; - down_read(&mm->mmap_sem); + DEFINE_RANGE_LOCK_FULL(mmrange); + mm_read_lock(mm, &mmrange); ret = __get_user_pages(tsk, mm, start, nr_pages, gup_flags, pages, NULL, NULL); - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); return ret; } @@ -249,12 +250,13 @@ void *vmalloc_user(unsigned long size) ret = __vmalloc(size, GFP_KERNEL | __GFP_ZERO, PAGE_KERNEL); if (ret) { struct vm_area_struct *vma; + DEFINE_RANGE_LOCK_FULL(mmrange); - down_write(¤t->mm->mmap_sem); + mm_write_lock(current->mm, &mmrange); vma = find_vma(current->mm, (unsigned long)ret); if (vma) vma->vm_flags |= VM_USERMAP; - up_write(¤t->mm->mmap_sem); + mm_write_unlock(current->mm, &mmrange); } return ret; @@ -1627,10 +1629,11 @@ int vm_munmap(unsigned long addr, size_t len) { struct mm_struct *mm = current->mm; int ret; + DEFINE_RANGE_LOCK_FULL(mmrange); - down_write(&mm->mmap_sem); + mm_write_lock(mm, &mmrange); ret = do_munmap(mm, addr, len, NULL); - up_write(&mm->mmap_sem); + mm_write_unlock(mm, &mmrange); return ret; } EXPORT_SYMBOL(vm_munmap); @@ -1716,10 +1719,11 @@ SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len, unsigned long, new_addr) { unsigned long ret; + DEFINE_RANGE_LOCK_FULL(mmrange); - down_write(¤t->mm->mmap_sem); + mm_write_lock(current->mm, &mmrange); ret = do_mremap(addr, old_len, new_len, flags, new_addr); - up_write(¤t->mm->mmap_sem); + mm_write_unlock(current->mm, &mmrange); return ret; } @@ -1790,8 +1794,9 @@ int __access_remote_vm(struct task_struct *tsk, struct mm_struct *mm, { struct vm_area_struct *vma; int write = gup_flags & FOLL_WRITE; + DEFINE_RANGE_LOCK_FULL(mmrange); - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); /* the access must start within one of the target process's mappings */ vma = find_vma(mm, addr); @@ -1813,7 +1818,7 @@ int __access_remote_vm(struct task_struct *tsk, struct mm_struct *mm, len = 0; } - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); return len; } diff --git a/mm/oom_kill.c b/mm/oom_kill.c index 539c91d0b26a..a8e3e6279718 100644 --- a/mm/oom_kill.c +++ b/mm/oom_kill.c @@ -558,8 +558,9 @@ bool __oom_reap_task_mm(struct mm_struct *mm) static bool oom_reap_task_mm(struct task_struct *tsk, struct mm_struct *mm) { bool ret = true; + DEFINE_RANGE_LOCK_FULL(mmrange); - if (!down_read_trylock(&mm->mmap_sem)) { + if (!mm_read_trylock(mm, &mmrange)) { trace_skip_task_reaping(tsk->pid); return false; } @@ -590,7 +591,7 @@ static bool oom_reap_task_mm(struct task_struct *tsk, struct mm_struct *mm) out_finish: trace_finish_task_reaping(tsk->pid); out_unlock: - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); return ret; } diff --git a/mm/process_vm_access.c b/mm/process_vm_access.c index ff6772b86195..aaccb8972f83 100644 --- a/mm/process_vm_access.c +++ b/mm/process_vm_access.c @@ -110,12 +110,12 @@ static int process_vm_rw_single_vec(unsigned long addr, * access remotely because task/mm might not * current/current->mm */ - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); pages = get_user_pages_remote(task, mm, pa, pages, flags, process_pages, NULL, &locked, &mmrange); if (locked) - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); if (pages <= 0) return -EFAULT; diff --git a/mm/shmem.c b/mm/shmem.c index 1bb3b8dc8bb2..bae06efb293d 100644 --- a/mm/shmem.c +++ b/mm/shmem.c @@ -2012,7 +2012,7 @@ static vm_fault_t shmem_fault(struct vm_fault *vmf) if ((vmf->flags & FAULT_FLAG_ALLOW_RETRY) && !(vmf->flags & FAULT_FLAG_RETRY_NOWAIT)) { /* It's polite to up mmap_sem if we can */ - up_read(&vma->vm_mm->mmap_sem); + mm_read_unlock(vma->vm_mm, vmf->lockrange); ret = VM_FAULT_RETRY; } diff --git a/mm/swapfile.c b/mm/swapfile.c index be36f6fe2f8c..dabe7d5391d1 100644 --- a/mm/swapfile.c +++ b/mm/swapfile.c @@ -1972,8 +1972,9 @@ static int unuse_mm(struct mm_struct *mm, unsigned int type, { struct vm_area_struct *vma; int ret = 0; + DEFINE_RANGE_LOCK_FULL(mmrange); - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); for (vma = mm->mmap; vma; vma = vma->vm_next) { if (vma->anon_vma) { ret = unuse_vma(vma, type, frontswap, @@ -1983,7 +1984,7 @@ static int unuse_mm(struct mm_struct *mm, unsigned int type, } cond_resched(); } - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); return ret; } diff --git a/mm/userfaultfd.c b/mm/userfaultfd.c index 9932d5755e4c..06daedcd06e6 100644 --- a/mm/userfaultfd.c +++ b/mm/userfaultfd.c @@ -177,7 +177,8 @@ static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm, unsigned long dst_start, unsigned long src_start, unsigned long len, - bool zeropage) + bool zeropage, + struct range_lock *mmrange) { int vm_alloc_shared = dst_vma->vm_flags & VM_SHARED; int vm_shared = dst_vma->vm_flags & VM_SHARED; @@ -199,7 +200,7 @@ static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm, * feature is not supported. */ if (zeropage) { - up_read(&dst_mm->mmap_sem); + mm_read_unlock(dst_mm, mmrange); return -EINVAL; } @@ -297,7 +298,7 @@ static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm, cond_resched(); if (unlikely(err == -ENOENT)) { - up_read(&dst_mm->mmap_sem); + mm_read_unlock(dst_mm, mmrange); BUG_ON(!page); err = copy_huge_page_from_user(page, @@ -307,7 +308,7 @@ static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm, err = -EFAULT; goto out; } - down_read(&dst_mm->mmap_sem); + mm_read_lock(dst_mm, mmrange); dst_vma = NULL; goto retry; @@ -327,7 +328,7 @@ static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm, } out_unlock: - up_read(&dst_mm->mmap_sem); + mm_read_unlock(dst_mm, mmrange); out: if (page) { /* @@ -445,6 +446,7 @@ static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm, unsigned long src_addr, dst_addr; long copied; struct page *page; + DEFINE_RANGE_LOCK_FULL(mmrange); /* * Sanitize the command parameters: @@ -461,7 +463,7 @@ static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm, copied = 0; page = NULL; retry: - down_read(&dst_mm->mmap_sem); + mm_read_lock(dst_mm, &mmrange); /* * If memory mappings are changing because of non-cooperative @@ -506,7 +508,8 @@ static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm, */ if (is_vm_hugetlb_page(dst_vma)) return __mcopy_atomic_hugetlb(dst_mm, dst_vma, dst_start, - src_start, len, zeropage); + src_start, len, zeropage, + &mmrange); if (!vma_is_anonymous(dst_vma) && !vma_is_shmem(dst_vma)) goto out_unlock; @@ -562,7 +565,7 @@ static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm, if (unlikely(err == -ENOENT)) { void *page_kaddr; - up_read(&dst_mm->mmap_sem); + mm_read_unlock(dst_mm, &mmrange); BUG_ON(!page); page_kaddr = kmap(page); @@ -591,7 +594,7 @@ static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm, } out_unlock: - up_read(&dst_mm->mmap_sem); + mm_read_unlock(dst_mm, &mmrange); out: if (page) put_page(page); diff --git a/mm/util.c b/mm/util.c index e2e4f8c3fa12..c410c17ddea7 100644 --- a/mm/util.c +++ b/mm/util.c @@ -350,6 +350,7 @@ unsigned long vm_mmap_pgoff(struct file *file, unsigned long addr, unsigned long len, unsigned long prot, unsigned long flag, unsigned long pgoff) { + DEFINE_RANGE_LOCK_FULL(mmrange); unsigned long ret; struct mm_struct *mm = current->mm; unsigned long populate; @@ -357,11 +358,11 @@ unsigned long vm_mmap_pgoff(struct file *file, unsigned long addr, ret = security_mmap_file(file, prot, flag); if (!ret) { - if (down_write_killable(&mm->mmap_sem)) + if (mm_write_lock_killable(mm, &mmrange)) return -EINTR; ret = do_mmap_pgoff(file, addr, len, prot, flag, pgoff, &populate, &uf); - up_write(&mm->mmap_sem); + mm_write_unlock(mm, &mmrange); userfaultfd_unmap_complete(mm, &uf); if (populate) mm_populate(ret, populate); @@ -711,18 +712,19 @@ int get_cmdline(struct task_struct *task, char *buffer, int buflen) int res = 0; unsigned int len; struct mm_struct *mm = get_task_mm(task); + DEFINE_RANGE_LOCK_FULL(mmrange); unsigned long arg_start, arg_end, env_start, env_end; if (!mm) goto out; if (!mm->arg_end) goto out_mm; /* Shh! No looking before we're done */ - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); arg_start = mm->arg_start; arg_end = mm->arg_end; env_start = mm->env_start; env_end = mm->env_end; - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); len = arg_end - arg_start; -- 2.16.4