4.11-stable review patch. If anyone has any objections, please let me know. ------------------ From: Paul Moore <paul@xxxxxxxxxxxxxx> commit 48d0e023af9799cd7220335baf8e3ba61eeafbeb upstream. Cong Wang correctly pointed out that the RCU read locking of the auditd_connection struct was wrong, this patch correct this by adopting a more traditional, and correct RCU locking model. This patch is heavily based on an earlier prototype by Cong Wang. Reported-by: Cong Wang <xiyou.wangcong@xxxxxxxxx> Signed-off-by: Cong Wang <xiyou.wangcong@xxxxxxxxx> Signed-off-by: Paul Moore <paul@xxxxxxxxxxxxxx> Signed-off-by: Greg Kroah-Hartman <gregkh@xxxxxxxxxxxxxxxxxxx> --- kernel/audit.c | 167 +++++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 115 insertions(+), 52 deletions(-) --- a/kernel/audit.c +++ b/kernel/audit.c @@ -110,18 +110,19 @@ struct audit_net { * @pid: auditd PID * @portid: netlink portid * @net: the associated network namespace - * @lock: spinlock to protect write access + * @rcu: RCU head * * Description: * This struct is RCU protected; you must either hold the RCU lock for reading - * or the included spinlock for writing. + * or the associated spinlock for writing. */ static struct auditd_connection { int pid; u32 portid; struct net *net; - spinlock_t lock; -} auditd_conn; + struct rcu_head rcu; +} *auditd_conn = NULL; +static DEFINE_SPINLOCK(auditd_conn_lock); /* If audit_rate_limit is non-zero, limit the rate of sending audit records * to that number per second. This prevents DoS attacks, but results in @@ -223,15 +224,39 @@ struct audit_reply { int auditd_test_task(const struct task_struct *task) { int rc; + struct auditd_connection *ac; rcu_read_lock(); - rc = (auditd_conn.pid && task->tgid == auditd_conn.pid ? 1 : 0); + ac = rcu_dereference(auditd_conn); + rc = (ac && ac->pid == task->tgid ? 1 : 0); rcu_read_unlock(); return rc; } /** + * auditd_pid_vnr - Return the auditd PID relative to the namespace + * + * Description: + * Returns the PID in relation to the namespace, 0 on failure. + */ +static pid_t auditd_pid_vnr(void) +{ + pid_t pid; + const struct auditd_connection *ac; + + rcu_read_lock(); + ac = rcu_dereference(auditd_conn); + if (!ac) + pid = 0; + else + pid = ac->pid; + rcu_read_unlock(); + + return pid; +} + +/** * audit_get_sk - Return the audit socket for the given network namespace * @net: the destination network namespace * @@ -427,6 +452,23 @@ static int audit_set_failure(u32 state) } /** + * auditd_conn_free - RCU helper to release an auditd connection struct + * @rcu: RCU head + * + * Description: + * Drop any references inside the auditd connection tracking struct and free + * the memory. + */ +static void auditd_conn_free(struct rcu_head *rcu) +{ + struct auditd_connection *ac; + + ac = container_of(rcu, struct auditd_connection, rcu); + put_net(ac->net); + kfree(ac); +} + +/** * auditd_set - Set/Reset the auditd connection state * @pid: auditd PID * @portid: auditd netlink portid @@ -434,22 +476,33 @@ static int audit_set_failure(u32 state) * * Description: * This function will obtain and drop network namespace references as - * necessary. + * necessary. Returns zero on success, negative values on failure. */ -static void auditd_set(int pid, u32 portid, struct net *net) +static int auditd_set(int pid, u32 portid, struct net *net) { unsigned long flags; + struct auditd_connection *ac_old, *ac_new; - spin_lock_irqsave(&auditd_conn.lock, flags); - auditd_conn.pid = pid; - auditd_conn.portid = portid; - if (auditd_conn.net) - put_net(auditd_conn.net); - if (net) - auditd_conn.net = get_net(net); - else - auditd_conn.net = NULL; - spin_unlock_irqrestore(&auditd_conn.lock, flags); + if (!pid || !net) + return -EINVAL; + + ac_new = kzalloc(sizeof(*ac_new), GFP_KERNEL); + if (!ac_new) + return -ENOMEM; + ac_new->pid = pid; + ac_new->portid = portid; + ac_new->net = get_net(net); + + spin_lock_irqsave(&auditd_conn_lock, flags); + ac_old = rcu_dereference_protected(auditd_conn, + lockdep_is_held(&auditd_conn_lock)); + rcu_assign_pointer(auditd_conn, ac_new); + spin_unlock_irqrestore(&auditd_conn_lock, flags); + + if (ac_old) + call_rcu(&ac_old->rcu, auditd_conn_free); + + return 0; } /** @@ -544,13 +597,19 @@ static void kauditd_retry_skb(struct sk_ */ static void auditd_reset(void) { + unsigned long flags; struct sk_buff *skb; + struct auditd_connection *ac_old; /* if it isn't already broken, break the connection */ - rcu_read_lock(); - if (auditd_conn.pid) - auditd_set(0, 0, NULL); - rcu_read_unlock(); + spin_lock_irqsave(&auditd_conn_lock, flags); + ac_old = rcu_dereference_protected(auditd_conn, + lockdep_is_held(&auditd_conn_lock)); + rcu_assign_pointer(auditd_conn, NULL); + spin_unlock_irqrestore(&auditd_conn_lock, flags); + + if (ac_old) + call_rcu(&ac_old->rcu, auditd_conn_free); /* flush all of the main and retry queues to the hold queue */ while ((skb = skb_dequeue(&audit_retry_queue))) @@ -576,6 +635,7 @@ static int auditd_send_unicast_skb(struc u32 portid; struct net *net; struct sock *sk; + struct auditd_connection *ac; /* NOTE: we can't call netlink_unicast while in the RCU section so * take a reference to the network namespace and grab local @@ -585,15 +645,15 @@ static int auditd_send_unicast_skb(struc * section netlink_unicast() should safely return an error */ rcu_read_lock(); - if (!auditd_conn.pid) { + ac = rcu_dereference(auditd_conn); + if (!ac) { rcu_read_unlock(); rc = -ECONNREFUSED; goto err; } - net = auditd_conn.net; - get_net(net); + net = get_net(ac->net); sk = audit_get_sk(net); - portid = auditd_conn.portid; + portid = ac->portid; rcu_read_unlock(); rc = netlink_unicast(sk, skb, portid, 0); @@ -728,6 +788,7 @@ static int kauditd_thread(void *dummy) u32 portid = 0; struct net *net = NULL; struct sock *sk = NULL; + struct auditd_connection *ac; #define UNICAST_RETRIES 5 @@ -735,14 +796,14 @@ static int kauditd_thread(void *dummy) while (!kthread_should_stop()) { /* NOTE: see the lock comments in auditd_send_unicast_skb() */ rcu_read_lock(); - if (!auditd_conn.pid) { + ac = rcu_dereference(auditd_conn); + if (!ac) { rcu_read_unlock(); goto main_queue; } - net = auditd_conn.net; - get_net(net); + net = get_net(ac->net); sk = audit_get_sk(net); - portid = auditd_conn.portid; + portid = ac->portid; rcu_read_unlock(); /* attempt to flush the hold queue */ @@ -1102,9 +1163,7 @@ static int audit_receive_msg(struct sk_b memset(&s, 0, sizeof(s)); s.enabled = audit_enabled; s.failure = audit_failure; - rcu_read_lock(); - s.pid = auditd_conn.pid; - rcu_read_unlock(); + s.pid = auditd_pid_vnr(); s.rate_limit = audit_rate_limit; s.backlog_limit = audit_backlog_limit; s.lost = atomic_read(&audit_lost); @@ -1143,38 +1202,44 @@ static int audit_receive_msg(struct sk_b /* test the auditd connection */ audit_replace(requesting_pid); - rcu_read_lock(); - auditd_pid = auditd_conn.pid; + auditd_pid = auditd_pid_vnr(); /* only the current auditd can unregister itself */ if ((!new_pid) && (requesting_pid != auditd_pid)) { - rcu_read_unlock(); audit_log_config_change("audit_pid", new_pid, auditd_pid, 0); return -EACCES; } /* replacing a healthy auditd is not allowed */ if (auditd_pid && new_pid) { - rcu_read_unlock(); audit_log_config_change("audit_pid", new_pid, auditd_pid, 0); return -EEXIST; } - rcu_read_unlock(); - - if (audit_enabled != AUDIT_OFF) - audit_log_config_change("audit_pid", new_pid, - auditd_pid, 1); if (new_pid) { /* register a new auditd connection */ - auditd_set(new_pid, - NETLINK_CB(skb).portid, - sock_net(NETLINK_CB(skb).sk)); + err = auditd_set(new_pid, + NETLINK_CB(skb).portid, + sock_net(NETLINK_CB(skb).sk)); + if (audit_enabled != AUDIT_OFF) + audit_log_config_change("audit_pid", + new_pid, + auditd_pid, + err ? 0 : 1); + if (err) + return err; + /* try to process any backlog */ wake_up_interruptible(&kauditd_wait); - } else + } else { + if (audit_enabled != AUDIT_OFF) + audit_log_config_change("audit_pid", + new_pid, + auditd_pid, 1); + /* unregister the auditd connection */ auditd_reset(); + } } if (s.mask & AUDIT_STATUS_RATE_LIMIT) { err = audit_set_rate_limit(s.rate_limit); @@ -1447,10 +1512,11 @@ static void __net_exit audit_net_exit(st { struct audit_net *aunet = net_generic(net, audit_net_id); - rcu_read_lock(); - if (net == auditd_conn.net) - auditd_reset(); - rcu_read_unlock(); + /* NOTE: you would think that we would want to check the auditd + * connection and potentially reset it here if it lives in this + * namespace, but since the auditd connection tracking struct holds a + * reference to this namespace (see auditd_set()) we are only ever + * going to get here after that connection has been released */ netlink_kernel_release(aunet->sk); } @@ -1470,9 +1536,6 @@ static int __init audit_init(void) if (audit_initialized == AUDIT_DISABLED) return 0; - memset(&auditd_conn, 0, sizeof(auditd_conn)); - spin_lock_init(&auditd_conn.lock); - skb_queue_head_init(&audit_queue); skb_queue_head_init(&audit_retry_queue); skb_queue_head_init(&audit_hold_queue);