[PATCH bpf-next v1 12/15] net/netfilter: Add bpf_ct_kptr_get helper

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



Require some more feedback on whether this is OK, before refactoring
netfilter functions to share code to increment reference and match the
tuple. Also probably need to work on allowing taking reference to struct
net * to save another lookup inside this function.

Signed-off-by: Kumar Kartikeya Dwivedi <memxor@xxxxxxxxx>
---
 include/net/netfilter/nf_conntrack_core.h |  17 +++
 net/netfilter/nf_conntrack_bpf.c          | 132 +++++++++++++++++-----
 net/netfilter/nf_conntrack_core.c         |  17 ---
 3 files changed, 119 insertions(+), 47 deletions(-)

diff --git a/include/net/netfilter/nf_conntrack_core.h b/include/net/netfilter/nf_conntrack_core.h
index 13807ea94cd2..09389769dce3 100644
--- a/include/net/netfilter/nf_conntrack_core.h
+++ b/include/net/netfilter/nf_conntrack_core.h
@@ -51,6 +51,23 @@ nf_conntrack_find_get(struct net *net,
 
 int __nf_conntrack_confirm(struct sk_buff *skb);
 
+static inline bool
+nf_ct_key_equal(struct nf_conntrack_tuple_hash *h,
+		const struct nf_conntrack_tuple *tuple,
+		const struct nf_conntrack_zone *zone,
+		const struct net *net)
+{
+	struct nf_conn *ct = nf_ct_tuplehash_to_ctrack(h);
+
+	/* A conntrack can be recreated with the equal tuple,
+	 * so we need to check that the conntrack is confirmed
+	 */
+	return nf_ct_tuple_equal(tuple, &h->tuple) &&
+	       nf_ct_zone_equal(ct, zone, NF_CT_DIRECTION(h)) &&
+	       nf_ct_is_confirmed(ct) &&
+	       net_eq(net, nf_ct_net(ct));
+}
+
 /* Confirm a connection: returns NF_DROP if packet must be dropped. */
 static inline int nf_conntrack_confirm(struct sk_buff *skb)
 {
diff --git a/net/netfilter/nf_conntrack_bpf.c b/net/netfilter/nf_conntrack_bpf.c
index 8ad3f52579f3..26211a5ec0c4 100644
--- a/net/netfilter/nf_conntrack_bpf.c
+++ b/net/netfilter/nf_conntrack_bpf.c
@@ -52,6 +52,30 @@ enum {
 	NF_BPF_CT_OPTS_SZ = 12,
 };
 
+static int bpf_fill_nf_tuple(struct nf_conntrack_tuple *tuple,
+			     struct bpf_sock_tuple *bpf_tuple, u32 tuple_len)
+{
+	switch (tuple_len) {
+	case sizeof(bpf_tuple->ipv4):
+		tuple->src.l3num = AF_INET;
+		tuple->src.u3.ip = bpf_tuple->ipv4.saddr;
+		tuple->src.u.tcp.port = bpf_tuple->ipv4.sport;
+		tuple->dst.u3.ip = bpf_tuple->ipv4.daddr;
+		tuple->dst.u.tcp.port = bpf_tuple->ipv4.dport;
+		break;
+	case sizeof(bpf_tuple->ipv6):
+		tuple->src.l3num = AF_INET6;
+		memcpy(tuple->src.u3.ip6, bpf_tuple->ipv6.saddr, sizeof(bpf_tuple->ipv6.saddr));
+		tuple->src.u.tcp.port = bpf_tuple->ipv6.sport;
+		memcpy(tuple->dst.u3.ip6, bpf_tuple->ipv6.daddr, sizeof(bpf_tuple->ipv6.daddr));
+		tuple->dst.u.tcp.port = bpf_tuple->ipv6.dport;
+		break;
+	default:
+		return -EAFNOSUPPORT;
+	}
+	return 0;
+}
+
 static struct nf_conn *__bpf_nf_ct_lookup(struct net *net,
 					  struct bpf_sock_tuple *bpf_tuple,
 					  u32 tuple_len, u8 protonum,
@@ -59,6 +83,7 @@ static struct nf_conn *__bpf_nf_ct_lookup(struct net *net,
 {
 	struct nf_conntrack_tuple_hash *hash;
 	struct nf_conntrack_tuple tuple;
+	int ret;
 
 	if (unlikely(protonum != IPPROTO_TCP && protonum != IPPROTO_UDP))
 		return ERR_PTR(-EPROTO);
@@ -66,25 +91,9 @@ static struct nf_conn *__bpf_nf_ct_lookup(struct net *net,
 		return ERR_PTR(-EINVAL);
 
 	memset(&tuple, 0, sizeof(tuple));
-	switch (tuple_len) {
-	case sizeof(bpf_tuple->ipv4):
-		tuple.src.l3num = AF_INET;
-		tuple.src.u3.ip = bpf_tuple->ipv4.saddr;
-		tuple.src.u.tcp.port = bpf_tuple->ipv4.sport;
-		tuple.dst.u3.ip = bpf_tuple->ipv4.daddr;
-		tuple.dst.u.tcp.port = bpf_tuple->ipv4.dport;
-		break;
-	case sizeof(bpf_tuple->ipv6):
-		tuple.src.l3num = AF_INET6;
-		memcpy(tuple.src.u3.ip6, bpf_tuple->ipv6.saddr, sizeof(bpf_tuple->ipv6.saddr));
-		tuple.src.u.tcp.port = bpf_tuple->ipv6.sport;
-		memcpy(tuple.dst.u3.ip6, bpf_tuple->ipv6.daddr, sizeof(bpf_tuple->ipv6.daddr));
-		tuple.dst.u.tcp.port = bpf_tuple->ipv6.dport;
-		break;
-	default:
-		return ERR_PTR(-EAFNOSUPPORT);
-	}
-
+	ret = bpf_fill_nf_tuple(&tuple, bpf_tuple, tuple_len);
+	if (ret < 0)
+		return ERR_PTR(ret);
 	tuple.dst.protonum = protonum;
 
 	if (netns_id >= 0) {
@@ -208,50 +217,113 @@ void bpf_ct_release(struct nf_conn *nfct)
 	nf_ct_put(nfct);
 }
 
+/* TODO: Just a PoC, need to reuse code in __nf_conntrack_find_get for this */
+struct nf_conn *bpf_ct_kptr_get(struct nf_conn **ptr, struct bpf_sock_tuple *bpf_tuple,
+				u32 tuple__sz, u8 protonum, u8 direction)
+{
+	struct nf_conntrack_tuple tuple;
+	struct nf_conn *nfct;
+	struct net *net;
+	u64 *nfct_p;
+	int ret;
+
+	WARN_ON_ONCE(!rcu_read_lock_held());
+
+	if ((protonum != IPPROTO_TCP && protonum != IPPROTO_UDP) ||
+	    (direction != IP_CT_DIR_ORIGINAL && direction != IP_CT_DIR_REPLY))
+		return NULL;
+
+	/* ptr is actually pointer to u64 having address, hence recast u64 load
+	 * to native pointer width.
+	 */
+	nfct_p = (u64 *)ptr;
+	nfct = (struct nf_conn *)READ_ONCE(*nfct_p);
+	if (!nfct || unlikely(!refcount_inc_not_zero(&nfct->ct_general.use)))
+		return NULL;
+
+	memset(&tuple, 0, sizeof(tuple));
+	ret = bpf_fill_nf_tuple(&tuple, bpf_tuple, tuple__sz);
+	if (ret < 0)
+		goto end;
+	tuple.dst.protonum = protonum;
+
+	/* XXX: Need to allow passing in struct net *, or take netns_id, this is non-sense */
+	net = nf_ct_net(nfct);
+	if (!nf_ct_key_equal(&nfct->tuplehash[direction], &tuple,
+			     &nf_ct_zone_dflt, nf_ct_net(nfct)))
+		goto end;
+	return nfct;
+end:
+	nf_ct_put(nfct);
+	return NULL;
+}
+
 __diag_pop()
 
 BTF_SET_START(nf_ct_xdp_check_kfunc_ids)
 BTF_ID(func, bpf_xdp_ct_lookup)
+BTF_ID(func, bpf_ct_kptr_get)
 BTF_ID(func, bpf_ct_release)
 BTF_SET_END(nf_ct_xdp_check_kfunc_ids)
 
 BTF_SET_START(nf_ct_tc_check_kfunc_ids)
 BTF_ID(func, bpf_skb_ct_lookup)
+BTF_ID(func, bpf_ct_kptr_get)
 BTF_ID(func, bpf_ct_release)
 BTF_SET_END(nf_ct_tc_check_kfunc_ids)
 
 BTF_SET_START(nf_ct_acquire_kfunc_ids)
 BTF_ID(func, bpf_xdp_ct_lookup)
 BTF_ID(func, bpf_skb_ct_lookup)
+BTF_ID(func, bpf_ct_kptr_get)
 BTF_SET_END(nf_ct_acquire_kfunc_ids)
 
 BTF_SET_START(nf_ct_release_kfunc_ids)
 BTF_ID(func, bpf_ct_release)
 BTF_SET_END(nf_ct_release_kfunc_ids)
 
+BTF_SET_START(nf_ct_kptr_acquire_kfunc_ids)
+BTF_ID(func, bpf_ct_kptr_get)
+BTF_SET_END(nf_ct_kptr_acquire_kfunc_ids)
+
 /* Both sets are identical */
 #define nf_ct_ret_null_kfunc_ids nf_ct_acquire_kfunc_ids
 
 static const struct btf_kfunc_id_set nf_conntrack_xdp_kfunc_set = {
-	.owner        = THIS_MODULE,
-	.check_set    = &nf_ct_xdp_check_kfunc_ids,
-	.acquire_set  = &nf_ct_acquire_kfunc_ids,
-	.release_set  = &nf_ct_release_kfunc_ids,
-	.ret_null_set = &nf_ct_ret_null_kfunc_ids,
+	.owner            = THIS_MODULE,
+	.check_set        = &nf_ct_xdp_check_kfunc_ids,
+	.acquire_set      = &nf_ct_acquire_kfunc_ids,
+	.release_set      = &nf_ct_release_kfunc_ids,
+	.ret_null_set     = &nf_ct_ret_null_kfunc_ids,
+	.kptr_acquire_set = &nf_ct_kptr_acquire_kfunc_ids,
 };
 
 static const struct btf_kfunc_id_set nf_conntrack_tc_kfunc_set = {
-	.owner        = THIS_MODULE,
-	.check_set    = &nf_ct_tc_check_kfunc_ids,
-	.acquire_set  = &nf_ct_acquire_kfunc_ids,
-	.release_set  = &nf_ct_release_kfunc_ids,
-	.ret_null_set = &nf_ct_ret_null_kfunc_ids,
+	.owner            = THIS_MODULE,
+	.check_set        = &nf_ct_tc_check_kfunc_ids,
+	.acquire_set      = &nf_ct_acquire_kfunc_ids,
+	.release_set      = &nf_ct_release_kfunc_ids,
+	.ret_null_set     = &nf_ct_ret_null_kfunc_ids,
+	.kptr_acquire_set = &nf_ct_kptr_acquire_kfunc_ids,
 };
 
+BTF_ID_LIST(nf_conntrack_dtor_kfunc_ids)
+BTF_ID(struct, nf_conn)
+BTF_ID(func, bpf_ct_release)
+
 int register_nf_conntrack_bpf(void)
 {
+	const struct btf_id_dtor_kfunc nf_conntrack_dtor_kfunc[] = {
+		{
+			.btf_id       = nf_conntrack_dtor_kfunc_ids[0],
+			.kfunc_btf_id = nf_conntrack_dtor_kfunc_ids[1],
+		}
+	};
 	int ret;
 
-	ret = register_btf_kfunc_id_set(BPF_PROG_TYPE_XDP, &nf_conntrack_xdp_kfunc_set);
+	ret = register_btf_id_dtor_kfuncs(nf_conntrack_dtor_kfunc,
+					  ARRAY_SIZE(nf_conntrack_dtor_kfunc),
+					  THIS_MODULE);
+	ret = ret ?: register_btf_kfunc_id_set(BPF_PROG_TYPE_XDP, &nf_conntrack_xdp_kfunc_set);
 	return ret ?: register_btf_kfunc_id_set(BPF_PROG_TYPE_SCHED_CLS, &nf_conntrack_tc_kfunc_set);
 }
diff --git a/net/netfilter/nf_conntrack_core.c b/net/netfilter/nf_conntrack_core.c
index 9b7f9c966f73..0aae98f60769 100644
--- a/net/netfilter/nf_conntrack_core.c
+++ b/net/netfilter/nf_conntrack_core.c
@@ -710,23 +710,6 @@ bool nf_ct_delete(struct nf_conn *ct, u32 portid, int report)
 }
 EXPORT_SYMBOL_GPL(nf_ct_delete);
 
-static inline bool
-nf_ct_key_equal(struct nf_conntrack_tuple_hash *h,
-		const struct nf_conntrack_tuple *tuple,
-		const struct nf_conntrack_zone *zone,
-		const struct net *net)
-{
-	struct nf_conn *ct = nf_ct_tuplehash_to_ctrack(h);
-
-	/* A conntrack can be recreated with the equal tuple,
-	 * so we need to check that the conntrack is confirmed
-	 */
-	return nf_ct_tuple_equal(tuple, &h->tuple) &&
-	       nf_ct_zone_equal(ct, zone, NF_CT_DIRECTION(h)) &&
-	       nf_ct_is_confirmed(ct) &&
-	       net_eq(net, nf_ct_net(ct));
-}
-
 static inline bool
 nf_ct_match(const struct nf_conn *ct1, const struct nf_conn *ct2)
 {
-- 
2.35.1




[Index of Archives]     [Netfitler Users]     [Berkeley Packet Filter]     [LARTC]     [Bugtraq]     [Yosemite Forum]

  Powered by Linux