Signed-off-by: Tim Wiederhake <twiederh@xxxxxxxxxx> --- src/util/virnetlink.c | 101 ++++++++++++++++-------------------------- 1 file changed, 39 insertions(+), 62 deletions(-) diff --git a/src/util/virnetlink.c b/src/util/virnetlink.c index 3216765492..f15bb68b02 100644 --- a/src/util/virnetlink.c +++ b/src/util/virnetlink.c @@ -799,18 +799,6 @@ virNetlinkGetErrorCode(struct nlmsghdr *resp, unsigned int recvbuflen) } -static void -virNetlinkEventServerLock(virNetlinkEventSrvPrivate *driver) -{ - virMutexLock(&driver->lock); -} - -static void -virNetlinkEventServerUnlock(virNetlinkEventSrvPrivate *driver) -{ - virMutexUnlock(&driver->lock); -} - /** * virNetlinkEventRemoveClientPrimitive: * @@ -857,6 +845,7 @@ virNetlinkEventCallback(int watch, int length; bool handled = false; g_autofree struct nlmsghdr *msg = NULL; + VIR_LOCK_GUARD lock = { NULL }; length = nl_recv(srv->netlinknh, &peer, (unsigned char **)&msg, &creds); @@ -869,7 +858,7 @@ virNetlinkEventCallback(int watch, return; } - virNetlinkEventServerLock(srv); + lock = virLockGuardLock(&srv->lock); VIR_DEBUG("dispatching to max %d clients, called from event watch %d", (int)srv->handlesCount, watch); @@ -886,8 +875,6 @@ virNetlinkEventCallback(int watch, if (!handled) VIR_DEBUG("event not handled."); - - virNetlinkEventServerUnlock(srv); } /** @@ -916,20 +903,20 @@ virNetlinkEventServiceStop(unsigned int protocol) if (!server[protocol]) return 0; - virNetlinkEventServerLock(srv); - nl_close(srv->netlinknh); - virNetlinkFree(srv->netlinknh); - virEventRemoveHandle(srv->eventwatch); + VIR_WITH_MUTEX_LOCK_GUARD(&srv->lock) { + nl_close(srv->netlinknh); + virNetlinkFree(srv->netlinknh); + virEventRemoveHandle(srv->eventwatch); - /* free any remaining clients on the list */ - for (i = 0; i < srv->handlesCount; i++) { - if (srv->handles[i].deleted == VIR_NETLINK_HANDLE_VALID) - virNetlinkEventRemoveClientPrimitive(i, protocol); - } + /* free any remaining clients on the list */ + for (i = 0; i < srv->handlesCount; i++) { + if (srv->handles[i].deleted == VIR_NETLINK_HANDLE_VALID) + virNetlinkEventRemoveClientPrimitive(i, protocol); + } - server[protocol] = NULL; - VIR_FREE(srv->handles); - virNetlinkEventServerUnlock(srv); + server[protocol] = NULL; + VIR_FREE(srv->handles); + } virMutexDestroy(&srv->lock); VIR_FREE(srv); @@ -1014,9 +1001,9 @@ int virNetlinkEventServiceLocalPid(unsigned int protocol) int virNetlinkEventServiceStart(unsigned int protocol, unsigned int groups) { - virNetlinkEventSrvPrivate *srv; + g_autofree virNetlinkEventSrvPrivate *srv = NULL; + VIR_LOCK_GUARD lock = { NULL }; int fd; - int ret = -1; if (protocol >= MAX_LINKS) { virReportSystemError(EINVAL, @@ -1031,34 +1018,32 @@ virNetlinkEventServiceStart(unsigned int protocol, unsigned int groups) srv = g_new0(virNetlinkEventSrvPrivate, 1); - if (virMutexInit(&srv->lock) < 0) { - VIR_FREE(srv); + if (virMutexInit(&srv->lock) < 0) return -1; - } - virNetlinkEventServerLock(srv); + lock = virLockGuardLock(&srv->lock); /* Allocate a new socket and get fd */ if (!(srv->netlinknh = virNetlinkCreateSocket(protocol))) - goto error_locked; + goto error; fd = nl_socket_get_fd(srv->netlinknh); if (fd < 0) { virReportSystemError(errno, "%s", _("cannot get netlink socket fd")); - goto error_server; + goto error; } if (groups && nl_socket_add_membership(srv->netlinknh, groups) < 0) { virReportSystemError(errno, "%s", _("cannot add netlink membership")); - goto error_server; + goto error; } if (nl_socket_set_nonblocking(srv->netlinknh)) { virReportSystemError(errno, "%s", _("cannot set netlink socket nonblocking")); - goto error_server; + goto error; } if ((srv->eventwatch = virEventAddHandle(fd, @@ -1067,27 +1052,24 @@ virNetlinkEventServiceStart(unsigned int protocol, unsigned int groups) srv, NULL)) < 0) { virReportError(VIR_ERR_INTERNAL_ERROR, "%s", _("Failed to add netlink event handle watch")); - goto error_server; + goto error; } srv->netlinkfd = fd; VIR_DEBUG("netlink event listener on fd: %i running", fd); - ret = 0; - server[protocol] = srv; + server[protocol] = g_steal_pointer(&srv); + return 0; - error_server: - if (ret < 0) { + error: + if (srv->netlinknh) { nl_close(srv->netlinknh); virNetlinkFree(srv->netlinknh); } - error_locked: - virNetlinkEventServerUnlock(srv); - if (ret < 0) { - virMutexDestroy(&srv->lock); - VIR_FREE(srv); - } - return ret; + + virLockGuardUnlock(&lock); + virMutexDestroy(&srv->lock); + return -1; } /** @@ -1114,8 +1096,9 @@ virNetlinkEventAddClient(virNetlinkEventHandleCallback handleCB, unsigned int protocol) { size_t i; - int r, ret = -1; + int r; virNetlinkEventSrvPrivate *srv = NULL; + VIR_LOCK_GUARD lock = { NULL }; if (protocol >= MAX_LINKS) return -EINVAL; @@ -1128,7 +1111,7 @@ virNetlinkEventAddClient(virNetlinkEventHandleCallback handleCB, return -1; } - virNetlinkEventServerLock(srv); + lock = virLockGuardLock(&srv->lock); VIR_DEBUG("adding client: %d.", nextWatch); @@ -1163,10 +1146,7 @@ virNetlinkEventAddClient(virNetlinkEventHandleCallback handleCB, VIR_DEBUG("added client to loop slot: %d. with macaddr ptr=%p", r, macaddr); - ret = nextWatch++; - - virNetlinkEventServerUnlock(srv); - return ret; + return nextWatch++; } /** @@ -1187,8 +1167,8 @@ virNetlinkEventRemoveClient(int watch, const virMacAddr *macaddr, unsigned int protocol) { size_t i; - int ret = -1; virNetlinkEventSrvPrivate *srv = NULL; + VIR_LOCK_GUARD lock = { NULL }; if (protocol >= MAX_LINKS) return -EINVAL; @@ -1202,7 +1182,7 @@ virNetlinkEventRemoveClient(int watch, const virMacAddr *macaddr, return -1; } - virNetlinkEventServerLock(srv); + lock = virLockGuardLock(&srv->lock); for (i = 0; i < srv->handlesCount; i++) { if (srv->handles[i].deleted != VIR_NETLINK_HANDLE_VALID) @@ -1215,15 +1195,12 @@ virNetlinkEventRemoveClient(int watch, const virMacAddr *macaddr, VIR_DEBUG("removed client: %d by %s.", srv->handles[i].watch, watch ? "index" : "mac"); virNetlinkEventRemoveClientPrimitive(i, protocol); - ret = 0; - goto cleanup; + return 0; } } VIR_DEBUG("no client found to remove."); - cleanup: - virNetlinkEventServerUnlock(srv); - return ret; + return -1; } #else -- 2.31.1