This patch refactors the verifier's sign-extension logic for narrow register values to use the new tnum_scast helper. The general idea is to replace manual range checks in coerce_reg_to_size_sx and coerce_subreg_to_size_sx by deriving smin/smax and umin/umax boundaries directly from the tnum via tnum_scast. In the original code,some coercion cases with unknown sign bits triggered a fallback to worst-case [S64_MIN, S64_MAX] ranges. With these changes, we can now track bitwise uncertainty more precisely, allowing for arbitratry bounds like `[-129, 126]` when upper bits are partially known. An example for such cases would be: For an 8-bit register with var_off = (value=0x7F, mask=0x80) i.e known lower 7 bits, unknown sign bit, the original code would default to [S64_MIN, S64_MAX] for the smin, smax ranges, while the tnum_scast implementation will bind smin = -1, smax = 127, while still tracking uncertainty in the upper 56 bits. Signed-off-by: Dimitar Kanaliev <dimitar.kanaliev@xxxxxxxxxxxxxx> --- kernel/bpf/verifier.c | 124 +++++++++++++----------------------------- 1 file changed, 39 insertions(+), 85 deletions(-) diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c index 9971c03adfd5..a98dba42abc0 100644 --- a/kernel/bpf/verifier.c +++ b/kernel/bpf/verifier.c @@ -6661,61 +6661,35 @@ static void set_sext64_default_val(struct bpf_reg_state *reg, int size) static void coerce_reg_to_size_sx(struct bpf_reg_state *reg, int size) { - s64 init_s64_max, init_s64_min, s64_max, s64_min, u64_cval; - u64 top_smax_value, top_smin_value; - u64 num_bits = size * 8; + u64 s = size * 8 - 1; + u64 sign_mask = 1ULL << s; + s64 smin_value, smax_value; + u64 umax_value; - if (tnum_is_const(reg->var_off)) { - u64_cval = reg->var_off.value; - if (size == 1) - reg->var_off = tnum_const((s8)u64_cval); - else if (size == 2) - reg->var_off = tnum_const((s16)u64_cval); - else - /* size == 4 */ - reg->var_off = tnum_const((s32)u64_cval); - - u64_cval = reg->var_off.value; - reg->smax_value = reg->smin_value = u64_cval; - reg->umax_value = reg->umin_value = u64_cval; - reg->s32_max_value = reg->s32_min_value = u64_cval; - reg->u32_max_value = reg->u32_min_value = u64_cval; + if (size >= 8) return; - } - top_smax_value = ((u64)reg->smax_value >> num_bits) << num_bits; - top_smin_value = ((u64)reg->smin_value >> num_bits) << num_bits; + reg->var_off = tnum_scast(reg->var_off, size); - if (top_smax_value != top_smin_value) - goto out; - - /* find the s64_min and s64_min after sign extension */ - if (size == 1) { - init_s64_max = (s8)reg->smax_value; - init_s64_min = (s8)reg->smin_value; - } else if (size == 2) { - init_s64_max = (s16)reg->smax_value; - init_s64_min = (s16)reg->smin_value; + if (reg->var_off.mask & sign_mask) { + smin_value = -(1LL << s); + smax_value = (1LL << s) - 1; } else { - init_s64_max = (s32)reg->smax_value; - init_s64_min = (s32)reg->smin_value; + smin_value = (s64)(reg->var_off.value); + smax_value = (s64)(reg->var_off.value | reg->var_off.mask); } - s64_max = max(init_s64_max, init_s64_min); - s64_min = min(init_s64_max, init_s64_min); + reg->smin_value = smin_value; + reg->smax_value = smax_value; - /* both of s64_max/s64_min positive or negative */ - if ((s64_max >= 0) == (s64_min >= 0)) { - reg->s32_min_value = reg->smin_value = s64_min; - reg->s32_max_value = reg->smax_value = s64_max; - reg->u32_min_value = reg->umin_value = s64_min; - reg->u32_max_value = reg->umax_value = s64_max; - reg->var_off = tnum_range(s64_min, s64_max); - return; - } + reg->umin_value = reg->var_off.value; + umax_value = reg->var_off.value | reg->var_off.mask; + reg->umax_value = umax_value; -out: - set_sext64_default_val(reg, size); + reg->s32_min_value = (s32)smin_value; + reg->s32_max_value = (s32)smax_value; + reg->u32_min_value = (u32)reg->umin_value; + reg->u32_max_value = (u32)umax_value; } static void set_sext32_default_val(struct bpf_reg_state *reg, int size) @@ -6735,52 +6709,32 @@ static void set_sext32_default_val(struct bpf_reg_state *reg, int size) static void coerce_subreg_to_size_sx(struct bpf_reg_state *reg, int size) { - s32 init_s32_max, init_s32_min, s32_max, s32_min, u32_val; - u32 top_smax_value, top_smin_value; - u32 num_bits = size * 8; + u32 s = size * 8 - 1; + u32 sign_mask = 1U << s; + s32 smin_value, smax_value; + u32 umax_value; - if (tnum_is_const(reg->var_off)) { - u32_val = reg->var_off.value; - if (size == 1) - reg->var_off = tnum_const((s8)u32_val); - else - reg->var_off = tnum_const((s16)u32_val); - - u32_val = reg->var_off.value; - reg->s32_min_value = reg->s32_max_value = u32_val; - reg->u32_min_value = reg->u32_max_value = u32_val; + if (size >= 4) return; - } - - top_smax_value = ((u32)reg->s32_max_value >> num_bits) << num_bits; - top_smin_value = ((u32)reg->s32_min_value >> num_bits) << num_bits; - if (top_smax_value != top_smin_value) - goto out; + reg->var_off = tnum_scast(reg->var_off, size); - /* find the s32_min and s32_min after sign extension */ - if (size == 1) { - init_s32_max = (s8)reg->s32_max_value; - init_s32_min = (s8)reg->s32_min_value; + if (reg->var_off.mask & sign_mask) { + smin_value = -(1 << s); + smax_value = (1 << s) - 1; } else { - /* size == 2 */ - init_s32_max = (s16)reg->s32_max_value; - init_s32_min = (s16)reg->s32_min_value; - } - s32_max = max(init_s32_max, init_s32_min); - s32_min = min(init_s32_max, init_s32_min); - - if ((s32_min >= 0) == (s32_max >= 0)) { - reg->s32_min_value = s32_min; - reg->s32_max_value = s32_max; - reg->u32_min_value = (u32)s32_min; - reg->u32_max_value = (u32)s32_max; - reg->var_off = tnum_subreg(tnum_range(s32_min, s32_max)); - return; + smin_value = (s32)(reg->var_off.value); + smax_value = (s32)(reg->var_off.value | reg->var_off.mask); } -out: - set_sext32_default_val(reg, size); + reg->s32_min_value = smin_value; + reg->s32_max_value = smax_value; + + reg->u32_min_value = reg->var_off.value; + umax_value = reg->var_off.value | reg->var_off.mask; + reg->u32_max_value = umax_value; + + reg->var_off = tnum_subreg(reg->var_off); } static bool bpf_map_is_rdonly(const struct bpf_map *map) -- 2.43.0