This patch allows for two statements for dynamic set updates, e.g. nft rule x y add @y { ip daddr limit rate 1/second counter } Signed-off-by: Pablo Neira Ayuso <pablo@xxxxxxxxxxxxx> --- include/statement.h | 4 +-- src/evaluate.c | 24 +++++++------ src/netlink.c | 1 + src/netlink_delinearize.c | 74 +++++++++++++++++++++++++++++++-------- src/netlink_linearize.c | 41 +++++++++++++++++----- src/parser_bison.y | 25 +++++++++---- src/statement.c | 34 +++++++++++++----- 7 files changed, 155 insertions(+), 48 deletions(-) diff --git a/include/statement.h b/include/statement.h index f2fc6ade7734..7637a82e4e00 100644 --- a/include/statement.h +++ b/include/statement.h @@ -201,7 +201,7 @@ uint32_t fwd_stmt_type(const char *type); struct set_stmt { struct expr *set; struct expr *key; - struct stmt *stmt; + struct list_head stmt_list; enum nft_dynset_ops op; }; @@ -213,7 +213,7 @@ struct map_stmt { struct expr *set; struct expr *key; struct expr *data; - struct stmt *stmt; + struct list_head stmt_list; enum nft_dynset_ops op; }; diff --git a/src/evaluate.c b/src/evaluate.c index e776cd018051..03f060eb465a 100644 --- a/src/evaluate.c +++ b/src/evaluate.c @@ -3370,6 +3370,8 @@ static int stmt_evaluate_log(struct eval_ctx *ctx, struct stmt *stmt) static int stmt_evaluate_set(struct eval_ctx *ctx, struct stmt *stmt) { + struct stmt *this; + expr_set_context(&ctx->ectx, NULL, 0); if (expr_evaluate(ctx, &stmt->set.set) < 0) return -1; @@ -3389,12 +3391,12 @@ static int stmt_evaluate_set(struct eval_ctx *ctx, struct stmt *stmt) if (stmt->set.key->comment != NULL) return expr_error(ctx->msgs, stmt->set.key, "Key expression comments are not supported"); - if (stmt->set.stmt) { - if (stmt_evaluate(ctx, stmt->set.stmt) < 0) + list_for_each_entry(this, &stmt->set.stmt_list, list) { + if (stmt_evaluate(ctx, this) < 0) return -1; - if (!(stmt->set.stmt->flags & STMT_F_STATEFUL)) - return stmt_binary_error(ctx, stmt->set.stmt, stmt, - "meter statement must be stateful"); + if (!(this->flags & STMT_F_STATEFUL)) + return stmt_error(ctx, this, + "statement must be stateful"); } return 0; @@ -3402,6 +3404,8 @@ static int stmt_evaluate_set(struct eval_ctx *ctx, struct stmt *stmt) static int stmt_evaluate_map(struct eval_ctx *ctx, struct stmt *stmt) { + struct stmt *this; + expr_set_context(&ctx->ectx, NULL, 0); if (expr_evaluate(ctx, &stmt->map.set) < 0) return -1; @@ -3435,12 +3439,12 @@ static int stmt_evaluate_map(struct eval_ctx *ctx, struct stmt *stmt) return expr_error(ctx->msgs, stmt->map.data, "Data expression comments are not supported"); - if (stmt->map.stmt) { - if (stmt_evaluate(ctx, stmt->map.stmt) < 0) + list_for_each_entry(this, &stmt->map.stmt_list, list) { + if (stmt_evaluate(ctx, this) < 0) return -1; - if (!(stmt->map.stmt->flags & STMT_F_STATEFUL)) - return stmt_binary_error(ctx, stmt->map.stmt, stmt, - "meter statement must be stateful"); + if (!(this->flags & STMT_F_STATEFUL)) + return stmt_error(ctx, this, + "statement must be stateful"); } return 0; diff --git a/src/netlink.c b/src/netlink.c index 8098b9746c95..ab0290926eaf 100644 --- a/src/netlink.c +++ b/src/netlink.c @@ -1132,6 +1132,7 @@ key_end: key = bitmask_expr_to_binops(key); expr = set_elem_expr_alloc(&netlink_location, key); + if (nftnl_set_elem_is_set(nlse, NFTNL_SET_ELEM_TIMEOUT)) expr->timeout = nftnl_set_elem_get_u64(nlse, NFTNL_SET_ELEM_TIMEOUT); if (nftnl_set_elem_is_set(nlse, NFTNL_SET_ELEM_EXPIRATION)) diff --git a/src/netlink_delinearize.c b/src/netlink_delinearize.c index 8b06c4c0985f..731507228411 100644 --- a/src/netlink_delinearize.c +++ b/src/netlink_delinearize.c @@ -1489,17 +1489,47 @@ static void netlink_parse_queue(struct netlink_parse_ctx *ctx, ctx->stmt = stmt; } +struct dynset_parse_ctx { + struct netlink_parse_ctx *nlctx; + const struct location *loc; + struct list_head stmt_list; +}; + +static int dynset_parse_expressions(struct nftnl_expr *e, void *data) +{ + struct dynset_parse_ctx *dynset_parse_ctx = data; + struct netlink_parse_ctx *ctx = dynset_parse_ctx->nlctx; + const struct location *loc = dynset_parse_ctx->loc; + struct stmt *stmt; + + if (netlink_parse_expr(e, ctx) < 0 || !ctx->stmt) { + netlink_error(ctx, loc, "Could not parse dynset stmt"); + return -1; + } + stmt = ctx->stmt; + + list_add_tail(&stmt->list, &dynset_parse_ctx->stmt_list); + + return 0; +} + static void netlink_parse_dynset(struct netlink_parse_ctx *ctx, const struct location *loc, const struct nftnl_expr *nle) { + struct dynset_parse_ctx dynset_parse_ctx = { + .nlctx = ctx, + .loc = loc, + }; struct expr *expr, *expr_data = NULL; enum nft_registers sreg, sreg_data; + struct stmt *stmt, *dstmt, *next; const struct nftnl_expr *dnle; - struct stmt *stmt, *dstmt; struct set *set; const char *name; + init_list_head(&dynset_parse_ctx.stmt_list); + name = nftnl_expr_get_str(nle, NFTNL_EXPR_DYNSET_SET_NAME); set = set_lookup(ctx->table, name); if (set == NULL) @@ -1523,16 +1553,25 @@ static void netlink_parse_dynset(struct netlink_parse_ctx *ctx, expr = set_elem_expr_alloc(&expr->location, expr); expr->timeout = nftnl_expr_get_u64(nle, NFTNL_EXPR_DYNSET_TIMEOUT); - dstmt = NULL; - dnle = nftnl_expr_get(nle, NFTNL_EXPR_DYNSET_EXPR, NULL); - if (dnle != NULL) { - if (netlink_parse_expr(dnle, ctx) < 0) - goto out_err; - if (ctx->stmt == NULL) { - netlink_error(ctx, loc, "Could not parse dynset stmt"); - goto out_err; + if (nftnl_expr_is_set(nle, NFTNL_EXPR_DYNSET_EXPR)) { + dstmt = NULL; + dnle = nftnl_expr_get(nle, NFTNL_EXPR_DYNSET_EXPR, NULL); + if (dnle != NULL) { + if (netlink_parse_expr(dnle, ctx) < 0) + goto out_err; + if (ctx->stmt == NULL) { + netlink_error(ctx, loc, + "Could not parse dynset stmt"); + goto out_err; + } + dstmt = ctx->stmt; + list_add_tail(&dstmt->list, + &dynset_parse_ctx.stmt_list); } - dstmt = ctx->stmt; + } else if (nftnl_expr_is_set(nle, NFTNL_EXPR_DYNSET_EXPRESSIONS)) { + if (nftnl_expr_expr_foreach(nle, dynset_parse_expressions, + &dynset_parse_ctx) < 0) + goto out_err; } if (nftnl_expr_is_set(nle, NFTNL_EXPR_DYNSET_SREG_DATA)) { @@ -1546,27 +1585,34 @@ static void netlink_parse_dynset(struct netlink_parse_ctx *ctx, stmt->map.set = set_ref_expr_alloc(loc, set); stmt->map.key = expr; stmt->map.data = expr_data; - stmt->map.stmt = dstmt; stmt->map.op = nftnl_expr_get_u32(nle, NFTNL_EXPR_DYNSET_OP); + list_splice_tail(&dynset_parse_ctx.stmt_list, + &stmt->map.stmt_list); } else { - if (dstmt != NULL && set->flags & NFT_SET_ANONYMOUS) { + if (!list_empty(&dynset_parse_ctx.stmt_list) && + set->flags & NFT_SET_ANONYMOUS) { stmt = meter_stmt_alloc(loc); stmt->meter.set = set_ref_expr_alloc(loc, set); stmt->meter.key = expr; - stmt->meter.stmt = dstmt; + stmt->meter.stmt = list_first_entry(&dynset_parse_ctx.stmt_list, + struct stmt, list); stmt->meter.size = set->desc.size; } else { stmt = set_stmt_alloc(loc); stmt->set.set = set_ref_expr_alloc(loc, set); stmt->set.op = nftnl_expr_get_u32(nle, NFTNL_EXPR_DYNSET_OP); stmt->set.key = expr; - stmt->set.stmt = dstmt; + list_splice_tail(&dynset_parse_ctx.stmt_list, + &stmt->set.stmt_list); } } ctx->stmt = stmt; return; out_err: + list_for_each_entry_safe(dstmt, next, &dynset_parse_ctx.stmt_list, list) + stmt_free(dstmt); + xfree(expr); } diff --git a/src/netlink_linearize.c b/src/netlink_linearize.c index 05af8bb1b485..09d0c61cfcc0 100644 --- a/src/netlink_linearize.c +++ b/src/netlink_linearize.c @@ -1397,8 +1397,10 @@ static void netlink_gen_set_stmt(struct netlink_linearize_ctx *ctx, const struct stmt *stmt) { struct set *set = stmt->meter.set->set; - struct nftnl_expr *nle; enum nft_registers sreg_key; + struct nftnl_expr *nle; + int num_stmts = 0; + struct stmt *this; sreg_key = get_register(ctx, stmt->set.key->key); netlink_gen_expr(ctx, stmt->set.key->key, sreg_key); @@ -1414,9 +1416,20 @@ static void netlink_gen_set_stmt(struct netlink_linearize_ctx *ctx, nftnl_expr_set_u32(nle, NFTNL_EXPR_DYNSET_SET_ID, set->handle.set_id); nft_rule_add_expr(ctx, nle, &stmt->location); - if (stmt->set.stmt) - nftnl_expr_set(nle, NFTNL_EXPR_DYNSET_EXPR, - netlink_gen_stmt_stateful(stmt->set.stmt), 0); + list_for_each_entry(this, &stmt->set.stmt_list, list) + num_stmts++; + + if (num_stmts == 1) { + list_for_each_entry(this, &stmt->set.stmt_list, list) { + nftnl_expr_set(nle, NFTNL_EXPR_DYNSET_EXPR, + netlink_gen_stmt_stateful(this), 0); + } + } else if (num_stmts > 1) { + list_for_each_entry(this, &stmt->set.stmt_list, list) { + nftnl_expr_add_expr(nle, NFTNL_EXPR_DYNSET_EXPRESSIONS, + netlink_gen_stmt_stateful(this)); + } + } } static void netlink_gen_map_stmt(struct netlink_linearize_ctx *ctx, @@ -1426,6 +1439,8 @@ static void netlink_gen_map_stmt(struct netlink_linearize_ctx *ctx, enum nft_registers sreg_data; enum nft_registers sreg_key; struct nftnl_expr *nle; + int num_stmts = 0; + struct stmt *this; sreg_key = get_register(ctx, stmt->map.key); netlink_gen_expr(ctx, stmt->map.key, sreg_key); @@ -1443,12 +1458,22 @@ static void netlink_gen_map_stmt(struct netlink_linearize_ctx *ctx, nftnl_expr_set_u32(nle, NFTNL_EXPR_DYNSET_OP, stmt->map.op); nftnl_expr_set_str(nle, NFTNL_EXPR_DYNSET_SET_NAME, set->handle.set.name); nftnl_expr_set_u32(nle, NFTNL_EXPR_DYNSET_SET_ID, set->handle.set_id); + nft_rule_add_expr(ctx, nle, &stmt->location); - if (stmt->map.stmt) - nftnl_expr_set(nle, NFTNL_EXPR_DYNSET_EXPR, - netlink_gen_stmt_stateful(stmt->map.stmt), 0); + list_for_each_entry(this, &stmt->map.stmt_list, list) + num_stmts++; - nft_rule_add_expr(ctx, nle, &stmt->location); + if (num_stmts == 1) { + list_for_each_entry(this, &stmt->map.stmt_list, list) { + nftnl_expr_set(nle, NFTNL_EXPR_DYNSET_EXPR, + netlink_gen_stmt_stateful(this), 0); + } + } else if (num_stmts > 1) { + list_for_each_entry(this, &stmt->map.stmt_list, list) { + nftnl_expr_add_expr(nle, NFTNL_EXPR_DYNSET_EXPRESSIONS, + netlink_gen_stmt_stateful(this)); + } + } } static void netlink_gen_meter_stmt(struct netlink_linearize_ctx *ctx, diff --git a/src/parser_bison.y b/src/parser_bison.y index 08aadaa32a86..2582ca1d3a0c 100644 --- a/src/parser_bison.y +++ b/src/parser_bison.y @@ -624,8 +624,8 @@ int nft_lex(void *, void *, void *); %type <obj> obj_block_alloc counter_block quota_block ct_helper_block ct_timeout_block ct_expect_block limit_block secmark_block synproxy_block %destructor { obj_free($$); } obj_block_alloc -%type <list> stmt_list -%destructor { stmt_list_free($$); xfree($$); } stmt_list +%type <list> stmt_list stateful_stmt_list +%destructor { stmt_list_free($$); xfree($$); } stmt_list stateful_stmt_list %type <stmt> stmt match_stmt verdict_stmt %destructor { stmt_free($$); } stmt match_stmt verdict_stmt %type <stmt> counter_stmt counter_stmt_alloc stateful_stmt @@ -2656,6 +2656,19 @@ stmt_list : stmt } ; +stateful_stmt_list : stateful_stmt + { + $$ = xmalloc(sizeof(*$$)); + init_list_head($$); + list_add_tail(&$1->list, $$); + } + | stateful_stmt_list stateful_stmt + { + $$ = $1; + list_add_tail(&$2->list, $1); + } + ; + stateful_stmt : counter_stmt | limit_stmt | quota_stmt @@ -3675,13 +3688,13 @@ set_stmt : SET set_stmt_op set_elem_expr_stmt set_ref_expr $$->set.key = $4; $$->set.set = $2; } - | set_stmt_op set_ref_expr '{' set_elem_expr_stmt stateful_stmt '}' + | set_stmt_op set_ref_expr '{' set_elem_expr_stmt stateful_stmt_list '}' { $$ = set_stmt_alloc(&@$); $$->set.op = $1; $$->set.key = $4; $$->set.set = $2; - $$->set.stmt = $5; + list_splice_tail($5, &$$->set.stmt_list); } ; @@ -3698,14 +3711,14 @@ map_stmt : set_stmt_op set_ref_expr '{' set_elem_expr_stmt COLON set_elem_expr_ $$->map.data = $6; $$->map.set = $2; } - | set_stmt_op set_ref_expr '{' set_elem_expr_stmt stateful_stmt COLON set_elem_expr_stmt '}' + | set_stmt_op set_ref_expr '{' set_elem_expr_stmt stateful_stmt_list COLON set_elem_expr_stmt '}' { $$ = map_stmt_alloc(&@$); $$->map.op = $1; $$->map.key = $4; $$->map.data = $7; - $$->map.stmt = $5; $$->map.set = $2; + list_splice_tail($5, &$$->map.stmt_list); } ; diff --git a/src/statement.c b/src/statement.c index 6fe8e9d9beb4..39020857ae9c 100644 --- a/src/statement.c +++ b/src/statement.c @@ -732,15 +732,16 @@ const char * const set_stmt_op_names[] = { static void set_stmt_print(const struct stmt *stmt, struct output_ctx *octx) { unsigned int flags = octx->flags; + struct stmt *this; nft_print(octx, "%s ", set_stmt_op_names[stmt->set.op]); expr_print(stmt->set.set, octx); nft_print(octx, " { "); expr_print(stmt->set.key, octx); - if (stmt->set.stmt) { + list_for_each_entry(this, &stmt->set.stmt_list, list) { nft_print(octx, " "); octx->flags |= NFT_CTX_OUTPUT_STATELESS; - stmt_print(stmt->set.stmt, octx); + stmt_print(this, octx); octx->flags = flags; } nft_print(octx, " }"); @@ -748,9 +749,12 @@ static void set_stmt_print(const struct stmt *stmt, struct output_ctx *octx) static void set_stmt_destroy(struct stmt *stmt) { + struct stmt *this, *next; + expr_free(stmt->set.key); expr_free(stmt->set.set); - stmt_free(stmt->set.stmt); + list_for_each_entry_safe(this, next, &stmt->set.stmt_list, list) + stmt_free(this); } static const struct stmt_ops set_stmt_ops = { @@ -763,21 +767,27 @@ static const struct stmt_ops set_stmt_ops = { struct stmt *set_stmt_alloc(const struct location *loc) { - return stmt_alloc(loc, &set_stmt_ops); + struct stmt *stmt; + + stmt = stmt_alloc(loc, &set_stmt_ops); + init_list_head(&stmt->set.stmt_list); + + return stmt; } static void map_stmt_print(const struct stmt *stmt, struct output_ctx *octx) { unsigned int flags = octx->flags; + struct stmt *this; nft_print(octx, "%s ", set_stmt_op_names[stmt->map.op]); expr_print(stmt->map.set, octx); nft_print(octx, " { "); expr_print(stmt->map.key, octx); - if (stmt->map.stmt) { + list_for_each_entry(this, &stmt->map.stmt_list, list) { nft_print(octx, " "); octx->flags |= NFT_CTX_OUTPUT_STATELESS; - stmt_print(stmt->map.stmt, octx); + stmt_print(this, octx); octx->flags = flags; } nft_print(octx, " : "); @@ -787,10 +797,13 @@ static void map_stmt_print(const struct stmt *stmt, struct output_ctx *octx) static void map_stmt_destroy(struct stmt *stmt) { + struct stmt *this, *next; + expr_free(stmt->map.key); expr_free(stmt->map.data); expr_free(stmt->map.set); - stmt_free(stmt->map.stmt); + list_for_each_entry_safe(this, next, &stmt->map.stmt_list, list) + stmt_free(this); } static const struct stmt_ops map_stmt_ops = { @@ -802,7 +815,12 @@ static const struct stmt_ops map_stmt_ops = { struct stmt *map_stmt_alloc(const struct location *loc) { - return stmt_alloc(loc, &map_stmt_ops); + struct stmt *stmt; + + stmt = stmt_alloc(loc, &map_stmt_ops); + init_list_head(&stmt->map.stmt_list); + + return stmt; } static void dup_stmt_print(const struct stmt *stmt, struct output_ctx *octx) -- 2.20.1