On Thu, Feb 04, 2021 at 03:48:29PM -0800, Yonghong Song wrote: > The bpf_for_each_map_elem() helper is introduced which > iterates all map elements with a callback function. The > helper signature looks like > long bpf_for_each_map_elem(map, callback_fn, callback_ctx, flags) > and for each map element, the callback_fn will be called. For example, > like hashmap, the callback signature may look like > long callback_fn(map, key, val, callback_ctx) > > There are two known use cases for this. One is from upstream ([1]) where > a for_each_map_elem helper may help implement a timeout mechanism > in a more generic way. Another is from our internal discussion > for a firewall use case where a map contains all the rules. The packet > data can be compared to all these rules to decide allow or deny > the packet. > > For array maps, users can already use a bounded loop to traverse > elements. Using this helper can avoid using bounded loop. For other > type of maps (e.g., hash maps) where bounded loop is hard or > impossible to use, this helper provides a convenient way to > operate on all elements. > > For callback_fn, besides map and map element, a callback_ctx, > allocated on caller stack, is also passed to the callback > function. This callback_ctx argument can provide additional > input and allow to write to caller stack for output. The approach and implementation look great! Few ideas below: > +static int check_map_elem_callback(struct bpf_verifier_env *env, int *insn_idx) > +{ > + struct bpf_verifier_state *state = env->cur_state; > + struct bpf_prog_aux *aux = env->prog->aux; > + struct bpf_func_state *caller, *callee; > + struct bpf_map *map; > + int err, subprog; > + > + if (state->curframe + 1 >= MAX_CALL_FRAMES) { > + verbose(env, "the call stack of %d frames is too deep\n", > + state->curframe + 2); > + return -E2BIG; > + } > + > + caller = state->frame[state->curframe]; > + if (state->frame[state->curframe + 1]) { > + verbose(env, "verifier bug. Frame %d already allocated\n", > + state->curframe + 1); > + return -EFAULT; > + } > + > + caller->with_callback_fn = true; > + > + callee = kzalloc(sizeof(*callee), GFP_KERNEL); > + if (!callee) > + return -ENOMEM; > + state->frame[state->curframe + 1] = callee; > + > + /* callee cannot access r0, r6 - r9 for reading and has to write > + * into its own stack before reading from it. > + * callee can read/write into caller's stack > + */ > + init_func_state(env, callee, > + /* remember the callsite, it will be used by bpf_exit */ > + *insn_idx /* callsite */, > + state->curframe + 1 /* frameno within this callchain */, > + subprog /* subprog number within this prog */); > + > + /* Transfer references to the callee */ > + err = transfer_reference_state(callee, caller); > + if (err) > + return err; > + > + subprog = caller->regs[BPF_REG_2].subprog; > + if (aux->func_info && aux->func_info_aux[subprog].linkage != BTF_FUNC_STATIC) { > + verbose(env, "callback function R2 not static\n"); > + return -EINVAL; > + } > + > + map = caller->regs[BPF_REG_1].map_ptr; Take a look at for (i = 0; i < 5; i++) err = check_func_arg loop and record_func_map. It stores the map pointer into map_ptr_state and makes sure it's unique, so that program doesn't try to pass two different maps into the same 'call insn'. It can make this function a bit more generic. There would be no need to hard code regs[BPF_REG_1]. The code would take it from map_ptr_state. Also it will help later with optimizing return map->ops->map_for_each_callback(map, callback_fn, callback_ctx, flags); since the map pointer will be the same the optimization (that is applied to other map operations) can be applied for this callback as well. The regs[BPF_REG_2] can be generalized a bit as well. It think linkage != BTF_FUNC_STATIC can be moved to early check_ld_imm phase. While here the check_func_arg() loop can look for PTR_TO_FUNC type, remeber the subprog into meta (just like map_ptr_state) and ... continues below > + if (!map->ops->map_set_for_each_callback_args || > + !map->ops->map_for_each_callback) { > + verbose(env, "callback function not allowed for map R1\n"); > + return -ENOTSUPP; > + } > + > + /* the following is only for hashmap, different maps > + * can have different callback signatures. > + */ > + err = map->ops->map_set_for_each_callback_args(env, caller, callee); > + if (err) > + return err; > + > + clear_caller_saved_regs(env, caller->regs); > + > + /* only increment it after check_reg_arg() finished */ > + state->curframe++; > + > + /* and go analyze first insn of the callee */ > + *insn_idx = env->subprog_info[subprog].start - 1; > + > + if (env->log.level & BPF_LOG_LEVEL) { > + verbose(env, "caller:\n"); > + print_verifier_state(env, caller); > + verbose(env, "callee:\n"); > + print_verifier_state(env, callee); > + } > + return 0; > +} > + > static int prepare_func_exit(struct bpf_verifier_env *env, int *insn_idx) > { > struct bpf_verifier_state *state = env->cur_state; > struct bpf_func_state *caller, *callee; > struct bpf_reg_state *r0; > - int err; > + int i, err; > > callee = state->frame[state->curframe]; > r0 = &callee->regs[BPF_REG_0]; > @@ -4955,7 +5090,17 @@ static int prepare_func_exit(struct bpf_verifier_env *env, int *insn_idx) > state->curframe--; > caller = state->frame[state->curframe]; > /* return to the caller whatever r0 had in the callee */ > - caller->regs[BPF_REG_0] = *r0; > + if (caller->with_callback_fn) { > + /* reset caller saved regs, the helper calling callback_fn > + * has RET_INTEGER return types. > + */ > + for (i = 0; i < CALLER_SAVED_REGS; i++) > + mark_reg_not_init(env, caller->regs, caller_saved[i]); > + caller->regs[BPF_REG_0].subreg_def = DEF_NOT_SUBREG; > + mark_reg_unknown(env, caller->regs, BPF_REG_0); this part can stay in check_helper_call(). > + } else { > + caller->regs[BPF_REG_0] = *r0; > + } > > /* Transfer references to the caller */ > err = transfer_reference_state(caller, callee); > @@ -5091,7 +5236,8 @@ static int check_reference_leak(struct bpf_verifier_env *env) > return state->acquired_refs ? -EINVAL : 0; > } > > -static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn_idx) > +static int check_helper_call(struct bpf_verifier_env *env, int func_id, int *insn_idx, > + bool map_elem_callback) > { > const struct bpf_func_proto *fn = NULL; > struct bpf_reg_state *regs; > @@ -5151,11 +5297,11 @@ static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn > return err; > } > > - err = record_func_map(env, &meta, func_id, insn_idx); > + err = record_func_map(env, &meta, func_id, *insn_idx); > if (err) > return err; > > - err = record_func_key(env, &meta, func_id, insn_idx); > + err = record_func_key(env, &meta, func_id, *insn_idx); > if (err) > return err; > > @@ -5163,7 +5309,7 @@ static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn > * is inferred from register state. > */ > for (i = 0; i < meta.access_size; i++) { > - err = check_mem_access(env, insn_idx, meta.regno, i, BPF_B, > + err = check_mem_access(env, *insn_idx, meta.regno, i, BPF_B, > BPF_WRITE, -1, false); > if (err) > return err; > @@ -5195,6 +5341,11 @@ static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn > return -EINVAL; > } > > + if (map_elem_callback) { > + env->prog->aux->with_callback_fn = true; > + return check_map_elem_callback(env, insn_idx); Instead of returning early here. The check_func_arg() loop can look for PTR_TO_FUNC type. The allocate new callee state, do map_set_for_each_callback_args() here. and then proceed further. > + } > + > /* reset caller saved regs */ > for (i = 0; i < CALLER_SAVED_REGS; i++) { > mark_reg_not_init(env, regs, caller_saved[i]); Instead of doing this loop in prepare_func_exit(). This code can just proceed here and clear caller regs. This loop can stay as-is. The transfer of caller->callee would happen already. Then there are few lines here that diff didn't show. They do regs[BPF_REG_0].subreg_def = DEF_NOT_SUBREG and mark_reg_unknown. No need to do them in prepare_func_exit(). This function can proceed further reusing this caller regs clearing loop and r0 marking. Then before returning from check_helper_call() it will do what you have in check_map_elem_callback() and it will adjust *insn_idx. At this point caller would have regs cleared and r0=undef. And callee would have regs setup the way map_set_for_each_callback_args callback meant to do it. The only thing prepare_func_exit would need to do is to make sure that assignment: caller->regs[BPF_REG_0] = *r0 doesn't happen. caller's r0 was already set to undef. To achieve that I think would be a bit cleaner to mark callee state instead of caller state. So instead of caller->with_callback_fn=true maybe callee->in_callback_fn=true ? > @@ -5306,7 +5457,7 @@ static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn > /* For release_reference() */ > regs[BPF_REG_0].ref_obj_id = meta.ref_obj_id; > } else if (is_acquire_function(func_id, meta.map_ptr)) { > - int id = acquire_reference_state(env, insn_idx); > + int id = acquire_reference_state(env, *insn_idx); > > if (id < 0) > return id; > @@ -5448,6 +5599,14 @@ static int retrieve_ptr_limit(const struct bpf_reg_state *ptr_reg, > else > *ptr_limit = -off; > return 0; > + case PTR_TO_MAP_KEY: > + if (mask_to_left) { > + *ptr_limit = ptr_reg->umax_value + ptr_reg->off; > + } else { > + off = ptr_reg->smin_value + ptr_reg->off; > + *ptr_limit = ptr_reg->map_ptr->key_size - off; > + } > + return 0; > case PTR_TO_MAP_VALUE: > if (mask_to_left) { > *ptr_limit = ptr_reg->umax_value + ptr_reg->off; > @@ -5614,6 +5773,7 @@ static int adjust_ptr_min_max_vals(struct bpf_verifier_env *env, > verbose(env, "R%d pointer arithmetic on %s prohibited\n", > dst, reg_type_str[ptr_reg->type]); > return -EACCES; > + case PTR_TO_MAP_KEY: > case PTR_TO_MAP_VALUE: > if (!env->allow_ptr_leaks && !known && (smin_val < 0) != (smax_val < 0)) { > verbose(env, "R%d has unknown scalar with mixed signed bounds, pointer arithmetic with it prohibited for !root\n", > @@ -7818,6 +7978,12 @@ static int check_ld_imm(struct bpf_verifier_env *env, struct bpf_insn *insn) > return 0; > } > > + if (insn->src_reg == BPF_PSEUDO_FUNC) { > + dst_reg->type = PTR_TO_FUNC; > + dst_reg->subprog = insn[1].imm; Like here check for linkage==static can happen ? > + return 0; > + } > + > map = env->used_maps[aux->map_index]; > mark_reg_known_zero(env, regs, insn->dst_reg); > dst_reg->map_ptr = map; > @@ -8195,9 +8361,23 @@ static int visit_insn(int t, int insn_cnt, struct bpf_verifier_env *env) > > /* All non-branch instructions have a single fall-through edge. */ > if (BPF_CLASS(insns[t].code) != BPF_JMP && > - BPF_CLASS(insns[t].code) != BPF_JMP32) > + BPF_CLASS(insns[t].code) != BPF_JMP32 && > + !bpf_pseudo_func(insns + t)) > return push_insn(t, t + 1, FALLTHROUGH, env, false); > > + if (bpf_pseudo_func(insns + t)) { > + ret = push_insn(t, t + 1, FALLTHROUGH, env, false); > + if (ret) > + return ret; > + > + if (t + 1 < insn_cnt) > + init_explored_state(env, t + 1); > + init_explored_state(env, t); > + ret = push_insn(t, t + insns[t].imm + 1, BRANCH, > + env, false); > + return ret; > + } > + > switch (BPF_OP(insns[t].code)) { > case BPF_EXIT: > return DONE_EXPLORING; > @@ -8819,6 +8999,7 @@ static bool regsafe(struct bpf_reg_state *rold, struct bpf_reg_state *rcur, > */ > return false; > } > + case PTR_TO_MAP_KEY: > case PTR_TO_MAP_VALUE: > /* If the new min/max/var_off satisfy the old ones and > * everything else matches, we are OK. > @@ -9646,6 +9827,8 @@ static int do_check(struct bpf_verifier_env *env) > > env->jmps_processed++; > if (opcode == BPF_CALL) { > + bool map_elem_callback; > + > if (BPF_SRC(insn->code) != BPF_K || > insn->off != 0 || > (insn->src_reg != BPF_REG_0 && > @@ -9662,13 +9845,15 @@ static int do_check(struct bpf_verifier_env *env) > verbose(env, "function calls are not allowed while holding a lock\n"); > return -EINVAL; > } > + map_elem_callback = insn->src_reg != BPF_PSEUDO_CALL && > + insn->imm == BPF_FUNC_for_each_map_elem; > if (insn->src_reg == BPF_PSEUDO_CALL) > err = check_func_call(env, insn, &env->insn_idx); > else > - err = check_helper_call(env, insn->imm, env->insn_idx); > + err = check_helper_call(env, insn->imm, &env->insn_idx, > + map_elem_callback); then hopefully this extra 'map_elem_callback' boolean won't be needed. Only env->insn_idx into &env->insn_idx. In that sense check_helper_call will become a superset of check_func_call. Maybe some code between them can be shared too. Beyond bpf_for_each_map_elem() helper other helpers might use PTR_TO_FUNC. I hope with this approach all of them will be handled a bit more generically. > if (err) > return err; > - > } else if (opcode == BPF_JA) { > if (BPF_SRC(insn->code) != BPF_K || > insn->imm != 0 || > @@ -10090,6 +10275,12 @@ static int resolve_pseudo_ldimm64(struct bpf_verifier_env *env) > goto next_insn; > } > > + if (insn[0].src_reg == BPF_PSEUDO_FUNC) { > + aux = &env->insn_aux_data[i]; > + aux->ptr_type = PTR_TO_FUNC; > + goto next_insn; > + } > + > /* In final convert_pseudo_ld_imm64() step, this is > * converted into regular 64-bit imm load insn. > */ > @@ -10222,9 +10413,13 @@ static void convert_pseudo_ld_imm64(struct bpf_verifier_env *env) > int insn_cnt = env->prog->len; > int i; > > - for (i = 0; i < insn_cnt; i++, insn++) > - if (insn->code == (BPF_LD | BPF_IMM | BPF_DW)) > - insn->src_reg = 0; > + for (i = 0; i < insn_cnt; i++, insn++) { > + if (insn->code != (BPF_LD | BPF_IMM | BPF_DW)) > + continue; > + if (insn->src_reg == BPF_PSEUDO_FUNC) > + continue; > + insn->src_reg = 0; > + } > } > > /* single env->prog->insni[off] instruction was replaced with the range > @@ -10846,6 +11041,12 @@ static int jit_subprogs(struct bpf_verifier_env *env) > return 0; > > for (i = 0, insn = prog->insnsi; i < prog->len; i++, insn++) { > + if (bpf_pseudo_func(insn)) { > + env->insn_aux_data[i].call_imm = insn->imm; > + /* subprog is encoded in insn[1].imm */ > + continue; > + } > + > if (!bpf_pseudo_call(insn)) > continue; > /* Upon error here we cannot fall back to interpreter but > @@ -10975,6 +11176,12 @@ static int jit_subprogs(struct bpf_verifier_env *env) > for (i = 0; i < env->subprog_cnt; i++) { > insn = func[i]->insnsi; > for (j = 0; j < func[i]->len; j++, insn++) { > + if (bpf_pseudo_func(insn)) { > + subprog = insn[1].imm; > + insn[0].imm = (u32)(long)func[subprog]->bpf_func; > + insn[1].imm = ((u64)(long)func[subprog]->bpf_func) >> 32; > + continue; > + } > if (!bpf_pseudo_call(insn)) > continue; > subprog = insn->off; > @@ -11020,6 +11227,11 @@ static int jit_subprogs(struct bpf_verifier_env *env) > * later look the same as if they were interpreted only. > */ > for (i = 0, insn = prog->insnsi; i < prog->len; i++, insn++) { > + if (bpf_pseudo_func(insn)) { > + insn[0].imm = env->insn_aux_data[i].call_imm; > + insn[1].imm = find_subprog(env, i + insn[0].imm + 1); > + continue; > + } > if (!bpf_pseudo_call(insn)) > continue; > insn->off = env->insn_aux_data[i].call_imm; > @@ -11083,6 +11295,13 @@ static int fixup_call_args(struct bpf_verifier_env *env) > verbose(env, "tail_calls are not allowed in non-JITed programs with bpf-to-bpf calls\n"); > return -EINVAL; > } > + if (env->subprog_cnt > 1 && env->prog->aux->with_callback_fn) { Does this bool really need to be be part of 'aux'? There is a loop below that does if (!bpf_pseudo_call to fixup insns for the interpreter. May be add if (bpf_pseudo_func()) { return callbacks are not allowed in non-JITed } to the loop below as well? It's a trade off between memory and few extra insn. > + /* When JIT fails the progs with callback calls > + * have to be rejected, since interpreter doesn't support them yet. > + */ > + verbose(env, "callbacks are not allowed in non-JITed programs\n"); > + return -EINVAL; > + } > for (i = 0; i < prog->len; i++, insn++) { > if (!bpf_pseudo_call(insn)) > continue; to this loop. Thanks!