With the start <= end restriction lifted, wrange32 gains the ability to track the s32 range as well. The example provided in previous patch shows that wrange32 can now track {0xffffffff, 0, 1}, which is in fact just a plain s32 range {-1, 0, 1}. This patch add helpers to extract the smin and smax from wrange32 along with wrange32_swrapping() helper that checks whether this wrange32 wraps in the s32 range. Additional z3Py checks are added to make sure that the smin/smax reasoning is correct as well. Signed-off-by: Shung-Hsi Yu <shung-hsi.yu@xxxxxxxx> --- include/linux/wrange.h | 19 ++++++ tools/testing/selftests/bpf/formal/wrange.py | 67 +++++++++++++++++++- 2 files changed, 85 insertions(+), 1 deletion(-) diff --git a/include/linux/wrange.h b/include/linux/wrange.h index f51e674d1f18..876e260017fe 100644 --- a/include/linux/wrange.h +++ b/include/linux/wrange.h @@ -29,4 +29,23 @@ static inline u32 wrange32_umax(struct wrange32 a) { return a.end; } +static inline bool wrange32_swrapping(struct wrange32 a) { + return (s32)a.end < (s32)a.start; +} + +/* Helper functions that will be required later */ +static inline s32 wrange32_smin(struct wrange32 a) { + if (wrange32_swrapping(a)) + return S32_MIN; + else + return a.start; +} + +static inline s32 wrange32_smax(struct wrange32 a) { + if (wrange32_swrapping(a)) + return S32_MAX; + else + return a.end; +} + #endif /* _LINUX_WRANGE_H */ diff --git a/tools/testing/selftests/bpf/formal/wrange.py b/tools/testing/selftests/bpf/formal/wrange.py index a2b1b083d291..825d79c6570f 100755 --- a/tools/testing/selftests/bpf/formal/wrange.py +++ b/tools/testing/selftests/bpf/formal/wrange.py @@ -37,6 +37,19 @@ class Wrange(abc.ABC): def umax(self): return If(self.uwrapping, BitVecVal(2**self.SIZE - 1, bv=self.SIZE), self.end) + @property + def swrapping(self): + # signed comparison, (s32)end < (s32)start + return self.end < self.start + + @property + def smin(self): + return If(self.swrapping, BitVecVal(1 << (self.SIZE - 1), bv=self.SIZE), self.start) + + @property + def smax(self): + return If(self.swrapping, BitVecVal((2**self.SIZE - 1) >> 1, bv=self.SIZE), self.end) + # Not used in wrange.c, but helps with checking later def contains(self, val: BitVecRef): assert(val.size() == self.SIZE) @@ -79,6 +92,14 @@ def main(): prove( w1.umax == BitVecVal32(1), ) + print('\nChecking w1.smin is 1') + prove( + w1.smin == BitVecVal32(1), + ) + print('\nChecking w1.smax is 1') + prove( + w1.smax == BitVecVal32(1), + ) print('\nChecking that w1 contains 1') prove( w1.contains(BitVecVal32(1)), @@ -102,6 +123,14 @@ def main(): prove( w2.umax == BitVecVal32(2**32 - 1), ) + print('\nChecking w2.smin is -2147483648/0x80000000') + prove( + w2.smin == BitVecVal32(0x80000000), + ) + print('\nChecking w2.smax is 2147483647/0x7fffffff') + prove( + w2.smax == BitVecVal32(0x7fffffff), + ) print('\nChecking that w2 contains 2**32 - 1') prove( w2.contains(BitVecVal32(2**32 - 1)), @@ -136,6 +165,14 @@ def main(): prove( w3.umax == BitVecVal32(2**32 - 1), ) + print('\nChecking w3.smin is -2147483648/0x80000000') + prove( + w3.smin == BitVecVal32(0x80000000), + ) + print('\nChecking w3.smax is 2147483647/0x7fffffff') + prove( + w3.smax == BitVecVal32(0x7fffffff), + ) print('\nChecking that w3 contains 0') prove( w3.contains(BitVecVal32(0)), @@ -163,6 +200,14 @@ def main(): prove( w4.umax == BitVecVal32(2**32 - 1), ) + print('\nChecking w4.smin is -1') + prove( + w4.smin == BitVecVal32(-1), + ) + print('\nChecking w4.smax is 1') + prove( + w4.smax == BitVecVal32(1), + ) print('\nChecking that w4 contains 0') prove( w4.contains(BitVecVal32(0)), @@ -176,7 +221,7 @@ def main(): w4.contains(x) == Or(x == BitVecVal32(2**32-1), x == BitVecVal32(0), x == BitVecVal32(1)), ) - # General checks for umin/umax + # General checks for umin/umax/smin/smax w = Wrange32('w') # Given a Wrange32 called w x = BitVec32('x') # And an 32-bit integer x (redeclared for clarity) print(f'\nGiven any possible Wrange32 called w, and any possible 32-bit integer called x') @@ -200,6 +245,26 @@ def main(): ULE(x, w.umax), ) ) + print('\nChecking if w.contains(x) == True, then w.smin <= (s32)x is also true') + prove( + Implies( + And( + w.wellformed(), + w.contains(x), + ), + w.smin <= x, + ) + ) + print('\nChecking if w.contains(x) == True, then (s32)x <= w.smax is also true') + prove( + Implies( + And( + w.wellformed(), + w.contains(x), + ), + x <= w.smax, + ) + ) if __name__ == '__main__': main() -- 2.42.0