From: "D. Wythe" <alibuda@xxxxxxxxxxxxxxxxx> --- net/smc/smc_diag.c | 155 ++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 137 insertions(+), 18 deletions(-) diff --git a/net/smc/smc_diag.c b/net/smc/smc_diag.c index 59a18ec..20532e1 100644 --- a/net/smc/smc_diag.c +++ b/net/smc/smc_diag.c @@ -22,9 +22,11 @@ #include "smc.h" #include "smc_core.h" #include "smc_ism.h" +#include "smc_inet.h" struct smc_diag_dump_ctx { int pos[2]; + int inet_pos[2]; }; static struct smc_diag_dump_ctx *smc_dump_context(struct netlink_callback *cb) @@ -35,24 +37,42 @@ static struct smc_diag_dump_ctx *smc_dump_context(struct netlink_callback *cb) static void smc_diag_msg_common_fill(struct smc_diag_msg *r, struct sock *sk) { struct smc_sock *smc = smc_sk(sk); + struct sock *clcsk; + bool is_v4, is_v6; + + if (smc_sock_is_inet_sock(sk)) + clcsk = sk; + else if (smc->clcsock) + clcsk = smc->clcsock->sk; + else + return; memset(r, 0, sizeof(*r)); r->diag_family = sk->sk_family; sock_diag_save_cookie(sk, r->id.idiag_cookie); - if (!smc->clcsock) - return; - r->id.idiag_sport = htons(smc->clcsock->sk->sk_num); - r->id.idiag_dport = smc->clcsock->sk->sk_dport; - r->id.idiag_if = smc->clcsock->sk->sk_bound_dev_if; - if (sk->sk_protocol == SMCPROTO_SMC) { - r->id.idiag_src[0] = smc->clcsock->sk->sk_rcv_saddr; - r->id.idiag_dst[0] = smc->clcsock->sk->sk_daddr; + + r->id.idiag_sport = htons(clcsk->sk_num); + r->id.idiag_dport = clcsk->sk_dport; + r->id.idiag_if = clcsk->sk_bound_dev_if; + + is_v4 = smc_sock_is_inet_sock(sk) ? clcsk->sk_family == AF_INET : + sk->sk_protocol == SMCPROTO_SMC; #if IS_ENABLED(CONFIG_IPV6) - } else if (sk->sk_protocol == SMCPROTO_SMC6) { - memcpy(&r->id.idiag_src, &smc->clcsock->sk->sk_v6_rcv_saddr, - sizeof(smc->clcsock->sk->sk_v6_rcv_saddr)); - memcpy(&r->id.idiag_dst, &smc->clcsock->sk->sk_v6_daddr, - sizeof(smc->clcsock->sk->sk_v6_daddr)); + is_v6 = smc_sock_is_inet_sock(sk) ? clcsk->sk_family == AF_INET6 : + sk->sk_protocol == SMCPROTO_SMC6; +#else + is_v6 = false; +#endif + + if (is_v4) { + r->id.idiag_src[0] = clcsk->sk_rcv_saddr; + r->id.idiag_dst[0] = clcsk->sk_daddr; +#if IS_ENABLED(CONFIG_IPV6) + } else if (is_v6) { + memcpy(&r->id.idiag_src, &clcsk->sk_v6_rcv_saddr, + sizeof(clcsk->sk_v6_rcv_saddr)); + memcpy(&r->id.idiag_dst, &clcsk->sk_v6_daddr, + sizeof(clcsk->sk_v6_daddr)); #endif } } @@ -72,7 +92,7 @@ static int smc_diag_msg_attrs_fill(struct sock *sk, struct sk_buff *skb, static int __smc_diag_dump(struct sock *sk, struct sk_buff *skb, struct netlink_callback *cb, const struct smc_diag_req *req, - struct nlattr *bc) + struct nlattr *bc, bool is_listen) { struct smc_sock *smc = smc_sk(sk); struct smc_diag_fallback fallback; @@ -88,6 +108,12 @@ static int __smc_diag_dump(struct sock *sk, struct sk_buff *skb, r = nlmsg_data(nlh); smc_diag_msg_common_fill(r, sk); r->diag_state = smc_sk_state(sk); + + if (is_listen) + r->diag_state = SMC_LISTEN; + else + r->diag_state = smc_sk_state(sk); + if (smc->use_fallback) r->diag_mode = SMC_DIAG_MODE_FALLBACK_TCP; else if (smc_conn_lgr_valid(&smc->conn) && smc->conn.lgr->is_smcd) @@ -193,6 +219,82 @@ static int __smc_diag_dump(struct sock *sk, struct sk_buff *skb, return -EMSGSIZE; } +static int smc_diag_dump_inet_proto(struct inet_hashinfo *hashinfo, struct sk_buff *skb, + struct netlink_callback *cb, int p_type) +{ + struct smc_diag_dump_ctx *cb_ctx = smc_dump_context(cb); + struct net *net = sock_net(skb->sk); + int snum = cb_ctx->inet_pos[p_type]; + struct nlattr *bc = NULL; + int rc = 0, num = 0, i; + struct proto *target_proto; + struct sock *sk; + +#if IS_ENABLED(CONFIG_IPV6) + target_proto = (p_type == SMCPROTO_SMC6) ? &smc_inet6_prot : &smc_inet_prot; +#else + target_proto = &smc_inet_prot; +#endif + + for (i = 0; i < hashinfo->lhash2_mask; i++) { + struct inet_listen_hashbucket *ilb; + struct hlist_nulls_node *node; + + ilb = &hashinfo->lhash2[i]; + spin_lock(&ilb->lock); + sk_nulls_for_each(sk, node, &ilb->nulls_head) { + if (!net_eq(sock_net(sk), net)) + continue; + if (sk->sk_prot != target_proto) + continue; + if (num < snum) + goto next_ls; + rc = __smc_diag_dump(sk, skb, cb, nlmsg_data(cb->nlh), bc, 1); + if (rc < 0) { + spin_unlock(&ilb->lock); + goto out; + } +next_ls: + num++; + } + spin_unlock(&ilb->lock); + } + + for (i = 0; i <= hashinfo->ehash_mask; i++) { + struct inet_ehash_bucket *head = &hashinfo->ehash[i]; + spinlock_t *lock = inet_ehash_lockp(hashinfo, i); + struct hlist_nulls_node *node; + + if (hlist_nulls_empty(&head->chain)) + continue; + + spin_lock_bh(lock); + sk_nulls_for_each(sk, node, &head->chain) { + if (!net_eq(sock_net(sk), net)) + continue; + if (sk->sk_state == TCP_TIME_WAIT) + continue; + if (sk->sk_state == TCP_NEW_SYN_RECV) + continue; + if (sk->sk_prot != target_proto) + continue; + if (num < snum) + goto next; + rc = __smc_diag_dump(sk, skb, cb, nlmsg_data(cb->nlh), bc, 0); + if (rc < 0) { + spin_unlock_bh(lock); + goto out; + } +next: + num++; + } + spin_unlock_bh(lock); + } +out: + cb_ctx->inet_pos[p_type] = num; + return rc; +} + static int smc_diag_dump_proto(struct proto *prot, struct sk_buff *skb, struct netlink_callback *cb, int p_type) { @@ -214,7 +316,7 @@ static int smc_diag_dump_proto(struct proto *prot, struct sk_buff *skb, continue; if (num < snum) goto next; - rc = __smc_diag_dump(sk, skb, cb, nlmsg_data(cb->nlh), bc); + rc = __smc_diag_dump(sk, skb, cb, nlmsg_data(cb->nlh), bc, 0); if (rc < 0) goto out; next: @@ -232,8 +334,26 @@ static int smc_diag_dump(struct sk_buff *skb, struct netlink_callback *cb) int rc = 0; rc = smc_diag_dump_proto(&smc_proto, skb, cb, SMCPROTO_SMC); - if (!rc) - smc_diag_dump_proto(&smc_proto6, skb, cb, SMCPROTO_SMC6); + if (rc) + goto out; + +#if IS_ENABLED(CONFIG_IPV6) + rc = smc_diag_dump_proto(&smc_proto6, skb, cb, SMCPROTO_SMC6); + if (rc) + goto out; +#endif + + rc = smc_diag_dump_inet_proto(smc_inet_prot.h.hashinfo, skb, cb, SMCPROTO_SMC); + if (rc) + goto out; + +#if IS_ENABLED(CONFIG_IPV6) + rc = smc_diag_dump_inet_proto(smc_inet6_prot.h.hashinfo, skb, cb, SMCPROTO_SMC6); + if (rc) + goto out; +#endif + return 0; +out: return skb->len; } @@ -273,6 +393,5 @@ static void __exit smc_diag_exit(void) module_init(smc_diag_init); module_exit(smc_diag_exit); MODULE_LICENSE("GPL"); -MODULE_DESCRIPTION("SMC socket monitoring via SOCK_DIAG"); MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 43 /* AF_SMC */); MODULE_ALIAS_GENL_FAMILY(SMCR_GENL_FAMILY_NAME); -- 1.8.3.1