On Thu, Jul 1, 2021 at 7:17 PM Kaiyang Zhao <zhao776@xxxxxxxxxx> wrote: > > In our research work [https://dl.acm.org/doi/10.1145/3447786.3456258], we > have identified a method that for large applications (i.e., a few hundred > MBs and larger), can significantly speed up the fork system call. Currently > the amount of time that the fork system call takes to complete is > proportional to the size of allocated memory of a process, and our design > speeds up fork invocation by up to 270x at 50GB in our experiments. > > The design is that instead of copying the entire paging tree during the > fork invocation, we make the child and the parent process share the same > set of last-level page tables, which will be reference counted. To preserve > the copy-on-write semantics, we disable the write permission in PMD entries > in fork, and copy PTE tables as needed in the page fault handler. Does application have options to choose between default fork() and new on demand fork() ? > > We tested a prototype with large workloads that call fork to take snapshots > such as fuzzers (e.g., AFL), and it yielded over 2x the execution > throughput for AFL. The patch is a prototype for x86 only and does not > support huge pages and swapping, and is meant to demonstrate the potential > performance gains to fork. Applications can opt-in by a switch use_odf in > procfs. > > On a side note, an approach that shares page tables was proposed by Dave > McCracken [http://lkml.iu.edu/hypermail/linux/kernel/0508.3/1623.html, > https://www.kernel.org/doc/ols/2006/ols2006v2-pages-125-130.pdf], but never > made it into the kernel. We believe that with the increasing memory > consumption of modern applications and modern use cases of fork such as > snapshotting, the shared page table approach in the context of fork is > worth exploring. > > Please let us know your level of interest in this or comments on the > general design. Thank you. > > Signed-off-by: Kaiyang Zhao <zhao776@xxxxxxxxxx> > --- > arch/x86/include/asm/pgtable.h | 19 +- > fs/proc/base.c | 74 ++++++ > include/linux/mm.h | 11 + > include/linux/mm_types.h | 2 + > include/linux/pgtable.h | 11 + > include/linux/sched/coredump.h | 5 +- > kernel/fork.c | 7 +- > mm/gup.c | 61 ++++- > mm/memory.c | 401 +++++++++++++++++++++++++++++++-- > mm/mmap.c | 91 +++++++- > mm/mprotect.c | 6 + > 11 files changed, 668 insertions(+), 20 deletions(-) > > diff --git a/arch/x86/include/asm/pgtable.h b/arch/x86/include/asm/pgtable.h > index b6c97b8f59ec..0fda05a5c7a1 100644 > --- a/arch/x86/include/asm/pgtable.h > +++ b/arch/x86/include/asm/pgtable.h > @@ -410,6 +410,16 @@ static inline pmd_t pmd_clear_flags(pmd_t pmd, pmdval_t clear) > return native_make_pmd(v & ~clear); > } > > +static inline pmd_t pmd_mknonpresent(pmd_t pmd) > +{ > + return pmd_clear_flags(pmd, _PAGE_PRESENT); > +} > + > +static inline pmd_t pmd_mkpresent(pmd_t pmd) > +{ > + return pmd_set_flags(pmd, _PAGE_PRESENT); > +} > + > #ifdef CONFIG_HAVE_ARCH_USERFAULTFD_WP > static inline int pmd_uffd_wp(pmd_t pmd) > { > @@ -798,6 +808,11 @@ static inline int pmd_present(pmd_t pmd) > return pmd_flags(pmd) & (_PAGE_PRESENT | _PAGE_PROTNONE | _PAGE_PSE); > } > > +static inline int pmd_iswrite(pmd_t pmd) > +{ > + return pmd_flags(pmd) & (_PAGE_RW); > +} > + > #ifdef CONFIG_NUMA_BALANCING > /* > * These work without NUMA balancing but the kernel does not care. See the > @@ -833,7 +848,7 @@ static inline unsigned long pmd_page_vaddr(pmd_t pmd) > * Currently stuck as a macro due to indirect forward reference to > * linux/mmzone.h's __section_mem_map_addr() definition: > */ > -#define pmd_page(pmd) pfn_to_page(pmd_pfn(pmd)) > +#define pmd_page(pmd) pfn_to_page(pmd_pfn(pmd_mkpresent(pmd))) > > /* > * Conversion functions: convert a page and protection to a page entry, > @@ -846,7 +861,7 @@ static inline unsigned long pmd_page_vaddr(pmd_t pmd) > > static inline int pmd_bad(pmd_t pmd) > { > - return (pmd_flags(pmd) & ~_PAGE_USER) != _KERNPG_TABLE; > + return ((pmd_flags(pmd) & ~(_PAGE_USER)) | (_PAGE_RW | _PAGE_PRESENT)) != _KERNPG_TABLE; > } > > static inline unsigned long pages_to_mb(unsigned long npg) > diff --git a/fs/proc/base.c b/fs/proc/base.c > index e5b5f7709d48..936f33594539 100644 > --- a/fs/proc/base.c > +++ b/fs/proc/base.c > @@ -2935,6 +2935,79 @@ static const struct file_operations proc_coredump_filter_operations = { > }; > #endif > > +static ssize_t proc_use_odf_read(struct file *file, char __user *buf, > + size_t count, loff_t *ppos) > +{ > + struct task_struct *task = get_proc_task(file_inode(file)); > + struct mm_struct *mm; > + char buffer[PROC_NUMBUF]; > + size_t len; > + int ret; > + > + if (!task) > + return -ESRCH; > + > + ret = 0; > + mm = get_task_mm(task); > + if (mm) { > + len = snprintf(buffer, sizeof(buffer), "%lu\n", > + ((mm->flags & MMF_USE_ODF_MASK) >> MMF_USE_ODF)); > + mmput(mm); > + ret = simple_read_from_buffer(buf, count, ppos, buffer, len); > + } > + > + put_task_struct(task); > + > + return ret; > +} > + > +static ssize_t proc_use_odf_write(struct file *file, > + const char __user *buf, > + size_t count, > + loff_t *ppos) > +{ > + struct task_struct *task; > + struct mm_struct *mm; > + unsigned int val; > + int ret; > + > + ret = kstrtouint_from_user(buf, count, 0, &val); > + if (ret < 0) > + return ret; > + > + ret = -ESRCH; > + task = get_proc_task(file_inode(file)); > + if (!task) > + goto out_no_task; > + > + mm = get_task_mm(task); > + if (!mm) > + goto out_no_mm; > + ret = 0; > + > + if (val == 1) { > + set_bit(MMF_USE_ODF, &mm->flags); > + } else if (val == 0) { > + clear_bit(MMF_USE_ODF, &mm->flags); > + } else { > + //ignore > + } > + > + mmput(mm); > + out_no_mm: > + put_task_struct(task); > + out_no_task: > + if (ret < 0) > + return ret; > + return count; > +} > + > +static const struct file_operations proc_use_odf_operations = { > + .read = proc_use_odf_read, > + .write = proc_use_odf_write, > + .llseek = generic_file_llseek, > +}; > + > #ifdef CONFIG_TASK_IO_ACCOUNTING > static int do_io_accounting(struct task_struct *task, struct seq_file *m, int whole) > { > @@ -3253,6 +3326,7 @@ static const struct pid_entry tgid_base_stuff[] = { > #ifdef CONFIG_ELF_CORE > REG("coredump_filter", S_IRUGO|S_IWUSR, proc_coredump_filter_operations), > #endif > + REG("use_odf", S_IRUGO|S_IWUSR, proc_use_odf_operations), > #ifdef CONFIG_TASK_IO_ACCOUNTING > ONE("io", S_IRUSR, proc_tgid_io_accounting), > #endif > diff --git a/include/linux/mm.h b/include/linux/mm.h > index 57453dba41b9..a30eca9e236a 100644 > --- a/include/linux/mm.h > +++ b/include/linux/mm.h > @@ -664,6 +664,7 @@ static inline void vma_init(struct vm_area_struct *vma, struct mm_struct *mm) > memset(vma, 0, sizeof(*vma)); > vma->vm_mm = mm; > vma->vm_ops = &dummy_vm_ops; > + vma->pte_table_counter_pending = true; > INIT_LIST_HEAD(&vma->anon_vma_chain); > } > > @@ -2250,6 +2251,9 @@ static inline bool pgtable_pte_page_ctor(struct page *page) > return false; > __SetPageTable(page); > inc_lruvec_page_state(page, NR_PAGETABLE); > + > + atomic64_set(&(page->pte_table_refcount), 0); > + > return true; > } > > @@ -2276,6 +2280,8 @@ static inline void pgtable_pte_page_dtor(struct page *page) > > #define pte_alloc(mm, pmd) (unlikely(pmd_none(*(pmd))) && __pte_alloc(mm, pmd)) > > +#define tfork_pte_alloc(mm, pmd) (__tfork_pte_alloc(mm, pmd)) > + > #define pte_alloc_map(mm, pmd, address) \ > (pte_alloc(mm, pmd) ? NULL : pte_offset_map(pmd, address)) > > @@ -2283,6 +2289,10 @@ static inline void pgtable_pte_page_dtor(struct page *page) > (pte_alloc(mm, pmd) ? \ > NULL : pte_offset_map_lock(mm, pmd, address, ptlp)) > > +#define tfork_pte_alloc_map_lock(mm, pmd, address, ptlp) \ > + (tfork_pte_alloc(mm, pmd) ? \ > + NULL : pte_offset_map_lock(mm, pmd, address, ptlp)) > + > #define pte_alloc_kernel(pmd, address) \ > ((unlikely(pmd_none(*(pmd))) && __pte_alloc_kernel(pmd))? \ > NULL: pte_offset_kernel(pmd, address)) > @@ -2616,6 +2626,7 @@ extern int do_madvise(struct mm_struct *mm, unsigned long start, size_t len_in, > #ifdef CONFIG_MMU > extern int __mm_populate(unsigned long addr, unsigned long len, > int ignore_errors); > +extern int __mm_populate_nolock(unsigned long addr, unsigned long len, int ignore_errors); > static inline void mm_populate(unsigned long addr, unsigned long len) > { > /* Ignore errors */ > diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h > index f37abb2d222e..e06c677ce279 100644 > --- a/include/linux/mm_types.h > +++ b/include/linux/mm_types.h > @@ -158,6 +158,7 @@ struct page { > union { > struct mm_struct *pt_mm; /* x86 pgds only */ > atomic_t pt_frag_refcount; /* powerpc */ > + atomic64_t pte_table_refcount; > }; > #if USE_SPLIT_PTE_PTLOCKS > #if ALLOC_SPLIT_PTLOCKS > @@ -379,6 +380,7 @@ struct vm_area_struct { > struct mempolicy *vm_policy; /* NUMA policy for the VMA */ > #endif > struct vm_userfaultfd_ctx vm_userfaultfd_ctx; > + bool pte_table_counter_pending; > } __randomize_layout; > > struct core_thread { > diff --git a/include/linux/pgtable.h b/include/linux/pgtable.h > index d147480cdefc..6afd77ff82e6 100644 > --- a/include/linux/pgtable.h > +++ b/include/linux/pgtable.h > @@ -90,6 +90,11 @@ static inline pte_t *pte_offset_kernel(pmd_t *pmd, unsigned long address) > return (pte_t *)pmd_page_vaddr(*pmd) + pte_index(address); > } > #define pte_offset_kernel pte_offset_kernel > +static inline pte_t *tfork_pte_offset_kernel(pmd_t pmd_val, unsigned long address) > +{ > + return (pte_t *)pmd_page_vaddr(pmd_val) + pte_index(address); > +} > +#define tfork_pte_offset_kernel tfork_pte_offset_kernel > #endif > > #if defined(CONFIG_HIGHPTE) > @@ -782,6 +787,12 @@ static inline void arch_swap_restore(swp_entry_t entry, struct page *page) > }) > #endif > > +#define pte_table_start(addr) \ > +(addr & PMD_MASK) > + > +#define pte_table_end(addr) \ > +(((addr) + PMD_SIZE) & PMD_MASK) > + > /* > * When walking page tables, we usually want to skip any p?d_none entries; > * and any p?d_bad entries - reporting the error before resetting to none. > diff --git a/include/linux/sched/coredump.h b/include/linux/sched/coredump.h > index 4d9e3a656875..8f6e50bc04ab 100644 > --- a/include/linux/sched/coredump.h > +++ b/include/linux/sched/coredump.h > @@ -83,7 +83,10 @@ static inline int get_dumpable(struct mm_struct *mm) > #define MMF_HAS_PINNED 28 /* FOLL_PIN has run, never cleared */ > #define MMF_DISABLE_THP_MASK (1 << MMF_DISABLE_THP) > > +#define MMF_USE_ODF 29 > +#define MMF_USE_ODF_MASK (1 << MMF_USE_ODF) > + > #define MMF_INIT_MASK (MMF_DUMPABLE_MASK | MMF_DUMP_FILTER_MASK |\ > - MMF_DISABLE_THP_MASK) > + MMF_DISABLE_THP_MASK | MMF_USE_ODF_MASK) > > #endif /* _LINUX_SCHED_COREDUMP_H */ > diff --git a/kernel/fork.c b/kernel/fork.c > index d738aae40f9e..4f21ea4f4f38 100644 > --- a/kernel/fork.c > +++ b/kernel/fork.c > @@ -594,8 +594,13 @@ static __latent_entropy int dup_mmap(struct mm_struct *mm, > rb_parent = &tmp->vm_rb; > > mm->map_count++; > - if (!(tmp->vm_flags & VM_WIPEONFORK)) > + if (!(tmp->vm_flags & VM_WIPEONFORK)) { > retval = copy_page_range(tmp, mpnt); > + if (oldmm->flags & MMF_USE_ODF_MASK) { > + tmp->pte_table_counter_pending = false; // reference of the shared PTE table by the new VMA is counted in copy_pmd_range_tfork > + mpnt->pte_table_counter_pending = false; // don't double count when forking again > + } > + } > > if (tmp->vm_ops && tmp->vm_ops->open) > tmp->vm_ops->open(tmp); > diff --git a/mm/gup.c b/mm/gup.c > index 42b8b1fa6521..5768f339b0ff 100644 > --- a/mm/gup.c > +++ b/mm/gup.c > @@ -1489,8 +1489,11 @@ long populate_vma_page_range(struct vm_area_struct *vma, > * to break COW, except for shared mappings because these don't COW > * and we would not want to dirty them for nothing. > */ > - if ((vma->vm_flags & (VM_WRITE | VM_SHARED)) == VM_WRITE) > - gup_flags |= FOLL_WRITE; > + if ((vma->vm_flags & (VM_WRITE | VM_SHARED)) == VM_WRITE) { > + if (!(mm->flags & MMF_USE_ODF_MASK)) { //for ODF processes, only allocate page tables > + gup_flags |= FOLL_WRITE; > + } > + } > > /* > * We want mlock to succeed for regions that have any permissions > @@ -1669,6 +1672,60 @@ static long __get_user_pages_locked(struct mm_struct *mm, unsigned long start, > } > #endif /* !CONFIG_MMU */ > > +int __mm_populate_nolock(unsigned long start, unsigned long len, int ignore_errors) > +{ > + struct mm_struct *mm = current->mm; > + unsigned long end, nstart, nend; > + struct vm_area_struct *vma = NULL; > + int locked = 0; > + long ret = 0; > + > + end = start + len; > + > + for (nstart = start; nstart < end; nstart = nend) { > + /* > + * We want to fault in pages for [nstart; end) address range. > + * Find first corresponding VMA. > + */ > + if (!locked) { > + locked = 1; > + //down_read(&mm->mmap_sem); > + vma = find_vma(mm, nstart); > + } else if (nstart >= vma->vm_end) > + vma = vma->vm_next; > + if (!vma || vma->vm_start >= end) > + break; > + /* > + * Set [nstart; nend) to intersection of desired address > + * range with the first VMA. Also, skip undesirable VMA types. > + */ > + nend = min(end, vma->vm_end); > + if (vma->vm_flags & (VM_IO | VM_PFNMAP)) > + continue; > + if (nstart < vma->vm_start) > + nstart = vma->vm_start; > + /* > + * Now fault in a range of pages. populate_vma_page_range() > + * double checks the vma flags, so that it won't mlock pages > + * if the vma was already munlocked. > + */ > + ret = populate_vma_page_range(vma, nstart, nend, &locked); > + if (ret < 0) { > + if (ignore_errors) { > + ret = 0; > + continue; /* continue at next VMA */ > + } > + break; > + } > + nend = nstart + ret * PAGE_SIZE; > + ret = 0; > + } > + /*if (locked) > + up_read(&mm->mmap_sem); > + */ > + return ret; /* 0 or negative error code */ > +} > + > /** > * get_dump_page() - pin user page in memory while writing it to core dump > * @addr: user address > diff --git a/mm/memory.c b/mm/memory.c > index db86558791f1..2b28766e4213 100644 > --- a/mm/memory.c > +++ b/mm/memory.c > @@ -83,6 +83,9 @@ > #include <asm/tlb.h> > #include <asm/tlbflush.h> > > +static bool tfork_one_pte_table(struct mm_struct *, pmd_t *, unsigned long, unsigned long); > +static inline void init_rss_vec(int *rss); > +static inline void add_mm_rss_vec(struct mm_struct *mm, int *rss); > #include "pgalloc-track.h" > #include "internal.h" > > @@ -227,7 +230,16 @@ static void free_pte_range(struct mmu_gather *tlb, pmd_t *pmd, > unsigned long addr) > { > pgtable_t token = pmd_pgtable(*pmd); > + long counter; > pmd_clear(pmd); > + counter = atomic64_read(&(token->pte_table_refcount)); > + if (counter > 0) { > + //the pte table can only be shared in this case > +#ifdef CONFIG_DEBUG_VM > + printk("free_pte_range: addr=%lx, counter=%ld, not freeing table", addr, counter); > +#endif > + return; //pte table is still in use > + } > pte_free_tlb(tlb, token, addr); > mm_dec_nr_ptes(tlb->mm); > } > @@ -433,6 +445,118 @@ void free_pgtables(struct mmu_gather *tlb, struct vm_area_struct *vma, > } > } > > +// frees every page described by the pte table > +void zap_one_pte_table(pmd_t pmd_val, unsigned long addr, struct mm_struct *mm) > +{ > + int rss[NR_MM_COUNTERS]; > + pte_t *pte; > + unsigned long end; > + > + init_rss_vec(rss); > + addr = pte_table_start(addr); > + end = pte_table_end(addr); > + pte = tfork_pte_offset_kernel(pmd_val, addr); > + do { > + pte_t ptent = *pte; > + > + if (pte_none(ptent)) > + continue; > + > + if (pte_present(ptent)) { > + struct page *page; > + > + if (pte_special(ptent)) { //known special pte: vvar VMA, which has just one page shared system-wide. Shouldn't matter > + continue; > + } > + page = vm_normal_page(NULL, addr, ptent); //kyz : vma is not important > + if (unlikely(!page)) > + continue; > + rss[mm_counter(page)]--; > +#ifdef CONFIG_DEBUG_VM > + // printk("zap_one_pte_table: addr=%lx, end=%lx, (before) mapcount=%d, refcount=%d\n", addr, end, page_mapcount(page), page_ref_count(page)); > +#endif > + page_remove_rmap(page, false); > + put_page(page); > + } > + } while (pte++, addr += PAGE_SIZE, addr != end); > + > + add_mm_rss_vec(mm, rss); > +} > + > +/* pmd lock should be held > + * returns 1 if the table becomes unused > + */ > +int dereference_pte_table(pmd_t pmd_val, bool free_table, struct mm_struct *mm, unsigned long addr) > +{ > + struct page *table_page; > + > + table_page = pmd_page(pmd_val); > + > + if (atomic64_dec_and_test(&(table_page->pte_table_refcount))) { > +#ifdef CONFIG_DEBUG_VM > + printk("dereference_pte_table: addr=%lx, free_table=%d, pte table reached end of life\n", addr, free_table); > +#endif > + > + zap_one_pte_table(pmd_val, addr, mm); > + if (free_table) { > + pgtable_pte_page_dtor(table_page); > + __free_page(table_page); > + mm_dec_nr_ptes(mm); > + } > + return 1; > + } else { > +#ifdef CONFIG_DEBUG_VM > + printk("dereference_pte_table: addr=%lx, (after) pte_table_count=%lld\n", addr, atomic64_read(&(table_page->pte_table_refcount))); > +#endif > + } > + return 0; > +} > + > +int dereference_pte_table_multiple(pmd_t pmd_val, bool free_table, struct mm_struct *mm, unsigned long addr, int num) > +{ > + struct page *table_page; > + int count_after; > + > + table_page = pmd_page(pmd_val); > + count_after = atomic64_sub_return(num, &(table_page->pte_table_refcount)); > + if (count_after <= 0) { > +#ifdef CONFIG_DEBUG_VM > + printk("dereference_pte_table_multiple: addr=%lx, free_table=%d, num=%d, after count=%d, table reached end of life\n", addr, free_table, num, count_after); > +#endif > + > + zap_one_pte_table(pmd_val, addr, mm); > + if (free_table) { > + pgtable_pte_page_dtor(table_page); > + __free_page(table_page); > + mm_dec_nr_ptes(mm); > + } > + return 1; > + } else { > +#ifdef CONFIG_DEBUG_VM > + printk("dereference_pte_table_multiple: addr=%lx, num=%d, (after) count=%lld\n", addr, num, atomic64_read(&(table_page->pte_table_refcount))); > +#endif > + } > + return 0; > +} > + > +int __tfork_pte_alloc(struct mm_struct *mm, pmd_t *pmd) > +{ > + pgtable_t new = pte_alloc_one(mm); > + > + if (!new) > + return -ENOMEM; > + smp_wmb(); /* Could be smp_wmb__xxx(before|after)_spin_lock */ > + > + mm_inc_nr_ptes(mm); > + //kyz: won't check if the pte table already exists > + pmd_populate(mm, pmd, new); > + new = NULL; > + if (new) > + pte_free(mm, new); > + return 0; > +} > + > + > int __pte_alloc(struct mm_struct *mm, pmd_t *pmd) > { > spinlock_t *ptl; > @@ -928,6 +1052,45 @@ copy_present_page(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma > return 0; > } > > +static inline unsigned long > +copy_one_pte_tfork(struct mm_struct *dst_mm, > + pte_t *dst_pte, pte_t *src_pte, struct vm_area_struct *vma, > + unsigned long addr, int *rss) > +{ > + unsigned long vm_flags = vma->vm_flags; > + pte_t pte = *src_pte; > + struct page *page; > + > + /* > + * If it's a COW mapping > + * only protect in the child (the faulting process) > + */ > + if (is_cow_mapping(vm_flags) && pte_write(pte)) { > + pte = pte_wrprotect(pte); > + } > + > + /* > + * If it's a shared mapping, mark it clean in > + * the child > + */ > + if (vm_flags & VM_SHARED) > + pte = pte_mkclean(pte); > + pte = pte_mkold(pte); > + > + page = vm_normal_page(vma, addr, pte); > + if (page) { > + get_page(page); > + page_dup_rmap(page, false); > + rss[mm_counter(page)]++; > +#ifdef CONFIG_DEBUG_VM > +// printk("copy_one_pte_tfork: addr=%lx, (after) mapcount=%d, refcount=%d\n", addr, page_mapcount(page), page_ref_count(page)); > +#endif > + } > + > + set_pte_at(dst_mm, addr, dst_pte, pte); > + return 0; > +} > + > /* > * Copy one pte. Returns 0 if succeeded, or -EAGAIN if one preallocated page > * is required to copy this pte. > @@ -999,6 +1162,59 @@ page_copy_prealloc(struct mm_struct *src_mm, struct vm_area_struct *vma, > return new_page; > } > > +static int copy_pte_range_tfork(struct mm_struct *dst_mm, > + pmd_t *dst_pmd, pmd_t src_pmd_val, struct vm_area_struct *vma, > + unsigned long addr, unsigned long end) > +{ > + pte_t *orig_src_pte, *orig_dst_pte; > + pte_t *src_pte, *dst_pte; > + spinlock_t *dst_ptl; > + int rss[NR_MM_COUNTERS]; > + swp_entry_t entry = (swp_entry_t){0}; > + struct page *dst_pte_page; > + > + init_rss_vec(rss); > + > + src_pte = tfork_pte_offset_kernel(src_pmd_val, addr); //src_pte points to the old table > + if (!pmd_iswrite(*dst_pmd)) { > + dst_pte = tfork_pte_alloc_map_lock(dst_mm, dst_pmd, addr, &dst_ptl); //dst_pte points to a new table > +#ifdef CONFIG_DEBUG_VM > + printk("copy_pte_range_tfork: allocated new table. addr=%lx, prev_table_page=%px, table_page=%px\n", addr, pmd_page(src_pmd_val), pmd_page(*dst_pmd)); > +#endif > + } else { > + dst_pte = pte_alloc_map_lock(dst_mm, dst_pmd, addr, &dst_ptl); > + } > + if (!dst_pte) > + return -ENOMEM; > + > + dst_pte_page = pmd_page(*dst_pmd); > + atomic64_inc(&(dst_pte_page->pte_table_refcount)); //kyz: associates the VMA with the new table > +#ifdef CONFIG_DEBUG_VM > + printk("copy_pte_range_tfork: addr = %lx, end = %lx, new pte table counter (after)=%lld\n", addr, end, atomic64_read(&(dst_pte_page->pte_table_refcount))); > +#endif > + > + orig_src_pte = src_pte; > + orig_dst_pte = dst_pte; > + arch_enter_lazy_mmu_mode(); > + > + do { > + if (pte_none(*src_pte)) { > + continue; > + } > + entry.val = copy_one_pte_tfork(dst_mm, dst_pte, src_pte, > + vma, addr, rss); > + if (entry.val) > + printk("kyz: failed copy_one_pte_tfork call\n"); > + } while (dst_pte++, src_pte++, addr += PAGE_SIZE, addr != end); > + > + arch_leave_lazy_mmu_mode(); > + pte_unmap(orig_src_pte); > + add_mm_rss_vec(dst_mm, rss); > + pte_unmap_unlock(orig_dst_pte, dst_ptl); > + > + return 0; > +} > + > static int > copy_pte_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma, > pmd_t *dst_pmd, pmd_t *src_pmd, unsigned long addr, > @@ -1130,8 +1346,9 @@ copy_pmd_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma, > { > struct mm_struct *dst_mm = dst_vma->vm_mm; > struct mm_struct *src_mm = src_vma->vm_mm; > - pmd_t *src_pmd, *dst_pmd; > + pmd_t *src_pmd, *dst_pmd, src_pmd_value; > unsigned long next; > + struct page *table_page; > > dst_pmd = pmd_alloc(dst_mm, dst_pud, addr); > if (!dst_pmd) > @@ -1153,9 +1370,43 @@ copy_pmd_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma, > } > if (pmd_none_or_clear_bad(src_pmd)) > continue; > - if (copy_pte_range(dst_vma, src_vma, dst_pmd, src_pmd, > - addr, next)) > - return -ENOMEM; > + if (src_mm->flags & MMF_USE_ODF_MASK) { > +#ifdef CONFIG_DEBUG_VM > + printk("copy_pmd_range: vm_start=%lx, addr=%lx, vm_end=%lx, end=%lx\n", src_vma->vm_start, addr, src_vma->vm_end, end); > +#endif > + > + src_pmd_value = *src_pmd; > + //kyz: sets write-protect to the pmd entry if the vma is writable > + if (src_vma->vm_flags & VM_WRITE) { > + src_pmd_value = pmd_wrprotect(src_pmd_value); > + set_pmd_at(src_mm, addr, src_pmd, src_pmd_value); > + } > + table_page = pmd_page(*src_pmd); > + if (src_vma->pte_table_counter_pending) { // kyz : the old VMA hasn't been counted in the PTE table, count it now > + atomic64_add(2, &(table_page->pte_table_refcount)); > +#ifdef CONFIG_DEBUG_VM > + printk("copy_pmd_range: addr=%lx, pte table counter (after counting old&new)=%lld\n", addr, atomic64_read(&(table_page->pte_table_refcount))); > +#endif > + } else { > + atomic64_inc(&(table_page->pte_table_refcount)); //increments the pte table counter > + if (atomic64_read(&(table_page->pte_table_refcount)) == 1) { //the VMA is old, but the pte table is new (created by a fault after the last odf call) > + atomic64_set(&(table_page->pte_table_refcount), 2); > +#ifdef CONFIG_DEBUG_VM > + printk("copy_pmd_range: addr=%lx, pte table counter (old VMA, new pte table)=%lld\n", addr, atomic64_read(&(table_page->pte_table_refcount))); > +#endif > + } > +#ifdef CONFIG_DEBUG_VM > + else { > + printk("copy_pmd_range: addr=%lx, pte table counter (after counting new)=%lld\n", addr, atomic64_read(&(table_page->pte_table_refcount))); > + } > +#endif > + } > + set_pmd_at(dst_mm, addr, dst_pmd, src_pmd_value); //shares the table with the child > + } else { > + if (copy_pte_range(dst_vma, src_vma, dst_pmd, src_pmd, > + addr, next)) > + return -ENOMEM; > + } > } while (dst_pmd++, src_pmd++, addr = next, addr != end); > return 0; > } > @@ -1240,9 +1491,10 @@ copy_page_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma) > * readonly mappings. The tradeoff is that copy_page_range is more > * efficient than faulting. > */ > - if (!(src_vma->vm_flags & (VM_HUGETLB | VM_PFNMAP | VM_MIXEDMAP)) && > +/* if (!(src_vma->vm_flags & (VM_HUGETLB | VM_PFNMAP | VM_MIXEDMAP)) && > !src_vma->anon_vma) > return 0; > +*/ > > if (is_vm_hugetlb_page(src_vma)) > return copy_hugetlb_page_range(dst_mm, src_mm, src_vma); > @@ -1304,7 +1556,7 @@ copy_page_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma) > static unsigned long zap_pte_range(struct mmu_gather *tlb, > struct vm_area_struct *vma, pmd_t *pmd, > unsigned long addr, unsigned long end, > - struct zap_details *details) > + struct zap_details *details, bool invalidate_pmd) > { > struct mm_struct *mm = tlb->mm; > int force_flush = 0; > @@ -1343,8 +1595,10 @@ static unsigned long zap_pte_range(struct mmu_gather *tlb, > details->check_mapping != page_rmapping(page)) > continue; > } > - ptent = ptep_get_and_clear_full(mm, addr, pte, > - tlb->fullmm); > + if (!invalidate_pmd) { > + ptent = ptep_get_and_clear_full(mm, addr, pte, > + tlb->fullmm); > + } > tlb_remove_tlb_entry(tlb, pte, addr); > if (unlikely(!page)) > continue; > @@ -1358,8 +1612,12 @@ static unsigned long zap_pte_range(struct mmu_gather *tlb, > likely(!(vma->vm_flags & VM_SEQ_READ))) > mark_page_accessed(page); > } > - rss[mm_counter(page)]--; > - page_remove_rmap(page, false); > + if (!invalidate_pmd) { > + rss[mm_counter(page)]--; > + page_remove_rmap(page, false); > + } else { > + continue; > + } > if (unlikely(page_mapcount(page) < 0)) > print_bad_pte(vma, addr, ptent, page); > if (unlikely(__tlb_remove_page(tlb, page))) { > @@ -1446,12 +1704,16 @@ static inline unsigned long zap_pmd_range(struct mmu_gather *tlb, > struct zap_details *details) > { > pmd_t *pmd; > - unsigned long next; > + unsigned long next, table_start, table_end; > + spinlock_t *ptl; > + struct page *table_page; > + bool got_new_table = false; > > pmd = pmd_offset(pud, addr); > do { > + ptl = pmd_lock(vma->vm_mm, pmd); > next = pmd_addr_end(addr, end); > - if (is_swap_pmd(*pmd) || pmd_trans_huge(*pmd) || pmd_devmap(*pmd)) { > + if (pmd_trans_huge(*pmd) || pmd_devmap(*pmd)) { > if (next - addr != HPAGE_PMD_SIZE) > __split_huge_pmd(vma, pmd, addr, false, NULL); > else if (zap_huge_pmd(tlb, vma, pmd, addr)) > @@ -1478,8 +1740,49 @@ static inline unsigned long zap_pmd_range(struct mmu_gather *tlb, > */ > if (pmd_none_or_trans_huge_or_clear_bad(pmd)) > goto next; > - next = zap_pte_range(tlb, vma, pmd, addr, next, details); > + //kyz: copy if the pte table is shared and VMA does not cover fully the 2MB region > + table_page = pmd_page(*pmd); > + table_start = pte_table_start(addr); > + > + if ((!pmd_iswrite(*pmd)) && (!vma->pte_table_counter_pending)) {//shared pte table. vma has gone through odf > + table_end = pte_table_end(addr); > + if (table_start < vma->vm_start || table_end > vma->vm_end) { > +#ifdef CONFIG_DEBUG_VM > + printk("%s: addr=%lx, end=%lx, table_start=%lx, table_end=%lx, copy then zap\n", __func__, addr, end, table_start, table_end); > +#endif > + if (dereference_pte_table(*pmd, false, vma->vm_mm, addr) != 1) { //dec the counter of the shared table. tfork_one_pte_table cannot find the current VMA (which is being unmapped) > + got_new_table = tfork_one_pte_table(vma->vm_mm, pmd, addr, vma->vm_end); > + if (got_new_table) { > + next = zap_pte_range(tlb, vma, pmd, addr, next, details, false); > + } else { > +#ifdef CONFIG_DEBUG_VM > + printk("zap_pmd_range: no more VMAs in this process are using the table, but there are other processes using it\n"); > +#endif > + pmd_clear(pmd); > + } > + } else { > +#ifdef CONFIG_DEBUG_VM > + printk("zap_pmd_range: the shared table is dead. NOT copying after all.\n"); > +#endif > + // the shared table will be freed by unmap_single_vma() > + } > + } else { > +#ifdef CONFIG_DEBUG_VM > + printk("%s: addr=%lx, end=%lx, table_start=%lx, table_end=%lx, zap while preserving pte entries\n", __func__, addr, end, table_start, table_end); > +#endif > + //kyz: shared and fully covered by the VMA, preserve the pte entries > + next = zap_pte_range(tlb, vma, pmd, addr, next, details, true); > + dereference_pte_table(*pmd, true, vma->vm_mm, addr); > + pmd_clear(pmd); > + } > + } else { > + next = zap_pte_range(tlb, vma, pmd, addr, next, details, false); > + if (!vma->pte_table_counter_pending) { > + atomic64_dec(&(table_page->pte_table_refcount)); > + } > + } > next: > + spin_unlock(ptl); > cond_resched(); > } while (pmd++, addr = next, addr != end); > > @@ -4476,6 +4779,66 @@ static vm_fault_t wp_huge_pud(struct vm_fault *vmf, pud_t orig_pud) > return VM_FAULT_FALLBACK; > } > > +/* kyz: Handles an entire pte-level page table, covering multiple VMAs (if they exist) > + * Returns true if a new table is put in place, false otherwise. > + * if exclude is not 0, the vma that covers addr to exclude will not be copied > + */ > +static bool tfork_one_pte_table(struct mm_struct *mm, pmd_t *dst_pmd, unsigned long addr, unsigned long exclude) > +{ > + unsigned long table_end, end, orig_addr; > + struct vm_area_struct *vma; > + pmd_t orig_pmd_val; > + bool copied = false; > + struct page *orig_pte_page; > + int num_vmas = 0; > + > + if (!pmd_none(*dst_pmd)) { > + orig_pmd_val = *dst_pmd; > + } else { > + BUG(); > + } > + > + //kyz: Starts from the beginning of the range covered by the table > + orig_addr = addr; > + table_end = pte_table_end(addr); > + addr = pte_table_start(addr); > +#ifdef CONFIG_DEBUG_VM > + orig_pte_page = pmd_page(orig_pmd_val); > + printk("tfork_one_pte_table: shared pte table counter=%lld, Covered Range: start=%lx, end=%lx\n", atomic64_read(&(orig_pte_page->pte_table_refcount)), addr, table_end); > +#endif > + do { > + vma = find_vma(mm, addr); > + if (!vma) { > + break; //inexplicable > + } > + if (vma->vm_start >= table_end) { > + break; > + } > + end = pmd_addr_end(addr, vma->vm_end); > + if (vma->pte_table_counter_pending) { //this vma is newly mapped (clean) and (fully/partly) described by this pte table > + addr = end; > + continue; > + } > + if (vma->vm_start > addr) { > + addr = vma->vm_start; > + } > + if (exclude > 0 && vma->vm_start <= orig_addr && vma->vm_end >= exclude) { > + addr = end; > + continue; > + } > +#ifdef CONFIG_DEBUG_VM > + printk("tfork_one_pte_table: vm_start=%lx, vm_end=%lx\n", vma->vm_start, vma->vm_end); > +#endif > + num_vmas++; > + copy_pte_range_tfork(mm, dst_pmd, orig_pmd_val, vma, addr, end); > + copied = true; > + addr = end; > + } while (addr < table_end); > + > + dereference_pte_table_multiple(orig_pmd_val, true, mm, orig_addr, num_vmas); > + return copied; > +} > + > /* > * These routines also need to handle stuff like marking pages dirty > * and/or accessed for architectures that don't do it in hardware (most > @@ -4610,6 +4973,7 @@ static vm_fault_t __handle_mm_fault(struct vm_area_struct *vma, > pgd_t *pgd; > p4d_t *p4d; > vm_fault_t ret; > + spinlock_t *ptl; > > pgd = pgd_offset(mm, address); > p4d = p4d_alloc(mm, pgd, address); > @@ -4659,6 +5023,7 @@ static vm_fault_t __handle_mm_fault(struct vm_area_struct *vma, > vmf.orig_pmd = *vmf.pmd; > > barrier(); > + /* > if (unlikely(is_swap_pmd(vmf.orig_pmd))) { > VM_BUG_ON(thp_migration_supported() && > !is_pmd_migration_entry(vmf.orig_pmd)); > @@ -4666,6 +5031,7 @@ static vm_fault_t __handle_mm_fault(struct vm_area_struct *vma, > pmd_migration_entry_wait(mm, vmf.pmd); > return 0; > } > + */ > if (pmd_trans_huge(vmf.orig_pmd) || pmd_devmap(vmf.orig_pmd)) { > if (pmd_protnone(vmf.orig_pmd) && vma_is_accessible(vma)) > return do_huge_pmd_numa_page(&vmf); > @@ -4679,6 +5045,15 @@ static vm_fault_t __handle_mm_fault(struct vm_area_struct *vma, > return 0; > } > } > + //kyz: checks if the pmd entry prohibits writes > + if ((!pmd_none(vmf.orig_pmd)) && (!pmd_iswrite(vmf.orig_pmd)) && (vma->vm_flags & VM_WRITE)) { > +#ifdef CONFIG_DEBUG_VM > + printk("__handle_mm_fault: PID=%d, addr=%lx\n", current->pid, address); > +#endif > + ptl = pmd_lock(mm, vmf.pmd); > + tfork_one_pte_table(mm, vmf.pmd, vmf.address, 0u); > + spin_unlock(ptl); > + } > } > > return handle_pte_fault(&vmf); > diff --git a/mm/mmap.c b/mm/mmap.c > index ca54d36d203a..308d86cfe544 100644 > --- a/mm/mmap.c > +++ b/mm/mmap.c > @@ -47,6 +47,7 @@ > #include <linux/pkeys.h> > #include <linux/oom.h> > #include <linux/sched/mm.h> > +#include <linux/pagewalk.h> > > #include <linux/uaccess.h> > #include <asm/cacheflush.h> > @@ -276,6 +277,9 @@ SYSCALL_DEFINE1(brk, unsigned long, brk) > > success: > populate = newbrk > oldbrk && (mm->def_flags & VM_LOCKED) != 0; > + if (mm->flags & MMF_USE_ODF_MASK) { //for ODF > + populate = true; > + } > if (downgraded) > mmap_read_unlock(mm); > else > @@ -1115,6 +1119,50 @@ can_vma_merge_after(struct vm_area_struct *vma, unsigned long vm_flags, > return 0; > } > > +static int pgtable_counter_fixup_pmd_entry(pmd_t *pmd, unsigned long addr, > + unsigned long next, struct mm_walk *walk) > +{ > + struct page *table_page; > + > + table_page = pmd_page(*pmd); > + atomic64_inc(&(table_page->pte_table_refcount)); > + > +#ifdef CONFIG_DEBUG_VM > + printk("fixup inc: addr=%lx\n", addr); > +#endif > + > + walk->action = ACTION_CONTINUE; //skip pte level > + return 0; > +} > + > +static int pgtable_counter_fixup_test(unsigned long addr, unsigned long next, > + struct mm_walk *walk) > +{ > + return 0; > +} > + > +static const struct mm_walk_ops pgtable_counter_fixup_walk_ops = { > +.pmd_entry = pgtable_counter_fixup_pmd_entry, > +.test_walk = pgtable_counter_fixup_test > +}; > + > +int merge_vma_pgtable_counter_fixup(struct vm_area_struct *vma, unsigned long start, unsigned long end) > +{ > + if (vma->pte_table_counter_pending) { > + return 0; > + } else { > +#ifdef CONFIG_DEBUG_VM > + printk("merge fixup: vm_start=%lx, vm_end=%lx, inc start=%lx, inc end=%lx\n", vma->vm_start, vma->vm_end, start, end); > +#endif > + start = pte_table_end(start); > + end = pte_table_start(end); > + __mm_populate_nolock(start, end-start, 1); //popuate tables for extended address range so that we can increment counters > + walk_page_range(vma->vm_mm, start, end, &pgtable_counter_fixup_walk_ops, NULL); > + } > + > + return 0; > +} > + > /* > * Given a mapping request (addr,end,vm_flags,file,pgoff), figure out > * whether that can be merged with its predecessor or its successor. > @@ -1215,6 +1263,9 @@ struct vm_area_struct *vma_merge(struct mm_struct *mm, > if (err) > return NULL; > khugepaged_enter_vma_merge(prev, vm_flags); > + > + merge_vma_pgtable_counter_fixup(prev, addr, end); > + > return prev; > } > > @@ -1242,6 +1293,9 @@ struct vm_area_struct *vma_merge(struct mm_struct *mm, > if (err) > return NULL; > khugepaged_enter_vma_merge(area, vm_flags); > + > + merge_vma_pgtable_counter_fixup(area, addr, end); > + > return area; > } > > @@ -1584,8 +1638,15 @@ unsigned long do_mmap(struct file *file, unsigned long addr, > addr = mmap_region(file, addr, len, vm_flags, pgoff, uf); > if (!IS_ERR_VALUE(addr) && > ((vm_flags & VM_LOCKED) || > - (flags & (MAP_POPULATE | MAP_NONBLOCK)) == MAP_POPULATE)) > + (flags & (MAP_POPULATE | MAP_NONBLOCK)) == MAP_POPULATE || > + (mm->flags & MMF_USE_ODF_MASK))) { > +#ifdef CONFIG_DEBUG_VM > + if (mm->flags & MMF_USE_ODF_MASK) { > + printk("mmap: force populate, addr=%lx, len=%lx\n", addr, len); > + } > +#endif > *populate = len; > + } > return addr; > } > > @@ -2799,6 +2860,31 @@ int split_vma(struct mm_struct *mm, struct vm_area_struct *vma, > return __split_vma(mm, vma, addr, new_below); > } > > +/* left and right vma after the split, address of split */ > +int split_vma_pgtable_counter_fixup(struct vm_area_struct *lvma, struct vm_area_struct *rvma, bool orig_pending_flag) > +{ > + if (orig_pending_flag) { > + return 0; //the new vma will have pending flag as true by default, just as the old vma > + } else { > +#ifdef CONFIG_DEBUG_VM > + printk("split fixup: set vma flag to false, rvma_start=%lx\n", rvma->vm_start); > +#endif > + lvma->pte_table_counter_pending = false; > + rvma->pte_table_counter_pending = false; > + > + if (pte_table_start(rvma->vm_start) == rvma->vm_start) { //the split was right at the pte table boundary > + return 0; //the only case where we don't increment pte table counter > + } else { > +#ifdef CONFIG_DEBUG_VM > + printk("split fixup: rvma_start=%lx\n", rvma->vm_start); > +#endif > + walk_page_range(rvma->vm_mm, pte_table_start(rvma->vm_start), pte_table_end(rvma->vm_start), &pgtable_counter_fixup_walk_ops, NULL); > + } > + } > + > + return 0; > +} > + > static inline void > unlock_range(struct vm_area_struct *start, unsigned long limit) > { > @@ -2869,6 +2955,8 @@ int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len, > if (error) > return error; > prev = vma; > + > + split_vma_pgtable_counter_fixup(prev, prev->vm_next, prev->pte_table_counter_pending); > } > > /* Does it split the last one? */ > @@ -2877,6 +2965,7 @@ int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len, > int error = __split_vma(mm, last, end, 1); > if (error) > return error; > + split_vma_pgtable_counter_fixup(last->vm_prev, last, last->pte_table_counter_pending); > } > vma = vma_next(mm, prev); > > diff --git a/mm/mprotect.c b/mm/mprotect.c > index 4cb240fd9936..d396b1d38fab 100644 > --- a/mm/mprotect.c > +++ b/mm/mprotect.c > @@ -445,6 +445,8 @@ static const struct mm_walk_ops prot_none_walk_ops = { > .test_walk = prot_none_test, > }; > > +int split_vma_pgtable_counter_fixup(struct vm_area_struct *lvma, struct vm_area_struct *rvma, bool orig_pending_flag); > + > int > mprotect_fixup(struct vm_area_struct *vma, struct vm_area_struct **pprev, > unsigned long start, unsigned long end, unsigned long newflags) > @@ -517,12 +519,16 @@ mprotect_fixup(struct vm_area_struct *vma, struct vm_area_struct **pprev, > error = split_vma(mm, vma, start, 1); > if (error) > goto fail; > + > + split_vma_pgtable_counter_fixup(vma->vm_prev, vma, vma->pte_table_counter_pending); > } > > if (end != vma->vm_end) { > error = split_vma(mm, vma, end, 0); > if (error) > goto fail; > + > + split_vma_pgtable_counter_fixup(vma, vma->vm_next, vma->pte_table_counter_pending); > } > > success: > -- > 2.30.2 > >