From: "Daniel P. Berrange" <berrange@xxxxxxxxxx> Make virNetSASLContext and virNetSASLSession use virObject APIs for reference counting Signed-off-by: Daniel P. Berrange <berrange@xxxxxxxxxx> --- daemon/remote.c | 8 ++-- src/remote/remote_driver.c | 4 +- src/rpc/virnetclient.c | 7 ++- src/rpc/virnetsaslcontext.c | 106 ++++++++++++++++++------------------------ src/rpc/virnetsaslcontext.h | 8 +--- src/rpc/virnetserverclient.c | 7 ++- src/rpc/virnetsocket.c | 7 ++- 7 files changed, 61 insertions(+), 86 deletions(-) diff --git a/daemon/remote.c b/daemon/remote.c index d25717c..832307e 100644 --- a/daemon/remote.c +++ b/daemon/remote.c @@ -2325,7 +2325,7 @@ authfail: PROBE(RPC_SERVER_CLIENT_AUTH_FAIL, "client=%p auth=%d", client, REMOTE_AUTH_SASL); - virNetSASLSessionFree(sasl); + virObjectUnref(sasl); virMutexUnlock(&priv->lock); return -1; } @@ -2369,7 +2369,7 @@ remoteSASLFinish(virNetServerClientPtr client) "client=%p auth=%d identity=%s", client, REMOTE_AUTH_SASL, identity); - virNetSASLSessionFree(priv->sasl); + virObjectUnref(priv->sasl); priv->sasl = NULL; return 0; @@ -2467,7 +2467,7 @@ authdeny: goto error; error: - virNetSASLSessionFree(priv->sasl); + virObjectUnref(priv->sasl); priv->sasl = NULL; virResetLastError(); virReportError(VIR_ERR_AUTH_FAILED, "%s", @@ -2565,7 +2565,7 @@ authdeny: goto error; error: - virNetSASLSessionFree(priv->sasl); + virObjectUnref(priv->sasl); priv->sasl = NULL; virResetLastError(); virReportError(VIR_ERR_AUTH_FAILED, "%s", diff --git a/src/remote/remote_driver.c b/src/remote/remote_driver.c index 91d337f..bfea919 100644 --- a/src/remote/remote_driver.c +++ b/src/remote/remote_driver.c @@ -3409,8 +3409,8 @@ remoteAuthSASL (virConnectPtr conn, struct private_data *priv, remoteAuthInteractStateClear(&state, true); VIR_FREE(saslcb); - virNetSASLSessionFree(sasl); - virNetSASLContextFree(saslCtxt); + virObjectUnref(sasl); + virObjectUnref(saslCtxt); return ret; } diff --git a/src/rpc/virnetclient.c b/src/rpc/virnetclient.c index 3a1b831..e611370 100644 --- a/src/rpc/virnetclient.c +++ b/src/rpc/virnetclient.c @@ -497,7 +497,7 @@ void virNetClientFree(virNetClientPtr client) virNetSocketFree(client->sock); virObjectUnref(client->tls); #if HAVE_SASL - virNetSASLSessionFree(client->sasl); + virObjectUnref(client->sasl); #endif virNetClientUnlock(client); virMutexDestroy(&client->lock); @@ -532,7 +532,7 @@ virNetClientCloseLocked(virNetClientPtr client) virObjectUnref(client->tls); client->tls = NULL; #if HAVE_SASL - virNetSASLSessionFree(client->sasl); + virObjectUnref(client->sasl); client->sasl = NULL; #endif ka = client->keepalive; @@ -604,8 +604,7 @@ void virNetClientSetSASLSession(virNetClientPtr client, virNetSASLSessionPtr sasl) { virNetClientLock(client); - client->sasl = sasl; - virNetSASLSessionRef(sasl); + client->sasl = virObjectRef(sasl); virNetSocketSetSASLSession(client->sock, client->sasl); virNetClientUnlock(client); } diff --git a/src/rpc/virnetsaslcontext.c b/src/rpc/virnetsaslcontext.c index af6e237..2feb9a9 100644 --- a/src/rpc/virnetsaslcontext.c +++ b/src/rpc/virnetsaslcontext.c @@ -33,24 +33,52 @@ #define VIR_FROM_THIS VIR_FROM_RPC struct _virNetSASLContext { + virObject object; + virMutex lock; const char *const*usernameWhitelist; - int refs; }; struct _virNetSASLSession { + virObject object; + virMutex lock; sasl_conn_t *conn; - int refs; size_t maxbufsize; }; +static virClassPtr virNetSASLContextClass; +static virClassPtr virNetSASLSessionClass; +static void virNetSASLContextDispose(void *obj); +static void virNetSASLSessionDispose(void *obj); + +static int virNetSASLContextOnceInit(void) +{ + if (!(virNetSASLContextClass = virClassNew("virNetSASLContext", + sizeof(virNetSASLContext), + virNetSASLContextDispose))) + return -1; + + if (!(virNetSASLSessionClass = virClassNew("virNetSASLSession", + sizeof(virNetSASLSession), + virNetSASLSessionDispose))) + return -1; + + return 0; +} + +VIR_ONCE_GLOBAL_INIT(virNetSASLContext) + + virNetSASLContextPtr virNetSASLContextNewClient(void) { virNetSASLContextPtr ctxt; int err; + if (virNetSASLContextInitialize() < 0) + return NULL; + err = sasl_client_init(NULL); if (err != SASL_OK) { virReportError(VIR_ERR_AUTH_FAILED, @@ -59,10 +87,8 @@ virNetSASLContextPtr virNetSASLContextNewClient(void) return NULL; } - if (VIR_ALLOC(ctxt) < 0) { - virReportOOMError(); + if (!(ctxt = virObjectNew(virNetSASLContextClass))) return NULL; - } if (virMutexInit(&ctxt->lock) < 0) { virReportError(VIR_ERR_INTERNAL_ERROR, "%s", @@ -71,8 +97,6 @@ virNetSASLContextPtr virNetSASLContextNewClient(void) return NULL; } - ctxt->refs = 1; - return ctxt; } @@ -81,6 +105,9 @@ virNetSASLContextPtr virNetSASLContextNewServer(const char *const*usernameWhitel virNetSASLContextPtr ctxt; int err; + if (virNetSASLContextInitialize() < 0) + return NULL; + err = sasl_server_init(NULL, "libvirt"); if (err != SASL_OK) { virReportError(VIR_ERR_AUTH_FAILED, @@ -89,10 +116,8 @@ virNetSASLContextPtr virNetSASLContextNewServer(const char *const*usernameWhitel return NULL; } - if (VIR_ALLOC(ctxt) < 0) { - virReportOOMError(); + if (!(ctxt = virObjectNew(virNetSASLContextClass))) return NULL; - } if (virMutexInit(&ctxt->lock) < 0) { virReportError(VIR_ERR_INTERNAL_ERROR, "%s", @@ -102,7 +127,6 @@ virNetSASLContextPtr virNetSASLContextNewServer(const char *const*usernameWhitel } ctxt->usernameWhitelist = usernameWhitelist; - ctxt->refs = 1; return ctxt; } @@ -152,28 +176,11 @@ cleanup: } -void virNetSASLContextRef(virNetSASLContextPtr ctxt) -{ - virMutexLock(&ctxt->lock); - ctxt->refs++; - virMutexUnlock(&ctxt->lock); -} - -void virNetSASLContextFree(virNetSASLContextPtr ctxt) +void virNetSASLContextDispose(void *obj) { - if (!ctxt) - return; - - virMutexLock(&ctxt->lock); - ctxt->refs--; - if (ctxt->refs > 0) { - virMutexUnlock(&ctxt->lock); - return; - } + virNetSASLContextPtr ctxt = obj; - virMutexUnlock(&ctxt->lock); virMutexDestroy(&ctxt->lock); - VIR_FREE(ctxt); } virNetSASLSessionPtr virNetSASLSessionNewClient(virNetSASLContextPtr ctxt ATTRIBUTE_UNUSED, @@ -186,10 +193,8 @@ virNetSASLSessionPtr virNetSASLSessionNewClient(virNetSASLContextPtr ctxt ATTRIB virNetSASLSessionPtr sasl = NULL; int err; - if (VIR_ALLOC(sasl) < 0) { - virReportOOMError(); - goto cleanup; - } + if (!(sasl = virObjectNew(virNetSASLSessionClass))) + return NULL; if (virMutexInit(&sasl->lock) < 0) { virReportError(VIR_ERR_INTERNAL_ERROR, "%s", @@ -198,7 +203,6 @@ virNetSASLSessionPtr virNetSASLSessionNewClient(virNetSASLContextPtr ctxt ATTRIB return NULL; } - sasl->refs = 1; /* Arbitrary size for amount of data we can encode in a single block */ sasl->maxbufsize = 1 << 16; @@ -219,7 +223,7 @@ virNetSASLSessionPtr virNetSASLSessionNewClient(virNetSASLContextPtr ctxt ATTRIB return sasl; cleanup: - virNetSASLSessionFree(sasl); + virObjectUnref(sasl); return NULL; } @@ -231,10 +235,8 @@ virNetSASLSessionPtr virNetSASLSessionNewServer(virNetSASLContextPtr ctxt ATTRIB virNetSASLSessionPtr sasl = NULL; int err; - if (VIR_ALLOC(sasl) < 0) { - virReportOOMError(); - goto cleanup; - } + if (!(sasl = virObjectNew(virNetSASLSessionClass))) + return NULL; if (virMutexInit(&sasl->lock) < 0) { virReportError(VIR_ERR_INTERNAL_ERROR, "%s", @@ -243,7 +245,6 @@ virNetSASLSessionPtr virNetSASLSessionNewServer(virNetSASLContextPtr ctxt ATTRIB return NULL; } - sasl->refs = 1; /* Arbitrary size for amount of data we can encode in a single block */ sasl->maxbufsize = 1 << 16; @@ -265,17 +266,10 @@ virNetSASLSessionPtr virNetSASLSessionNewServer(virNetSASLContextPtr ctxt ATTRIB return sasl; cleanup: - virNetSASLSessionFree(sasl); + virObjectUnref(sasl); return NULL; } -void virNetSASLSessionRef(virNetSASLSessionPtr sasl) -{ - virMutexLock(&sasl->lock); - sasl->refs++; - virMutexUnlock(&sasl->lock); -} - int virNetSASLSessionExtKeySize(virNetSASLSessionPtr sasl, int ssf) { @@ -712,22 +706,12 @@ cleanup: return ret; } -void virNetSASLSessionFree(virNetSASLSessionPtr sasl) +void virNetSASLSessionDispose(void *obj) { - if (!sasl) - return; - - virMutexLock(&sasl->lock); - sasl->refs--; - if (sasl->refs > 0) { - virMutexUnlock(&sasl->lock); - return; - } + virNetSASLSessionPtr sasl = obj; if (sasl->conn) sasl_dispose(&sasl->conn); - virMutexUnlock(&sasl->lock); virMutexDestroy(&sasl->lock); - VIR_FREE(sasl); } diff --git a/src/rpc/virnetsaslcontext.h b/src/rpc/virnetsaslcontext.h index 914c45c..8e322d8 100644 --- a/src/rpc/virnetsaslcontext.h +++ b/src/rpc/virnetsaslcontext.h @@ -24,6 +24,7 @@ # include <sasl/sasl.h> # include "internal.h" +# include "virobject.h" typedef struct _virNetSASLContext virNetSASLContext; typedef virNetSASLContext *virNetSASLContextPtr; @@ -43,9 +44,6 @@ virNetSASLContextPtr virNetSASLContextNewServer(const char *const*usernameWhitel int virNetSASLContextCheckIdentity(virNetSASLContextPtr ctxt, const char *identity); -void virNetSASLContextRef(virNetSASLContextPtr sasl); -void virNetSASLContextFree(virNetSASLContextPtr sasl); - virNetSASLSessionPtr virNetSASLSessionNewClient(virNetSASLContextPtr ctxt, const char *service, const char *hostname, @@ -59,8 +57,6 @@ virNetSASLSessionPtr virNetSASLSessionNewServer(virNetSASLContextPtr ctxt, char *virNetSASLSessionListMechanisms(virNetSASLSessionPtr sasl); -void virNetSASLSessionRef(virNetSASLSessionPtr sasl); - int virNetSASLSessionExtKeySize(virNetSASLSessionPtr sasl, int ssf); @@ -114,6 +110,4 @@ ssize_t virNetSASLSessionDecode(virNetSASLSessionPtr sasl, const char **output, size_t *outputlen); -void virNetSASLSessionFree(virNetSASLSessionPtr sasl); - #endif /* __VIR_NET_CLIENT_SASL_CONTEXT_H__ */ diff --git a/src/rpc/virnetserverclient.c b/src/rpc/virnetserverclient.c index c419e74..471cca0 100644 --- a/src/rpc/virnetserverclient.c +++ b/src/rpc/virnetserverclient.c @@ -474,8 +474,7 @@ void virNetServerClientSetSASLSession(virNetServerClientPtr client, * operation do we switch to SASL mode */ virNetServerClientLock(client); - client->sasl = sasl; - virNetSASLSessionRef(sasl); + client->sasl = virObjectRef(sasl); virNetServerClientUnlock(client); } #endif @@ -591,7 +590,7 @@ void virNetServerClientFree(virNetServerClientPtr client) VIR_FREE(client->identity); #if HAVE_SASL - virNetSASLSessionFree(client->sasl); + virObjectUnref(client->sasl); #endif if (client->sockTimer > 0) virEventRemoveTimeout(client->sockTimer); @@ -1009,7 +1008,7 @@ virNetServerClientDispatchWrite(virNetServerClientPtr client) */ if (client->sasl) { virNetSocketSetSASLSession(client->sock, client->sasl); - virNetSASLSessionFree(client->sasl); + virObjectUnref(client->sasl); client->sasl = NULL; } #endif diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c index bca78b5..b6bb211 100644 --- a/src/rpc/virnetsocket.c +++ b/src/rpc/virnetsocket.c @@ -750,7 +750,7 @@ void virNetSocketFree(virNetSocketPtr sock) virNetTLSSessionSetIOCallbacks(sock->tlsSession, NULL, NULL, NULL); virObjectUnref(sock->tlsSession); #if HAVE_SASL - virNetSASLSessionFree(sock->saslSession); + virObjectUnref(sock->saslSession); #endif VIR_FORCE_CLOSE(sock->fd); @@ -924,9 +924,8 @@ void virNetSocketSetSASLSession(virNetSocketPtr sock, virNetSASLSessionPtr sess) { virMutexLock(&sock->lock); - virNetSASLSessionFree(sock->saslSession); - sock->saslSession = sess; - virNetSASLSessionRef(sess); + virObjectUnref(sock->saslSession); + sock->saslSession = virObjectRef(sess); virMutexUnlock(&sock->lock); } #endif -- 1.7.10.4 -- libvir-list mailing list libvir-list@xxxxxxxxxx https://www.redhat.com/mailman/listinfo/libvir-list