This patch introduces a new helper function - tnum_scast(), which sign-extends a tnum from a smaller integer size to the full 64-bit bpf register range. This is achieved by: Given a tnum (v, m) and target size S bytes: 1) Mask value/mask to S bytes: val = v & mask, msk = m & mask 2) If sign bit (bit S*8-1) is unknown (msk has bit set): - Extended bits become unknown (mask |= ~value_mask) - Sign possibilities constrain value (if sign could be 1, upper bits must allow both 0s and 1s) 3) If sign bit is known: - Upper bits follow sign extension of val - Mask upper bits then follow sign extension of msk a) When the sign bit is known: Assume a tnum with value = 0xFF, mask = 0x00, size = 1, which corresponds to an 8-bit subregister of value 0xFF. We extract the sign bit position, compute the value mask, apply it to the lower bits and check the sign bit at said position. s = size * 8 - 1 // 1 * 8 - 1 = 7. value_mask = (1ULL << (s + 1)) - 1; // (1 << (7 + 1)) - 1 = 0xFF new_value = a.value & value_mask; // 0xFF & 0xFF = 0xFF new_mask = a.mask & value_mask; // 0x00 & 0xFF = 0x00 sign_bit_unknown = (0x00 >> 7) & 1 = 0; // sign bit is known sign_bit_value = (0xFF >> 7) & 1 = 1; // with value 1 Because the sign bit is known to be 1, we sign-extend with 1s above bit 7, so all upper bits [63,8] become 1, new_value in 64 bits is 0xFFFFFFFFFFFFFFFF and new_mask for those bits is 0 (since we know for sure they are all 1). So after the tnum_scast call and the sign extension, the tnum is (0xFFFFFFFFFFFFFFFF, 0x0000000000000000), which corresponds to the 64-bit value -1. b) When the sign bit is unknown: Assume a tnum wih value = 0x7F, mask = 0x80, size = 1. In this case the lower 8 bits [6,0] are known to be 0x7F or b(0111 1111). Bit 7 is unknown (mask = 0x80), so it could be 0 or 1. This means the subregister could be 0x7F (+127) or 0xFF (-1), or otherwise anythnig that differs in bit 7. Following the same operations as the previous example, we get s = 7 and value_mask = 0xFF. Then: new_value = a.value & value_mask; // 0x7F & 0xFF = 0x7F new_mask = a.mask & value_mask; // 0x80 & 0xFF = 0x80 sign_bit_unknown = (0x80 >> 7) & 1 = 1; // bit 7 is unknown // sign bit is unkown, so we treat upper bits [63,8] as unknown new_mask |= ~value_mask; This leads to a new tnum with value=0x7F, mask=0xFFFFFFFFFFFFFF80 The lower 8 bits can be 0x7F or 0xFF, and the higher 56 bits are fully unknown. In 64-bit form, this tnum can represent anything from: 0x000000000000007F (+127) if the sign bit is 0 and all higher bits are 0, up to 0xFFFFFFFFFFFFFFFF (-1) if the sign bit and all higher bits are 1. Signed-off-by: Dimitar Kanaliev <dimitar.kanaliev@xxxxxxxxxxxxxx> --- include/linux/tnum.h | 3 +++ kernel/bpf/tnum.c | 29 +++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/include/linux/tnum.h b/include/linux/tnum.h index 3c13240077b8..6933db04c9ee 100644 --- a/include/linux/tnum.h +++ b/include/linux/tnum.h @@ -55,6 +55,9 @@ struct tnum tnum_intersect(struct tnum a, struct tnum b); /* Return @a with all but the lowest @size bytes cleared */ struct tnum tnum_cast(struct tnum a, u8 size); +/* Return @a sign-extended from @size bytes */ +struct tnum tnum_scast(struct tnum a, u8 size); + /* Returns true if @a is a known constant */ static inline bool tnum_is_const(struct tnum a) { diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c index 9dbc31b25e3d..cb29dbc793d4 100644 --- a/kernel/bpf/tnum.c +++ b/kernel/bpf/tnum.c @@ -157,6 +157,35 @@ struct tnum tnum_cast(struct tnum a, u8 size) return a; } +struct tnum tnum_scast(struct tnum a, u8 size) +{ + u64 s = size * 8 - 1; + u64 sign_mask; + u64 value_mask; + u64 new_value, new_mask; + u64 sign_bit_unknown, sign_bit_value; + u64 mask; + + if (size >= 8) { + return a; + } + + sign_mask = 1ULL << s; + value_mask = (1ULL << (s + 1)) - 1; + + new_value = a.value & value_mask; + new_mask = a.mask & value_mask; + + sign_bit_unknown = (a.mask >> s) & 1; + sign_bit_value = (a.value >> s) & 1; + + mask = ~value_mask; + new_mask |= mask & (0 - sign_bit_unknown); + new_value |= mask & (0 - ((sign_bit_unknown ^ 1) & sign_bit_value)); + + return TNUM(new_value, new_mask); +} + bool tnum_is_aligned(struct tnum a, u64 size) { if (!size) -- 2.43.0