On Mon, 2024-03-25 at 19:21 -0700, Yonghong Song wrote: [...] > diff --git a/net/core/sock_map.c b/net/core/sock_map.c > index 27d733c0f65e..dafc9aa6e192 100644 > --- a/net/core/sock_map.c > +++ b/net/core/sock_map.c [...] > @@ -1488,21 +1492,90 @@ static int sock_map_prog_lookup(struct bpf_map *map, struct bpf_prog ***pprog, > return 0; > } > > +static int sock_map_link_lookup(struct bpf_map *map, struct bpf_link ***plink, > + struct bpf_link *link, bool skip_check, u32 which) > +{ > + struct sk_psock_progs *progs = sock_map_progs(map); > + > + switch (which) { > + case BPF_SK_MSG_VERDICT: > + if (!skip_check && > + ((!link && progs->msg_parser_link) || > + (link && link != progs->msg_parser_link))) > + return -EBUSY; These checks seem a bit repetitive, maybe factor it out as a single check at the end of the function? E.g.: if (!skip_check && ((!link && **plink) || (link && link != **plink))) return -EBUSY; Or inline these checks at call sites for sock_map_link_lookup()? I tried this on top of this in [1] and all tests seem to pass. [1] https://gist.github.com/eddyz87/38d832b3f1fc74120598d3480bc16ae1 > + *plink = &progs->msg_parser_link; > + break; > +#if IS_ENABLED(CONFIG_BPF_STREAM_PARSER) > + case BPF_SK_SKB_STREAM_PARSER: > + if (!skip_check && > + ((!link && progs->stream_parser_link) || > + (link && link != progs->stream_parser_link))) > + return -EBUSY; > + *plink = &progs->stream_parser_link; > + break; > +#endif > + case BPF_SK_SKB_STREAM_VERDICT: > + if (!skip_check && > + ((!link && progs->stream_verdict_link) || > + (link && link != progs->stream_verdict_link))) > + return -EBUSY; > + *plink = &progs->stream_verdict_link; > + break; > + case BPF_SK_SKB_VERDICT: > + if (!skip_check && > + ((!link && progs->skb_verdict_link) || > + (link && link != progs->skb_verdict_link))) > + return -EBUSY; > + *plink = &progs->skb_verdict_link; > + break; > + default: > + return -EOPNOTSUPP; > + } > + > + return 0; > +} [...] > +/* Handle the following two cases: > + * case 1: link != NULL, prog != NULL, old != NULL > + * case 2: link != NULL, prog != NULL, old == NULL > + */ > +static int sock_map_link_update_prog(struct bpf_link *link, > + struct bpf_prog *prog, > + struct bpf_prog *old) > +{ > + const struct sockmap_link *sockmap_link = get_sockmap_link(link); > + struct bpf_prog **pprog; > + struct bpf_link **plink; > + int ret = 0; > + > + mutex_lock(&sockmap_prog_update_mutex); > + > + /* If old prog not NULL, ensure old prog the same as link->prog. */ > + if (old && link->prog != old) { > + ret = -EINVAL; > + goto out; > + } > + /* Ensure link->prog has the same type/attach_type as the new prog. */ > + if (link->prog->type != prog->type || > + link->prog->expected_attach_type != prog->expected_attach_type) { > + ret = -EINVAL; > + goto out; > + } > + > + ret = sock_map_prog_lookup(sockmap_link->map, &pprog, > + sockmap_link->attach_type); > + if (ret) > + goto out; > + > + /* Ensure the same link between the one in map and the passed-in. */ > + ret = sock_map_link_lookup(sockmap_link->map, &plink, link, false, > + sockmap_link->attach_type); > + if (ret) > + goto out; > + > + if (old) > + return psock_replace_prog(pprog, prog, old); should this be 'goto out' in order to unlock the mutex? > + > + psock_set_prog(pprog, prog); > + > +out: > + if (!ret) > + bpf_prog_inc(prog); > + mutex_unlock(&sockmap_prog_update_mutex); > + return ret; > +} [...]