Store consumed bytes, instead of remaining bytes, this simplifies logic quite a bit. Cc: Chenbo Feng <fengc@xxxxxxxxxx> Cc: Maciej Żenczykowski <maze@xxxxxxxxxx> Signed-off-by: Pablo Neira Ayuso <pablo@xxxxxxxxxxxxx> --- Before merge window closes and it's too late to change semantics. include/uapi/linux/netfilter/xt_quota.h | 4 ++-- net/netfilter/xt_quota.c | 27 ++++++++------------------- 2 files changed, 10 insertions(+), 21 deletions(-) diff --git a/include/uapi/linux/netfilter/xt_quota.h b/include/uapi/linux/netfilter/xt_quota.h index d72fd52adbba..8abe7e6261c9 100644 --- a/include/uapi/linux/netfilter/xt_quota.h +++ b/include/uapi/linux/netfilter/xt_quota.h @@ -16,9 +16,9 @@ struct xt_quota_info { __u32 pad; __aligned_u64 quota; #ifdef __KERNEL__ - atomic64_t counter; + atomic64_t consumed; #else - __aligned_u64 remain; + __aligned_u64 consumed; #endif }; diff --git a/net/netfilter/xt_quota.c b/net/netfilter/xt_quota.c index fceae245eb03..83ce440e4b6e 100644 --- a/net/netfilter/xt_quota.c +++ b/net/netfilter/xt_quota.c @@ -17,27 +17,18 @@ MODULE_DESCRIPTION("Xtables: countdown quota match"); MODULE_ALIAS("ipt_quota"); MODULE_ALIAS("ip6t_quota"); +static inline bool xt_overquota(struct xt_quota_info *q, + const struct sk_buff *skb) +{ + return atomic64_add_return(skb->len, &q->consumed) >= q->quota; +} + static bool quota_mt(const struct sk_buff *skb, struct xt_action_param *par) { struct xt_quota_info *q = (void *)par->matchinfo; - u64 current_count = atomic64_read(&q->counter); - bool ret = q->flags & XT_QUOTA_INVERT; - u64 old_count, new_count; - do { - if (current_count == 1) - return ret; - if (current_count <= skb->len) { - atomic64_set(&q->counter, 1); - return ret; - } - old_count = current_count; - new_count = current_count - skb->len; - current_count = atomic64_cmpxchg(&q->counter, old_count, - new_count); - } while (current_count != old_count); - return !ret; + return xt_overquota(q, skb) ^ (q->flags & XT_QUOTA_INVERT); } static int quota_mt_check(const struct xt_mtchk_param *par) @@ -48,11 +39,9 @@ static int quota_mt_check(const struct xt_mtchk_param *par) if (q->flags & ~XT_QUOTA_MASK) return -EINVAL; - if (atomic64_read(&q->counter) > q->quota + 1) + if (atomic64_read(&q->consumed) > q->quota) return -ERANGE; - if (atomic64_read(&q->counter) == 0) - atomic64_set(&q->counter, q->quota + 1); return 0; } -- 2.11.0