This patch at first adds a pernet global l3mdev_accept to decide if it accepts the packets from a l3mdev when a SCTP socket doesn't bind to any interface. It's set to 1 to avoid any possible incompatible issue, and in next patch, a sysctl will be introduced to allow to change it. Then similar to inet/udp_sk_bound_dev_eq(), sctp_sk_bound_dev_eq() is added to check either dif or sdif is equal to sk_bound_dev_if, and to check sid is 0 or l3mdev_accept is 1 if sk_bound_dev_if is not set. This function is used to match a association or a endpoint, namely called by sctp_addrs_lookup_transport() and sctp_endpoint_is_match(). All functions that needs updating are: sctp_rcv(): asoc: __sctp_rcv_lookup() __sctp_lookup_association() -> sctp_addrs_lookup_transport() __sctp_rcv_lookup_harder() __sctp_rcv_init_lookup() __sctp_lookup_association() -> sctp_addrs_lookup_transport() __sctp_rcv_walk_lookup() __sctp_rcv_asconf_lookup() __sctp_lookup_association() -> sctp_addrs_lookup_transport() ep: __sctp_rcv_lookup_endpoint() -> sctp_endpoint_is_match() sctp_connect(): sctp_endpoint_is_peeled_off() __sctp_lookup_association() sctp_has_association() sctp_lookup_association() __sctp_lookup_association() -> sctp_addrs_lookup_transport() sctp_diag_dump_one(): sctp_transport_lookup_process() -> sctp_addrs_lookup_transport() Signed-off-by: Xin Long <lucien.xin@xxxxxxxxx> --- include/net/netns/sctp.h | 4 ++ include/net/sctp/sctp.h | 6 ++- include/net/sctp/structs.h | 8 +-- net/sctp/diag.c | 3 +- net/sctp/endpointola.c | 13 +++-- net/sctp/input.c | 108 ++++++++++++++++++++----------------- net/sctp/protocol.c | 4 ++ net/sctp/socket.c | 4 +- 8 files changed, 89 insertions(+), 61 deletions(-) diff --git a/include/net/netns/sctp.h b/include/net/netns/sctp.h index a681147aecd8..7eff3d981b89 100644 --- a/include/net/netns/sctp.h +++ b/include/net/netns/sctp.h @@ -175,6 +175,10 @@ struct netns_sctp { /* Threshold for autoclose timeout, in seconds. */ unsigned long max_autoclose; + +#ifdef CONFIG_NET_L3_MASTER_DEV + int l3mdev_accept; +#endif }; #endif /* __NETNS_SCTP_H__ */ diff --git a/include/net/sctp/sctp.h b/include/net/sctp/sctp.h index a04999ee99b0..e499fd753fc4 100644 --- a/include/net/sctp/sctp.h +++ b/include/net/sctp/sctp.h @@ -114,7 +114,7 @@ struct sctp_transport *sctp_transport_get_idx(struct net *net, struct rhashtable_iter *iter, int pos); int sctp_transport_lookup_process(sctp_callback_t cb, struct net *net, const union sctp_addr *laddr, - const union sctp_addr *paddr, void *p); + const union sctp_addr *paddr, void *p, int dif); int sctp_transport_traverse_process(sctp_callback_t cb, sctp_callback_t cb_done, struct net *net, int *pos, void *p); int sctp_for_each_endpoint(int (*cb)(struct sctp_endpoint *, void *), void *p); @@ -162,10 +162,12 @@ void sctp_unhash_transport(struct sctp_transport *t); struct sctp_transport *sctp_addrs_lookup_transport( struct net *net, const union sctp_addr *laddr, - const union sctp_addr *paddr); + const union sctp_addr *paddr, + int dif, int sdif); struct sctp_transport *sctp_epaddr_lookup_transport( const struct sctp_endpoint *ep, const union sctp_addr *paddr); +bool sctp_sk_bound_dev_eq(struct net *net, int bound_dev_if, int dif, int sdif); /* * sctp/proc.c diff --git a/include/net/sctp/structs.h b/include/net/sctp/structs.h index 7b4884c63b26..afa3781e3ca2 100644 --- a/include/net/sctp/structs.h +++ b/include/net/sctp/structs.h @@ -1379,10 +1379,12 @@ struct sctp_association *sctp_endpoint_lookup_assoc( struct sctp_transport **); bool sctp_endpoint_is_peeled_off(struct sctp_endpoint *ep, const union sctp_addr *paddr); -struct sctp_endpoint *sctp_endpoint_is_match(struct sctp_endpoint *, - struct net *, const union sctp_addr *); +struct sctp_endpoint *sctp_endpoint_is_match(struct sctp_endpoint *ep, + struct net *net, + const union sctp_addr *laddr, + int dif, int sdif); bool sctp_has_association(struct net *net, const union sctp_addr *laddr, - const union sctp_addr *paddr); + const union sctp_addr *paddr, int dif, int sdif); int sctp_verify_init(struct net *net, const struct sctp_endpoint *ep, const struct sctp_association *asoc, diff --git a/net/sctp/diag.c b/net/sctp/diag.c index d9c6d8f30f09..a557009e9832 100644 --- a/net/sctp/diag.c +++ b/net/sctp/diag.c @@ -426,6 +426,7 @@ static int sctp_diag_dump_one(struct netlink_callback *cb, struct net *net = sock_net(skb->sk); const struct nlmsghdr *nlh = cb->nlh; union sctp_addr laddr, paddr; + int dif = req->id.idiag_if; struct sctp_comm_param commp = { .skb = skb, .r = req, @@ -454,7 +455,7 @@ static int sctp_diag_dump_one(struct netlink_callback *cb, } return sctp_transport_lookup_process(sctp_sock_dump_one, - net, &laddr, &paddr, &commp); + net, &laddr, &paddr, &commp, dif); } static void sctp_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, diff --git a/net/sctp/endpointola.c b/net/sctp/endpointola.c index efffde7f2328..7e77b450697c 100644 --- a/net/sctp/endpointola.c +++ b/net/sctp/endpointola.c @@ -246,12 +246,15 @@ void sctp_endpoint_put(struct sctp_endpoint *ep) /* Is this the endpoint we are looking for? */ struct sctp_endpoint *sctp_endpoint_is_match(struct sctp_endpoint *ep, struct net *net, - const union sctp_addr *laddr) + const union sctp_addr *laddr, + int dif, int sdif) { + int bound_dev_if = READ_ONCE(ep->base.sk->sk_bound_dev_if); struct sctp_endpoint *retval = NULL; - if ((htons(ep->base.bind_addr.port) == laddr->v4.sin_port) && - net_eq(ep->base.net, net)) { + if (net_eq(ep->base.net, net) && + sctp_sk_bound_dev_eq(net, bound_dev_if, dif, sdif) && + (htons(ep->base.bind_addr.port) == laddr->v4.sin_port)) { if (sctp_bind_addr_match(&ep->base.bind_addr, laddr, sctp_sk(ep->base.sk))) retval = ep; @@ -298,6 +301,7 @@ struct sctp_association *sctp_endpoint_lookup_assoc( bool sctp_endpoint_is_peeled_off(struct sctp_endpoint *ep, const union sctp_addr *paddr) { + int bound_dev_if = READ_ONCE(ep->base.sk->sk_bound_dev_if); struct sctp_sockaddr_entry *addr; struct net *net = ep->base.net; struct sctp_bind_addr *bp; @@ -307,7 +311,8 @@ bool sctp_endpoint_is_peeled_off(struct sctp_endpoint *ep, * so the address_list can not change. */ list_for_each_entry(addr, &bp->address_list, list) { - if (sctp_has_association(net, &addr->a, paddr)) + if (sctp_has_association(net, &addr->a, paddr, + bound_dev_if, bound_dev_if)) return true; } diff --git a/net/sctp/input.c b/net/sctp/input.c index 4f43afa8678f..bf70371301ff 100644 --- a/net/sctp/input.c +++ b/net/sctp/input.c @@ -50,16 +50,19 @@ static struct sctp_association *__sctp_rcv_lookup(struct net *net, struct sk_buff *skb, const union sctp_addr *paddr, const union sctp_addr *laddr, - struct sctp_transport **transportp); + struct sctp_transport **transportp, + int dif, int sdif); static struct sctp_endpoint *__sctp_rcv_lookup_endpoint( struct net *net, struct sk_buff *skb, const union sctp_addr *laddr, - const union sctp_addr *daddr); + const union sctp_addr *daddr, + int dif, int sdif); static struct sctp_association *__sctp_lookup_association( struct net *net, const union sctp_addr *local, const union sctp_addr *peer, - struct sctp_transport **pt); + struct sctp_transport **pt, + int dif, int sdif); static int sctp_add_backlog(struct sock *sk, struct sk_buff *skb); @@ -92,11 +95,11 @@ int sctp_rcv(struct sk_buff *skb) struct sctp_chunk *chunk; union sctp_addr src; union sctp_addr dest; - int bound_dev_if; int family; struct sctp_af *af; struct net *net = dev_net(skb->dev); bool is_gso = skb_is_gso(skb) && skb_is_gso_sctp(skb); + int dif, sdif; if (skb->pkt_type != PACKET_HOST) goto discard_it; @@ -141,6 +144,8 @@ int sctp_rcv(struct sk_buff *skb) /* Initialize local addresses for lookups. */ af->from_skb(&src, skb, 1); af->from_skb(&dest, skb, 0); + dif = af->skb_iif(skb); + sdif = af->skb_sdif(skb); /* If the packet is to or from a non-unicast address, * silently discard the packet. @@ -157,35 +162,15 @@ int sctp_rcv(struct sk_buff *skb) !af->addr_valid(&dest, NULL, skb)) goto discard_it; - asoc = __sctp_rcv_lookup(net, skb, &src, &dest, &transport); + asoc = __sctp_rcv_lookup(net, skb, &src, &dest, &transport, dif, sdif); if (!asoc) - ep = __sctp_rcv_lookup_endpoint(net, skb, &dest, &src); + ep = __sctp_rcv_lookup_endpoint(net, skb, &dest, &src, dif, sdif); /* Retrieve the common input handling substructure. */ rcvr = asoc ? &asoc->base : &ep->base; sk = rcvr->sk; - /* - * If a frame arrives on an interface and the receiving socket is - * bound to another interface, via SO_BINDTODEVICE, treat it as OOTB - */ - bound_dev_if = READ_ONCE(sk->sk_bound_dev_if); - if (bound_dev_if && (bound_dev_if != af->skb_iif(skb))) { - if (transport) { - sctp_transport_put(transport); - asoc = NULL; - transport = NULL; - } else { - sctp_endpoint_put(ep); - ep = NULL; - } - sk = net->sctp.ctl_sock; - ep = sctp_sk(sk)->ep; - sctp_endpoint_hold(ep); - rcvr = &ep->base; - } - /* * RFC 2960, 8.4 - Handle "Out of the blue" Packets. * An SCTP packet is called an "out of the blue" (OOTB) @@ -485,6 +470,8 @@ struct sock *sctp_err_lookup(struct net *net, int family, struct sk_buff *skb, struct sctp_association *asoc; struct sctp_transport *transport = NULL; __u32 vtag = ntohl(sctphdr->vtag); + int sdif = inet_sdif(skb); + int dif = inet_iif(skb); *app = NULL; *tpp = NULL; @@ -500,7 +487,7 @@ struct sock *sctp_err_lookup(struct net *net, int family, struct sk_buff *skb, /* Look for an association that matches the incoming ICMP error * packet. */ - asoc = __sctp_lookup_association(net, &saddr, &daddr, &transport); + asoc = __sctp_lookup_association(net, &saddr, &daddr, &transport, dif, sdif); if (!asoc) return NULL; @@ -850,7 +837,8 @@ static inline __u32 sctp_hashfn(const struct net *net, __be16 lport, static struct sctp_endpoint *__sctp_rcv_lookup_endpoint( struct net *net, struct sk_buff *skb, const union sctp_addr *laddr, - const union sctp_addr *paddr) + const union sctp_addr *paddr, + int dif, int sdif) { struct sctp_hashbucket *head; struct sctp_endpoint *ep; @@ -863,7 +851,7 @@ static struct sctp_endpoint *__sctp_rcv_lookup_endpoint( head = &sctp_ep_hashtable[hash]; read_lock(&head->lock); sctp_for_each_hentry(ep, &head->chain) { - if (sctp_endpoint_is_match(ep, net, laddr)) + if (sctp_endpoint_is_match(ep, net, laddr, dif, sdif)) goto hit; } @@ -990,14 +978,26 @@ void sctp_unhash_transport(struct sctp_transport *t) sctp_hash_params); } +bool sctp_sk_bound_dev_eq(struct net *net, int bound_dev_if, int dif, int sdif) +{ + bool l3mdev_accept = true; + +#if IS_ENABLED(CONFIG_NET_L3_MASTER_DEV) + l3mdev_accept = !!READ_ONCE(net->sctp.l3mdev_accept); +#endif + return inet_bound_dev_eq(l3mdev_accept, bound_dev_if, dif, sdif); +} + /* return a transport with holding it */ struct sctp_transport *sctp_addrs_lookup_transport( struct net *net, const union sctp_addr *laddr, - const union sctp_addr *paddr) + const union sctp_addr *paddr, + int dif, int sdif) { struct rhlist_head *tmp, *list; struct sctp_transport *t; + int bound_dev_if; struct sctp_hash_cmp_arg arg = { .paddr = paddr, .net = net, @@ -1011,7 +1011,9 @@ struct sctp_transport *sctp_addrs_lookup_transport( if (!sctp_transport_hold(t)) continue; - if (sctp_bind_addr_match(&t->asoc->base.bind_addr, + bound_dev_if = READ_ONCE(t->asoc->base.sk->sk_bound_dev_if); + if (sctp_sk_bound_dev_eq(net, bound_dev_if, dif, sdif) && + sctp_bind_addr_match(&t->asoc->base.bind_addr, laddr, sctp_sk(t->asoc->base.sk))) return t; sctp_transport_put(t); @@ -1048,12 +1050,13 @@ static struct sctp_association *__sctp_lookup_association( struct net *net, const union sctp_addr *local, const union sctp_addr *peer, - struct sctp_transport **pt) + struct sctp_transport **pt, + int dif, int sdif) { struct sctp_transport *t; struct sctp_association *asoc = NULL; - t = sctp_addrs_lookup_transport(net, local, peer); + t = sctp_addrs_lookup_transport(net, local, peer, dif, sdif); if (!t) goto out; @@ -1069,12 +1072,13 @@ static struct sctp_association *sctp_lookup_association(struct net *net, const union sctp_addr *laddr, const union sctp_addr *paddr, - struct sctp_transport **transportp) + struct sctp_transport **transportp, + int dif, int sdif) { struct sctp_association *asoc; rcu_read_lock(); - asoc = __sctp_lookup_association(net, laddr, paddr, transportp); + asoc = __sctp_lookup_association(net, laddr, paddr, transportp, dif, sdif); rcu_read_unlock(); return asoc; @@ -1083,11 +1087,12 @@ struct sctp_association *sctp_lookup_association(struct net *net, /* Is there an association matching the given local and peer addresses? */ bool sctp_has_association(struct net *net, const union sctp_addr *laddr, - const union sctp_addr *paddr) + const union sctp_addr *paddr, + int dif, int sdif) { struct sctp_transport *transport; - if (sctp_lookup_association(net, laddr, paddr, &transport)) { + if (sctp_lookup_association(net, laddr, paddr, &transport, dif, sdif)) { sctp_transport_put(transport); return true; } @@ -1115,7 +1120,8 @@ bool sctp_has_association(struct net *net, */ static struct sctp_association *__sctp_rcv_init_lookup(struct net *net, struct sk_buff *skb, - const union sctp_addr *laddr, struct sctp_transport **transportp) + const union sctp_addr *laddr, struct sctp_transport **transportp, + int dif, int sdif) { struct sctp_association *asoc; union sctp_addr addr; @@ -1154,7 +1160,7 @@ static struct sctp_association *__sctp_rcv_init_lookup(struct net *net, if (!af->from_addr_param(paddr, params.addr, sh->source, 0)) continue; - asoc = __sctp_lookup_association(net, laddr, paddr, transportp); + asoc = __sctp_lookup_association(net, laddr, paddr, transportp, dif, sdif); if (asoc) return asoc; } @@ -1181,7 +1187,8 @@ static struct sctp_association *__sctp_rcv_asconf_lookup( struct sctp_chunkhdr *ch, const union sctp_addr *laddr, __be16 peer_port, - struct sctp_transport **transportp) + struct sctp_transport **transportp, + int dif, int sdif) { struct sctp_addip_chunk *asconf = (struct sctp_addip_chunk *)ch; struct sctp_af *af; @@ -1201,7 +1208,7 @@ static struct sctp_association *__sctp_rcv_asconf_lookup( if (!af->from_addr_param(&paddr, param, peer_port, 0)) return NULL; - return __sctp_lookup_association(net, laddr, &paddr, transportp); + return __sctp_lookup_association(net, laddr, &paddr, transportp, dif, sdif); } @@ -1217,7 +1224,8 @@ static struct sctp_association *__sctp_rcv_asconf_lookup( static struct sctp_association *__sctp_rcv_walk_lookup(struct net *net, struct sk_buff *skb, const union sctp_addr *laddr, - struct sctp_transport **transportp) + struct sctp_transport **transportp, + int dif, int sdif) { struct sctp_association *asoc = NULL; struct sctp_chunkhdr *ch; @@ -1260,7 +1268,7 @@ static struct sctp_association *__sctp_rcv_walk_lookup(struct net *net, asoc = __sctp_rcv_asconf_lookup( net, ch, laddr, sctp_hdr(skb)->source, - transportp); + transportp, dif, sdif); break; default: break; @@ -1285,7 +1293,8 @@ static struct sctp_association *__sctp_rcv_walk_lookup(struct net *net, static struct sctp_association *__sctp_rcv_lookup_harder(struct net *net, struct sk_buff *skb, const union sctp_addr *laddr, - struct sctp_transport **transportp) + struct sctp_transport **transportp, + int dif, int sdif) { struct sctp_chunkhdr *ch; @@ -1309,9 +1318,9 @@ static struct sctp_association *__sctp_rcv_lookup_harder(struct net *net, /* If this is INIT/INIT-ACK look inside the chunk too. */ if (ch->type == SCTP_CID_INIT || ch->type == SCTP_CID_INIT_ACK) - return __sctp_rcv_init_lookup(net, skb, laddr, transportp); + return __sctp_rcv_init_lookup(net, skb, laddr, transportp, dif, sdif); - return __sctp_rcv_walk_lookup(net, skb, laddr, transportp); + return __sctp_rcv_walk_lookup(net, skb, laddr, transportp, dif, sdif); } /* Lookup an association for an inbound skb. */ @@ -1319,11 +1328,12 @@ static struct sctp_association *__sctp_rcv_lookup(struct net *net, struct sk_buff *skb, const union sctp_addr *paddr, const union sctp_addr *laddr, - struct sctp_transport **transportp) + struct sctp_transport **transportp, + int dif, int sdif) { struct sctp_association *asoc; - asoc = __sctp_lookup_association(net, laddr, paddr, transportp); + asoc = __sctp_lookup_association(net, laddr, paddr, transportp, dif, sdif); if (asoc) goto out; @@ -1331,7 +1341,7 @@ static struct sctp_association *__sctp_rcv_lookup(struct net *net, * SCTP Implementors Guide, 2.18 Handling of address * parameters within the INIT or INIT-ACK. */ - asoc = __sctp_rcv_lookup_harder(net, skb, laddr, transportp); + asoc = __sctp_rcv_lookup_harder(net, skb, laddr, transportp, dif, sdif); if (asoc) goto out; diff --git a/net/sctp/protocol.c b/net/sctp/protocol.c index a18cf0471a8d..909a89a1cff4 100644 --- a/net/sctp/protocol.c +++ b/net/sctp/protocol.c @@ -1394,6 +1394,10 @@ static int __net_init sctp_defaults_init(struct net *net) /* Initialize maximum autoclose timeout. */ net->sctp.max_autoclose = INT_MAX / HZ; +#ifdef CONFIG_NET_L3_MASTER_DEV + net->sctp.l3mdev_accept = 1; +#endif + status = sctp_sysctl_net_register(net); if (status) goto err_sysctl_register; diff --git a/net/sctp/socket.c b/net/sctp/socket.c index 4306164238ef..5acbdf0d38f3 100644 --- a/net/sctp/socket.c +++ b/net/sctp/socket.c @@ -5315,14 +5315,14 @@ EXPORT_SYMBOL_GPL(sctp_for_each_endpoint); int sctp_transport_lookup_process(sctp_callback_t cb, struct net *net, const union sctp_addr *laddr, - const union sctp_addr *paddr, void *p) + const union sctp_addr *paddr, void *p, int dif) { struct sctp_transport *transport; struct sctp_endpoint *ep; int err = -ENOENT; rcu_read_lock(); - transport = sctp_addrs_lookup_transport(net, laddr, paddr); + transport = sctp_addrs_lookup_transport(net, laddr, paddr, dif, dif); if (!transport) { rcu_read_unlock(); return err; -- 2.31.1