Will be re-used for nft_ctx_input_set_flags() and nft_ctx_input_get_flags(). There are changes in behavior here. - when passing an unrecognized string (e.g. `ctx.set_debug('foo')` or `ctx.set_debug(['foo'])`), a ValueError is now raised instead of a KeyError. - when passing an out-of-range integer, now a ValueError is no raised. Previously the integer was truncated to 32bit. Changing the exception is an API change, but most likely nobody will care or try to catch a KeyError to find out whether a flag is supported. Especially, since such a check would be better performed via `'foo' in ctx.debug_flags`. In other cases, a TypeError is raised as before. Signed-off-by: Thomas Haller <thaller@xxxxxxxxxx> Reviewed-by: Phil Sutter <phil@xxxxxx> --- py/src/nftables.py | 52 +++++++++++++++++++++++++++------------------- 1 file changed, 31 insertions(+), 21 deletions(-) diff --git a/py/src/nftables.py b/py/src/nftables.py index b1186781ab5c..95c65cde69c4 100644 --- a/py/src/nftables.py +++ b/py/src/nftables.py @@ -156,6 +156,35 @@ class Nftables: self.nft_ctx_free(self.__ctx) self.__ctx = None + def _flags_from_numeric(self, flags_dict, val): + names = [] + for n, v in flags_dict.items(): + if val & v: + names.append(n) + val &= ~v + if val: + names.append(val) + return names + + def _flags_to_numeric(self, flags_dict, values): + if isinstance(values, (str, int)): + values = (values,) + + val = 0 + for v in values: + if isinstance(v, str): + v = flags_dict.get(v) + if v is None: + raise ValueError("Invalid argument") + elif isinstance(v, int): + if v < 0 or v > 0xFFFFFFFF: + raise ValueError("Invalid argument") + else: + raise TypeError("Not a valid flag") + val |= v + + return val + def __get_output_flag(self, name): flag = self.output_flags[name] return (self.nft_ctx_output_get_flags(self.__ctx) & flag) != 0 @@ -375,16 +404,7 @@ class Nftables: Returns a set of flag names. See set_debug() for details. """ val = self.nft_ctx_output_get_debug(self.__ctx) - - names = [] - for n,v in self.debug_flags.items(): - if val & v: - names.append(n) - val &= ~v - if val: - names.append(val) - - return names + return self._flags_from_numeric(self.debug_flags, val) def set_debug(self, values): """Set debug output flags. @@ -406,19 +426,9 @@ class Nftables: Returns a set of previously active debug flags, as returned by get_debug() method. """ + val = self._flags_to_numeric(self.debug_flags, values) old = self.get_debug() - - if type(values) in [str, int]: - values = [values] - - val = 0 - for v in values: - if type(v) is str: - v = self.debug_flags[v] - val |= v - self.nft_ctx_output_set_debug(self.__ctx, val) - return old def cmd(self, cmdline): -- 2.41.0