This patch use BPF prog to bypass the default select_bad_process method and select a victim memcg when gobal oom is invoked. Specifically, we iterate root_mem_cgroup's children and select a next iteration root through __bpf_run_oom_policy(). Repeat until we finally find a leaf memcg in the last layer. Then we use oom_evaluate_task() to find a victim task in the selected memcg. If there are no suitable process to be killed in the memcg, we go back to the default method. Suggested-by: Abel Wu <wuyun.abel@xxxxxxxxxxxxx> Signed-off-by: Chuyi Zhou <zhouchuyi@xxxxxxxxxxxxx> --- include/linux/memcontrol.h | 6 +++++ mm/memcontrol.c | 50 ++++++++++++++++++++++++++++++++++++++ mm/oom_kill.c | 17 +++++++++++++ 3 files changed, 73 insertions(+) diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h index 5818af8eca5a..7fedc2521c8b 100644 --- a/include/linux/memcontrol.h +++ b/include/linux/memcontrol.h @@ -1155,6 +1155,7 @@ unsigned long mem_cgroup_soft_limit_reclaim(pg_data_t *pgdat, int order, gfp_t gfp_mask, unsigned long *total_scanned); +struct mem_cgroup *select_victim_memcg(void); #else /* CONFIG_MEMCG */ #define MEM_CGROUP_ID_SHIFT 0 @@ -1588,6 +1589,11 @@ unsigned long mem_cgroup_soft_limit_reclaim(pg_data_t *pgdat, int order, { return 0; } + +static inline struct mem_cgroup *select_victim_memcg(void) +{ + return NULL; +} #endif /* CONFIG_MEMCG */ static inline void __inc_lruvec_kmem_state(void *p, enum node_stat_item idx) diff --git a/mm/memcontrol.c b/mm/memcontrol.c index e8ca4bdcb03c..c6b42635f1af 100644 --- a/mm/memcontrol.c +++ b/mm/memcontrol.c @@ -64,6 +64,7 @@ #include <linux/psi.h> #include <linux/seq_buf.h> #include <linux/sched/isolation.h> +#include <linux/bpf_oom.h> #include "internal.h" #include <net/sock.h> #include <net/ip.h> @@ -2638,6 +2639,55 @@ void mem_cgroup_handle_over_high(void) css_put(&memcg->css); } +struct mem_cgroup *select_victim_memcg(void) +{ + struct cgroup_subsys_state *pos, *parent, *victim; + struct mem_cgroup *victim_memcg; + + parent = &root_mem_cgroup->css; + victim_memcg = NULL; + + if (!cgroup_subsys_on_dfl(memory_cgrp_subsys)) + return NULL; + + rcu_read_lock(); + while (parent) { + struct cgroup_subsys_state *chosen = NULL; + struct mem_cgroup *pos_mem, *chosen_mem; + u64 chosen_id, pos_id; + int cmp_ret; + + victim = parent; + + list_for_each_entry_rcu(pos, &parent->children, sibling) { + pos_id = cgroup_id(pos->cgroup); + if (!chosen) + goto chose; + + cmp_ret = __bpf_run_oom_policy(chosen_id, pos_id); + if (cmp_ret == BPF_OOM_CMP_GREATER) + continue; + if (cmp_ret == BPF_OOM_CMP_EQUAL) { + pos_mem = mem_cgroup_from_css(pos); + chosen_mem = mem_cgroup_from_css(chosen); + if (page_counter_read(&pos_mem->memory) <= + page_counter_read(&chosen_mem->memory)) + continue; + } +chose: + chosen = pos; + chosen_id = pos_id; + } + parent = chosen; + } + + if (victim && css_tryget(victim)) + victim_memcg = mem_cgroup_from_css(victim); + rcu_read_unlock(); + + return victim_memcg; +} + static int try_charge_memcg(struct mem_cgroup *memcg, gfp_t gfp_mask, unsigned int nr_pages) { diff --git a/mm/oom_kill.c b/mm/oom_kill.c index 01af8adaa16c..b88c8c7d4ee4 100644 --- a/mm/oom_kill.c +++ b/mm/oom_kill.c @@ -361,6 +361,19 @@ static int oom_evaluate_task(struct task_struct *task, void *arg) return 1; } +static bool bpf_select_bad_process(struct oom_control *oc) +{ + struct mem_cgroup *victim_memcg; + + victim_memcg = select_victim_memcg(); + if (victim_memcg) { + mem_cgroup_scan_tasks(victim_memcg, oom_evaluate_task, oc); + css_put(&victim_memcg->css); + } + + return !!oc->chosen; +} + /* * Simple selection loop. We choose the process with the highest number of * 'points'. In case scan was aborted, oc->chosen is set to -1. @@ -372,6 +385,9 @@ static void select_bad_process(struct oom_control *oc) if (is_memcg_oom(oc)) mem_cgroup_scan_tasks(oc->memcg, oom_evaluate_task, oc); else { + if (bpf_oom_policy_enabled() && bpf_select_bad_process(oc)) + return; + struct task_struct *p; rcu_read_lock(); @@ -1426,3 +1442,4 @@ bool bpf_oom_policy_enabled(void) rcu_read_unlock(); return !empty; } + -- 2.20.1