On Wed, Jul 24, 2024 at 01:32 PM +02, Michal Luczaj wrote: > Rewrite function to have (unneeded) socket descriptors automatically > close()d when leaving the scope. Make sure the "ownership" of fds is > correctly passed via take_fd(); i.e. descriptor returned to caller will > remain valid. > > Suggested-by: Jakub Sitnicki <jakub@xxxxxxxxxxxxxx> > Signed-off-by: Michal Luczaj <mhal@xxxxxxx> > --- > .../selftests/bpf/prog_tests/sockmap_helpers.h | 57 ++++++++++++---------- > 1 file changed, 32 insertions(+), 25 deletions(-) > > diff --git a/tools/testing/selftests/bpf/prog_tests/sockmap_helpers.h b/tools/testing/selftests/bpf/prog_tests/sockmap_helpers.h > index ead8ea4fd0da..2e0f9fe459be 100644 > --- a/tools/testing/selftests/bpf/prog_tests/sockmap_helpers.h > +++ b/tools/testing/selftests/bpf/prog_tests/sockmap_helpers.h > @@ -182,6 +182,21 @@ > __ret; \ > }) > > +#define take_fd(fd) \ > + ({ \ > + __auto_type __val = (fd); \ > + fd = -EBADF; \ > + __val; \ > + }) Probably should operate on a pointer to fd to avoid side effects, like __get_and_null macro in include/linux/cleanup.h. take_fd is effectively __get_and_null(fd, -EBADFD). > + > +static inline void close_fd(int *fd) > +{ > + if (*fd >= 0) > + xclose(*fd); > +} > + > +#define __close_fd __attribute__((cleanup(close_fd))) > + > static inline int poll_connect(int fd, unsigned int timeout_sec) > { > struct timeval timeout = { .tv_sec = timeout_sec }; > @@ -369,9 +384,10 @@ static inline int socket_loopback(int family, int sotype) > > static inline int create_pair(int family, int sotype, int *p0, int *p1) > { > + __close_fd int s, c = -1, p = -1; > struct sockaddr_storage addr; > socklen_t len = sizeof(addr); > - int s, c, p, err; > + int err; > > s = socket_loopback(family, sotype); > if (s < 0) > @@ -379,25 +395,23 @@ static inline int create_pair(int family, int sotype, int *p0, int *p1) > > err = xgetsockname(s, sockaddr(&addr), &len); > if (err) > - goto close_s; > + return err; > > c = xsocket(family, sotype, 0); > - if (c < 0) { > - err = c; > - goto close_s; > - } > + if (c < 0) > + return c; > > err = connect(c, sockaddr(&addr), len); > if (err) { > if (errno != EINPROGRESS) { > FAIL_ERRNO("connect"); > - goto close_c; > + return err; > } > > err = poll_connect(c, IO_TIMEOUT_SEC); > if (err) { > FAIL_ERRNO("poll_connect"); > - goto close_c; > + return err; > } > } > > @@ -405,36 +419,29 @@ static inline int create_pair(int family, int sotype, int *p0, int *p1) > case SOCK_DGRAM: > err = xgetsockname(c, sockaddr(&addr), &len); > if (err) > - goto close_c; > + return err; > > err = xconnect(s, sockaddr(&addr), len); > - if (!err) { > - *p0 = s; > - *p1 = c; > + if (err) > return err; > - } > + > + *p0 = take_fd(s); > break; > case SOCK_STREAM: > case SOCK_SEQPACKET: > p = xaccept_nonblock(s, NULL, NULL); > - if (p >= 0) { > - *p0 = p; > - *p1 = c; > - goto close_s; > - } > + if (p < 0) > + return p; > > - err = p; > + *p0 = take_fd(p); > break; > default: > FAIL("Unsupported socket type %#x", sotype); > - err = -EOPNOTSUPP; > + return -EOPNOTSUPP; > } > > -close_c: > - close(c); > -close_s: > - close(s); > - return err; > + *p1 = take_fd(c); > + return 0; > } > > static inline int create_socket_pairs(int family, int sotype, int *c0, int *c1, This turned out nice and readable, IMHO.