Signed-off-by: Tim Wiederhake <twiederh@xxxxxxxxxx> --- src/rpc/virnetserverclient.c | 432 +++++++++++++++-------------------- 1 file changed, 186 insertions(+), 246 deletions(-) diff --git a/src/rpc/virnetserverclient.c b/src/rpc/virnetserverclient.c index 7d5c0965b8..da9956f2b4 100644 --- a/src/rpc/virnetserverclient.c +++ b/src/rpc/virnetserverclient.c @@ -234,14 +234,12 @@ int virNetServerClientAddFilter(virNetServerClient *client, virNetServerClientFilterFunc func, void *opaque) { + VIR_LOCK_GUARD lock = virObjectLockGuard(client); virNetServerClientFilter *filter; virNetServerClientFilter **place; - int ret; filter = g_new0(virNetServerClientFilter, 1); - virObjectLock(client); - filter->id = client->nextFilterID++; filter->func = func; filter->opaque = opaque; @@ -251,21 +249,16 @@ int virNetServerClientAddFilter(virNetServerClient *client, place = &(*place)->next; *place = filter; - ret = filter->id; - - virObjectUnlock(client); - - return ret; + return filter->id; } void virNetServerClientRemoveFilter(virNetServerClient *client, int filterID) { + VIR_LOCK_GUARD lock = virObjectLockGuard(client); virNetServerClientFilter *tmp; virNetServerClientFilter *prev; - virObjectLock(client); - prev = NULL; tmp = client->filters; while (tmp) { @@ -281,8 +274,6 @@ void virNetServerClientRemoveFilter(virNetServerClient *client, prev = tmp; tmp = tmp->next; } - - virObjectUnlock(client); } @@ -322,19 +313,19 @@ virNetServerClientCheckAccess(virNetServerClient *client) static void virNetServerClientDispatchMessage(virNetServerClient *client, virNetMessage *msg) { - virObjectLock(client); - if (!client->dispatchFunc) { - virNetMessageFree(msg); - client->wantClose = true; - virObjectUnlock(client); - } else { - virObjectUnlock(client); - /* Accessing 'client' is safe, because virNetServerClientSetDispatcher - * only permits setting 'dispatchFunc' once, so if non-NULL, it will - * never change again - */ - client->dispatchFunc(client, msg, client->dispatchOpaque); + VIR_WITH_OBJECT_LOCK_GUARD(client) { + if (!client->dispatchFunc) { + virNetMessageFree(msg); + client->wantClose = true; + return; + } } + + /* Accessing 'client' is safe, because virNetServerClientSetDispatcher + * only permits setting 'dispatchFunc' once, so if non-NULL, it will + * never change again + */ + client->dispatchFunc(client, msg, client->dispatchOpaque); } @@ -343,13 +334,14 @@ static void virNetServerClientSockTimerFunc(int timer, { virNetServerClient *client = opaque; virNetMessage *msg = NULL; - virObjectLock(client); - virEventUpdateTimeout(timer, -1); - /* Although client->rx != NULL when this timer is enabled, it might have - * changed since the client was unlocked in the meantime. */ - if (client->rx) - msg = virNetServerClientDispatchRead(client); - virObjectUnlock(client); + + VIR_WITH_OBJECT_LOCK_GUARD(client) { + virEventUpdateTimeout(timer, -1); + /* Although client->rx != NULL when this timer is enabled, it might have + * changed since the client was unlocked in the meantime. */ + if (client->rx) + msg = virNetServerClientDispatchRead(client); + } if (msg) virNetServerClientDispatchMessage(client, msg); @@ -587,53 +579,45 @@ virJSONValue *virNetServerClientPreExecRestart(virNetServerClient *client) g_autoptr(virJSONValue) object = virJSONValueNewObject(); g_autoptr(virJSONValue) sock = NULL; g_autoptr(virJSONValue) priv = NULL; - - virObjectLock(client); + VIR_LOCK_GUARD lock = virObjectLockGuard(client); if (virJSONValueObjectAppendNumberUlong(object, "id", client->id) < 0) - goto error; + return NULL; if (virJSONValueObjectAppendNumberInt(object, "auth", client->auth) < 0) - goto error; + return NULL; if (virJSONValueObjectAppendBoolean(object, "auth_pending", client->auth_pending) < 0) - goto error; + return NULL; if (virJSONValueObjectAppendBoolean(object, "readonly", client->readonly) < 0) - goto error; + return NULL; if (virJSONValueObjectAppendNumberUint(object, "nrequests_max", client->nrequests_max) < 0) - goto error; + return NULL; if (client->conn_time && virJSONValueObjectAppendNumberLong(object, "conn_time", client->conn_time) < 0) - goto error; + return NULL; if (!(sock = virNetSocketPreExecRestart(client->sock))) - goto error; + return NULL; if (virJSONValueObjectAppend(object, "sock", &sock) < 0) - goto error; + return NULL; if (!(priv = client->privateDataPreExecRestart(client, client->privateData))) - goto error; + return NULL; if (virJSONValueObjectAppend(object, "privateData", &priv) < 0) - goto error; + return NULL; - virObjectUnlock(client); return g_steal_pointer(&object); - - error: - virObjectUnlock(client); - return NULL; } int virNetServerClientGetAuth(virNetServerClient *client) { - int auth; - virObjectLock(client); - auth = client->auth; - virObjectUnlock(client); - return auth; + VIR_LOCK_GUARD lock = virObjectLockGuard(client); + + return client->auth; } @@ -647,11 +631,9 @@ virNetServerClientSetAuthLocked(virNetServerClient *client, bool virNetServerClientGetReadonly(virNetServerClient *client) { - bool readonly; - virObjectLock(client); - readonly = client->readonly; - virObjectUnlock(client); - return readonly; + VIR_LOCK_GUARD lock = virObjectLockGuard(client); + + return client->readonly; } @@ -659,9 +641,9 @@ void virNetServerClientSetReadonly(virNetServerClient *client, bool readonly) { - virObjectLock(client); + VIR_LOCK_GUARD lock = virObjectLockGuard(client); + client->readonly = readonly; - virObjectUnlock(client); } @@ -677,52 +659,48 @@ long long virNetServerClientGetTimestamp(virNetServerClient *client) bool virNetServerClientHasTLSSession(virNetServerClient *client) { - bool has; - virObjectLock(client); - has = client->tls ? true : false; - virObjectUnlock(client); - return has; + VIR_LOCK_GUARD lock = virObjectLockGuard(client); + + return !!client->tls; } virNetTLSSession *virNetServerClientGetTLSSession(virNetServerClient *client) { - virNetTLSSession *tls; - virObjectLock(client); - tls = client->tls; - virObjectUnlock(client); - return tls; + VIR_LOCK_GUARD lock = virObjectLockGuard(client); + + return client->tls; } int virNetServerClientGetTLSKeySize(virNetServerClient *client) { - int size = 0; - virObjectLock(client); - if (client->tls) - size = virNetTLSSessionGetKeySize(client->tls); - virObjectUnlock(client); - return size; + VIR_LOCK_GUARD lock = virObjectLockGuard(client); + + if (!client->tls) + return 0; + + return virNetTLSSessionGetKeySize(client->tls); } int virNetServerClientGetFD(virNetServerClient *client) { - int fd = -1; - virObjectLock(client); - if (client->sock) - fd = virNetSocketGetFD(client->sock); - virObjectUnlock(client); - return fd; + VIR_LOCK_GUARD lock = virObjectLockGuard(client); + + if (!client->sock) + return -1; + + return virNetSocketGetFD(client->sock); } bool virNetServerClientIsLocal(virNetServerClient *client) { - bool local = false; - virObjectLock(client); - if (client->sock) - local = virNetSocketIsLocal(client->sock); - virObjectUnlock(client); - return local; + VIR_LOCK_GUARD lock = virObjectLockGuard(client); + + if (!client->sock) + return false; + + return virNetSocketIsLocal(client->sock); } @@ -730,14 +708,12 @@ int virNetServerClientGetUNIXIdentity(virNetServerClient *client, uid_t *uid, gid_t *gid, pid_t *pid, unsigned long long *timestamp) { - int ret = -1; - virObjectLock(client); - if (client->sock) - ret = virNetSocketGetUNIXIdentity(client->sock, - uid, gid, pid, - timestamp); - virObjectUnlock(client); - return ret; + VIR_LOCK_GUARD lock = virObjectLockGuard(client); + + if (!client->sock) + return -1; + + return virNetSocketGetUNIXIdentity(client->sock, uid, gid, pid, timestamp); } @@ -806,56 +782,60 @@ virNetServerClientCreateIdentity(virNetServerClient *client) virIdentity *virNetServerClientGetIdentity(virNetServerClient *client) { - virIdentity *ret = NULL; - virObjectLock(client); + VIR_LOCK_GUARD lock = virObjectLockGuard(client); + if (!client->identity) client->identity = virNetServerClientCreateIdentity(client); - if (client->identity) - ret = g_object_ref(client->identity); - virObjectUnlock(client); - return ret; + + if (!client->identity) + return NULL; + + return g_object_ref(client->identity); } void virNetServerClientSetIdentity(virNetServerClient *client, virIdentity *identity) { - virObjectLock(client); + VIR_LOCK_GUARD lock = virObjectLockGuard(client); + g_clear_object(&client->identity); client->identity = identity; if (client->identity) g_object_ref(client->identity); - virObjectUnlock(client); } int virNetServerClientGetSELinuxContext(virNetServerClient *client, char **context) { - int ret = 0; + VIR_LOCK_GUARD lock = virObjectLockGuard(client); + *context = NULL; - virObjectLock(client); - if (client->sock) - ret = virNetSocketGetSELinuxContext(client->sock, context); - virObjectUnlock(client); - return ret; + + if (!client->sock) + return 0; + + return virNetSocketGetSELinuxContext(client->sock, context); } bool virNetServerClientIsSecure(virNetServerClient *client) { - bool secure = false; - virObjectLock(client); + VIR_LOCK_GUARD lock = virObjectLockGuard(client); + if (client->tls) - secure = true; + return true; + #if WITH_SASL if (client->sasl) - secure = true; + return true; #endif + if (client->sock && virNetSocketIsLocal(client->sock)) - secure = true; - virObjectUnlock(client); - return secure; + return true; + + return false; } @@ -863,53 +843,47 @@ bool virNetServerClientIsSecure(virNetServerClient *client) void virNetServerClientSetSASLSession(virNetServerClient *client, virNetSASLSession *sasl) { + VIR_LOCK_GUARD lock = virObjectLockGuard(client); + /* We don't set the sasl session on the socket here * because we need to send out the auth confirmation * in the clear. Only once we complete the next 'tx' * operation do we switch to SASL mode */ - virObjectLock(client); client->sasl = virObjectRef(sasl); - virObjectUnlock(client); } virNetSASLSession *virNetServerClientGetSASLSession(virNetServerClient *client) { - virNetSASLSession *sasl; - virObjectLock(client); - sasl = client->sasl; - virObjectUnlock(client); - return sasl; + VIR_LOCK_GUARD lock = virObjectLockGuard(client); + + return client->sasl; } bool virNetServerClientHasSASLSession(virNetServerClient *client) { - bool has = false; - virObjectLock(client); - has = !!client->sasl; - virObjectUnlock(client); - return has; + VIR_LOCK_GUARD lock = virObjectLockGuard(client); + + return !!client->sasl; } #endif void *virNetServerClientGetPrivateData(virNetServerClient *client) { - void *data; - virObjectLock(client); - data = client->privateData; - virObjectUnlock(client); - return data; + VIR_LOCK_GUARD lock = virObjectLockGuard(client); + + return client->privateData; } void virNetServerClientSetCloseHook(virNetServerClient *client, virNetServerClientCloseFunc cf) { - virObjectLock(client); + VIR_LOCK_GUARD lock = virObjectLockGuard(client); + client->privateDataCloseFunc = cf; - virObjectUnlock(client); } @@ -917,7 +891,8 @@ void virNetServerClientSetDispatcher(virNetServerClient *client, virNetServerClientDispatchFunc func, void *opaque) { - virObjectLock(client); + VIR_LOCK_GUARD lock = virObjectLockGuard(client); + /* Only set dispatcher if not already set, to avoid race * with dispatch code that runs without locks held */ @@ -925,7 +900,6 @@ void virNetServerClientSetDispatcher(virNetServerClient *client, client->dispatchFunc = func; client->dispatchOpaque = opaque; } - virObjectUnlock(client); } @@ -1042,9 +1016,9 @@ virNetServerClientCloseLocked(virNetServerClient *client) void virNetServerClientClose(virNetServerClient *client) { - virObjectLock(client); + VIR_LOCK_GUARD lock = virObjectLockGuard(client); + virNetServerClientCloseLocked(client); - virObjectUnlock(client); } @@ -1057,16 +1031,16 @@ virNetServerClientIsClosedLocked(virNetServerClient *client) void virNetServerClientDelayedClose(virNetServerClient *client) { - virObjectLock(client); + VIR_LOCK_GUARD lock = virObjectLockGuard(client); + client->delayedClose = true; - virObjectUnlock(client); } void virNetServerClientImmediateClose(virNetServerClient *client) { - virObjectLock(client); + VIR_LOCK_GUARD lock = virObjectLockGuard(client); + client->wantClose = true; - virObjectUnlock(client); } @@ -1079,49 +1053,46 @@ virNetServerClientWantCloseLocked(virNetServerClient *client) int virNetServerClientInit(virNetServerClient *client) { - virObjectLock(client); + VIR_LOCK_GUARD lock = virObjectLockGuard(client); + int ret = -1; if (!client->tlsCtxt) { /* Plain socket, so prepare to read first message */ if (virNetServerClientRegisterEvent(client) < 0) goto error; - } else { - int ret; + return 0; + } - if (!(client->tls = virNetTLSSessionNew(client->tlsCtxt, - NULL))) - goto error; + if (!(client->tls = virNetTLSSessionNew(client->tlsCtxt, NULL))) + goto error; - virNetSocketSetTLSSession(client->sock, - client->tls); + virNetSocketSetTLSSession(client->sock, client->tls); - /* Begin the TLS handshake. */ - virObjectLock(client->tlsCtxt); + /* Begin the TLS handshake. */ + VIR_WITH_OBJECT_LOCK_GUARD(client->tlsCtxt) { ret = virNetTLSSessionHandshake(client->tls); - virObjectUnlock(client->tlsCtxt); - if (ret == 0) { - /* Unlikely, but ... Next step is to check the certificate. */ - if (virNetServerClientCheckAccess(client) < 0) - goto error; - - /* Handshake & cert check OK, so prepare to read first message */ - if (virNetServerClientRegisterEvent(client) < 0) - goto error; - } else if (ret > 0) { - /* Most likely, need to do more handshake data */ - if (virNetServerClientRegisterEvent(client) < 0) - goto error; - } else { + } + + if (ret == 0) { + /* Unlikely, but ... Next step is to check the certificate. */ + if (virNetServerClientCheckAccess(client) < 0) goto error; - } + + /* Handshake & cert check OK, so prepare to read first message */ + if (virNetServerClientRegisterEvent(client) < 0) + goto error; + } else if (ret > 0) { + /* Most likely, need to do more handshake data */ + if (virNetServerClientRegisterEvent(client) < 0) + goto error; + } else { + goto error; } - virObjectUnlock(client); return 0; error: client->wantClose = true; - virObjectUnlock(client); return -1; } @@ -1406,11 +1377,13 @@ virNetServerClientDispatchWrite(virNetServerClient *client) static void virNetServerClientDispatchHandshake(virNetServerClient *client) { - int ret; + int ret = -1; + /* Continue the handshake. */ - virObjectLock(client->tlsCtxt); - ret = virNetTLSSessionHandshake(client->tls); - virObjectUnlock(client->tlsCtxt); + VIR_WITH_OBJECT_LOCK_GUARD(client->tlsCtxt) { + ret = virNetTLSSessionHandshake(client->tls); + } + if (ret == 0) { /* Finished. Next step is to check the certificate. */ if (virNetServerClientCheckAccess(client) < 0) @@ -1435,36 +1408,29 @@ virNetServerClientDispatchEvent(virNetSocket *sock, int events, void *opaque) virNetServerClient *client = opaque; virNetMessage *msg = NULL; - virObjectLock(client); - - if (client->sock != sock) { - virNetSocketRemoveIOCallback(sock); - virObjectUnlock(client); - return; - } - - if (events & (VIR_EVENT_HANDLE_WRITABLE | - VIR_EVENT_HANDLE_READABLE)) { - if (client->tls && - virNetTLSSessionGetHandshakeStatus(client->tls) != - VIR_NET_TLS_HANDSHAKE_COMPLETE) { - virNetServerClientDispatchHandshake(client); - } else { - if (events & VIR_EVENT_HANDLE_WRITABLE) - virNetServerClientDispatchWrite(client); - if (events & VIR_EVENT_HANDLE_READABLE && - client->rx) - msg = virNetServerClientDispatchRead(client); + VIR_WITH_OBJECT_LOCK_GUARD(client) { + if (client->sock != sock) { + virNetSocketRemoveIOCallback(sock); + return; } - } - /* NB, will get HANGUP + READABLE at same time upon - * disconnect */ - if (events & (VIR_EVENT_HANDLE_ERROR | - VIR_EVENT_HANDLE_HANGUP)) - client->wantClose = true; + if (events & (VIR_EVENT_HANDLE_WRITABLE | VIR_EVENT_HANDLE_READABLE)) { + if (client->tls && + virNetTLSSessionGetHandshakeStatus(client->tls) != + VIR_NET_TLS_HANDSHAKE_COMPLETE) { + virNetServerClientDispatchHandshake(client); + } else { + if (events & VIR_EVENT_HANDLE_WRITABLE) + virNetServerClientDispatchWrite(client); + if ((events & VIR_EVENT_HANDLE_READABLE) && client->rx) + msg = virNetServerClientDispatchRead(client); + } + } - virObjectUnlock(client); + /* NB, will get HANGUP + READABLE at same time upon disconnect */ + if (events & (VIR_EVENT_HANDLE_ERROR | VIR_EVENT_HANDLE_HANGUP)) + client->wantClose = true; + } if (msg) virNetServerClientDispatchMessage(client, msg); @@ -1499,24 +1465,18 @@ virNetServerClientSendMessageLocked(virNetServerClient *client, int virNetServerClientSendMessage(virNetServerClient *client, virNetMessage *msg) { - int ret; - - virObjectLock(client); - ret = virNetServerClientSendMessageLocked(client, msg); - virObjectUnlock(client); + VIR_LOCK_GUARD lock = virObjectLockGuard(client); - return ret; + return virNetServerClientSendMessageLocked(client, msg); } bool virNetServerClientIsAuthenticated(virNetServerClient *client) { - bool authenticated; - virObjectLock(client); - authenticated = virNetServerClientAuthMethodImpliesAuthenticated(client->auth); - virObjectUnlock(client); - return authenticated; + VIR_LOCK_GUARD lock = virObjectLockGuard(client); + + return virNetServerClientAuthMethodImpliesAuthenticated(client->auth); } @@ -1556,57 +1516,44 @@ virNetServerClientInitKeepAlive(virNetServerClient *client, int interval, unsigned int count) { + VIR_LOCK_GUARD lock = virObjectLockGuard(client); virKeepAlive *ka; - int ret = -1; - - virObjectLock(client); if (!(ka = virKeepAliveNew(interval, count, client, virNetServerClientKeepAliveSendCB, virNetServerClientKeepAliveDeadCB, virObjectFreeCallback))) - goto cleanup; + return -1; + /* keepalive object has a reference to client */ virObjectRef(client); client->keepalive = ka; - ret = 0; - cleanup: - virObjectUnlock(client); - - return ret; + return 0; } int virNetServerClientStartKeepAlive(virNetServerClient *client) { - int ret = -1; - - virObjectLock(client); + VIR_LOCK_GUARD lock = virObjectLockGuard(client); /* The connection might have been closed before we got here and thus the * keepalive object could have been removed too. */ if (!client->keepalive) { - virReportError(VIR_ERR_INTERNAL_ERROR, "%s", - _("connection not open")); - goto cleanup; + virReportError(VIR_ERR_INTERNAL_ERROR, "%s", _("connection not open")); + return -1; } - ret = virKeepAliveStart(client->keepalive, 0, 0); - - cleanup: - virObjectUnlock(client); - return ret; + return virKeepAliveStart(client->keepalive, 0, 0); } int virNetServerClientGetTransport(virNetServerClient *client) { + VIR_LOCK_GUARD lock = virObjectLockGuard(client); int ret = -1; - virObjectLock(client); - if (client->sock && virNetSocketIsLocal(client->sock)) ret = VIR_CLIENT_TRANS_UNIX; else @@ -1615,8 +1562,6 @@ virNetServerClientGetTransport(virNetServerClient *client) if (client->tls) ret = VIR_CLIENT_TRANS_TLS; - virObjectUnlock(client); - return ret; } @@ -1625,16 +1570,15 @@ virNetServerClientGetInfo(virNetServerClient *client, bool *readonly, char **sock_addr, virIdentity **identity) { - int ret = -1; + VIR_LOCK_GUARD lock = virObjectLockGuard(client); const char *addr; - virObjectLock(client); *readonly = client->readonly; if (!(addr = virNetServerClientRemoteAddrStringURI(client))) { virReportError(VIR_ERR_INTERNAL_ERROR, "%s", _("No network socket associated with client")); - goto cleanup; + return -1; } *sock_addr = g_strdup(addr); @@ -1642,15 +1586,11 @@ virNetServerClientGetInfo(virNetServerClient *client, if (!client->identity) { virReportError(VIR_ERR_INTERNAL_ERROR, "%s", _("No identity information available for client")); - goto cleanup; + return -1; } *identity = g_object_ref(client->identity); - - ret = 0; - cleanup: - virObjectUnlock(client); - return ret; + return 0; } -- 2.31.1