Add a hook in the fork and exec path to link mm_struct. Reuse the mm_slot infrastructure to aid insert and lookup of mm_struct. CC: linux-fsdevel@xxxxxxxxxxxxxxx Suggested-by: Bharata B Rao <bharata@xxxxxxx> Signed-off-by: Raghavendra K T <raghavendra.kt@xxxxxxx> --- fs/exec.c | 4 ++ include/linux/kmmscand.h | 30 ++++++++++++++ kernel/fork.c | 4 ++ mm/kmmscand.c | 86 +++++++++++++++++++++++++++++++++++++++- 4 files changed, 123 insertions(+), 1 deletion(-) create mode 100644 include/linux/kmmscand.h diff --git a/fs/exec.c b/fs/exec.c index 98cb7ba9983c..bd72107b2ab1 100644 --- a/fs/exec.c +++ b/fs/exec.c @@ -68,6 +68,7 @@ #include <linux/user_events.h> #include <linux/rseq.h> #include <linux/ksm.h> +#include <linux/kmmscand.h> #include <linux/uaccess.h> #include <asm/mmu_context.h> @@ -274,6 +275,8 @@ static int __bprm_mm_init(struct linux_binprm *bprm) if (err) goto err_ksm; + kmmscand_execve(mm); + /* * Place the stack at the largest stack address the architecture * supports. Later, we'll move this to an appropriate place. We don't @@ -296,6 +299,7 @@ static int __bprm_mm_init(struct linux_binprm *bprm) return 0; err: ksm_exit(mm); + kmmscand_exit(mm); err_ksm: mmap_write_unlock(mm); err_free: diff --git a/include/linux/kmmscand.h b/include/linux/kmmscand.h new file mode 100644 index 000000000000..b120c65ee8c6 --- /dev/null +++ b/include/linux/kmmscand.h @@ -0,0 +1,30 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef _LINUX_KMMSCAND_H_ +#define _LINUX_KMMSCAND_H_ + +#ifdef CONFIG_KMMSCAND +extern void __kmmscand_enter(struct mm_struct *mm); +extern void __kmmscand_exit(struct mm_struct *mm); + +static inline void kmmscand_execve(struct mm_struct *mm) +{ + __kmmscand_enter(mm); +} + +static inline void kmmscand_fork(struct mm_struct *mm, struct mm_struct *oldmm) +{ + __kmmscand_enter(mm); +} + +static inline void kmmscand_exit(struct mm_struct *mm) +{ + __kmmscand_exit(mm); +} +#else /* !CONFIG_KMMSCAND */ +static inline void __kmmscand_enter(struct mm_struct *mm) {} +static inline void __kmmscand_exit(struct mm_struct *mm) {} +static inline void kmmscand_execve(struct mm_struct *mm) {} +static inline void kmmscand_fork(struct mm_struct *mm, struct mm_struct *oldmm) {} +static inline void kmmscand_exit(struct mm_struct *mm) {} +#endif +#endif /* _LINUX_KMMSCAND_H_ */ diff --git a/kernel/fork.c b/kernel/fork.c index 1450b461d196..812b0032592e 100644 --- a/kernel/fork.c +++ b/kernel/fork.c @@ -85,6 +85,7 @@ #include <linux/user-return-notifier.h> #include <linux/oom.h> #include <linux/khugepaged.h> +#include <linux/kmmscand.h> #include <linux/signalfd.h> #include <linux/uprobes.h> #include <linux/aio.h> @@ -659,6 +660,8 @@ static __latent_entropy int dup_mmap(struct mm_struct *mm, mm->exec_vm = oldmm->exec_vm; mm->stack_vm = oldmm->stack_vm; + kmmscand_fork(mm, oldmm); + /* Use __mt_dup() to efficiently build an identical maple tree. */ retval = __mt_dup(&oldmm->mm_mt, &mm->mm_mt, GFP_KERNEL); if (unlikely(retval)) @@ -1350,6 +1353,7 @@ static inline void __mmput(struct mm_struct *mm) exit_aio(mm); ksm_exit(mm); khugepaged_exit(mm); /* must run before exit_mmap */ + kmmscand_exit(mm); exit_mmap(mm); mm_put_huge_zero_folio(mm); set_mm_exe_file(mm, NULL); diff --git a/mm/kmmscand.c b/mm/kmmscand.c index 23cf5638fe10..957128d4e425 100644 --- a/mm/kmmscand.c +++ b/mm/kmmscand.c @@ -7,13 +7,14 @@ #include <linux/swap.h> #include <linux/mm_inline.h> #include <linux/kthread.h> +#include <linux/kmmscand.h> #include <linux/string.h> #include <linux/delay.h> #include <linux/cleanup.h> #include <asm/pgalloc.h> #include "internal.h" - +#include "mm_slot.h" static struct task_struct *kmmscand_thread __read_mostly; static DEFINE_MUTEX(kmmscand_mutex); @@ -30,10 +31,21 @@ static bool need_wakeup; static unsigned long kmmscand_sleep_expire; +static DEFINE_SPINLOCK(kmmscand_mm_lock); static DECLARE_WAIT_QUEUE_HEAD(kmmscand_wait); +#define KMMSCAND_SLOT_HASH_BITS 10 +static DEFINE_READ_MOSTLY_HASHTABLE(kmmscand_slots_hash, KMMSCAND_SLOT_HASH_BITS); + +static struct kmem_cache *kmmscand_slot_cache __read_mostly; + +struct kmmscand_mm_slot { + struct mm_slot slot; +}; + struct kmmscand_scan { struct list_head mm_head; + struct kmmscand_mm_slot *mm_slot; }; struct kmmscand_scan kmmscand_scan = { @@ -76,6 +88,11 @@ static void kmmscand_migrate_folio(void) { } +static inline int kmmscand_test_exit(struct mm_struct *mm) +{ + return atomic_read(&mm->mm_users) == 0; +} + static unsigned long kmmscand_scan_mm_slot(void) { /* placeholder for scanning */ @@ -123,6 +140,65 @@ static int kmmscand(void *none) return 0; } +static inline void kmmscand_destroy(void) +{ + kmem_cache_destroy(kmmscand_slot_cache); +} + +void __kmmscand_enter(struct mm_struct *mm) +{ + struct kmmscand_mm_slot *kmmscand_slot; + struct mm_slot *slot; + int wakeup; + + /* __kmmscand_exit() must not run from under us */ + VM_BUG_ON_MM(kmmscand_test_exit(mm), mm); + + kmmscand_slot = mm_slot_alloc(kmmscand_slot_cache); + + if (!kmmscand_slot) + return; + + slot = &kmmscand_slot->slot; + + spin_lock(&kmmscand_mm_lock); + mm_slot_insert(kmmscand_slots_hash, mm, slot); + + wakeup = list_empty(&kmmscand_scan.mm_head); + list_add_tail(&slot->mm_node, &kmmscand_scan.mm_head); + spin_unlock(&kmmscand_mm_lock); + + mmgrab(mm); + if (wakeup) + wake_up_interruptible(&kmmscand_wait); +} + +void __kmmscand_exit(struct mm_struct *mm) +{ + struct kmmscand_mm_slot *mm_slot; + struct mm_slot *slot; + int free = 0; + + spin_lock(&kmmscand_mm_lock); + slot = mm_slot_lookup(kmmscand_slots_hash, mm); + mm_slot = mm_slot_entry(slot, struct kmmscand_mm_slot, slot); + if (mm_slot && kmmscand_scan.mm_slot != mm_slot) { + hash_del(&slot->hash); + list_del(&slot->mm_node); + free = 1; + } + + spin_unlock(&kmmscand_mm_lock); + + if (free) { + mm_slot_free(kmmscand_slot_cache, mm_slot); + mmdrop(mm); + } else if (mm_slot) { + mmap_write_lock(mm); + mmap_write_unlock(mm); + } +} + static int start_kmmscand(void) { int err = 0; @@ -168,6 +244,13 @@ static int __init kmmscand_init(void) { int err; + kmmscand_slot_cache = KMEM_CACHE(kmmscand_mm_slot, 0); + + if (!kmmscand_slot_cache) { + pr_err("kmmscand: kmem_cache error"); + return -ENOMEM; + } + err = start_kmmscand(); if (err) goto err_kmmscand; @@ -176,6 +259,7 @@ static int __init kmmscand_init(void) err_kmmscand: stop_kmmscand(); + kmmscand_destroy(); return err; } -- 2.39.3