From: "Daniel P. Berrange" <berrange@xxxxxxxxxx> Remove the need for a virNetSocket object to be protected by locks from the object using it, by introducing its own native locking and reference counting * src/rpc/virnetsocket.c: Add locking & reference counting --- src/rpc/virnetsocket.c | 147 +++++++++++++++++++++++++++++++++++++++--------- 1 files changed, 120 insertions(+), 27 deletions(-) diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c index 7ea1ab7..8dd4d3a 100644 --- a/src/rpc/virnetsocket.c +++ b/src/rpc/virnetsocket.c @@ -40,6 +40,7 @@ #include "logging.h" #include "files.h" #include "event.h" +#include "threads.h" #define VIR_FROM_THIS VIR_FROM_RPC @@ -49,6 +50,9 @@ struct _virNetSocket { + virMutex lock; + int refs; + int fd; int watch; pid_t pid; @@ -122,6 +126,14 @@ static virNetSocketPtr virNetSocketNew(virSocketAddrPtr localAddr, return NULL; } + if (virMutexInit(&sock->lock) < 0) { + virReportSystemError(errno, "%s", + _("Unable to initialize mutex")); + VIR_FREE(sock); + return NULL; + } + sock->refs = 1; + if (localAddr) sock->localAddr = *localAddr; if (remoteAddr) @@ -627,6 +639,13 @@ void virNetSocketFree(virNetSocketPtr sock) if (!sock) return; + virMutexLock(&sock->lock); + sock->refs--; + if (sock->refs > 0) { + virMutexUnlock(&sock->lock); + return; + } + VIR_DEBUG("sock=%p fd=%d", sock, sock->fd); if (sock->watch > 0) { virEventRemoveHandle(sock->watch); @@ -657,27 +676,41 @@ void virNetSocketFree(virNetSocketPtr sock) VIR_FREE(sock->localAddrStr); VIR_FREE(sock->remoteAddrStr); + virMutexUnlock(&sock->lock); + virMutexDestroy(&sock->lock); + VIR_FREE(sock); } int virNetSocketGetFD(virNetSocketPtr sock) { - return sock->fd; + int fd; + virMutexLock(&sock->lock); + fd = sock->fd; + virMutexUnlock(&sock->lock); + return fd; } bool virNetSocketIsLocal(virNetSocketPtr sock) { + bool isLocal = false; + virMutexLock(&sock->lock); if (sock->localAddr.data.sa.sa_family == AF_UNIX) - return true; - return false; + isLocal = true; + virMutexUnlock(&sock->lock); + return isLocal; } int virNetSocketGetPort(virNetSocketPtr sock) { - return virSocketGetPort(&sock->localAddr); + int port; + virMutexLock(&sock->lock); + port = virSocketGetPort(&sock->localAddr); + virMutexUnlock(&sock->lock); + return port; } @@ -688,15 +721,19 @@ int virNetSocketGetLocalIdentity(virNetSocketPtr sock, { struct ucred cr; unsigned int cr_len = sizeof (cr); + virMutexLock(&sock->lock); if (getsockopt(sock->fd, SOL_SOCKET, SO_PEERCRED, &cr, &cr_len) < 0) { virReportSystemError(errno, "%s", _("Failed to get client socket identity")); + virMutexUnlock(&sock->lock); return -1; } *pid = cr.pid; *uid = cr.uid; + + virMutexUnlock(&sock->lock); return 0; } #else @@ -715,7 +752,11 @@ int virNetSocketGetLocalIdentity(virNetSocketPtr sock ATTRIBUTE_UNUSED, int virNetSocketSetBlocking(virNetSocketPtr sock, bool blocking) { - return virSetBlocking(sock->fd, blocking); + int ret; + virMutexLock(&sock->lock); + ret = virSetBlocking(sock->fd, blocking); + virMutexUnlock(&sock->lock); + return ret; } @@ -751,6 +792,7 @@ static ssize_t virNetSocketTLSSessionRead(char *buf, void virNetSocketSetTLSSession(virNetSocketPtr sock, virNetTLSSessionPtr sess) { + virMutexLock(&sock->lock); virNetTLSSessionFree(sock->tlsSession); sock->tlsSession = sess; virNetTLSSessionSetIOCallbacks(sess, @@ -758,6 +800,7 @@ void virNetSocketSetTLSSession(virNetSocketPtr sock, virNetSocketTLSSessionRead, sock); virNetTLSSessionRef(sess); + virMutexUnlock(&sock->lock); } @@ -765,20 +808,25 @@ void virNetSocketSetTLSSession(virNetSocketPtr sock, void virNetSocketSetSASLSession(virNetSocketPtr sock, virNetSASLSessionPtr sess) { + virMutexLock(&sock->lock); virNetSASLSessionFree(sock->saslSession); sock->saslSession = sess; virNetSASLSessionRef(sess); + virMutexUnlock(&sock->lock); } #endif bool virNetSocketHasCachedData(virNetSocketPtr sock ATTRIBUTE_UNUSED) { + bool hasCached = false; + virMutexLock(&sock->lock); #if HAVE_SASL if (sock->saslDecoded) - return true; + hasCached = true; #endif - return false; + virMutexUnlock(&sock->lock); + return hasCached; } @@ -965,39 +1013,54 @@ static ssize_t virNetSocketWriteSASL(virNetSocketPtr sock, const char *buf, size ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len) { + ssize_t ret; + virMutexLock(&sock->lock); #if HAVE_SASL if (sock->saslSession) - return virNetSocketReadSASL(sock, buf, len); + ret = virNetSocketReadSASL(sock, buf, len); else #endif - return virNetSocketReadWire(sock, buf, len); + ret = virNetSocketReadWire(sock, buf, len); + virMutexUnlock(&sock->lock); + return ret; } ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len) { + ssize_t ret; + + virMutexLock(&sock->lock); #if HAVE_SASL if (sock->saslSession) - return virNetSocketWriteSASL(sock, buf, len); + ret = virNetSocketWriteSASL(sock, buf, len); else #endif - return virNetSocketWriteWire(sock, buf, len); + ret = virNetSocketWriteWire(sock, buf, len); + virMutexUnlock(&sock->lock); + return ret; } int virNetSocketListen(virNetSocketPtr sock) { + virMutexLock(&sock->lock); if (listen(sock->fd, 30) < 0) { virReportSystemError(errno, "%s", _("Unable to listen on socket")); + virMutexUnlock(&sock->lock); return -1; } + virMutexUnlock(&sock->lock); return 0; } int virNetSocketAccept(virNetSocketPtr sock, virNetSocketPtr *clientsock) { - int fd; + int fd = -1; virSocketAddr localAddr; virSocketAddr remoteAddr; + int ret = -1; + + virMutexLock(&sock->lock); *clientsock = NULL; @@ -1007,30 +1070,35 @@ int virNetSocketAccept(virNetSocketPtr sock, virNetSocketPtr *clientsock) remoteAddr.len = sizeof(remoteAddr.data.stor); if ((fd = accept(sock->fd, &remoteAddr.data.sa, &remoteAddr.len)) < 0) { if (errno == ECONNABORTED || - errno == EAGAIN) - return 0; + errno == EAGAIN) { + ret = 0; + goto cleanup; + } virReportSystemError(errno, "%s", _("Unable to accept client")); - return -1; + goto cleanup; } localAddr.len = sizeof(localAddr.data); if (getsockname(fd, &localAddr.data.sa, &localAddr.len) < 0) { virReportSystemError(errno, "%s", _("Unable to get local socket name")); - VIR_FORCE_CLOSE(fd); - return -1; + goto cleanup; } if (!(*clientsock = virNetSocketNew(&localAddr, &remoteAddr, true, - fd, -1, 0))) { - VIR_FORCE_CLOSE(fd); - return -1; - } + fd, -1, 0))) + goto cleanup; - return 0; + fd = -1; + ret = 0; + +cleanup: + VIR_FORCE_CLOSE(fd); + virMutexUnlock(&sock->lock); + return ret; } @@ -1040,52 +1108,77 @@ static void virNetSocketEventHandle(int watch ATTRIBUTE_UNUSED, void *opaque) { virNetSocketPtr sock = opaque; + virNetSocketIOFunc func; + void *eopaque; - sock->func(sock, events, sock->opaque); + virMutexLock(&sock->lock); + func = sock->func; + eopaque = sock->opaque; + virMutexUnlock(&sock->lock); + + if (func) + func(sock, events, eopaque); } + int virNetSocketAddIOCallback(virNetSocketPtr sock, int events, virNetSocketIOFunc func, void *opaque) { + int ret = -1; + + virMutexLock(&sock->lock); if (sock->watch > 0) { VIR_DEBUG("Watch already registered on socket %p", sock); - return -1; + goto cleanup; } + sock->refs++; if ((sock->watch = virEventAddHandle(sock->fd, events, virNetSocketEventHandle, sock, NULL)) < 0) { VIR_DEBUG("Failed to register watch on socket %p", sock); - return -1; + goto cleanup; } sock->func = func; sock->opaque = opaque; - return 0; + ret = 0; + +cleanup: + virMutexUnlock(&sock->lock); + return ret; } void virNetSocketUpdateIOCallback(virNetSocketPtr sock, int events) { + virMutexLock(&sock->lock); if (sock->watch <= 0) { VIR_DEBUG("Watch not registered on socket %p", sock); + virMutexUnlock(&sock->lock); return; } virEventUpdateHandle(sock->watch, events); + + virMutexUnlock(&sock->lock); } void virNetSocketRemoveIOCallback(virNetSocketPtr sock) { + virMutexLock(&sock->lock); + if (sock->watch <= 0) { VIR_DEBUG("Watch not registered on socket %p", sock); + virMutexUnlock(&sock->lock); return; } virEventRemoveHandle(sock->watch); - sock->watch = 0; + + virMutexUnlock(&sock->lock); } -- 1.7.6 -- libvir-list mailing list libvir-list@xxxxxxxxxx https://www.redhat.com/mailman/listinfo/libvir-list