[PATCH 6.1 080/313] l2tp: convert l2tp_tunnel_list to idr

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



From: Cong Wang <cong.wang@xxxxxxxxxxxxx>

[ Upstream commit c4d48a58f32c5972174a1d01c33b296fe378cce0 ]

l2tp uses l2tp_tunnel_list to track all registered tunnels and
to allocate tunnel ID's. IDR can do the same job.

More importantly, with IDR we can hold the ID before a successful
registration so that we don't need to worry about late error
handling, it is not easy to rollback socket changes.

This is a preparation for the following fix.

Cc: Tetsuo Handa <penguin-kernel@xxxxxxxxxxxxxxxxxxx>
Cc: Guillaume Nault <gnault@xxxxxxxxxx>
Cc: Jakub Sitnicki <jakub@xxxxxxxxxxxxxx>
Cc: Eric Dumazet <edumazet@xxxxxxxxxx>
Cc: Tom Parkin <tparkin@xxxxxxxxxxx>
Signed-off-by: Cong Wang <cong.wang@xxxxxxxxxxxxx>
Reviewed-by: Guillaume Nault <gnault@xxxxxxxxxx>
Signed-off-by: David S. Miller <davem@xxxxxxxxxxxxx>
Stable-dep-of: 0b2c59720e65 ("l2tp: close all race conditions in l2tp_tunnel_register()")
Signed-off-by: Sasha Levin <sashal@xxxxxxxxxx>
---
 net/l2tp/l2tp_core.c | 85 ++++++++++++++++++++++----------------------
 1 file changed, 42 insertions(+), 43 deletions(-)

diff --git a/net/l2tp/l2tp_core.c b/net/l2tp/l2tp_core.c
index 9a1415fe3fa7..e9c0ce0b7972 100644
--- a/net/l2tp/l2tp_core.c
+++ b/net/l2tp/l2tp_core.c
@@ -104,9 +104,9 @@ static struct workqueue_struct *l2tp_wq;
 /* per-net private data for this module */
 static unsigned int l2tp_net_id;
 struct l2tp_net {
-	struct list_head l2tp_tunnel_list;
-	/* Lock for write access to l2tp_tunnel_list */
-	spinlock_t l2tp_tunnel_list_lock;
+	/* Lock for write access to l2tp_tunnel_idr */
+	spinlock_t l2tp_tunnel_idr_lock;
+	struct idr l2tp_tunnel_idr;
 	struct hlist_head l2tp_session_hlist[L2TP_HASH_SIZE_2];
 	/* Lock for write access to l2tp_session_hlist */
 	spinlock_t l2tp_session_hlist_lock;
@@ -208,13 +208,10 @@ struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id)
 	struct l2tp_tunnel *tunnel;
 
 	rcu_read_lock_bh();
-	list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) {
-		if (tunnel->tunnel_id == tunnel_id &&
-		    refcount_inc_not_zero(&tunnel->ref_count)) {
-			rcu_read_unlock_bh();
-
-			return tunnel;
-		}
+	tunnel = idr_find(&pn->l2tp_tunnel_idr, tunnel_id);
+	if (tunnel && refcount_inc_not_zero(&tunnel->ref_count)) {
+		rcu_read_unlock_bh();
+		return tunnel;
 	}
 	rcu_read_unlock_bh();
 
@@ -224,13 +221,14 @@ EXPORT_SYMBOL_GPL(l2tp_tunnel_get);
 
 struct l2tp_tunnel *l2tp_tunnel_get_nth(const struct net *net, int nth)
 {
-	const struct l2tp_net *pn = l2tp_pernet(net);
+	struct l2tp_net *pn = l2tp_pernet(net);
+	unsigned long tunnel_id, tmp;
 	struct l2tp_tunnel *tunnel;
 	int count = 0;
 
 	rcu_read_lock_bh();
-	list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) {
-		if (++count > nth &&
+	idr_for_each_entry_ul(&pn->l2tp_tunnel_idr, tunnel, tmp, tunnel_id) {
+		if (tunnel && ++count > nth &&
 		    refcount_inc_not_zero(&tunnel->ref_count)) {
 			rcu_read_unlock_bh();
 			return tunnel;
@@ -1227,6 +1225,15 @@ static void l2tp_udp_encap_destroy(struct sock *sk)
 		l2tp_tunnel_delete(tunnel);
 }
 
+static void l2tp_tunnel_remove(struct net *net, struct l2tp_tunnel *tunnel)
+{
+	struct l2tp_net *pn = l2tp_pernet(net);
+
+	spin_lock_bh(&pn->l2tp_tunnel_idr_lock);
+	idr_remove(&pn->l2tp_tunnel_idr, tunnel->tunnel_id);
+	spin_unlock_bh(&pn->l2tp_tunnel_idr_lock);
+}
+
 /* Workqueue tunnel deletion function */
 static void l2tp_tunnel_del_work(struct work_struct *work)
 {
@@ -1234,7 +1241,6 @@ static void l2tp_tunnel_del_work(struct work_struct *work)
 						  del_work);
 	struct sock *sk = tunnel->sock;
 	struct socket *sock = sk->sk_socket;
-	struct l2tp_net *pn;
 
 	l2tp_tunnel_closeall(tunnel);
 
@@ -1248,12 +1254,7 @@ static void l2tp_tunnel_del_work(struct work_struct *work)
 		}
 	}
 
-	/* Remove the tunnel struct from the tunnel list */
-	pn = l2tp_pernet(tunnel->l2tp_net);
-	spin_lock_bh(&pn->l2tp_tunnel_list_lock);
-	list_del_rcu(&tunnel->list);
-	spin_unlock_bh(&pn->l2tp_tunnel_list_lock);
-
+	l2tp_tunnel_remove(tunnel->l2tp_net, tunnel);
 	/* drop initial ref */
 	l2tp_tunnel_dec_refcount(tunnel);
 
@@ -1455,12 +1456,19 @@ static int l2tp_validate_socket(const struct sock *sk, const struct net *net,
 int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net,
 			 struct l2tp_tunnel_cfg *cfg)
 {
-	struct l2tp_tunnel *tunnel_walk;
-	struct l2tp_net *pn;
+	struct l2tp_net *pn = l2tp_pernet(net);
+	u32 tunnel_id = tunnel->tunnel_id;
 	struct socket *sock;
 	struct sock *sk;
 	int ret;
 
+	spin_lock_bh(&pn->l2tp_tunnel_idr_lock);
+	ret = idr_alloc_u32(&pn->l2tp_tunnel_idr, NULL, &tunnel_id, tunnel_id,
+			    GFP_ATOMIC);
+	spin_unlock_bh(&pn->l2tp_tunnel_idr_lock);
+	if (ret)
+		return ret == -ENOSPC ? -EEXIST : ret;
+
 	if (tunnel->fd < 0) {
 		ret = l2tp_tunnel_sock_create(net, tunnel->tunnel_id,
 					      tunnel->peer_tunnel_id, cfg,
@@ -1481,23 +1489,13 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net,
 	rcu_assign_sk_user_data(sk, tunnel);
 	write_unlock_bh(&sk->sk_callback_lock);
 
-	tunnel->l2tp_net = net;
-	pn = l2tp_pernet(net);
-
 	sock_hold(sk);
 	tunnel->sock = sk;
+	tunnel->l2tp_net = net;
 
-	spin_lock_bh(&pn->l2tp_tunnel_list_lock);
-	list_for_each_entry(tunnel_walk, &pn->l2tp_tunnel_list, list) {
-		if (tunnel_walk->tunnel_id == tunnel->tunnel_id) {
-			spin_unlock_bh(&pn->l2tp_tunnel_list_lock);
-			sock_put(sk);
-			ret = -EEXIST;
-			goto err_sock;
-		}
-	}
-	list_add_rcu(&tunnel->list, &pn->l2tp_tunnel_list);
-	spin_unlock_bh(&pn->l2tp_tunnel_list_lock);
+	spin_lock_bh(&pn->l2tp_tunnel_idr_lock);
+	idr_replace(&pn->l2tp_tunnel_idr, tunnel, tunnel->tunnel_id);
+	spin_unlock_bh(&pn->l2tp_tunnel_idr_lock);
 
 	if (tunnel->encap == L2TP_ENCAPTYPE_UDP) {
 		struct udp_tunnel_sock_cfg udp_cfg = {
@@ -1523,9 +1521,6 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net,
 
 	return 0;
 
-err_sock:
-	write_lock_bh(&sk->sk_callback_lock);
-	rcu_assign_sk_user_data(sk, NULL);
 err_inval_sock:
 	write_unlock_bh(&sk->sk_callback_lock);
 
@@ -1534,6 +1529,7 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net,
 	else
 		sockfd_put(sock);
 err:
+	l2tp_tunnel_remove(net, tunnel);
 	return ret;
 }
 EXPORT_SYMBOL_GPL(l2tp_tunnel_register);
@@ -1647,8 +1643,8 @@ static __net_init int l2tp_init_net(struct net *net)
 	struct l2tp_net *pn = net_generic(net, l2tp_net_id);
 	int hash;
 
-	INIT_LIST_HEAD(&pn->l2tp_tunnel_list);
-	spin_lock_init(&pn->l2tp_tunnel_list_lock);
+	idr_init(&pn->l2tp_tunnel_idr);
+	spin_lock_init(&pn->l2tp_tunnel_idr_lock);
 
 	for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++)
 		INIT_HLIST_HEAD(&pn->l2tp_session_hlist[hash]);
@@ -1662,11 +1658,13 @@ static __net_exit void l2tp_exit_net(struct net *net)
 {
 	struct l2tp_net *pn = l2tp_pernet(net);
 	struct l2tp_tunnel *tunnel = NULL;
+	unsigned long tunnel_id, tmp;
 	int hash;
 
 	rcu_read_lock_bh();
-	list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) {
-		l2tp_tunnel_delete(tunnel);
+	idr_for_each_entry_ul(&pn->l2tp_tunnel_idr, tunnel, tmp, tunnel_id) {
+		if (tunnel)
+			l2tp_tunnel_delete(tunnel);
 	}
 	rcu_read_unlock_bh();
 
@@ -1676,6 +1674,7 @@ static __net_exit void l2tp_exit_net(struct net *net)
 
 	for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++)
 		WARN_ON_ONCE(!hlist_empty(&pn->l2tp_session_hlist[hash]));
+	idr_destroy(&pn->l2tp_tunnel_idr);
 }
 
 static struct pernet_operations l2tp_net_ops = {
-- 
2.39.0






[Index of Archives]     [Linux Kernel]     [Kernel Development Newbies]     [Linux USB Devel]     [Video for Linux]     [Linux Audio Users]     [Yosemite Hiking]     [Linux Kernel]     [Linux SCSI]

  Powered by Linux