[nft PATCH v5 5/6] py: extract flags helper functions for set_debug()/get_debug()

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



Will be re-used for nft_ctx_input_set_flags() and
nft_ctx_input_get_flags().

There are changes in behavior here.

- when passing an unrecognized string (e.g. `ctx.set_debug('foo')` or
  `ctx.set_debug(['foo'])`), a ValueError is now raised instead of a
  KeyError.

- when passing an out-of-range integer, now a ValueError is no raised.
  Previously the integer was truncated to 32bit.

Changing the exception is an API change, but most likely nobody will
care or try to catch a KeyError to find out whether a flag is supported.
Especially, since such a check would be better performed via `'foo' in
ctx.debug_flags`.

In other cases, a TypeError is raised as before.

Signed-off-by: Thomas Haller <thaller@xxxxxxxxxx>
Reviewed-by: Phil Sutter <phil@xxxxxx>
---
 py/src/nftables.py | 52 +++++++++++++++++++++++++++-------------------
 1 file changed, 31 insertions(+), 21 deletions(-)

diff --git a/py/src/nftables.py b/py/src/nftables.py
index b1186781ab5c..95c65cde69c4 100644
--- a/py/src/nftables.py
+++ b/py/src/nftables.py
@@ -156,6 +156,35 @@ class Nftables:
             self.nft_ctx_free(self.__ctx)
             self.__ctx = None
 
+    def _flags_from_numeric(self, flags_dict, val):
+        names = []
+        for n, v in flags_dict.items():
+            if val & v:
+                names.append(n)
+                val &= ~v
+        if val:
+            names.append(val)
+        return names
+
+    def _flags_to_numeric(self, flags_dict, values):
+        if isinstance(values, (str, int)):
+            values = (values,)
+
+        val = 0
+        for v in values:
+            if isinstance(v, str):
+                v = flags_dict.get(v)
+                if v is None:
+                    raise ValueError("Invalid argument")
+            elif isinstance(v, int):
+                if v < 0 or v > 0xFFFFFFFF:
+                    raise ValueError("Invalid argument")
+            else:
+                raise TypeError("Not a valid flag")
+            val |= v
+
+        return val
+
     def __get_output_flag(self, name):
         flag = self.output_flags[name]
         return (self.nft_ctx_output_get_flags(self.__ctx) & flag) != 0
@@ -375,16 +404,7 @@ class Nftables:
         Returns a set of flag names. See set_debug() for details.
         """
         val = self.nft_ctx_output_get_debug(self.__ctx)
-
-        names = []
-        for n,v in self.debug_flags.items():
-            if val & v:
-                names.append(n)
-                val &= ~v
-        if val:
-            names.append(val)
-
-        return names
+        return self._flags_from_numeric(self.debug_flags, val)
 
     def set_debug(self, values):
         """Set debug output flags.
@@ -406,19 +426,9 @@ class Nftables:
         Returns a set of previously active debug flags, as returned by
         get_debug() method.
         """
+        val = self._flags_to_numeric(self.debug_flags, values)
         old = self.get_debug()
-
-        if type(values) in [str, int]:
-            values = [values]
-
-        val = 0
-        for v in values:
-            if type(v) is str:
-                v = self.debug_flags[v]
-            val |= v
-
         self.nft_ctx_output_set_debug(self.__ctx, val)
-
         return old
 
     def cmd(self, cmdline):
-- 
2.41.0




[Index of Archives]     [Netfitler Users]     [Berkeley Packet Filter]     [LARTC]     [Bugtraq]     [Yosemite Forum]

  Powered by Linux