From: Johannes Berg <johannes.berg@xxxxxxxxx> Using a pointer to a struct indicating the min/max values, extend the ability to do range validation for arbitrary values. Small values in the s16 range can be kept in the policy directly. Signed-off-by: Johannes Berg <johannes.berg@xxxxxxxxx> --- v2: - fix WARN_ON condition as spotted by Jakub --- include/net/netlink.h | 45 +++++++++++++++++ lib/nlattr.c | 113 ++++++++++++++++++++++++++++++++++-------- 2 files changed, 137 insertions(+), 21 deletions(-) diff --git a/include/net/netlink.h b/include/net/netlink.h index 671b29d170a8..94a7df4ab122 100644 --- a/include/net/netlink.h +++ b/include/net/netlink.h @@ -189,11 +189,20 @@ enum { #define NLA_TYPE_MAX (__NLA_TYPE_MAX - 1) +struct netlink_range_validation { + u64 min, max; +}; + +struct netlink_range_validation_signed { + s64 min, max; +}; + enum nla_policy_validation { NLA_VALIDATE_NONE, NLA_VALIDATE_RANGE, NLA_VALIDATE_MIN, NLA_VALIDATE_MAX, + NLA_VALIDATE_RANGE_PTR, NLA_VALIDATE_FUNCTION, }; @@ -271,6 +280,22 @@ enum nla_policy_validation { * of s16 - do that as usual in the code instead. * Use the NLA_POLICY_MIN(), NLA_POLICY_MAX() and * NLA_POLICY_RANGE() macros. + * NLA_U8, + * NLA_U16, + * NLA_U32, + * NLA_U64 If the validation_type field instead is set to + * NLA_VALIDATE_RANGE_PTR, `range' must be a pointer + * to a struct netlink_range_validation that indicates + * the min/max values. + * Use NLA_POLICY_FULL_RANGE(). + * NLA_S8, + * NLA_S16, + * NLA_S32, + * NLA_S64 If the validation_type field instead is set to + * NLA_VALIDATE_RANGE_PTR, `range_signed' must be a + * pointer to a struct netlink_range_validation_signed + * that indicates the min/max values. + * Use NLA_POLICY_FULL_RANGE_SIGNED(). * All other Unused - but note that it's a union * * Meaning of `validate' field, use via NLA_POLICY_VALIDATE_FN: @@ -296,6 +321,8 @@ struct nla_policy { const u32 bitfield32_valid; const char *reject_message; const struct nla_policy *nested_policy; + struct netlink_range_validation *range; + struct netlink_range_validation_signed *range_signed; struct { s16 min, max; }; @@ -342,6 +369,12 @@ struct nla_policy { { .type = NLA_BITFIELD32, .bitfield32_valid = valid } #define __NLA_ENSURE(condition) BUILD_BUG_ON_ZERO(!(condition)) +#define NLA_ENSURE_UINT_TYPE(tp) \ + (__NLA_ENSURE(tp == NLA_U8 || tp == NLA_U16 || \ + tp == NLA_U32 || tp == NLA_U64) + tp) +#define NLA_ENSURE_SINT_TYPE(tp) \ + (__NLA_ENSURE(tp == NLA_S8 || tp == NLA_S16 || \ + tp == NLA_S32 || tp == NLA_S64) + tp) #define NLA_ENSURE_INT_TYPE(tp) \ (__NLA_ENSURE(tp == NLA_S8 || tp == NLA_U8 || \ tp == NLA_S16 || tp == NLA_U16 || \ @@ -360,6 +393,18 @@ struct nla_policy { .max = _max \ } +#define NLA_POLICY_FULL_RANGE(tp, _range) { \ + .type = NLA_ENSURE_UINT_TYPE(tp), \ + .validation_type = NLA_VALIDATE_RANGE_PTR, \ + .range = _range, \ +} + +#define NLA_POLICY_FULL_RANGE_SIGNED(tp, _range) { \ + .type = NLA_ENSURE_SINT_TYPE(tp), \ + .validation_type = NLA_VALIDATE_RANGE_PTR, \ + .range_signed = _range, \ +} + #define NLA_POLICY_MIN(tp, _min) { \ .type = NLA_ENSURE_INT_TYPE(tp), \ .validation_type = NLA_VALIDATE_MIN, \ diff --git a/lib/nlattr.c b/lib/nlattr.c index 7f7ebd89caa4..a8beb173f558 100644 --- a/lib/nlattr.c +++ b/lib/nlattr.c @@ -111,17 +111,34 @@ static int nla_validate_array(const struct nlattr *head, int len, int maxtype, return 0; } -static int nla_validate_int_range(const struct nla_policy *pt, - const struct nlattr *nla, - struct netlink_ext_ack *extack) +static int nla_validate_int_range_unsigned(const struct nla_policy *pt, + const struct nlattr *nla, + struct netlink_ext_ack *extack) { - bool validate_min, validate_max; - s64 value; + struct netlink_range_validation _range = { + .min = 0, + .max = U64_MAX, + }, *range = &_range; + u64 value; - validate_min = pt->validation_type == NLA_VALIDATE_RANGE || - pt->validation_type == NLA_VALIDATE_MIN; - validate_max = pt->validation_type == NLA_VALIDATE_RANGE || - pt->validation_type == NLA_VALIDATE_MAX; + WARN_ON_ONCE(pt->validation_type != NLA_VALIDATE_RANGE_PTR && + (pt->min < 0 || pt->max < 0)); + + switch (pt->validation_type) { + case NLA_VALIDATE_RANGE: + range->min = pt->min; + range->max = pt->max; + break; + case NLA_VALIDATE_RANGE_PTR: + range = pt->range; + break; + case NLA_VALIDATE_MIN: + range->min = pt->min; + break; + case NLA_VALIDATE_MAX: + range->max = pt->max; + break; + } switch (pt->type) { case NLA_U8: @@ -133,6 +150,49 @@ static int nla_validate_int_range(const struct nla_policy *pt, case NLA_U32: value = nla_get_u32(nla); break; + case NLA_U64: + value = nla_get_u64(nla); + break; + default: + return -EINVAL; + } + + if (value < range->min || value > range->max) { + NL_SET_ERR_MSG_ATTR(extack, nla, + "integer out of range"); + return -ERANGE; + } + + return 0; +} + +static int nla_validate_int_range_signed(const struct nla_policy *pt, + const struct nlattr *nla, + struct netlink_ext_ack *extack) +{ + struct netlink_range_validation_signed _range = { + .min = S64_MIN, + .max = S64_MAX, + }, *range = &_range; + s64 value; + + switch (pt->validation_type) { + case NLA_VALIDATE_RANGE: + range->min = pt->min; + range->max = pt->max; + break; + case NLA_VALIDATE_RANGE_PTR: + range = pt->range_signed; + break; + case NLA_VALIDATE_MIN: + range->min = pt->min; + break; + case NLA_VALIDATE_MAX: + range->max = pt->max; + break; + } + + switch (pt->type) { case NLA_S8: value = nla_get_s8(nla); break; @@ -145,22 +205,11 @@ static int nla_validate_int_range(const struct nla_policy *pt, case NLA_S64: value = nla_get_s64(nla); break; - case NLA_U64: - /* treat this one specially, since it may not fit into s64 */ - if ((validate_min && nla_get_u64(nla) < pt->min) || - (validate_max && nla_get_u64(nla) > pt->max)) { - NL_SET_ERR_MSG_ATTR(extack, nla, - "integer out of range"); - return -ERANGE; - } - return 0; default: - WARN_ON(1); return -EINVAL; } - if ((validate_min && value < pt->min) || - (validate_max && value > pt->max)) { + if (value < range->min || value > range->max) { NL_SET_ERR_MSG_ATTR(extack, nla, "integer out of range"); return -ERANGE; @@ -169,6 +218,27 @@ static int nla_validate_int_range(const struct nla_policy *pt, return 0; } +static int nla_validate_int_range(const struct nla_policy *pt, + const struct nlattr *nla, + struct netlink_ext_ack *extack) +{ + switch (pt->type) { + case NLA_U8: + case NLA_U16: + case NLA_U32: + case NLA_U64: + return nla_validate_int_range_unsigned(pt, nla, extack); + case NLA_S8: + case NLA_S16: + case NLA_S32: + case NLA_S64: + return nla_validate_int_range_signed(pt, nla, extack); + default: + WARN_ON(1); + return -EINVAL; + } +} + static int validate_nla(const struct nlattr *nla, int maxtype, const struct nla_policy *policy, unsigned int validate, struct netlink_ext_ack *extack, unsigned int depth) @@ -348,6 +418,7 @@ static int validate_nla(const struct nlattr *nla, int maxtype, case NLA_VALIDATE_NONE: /* nothing to do */ break; + case NLA_VALIDATE_RANGE_PTR: case NLA_VALIDATE_RANGE: case NLA_VALIDATE_MIN: case NLA_VALIDATE_MAX: -- 2.25.1