From: "Daniel P. Berrange" <berrange@xxxxxxxxxx> Make virNetTLSContext and virNetTLSSession use the virObject APIs for reference counting Signed-off-by: Daniel P. Berrange <berrange@xxxxxxxxxx> --- daemon/libvirtd.c | 4 +- src/libvirt_private.syms | 2 - src/libvirt_probes.d | 8 +-- src/remote/remote_driver.c | 2 +- src/rpc/virnetclient.c | 6 +-- src/rpc/virnetserver.c | 3 +- src/rpc/virnetserverclient.c | 11 ++--- src/rpc/virnetserverservice.c | 10 ++-- src/rpc/virnetsocket.c | 7 ++- src/rpc/virnettlscontext.c | 110 +++++++++++++++-------------------------- src/rpc/virnettlscontext.h | 10 +--- tests/virnettlscontexttest.c | 10 ++-- 12 files changed, 66 insertions(+), 117 deletions(-) diff --git a/daemon/libvirtd.c b/daemon/libvirtd.c index 79f37ae..211a4bc 100644 --- a/daemon/libvirtd.c +++ b/daemon/libvirtd.c @@ -541,7 +541,7 @@ static int daemonSetupNetworking(virNetServerPtr srv, false, config->max_client_requests, ctxt))) { - virNetTLSContextFree(ctxt); + virObjectUnref(ctxt); goto error; } if (virNetServerAddService(srv, svcTLS, @@ -549,7 +549,7 @@ static int daemonSetupNetworking(virNetServerPtr srv, !config->listen_tcp ? "_libvirt._tcp" : NULL) < 0) goto error; - virNetTLSContextFree(ctxt); + virObjectUnref(ctxt); } } diff --git a/src/libvirt_private.syms b/src/libvirt_private.syms index 3551fd0..035658e 100644 --- a/src/libvirt_private.syms +++ b/src/libvirt_private.syms @@ -1481,11 +1481,9 @@ virNetSocketWrite; # virnettlscontext.h virNetTLSContextCheckCertificate; -virNetTLSContextFree; virNetTLSContextNewClient; virNetTLSContextNewServer; virNetTLSContextNewServerPath; -virNetTLSSessionFree; virNetTLSSessionHandshake; virNetTLSSessionNew; virNetTLSSessionSetIOCallbacks; diff --git a/src/libvirt_probes.d b/src/libvirt_probes.d index ceb3caa..3b138a9 100644 --- a/src/libvirt_probes.d +++ b/src/libvirt_probes.d @@ -61,19 +61,15 @@ provider libvirt { # file: src/rpc/virnettlscontext.c # prefix: rpc - probe rpc_tls_context_new(void *ctxt, int refs, const char *cacert, const char *cacrl, + probe rpc_tls_context_new(void *ctxt, const char *cacert, const char *cacrl, const char *cert, const char *key, int sanityCheckCert, int requireValidCert, int isServer); - probe rpc_tls_context_ref(void *ctxt, int refs); - probe rpc_tls_context_free(void *ctxt, int refs); probe rpc_tls_context_session_allow(void *ctxt, void *sess, const char *dname); probe rpc_tls_context_session_deny(void *ctxt, void *sess, const char *dname); probe rpc_tls_context_session_fail(void *ctxt, void *sess); - probe rpc_tls_session_new(void *sess, void *ctxt, int refs, const char *hostname, int isServer); - probe rpc_tls_session_ref(void *sess, int refs); - probe rpc_tls_session_free(void *sess, int refs); + probe rpc_tls_session_new(void *sess, void *ctxt, const char *hostname, int isServer); probe rpc_tls_session_handshake_pass(void *sess); probe rpc_tls_session_handshake_fail(void *sess); diff --git a/src/remote/remote_driver.c b/src/remote/remote_driver.c index eac50e6..28035de 100644 --- a/src/remote/remote_driver.c +++ b/src/remote/remote_driver.c @@ -908,7 +908,7 @@ doRemoteClose (virConnectPtr conn, struct private_data *priv) (xdrproc_t) xdr_void, (char *) NULL) == -1) ret = -1; - virNetTLSContextFree(priv->tls); + virObjectUnref(priv->tls); priv->tls = NULL; virNetClientClose(priv->client); virNetClientFree(priv->client); diff --git a/src/rpc/virnetclient.c b/src/rpc/virnetclient.c index 49d238e..2b51246 100644 --- a/src/rpc/virnetclient.c +++ b/src/rpc/virnetclient.c @@ -475,7 +475,7 @@ void virNetClientFree(virNetClientPtr client) if (client->sock) virNetSocketRemoveIOCallback(client->sock); virNetSocketFree(client->sock); - virNetTLSSessionFree(client->tls); + virObjectUnref(client->tls); #if HAVE_SASL virNetSASLSessionFree(client->sasl); #endif @@ -499,7 +499,7 @@ virNetClientCloseLocked(virNetClientPtr client) virNetSocketRemoveIOCallback(client->sock); virNetSocketFree(client->sock); client->sock = NULL; - virNetTLSSessionFree(client->tls); + virObjectUnref(client->tls); client->tls = NULL; #if HAVE_SASL virNetSASLSessionFree(client->sasl); @@ -661,7 +661,7 @@ int virNetClientSetTLSSession(virNetClientPtr client, return 0; error: - virNetTLSSessionFree(client->tls); + virObjectUnref(client->tls); client->tls = NULL; virNetClientUnlock(client); return -1; diff --git a/src/rpc/virnetserver.c b/src/rpc/virnetserver.c index 4a02aab..17da40c 100644 --- a/src/rpc/virnetserver.c +++ b/src/rpc/virnetserver.c @@ -655,8 +655,7 @@ no_memory: int virNetServerSetTLSContext(virNetServerPtr srv, virNetTLSContextPtr tls) { - srv->tls = tls; - virNetTLSContextRef(tls); + srv->tls = virObjectRef(tls); return 0; } diff --git a/src/rpc/virnetserverclient.c b/src/rpc/virnetserverclient.c index a56031c..85a457e 100644 --- a/src/rpc/virnetserverclient.c +++ b/src/rpc/virnetserverclient.c @@ -348,7 +348,7 @@ virNetServerClientPtr virNetServerClientNew(virNetSocketPtr sock, client->sock = sock; client->auth = auth; client->readonly = readonly; - client->tlsCtxt = tls; + client->tlsCtxt = virObjectRef(tls); client->nrequests_max = nrequests_max; client->sockTimer = virEventAddTimeout(-1, virNetServerClientSockTimerFunc, @@ -356,9 +356,6 @@ virNetServerClientPtr virNetServerClientNew(virNetSocketPtr sock, if (client->sockTimer < 0) goto error; - if (tls) - virNetTLSContextRef(tls); - /* Prepare one for packet receive */ if (!(client->rx = virNetMessageNew(true))) goto error; @@ -600,8 +597,8 @@ void virNetServerClientFree(virNetServerClientPtr client) #endif if (client->sockTimer > 0) virEventRemoveTimeout(client->sockTimer); - virNetTLSSessionFree(client->tls); - virNetTLSContextFree(client->tlsCtxt); + virObjectUnref(client->tls); + virObjectUnref(client->tlsCtxt); virNetSocketFree(client->sock); virNetServerClientUnlock(client); virMutexDestroy(&client->lock); @@ -656,7 +653,7 @@ void virNetServerClientClose(virNetServerClientPtr client) virNetSocketRemoveIOCallback(client->sock); if (client->tls) { - virNetTLSSessionFree(client->tls); + virObjectUnref(client->tls); client->tls = NULL; } client->wantClose = true; diff --git a/src/rpc/virnetserverservice.c b/src/rpc/virnetserverservice.c index 28202a4..b4689b4 100644 --- a/src/rpc/virnetserverservice.c +++ b/src/rpc/virnetserverservice.c @@ -116,9 +116,7 @@ virNetServerServicePtr virNetServerServiceNewTCP(const char *nodename, svc->auth = auth; svc->readonly = readonly; svc->nrequests_client_max = nrequests_client_max; - svc->tls = tls; - if (tls) - virNetTLSContextRef(tls); + svc->tls = virObjectRef(tls); if (virNetSocketNewListenTCP(nodename, service, @@ -172,9 +170,7 @@ virNetServerServicePtr virNetServerServiceNewUNIX(const char *path, svc->auth = auth; svc->readonly = readonly; svc->nrequests_client_max = nrequests_client_max; - svc->tls = tls; - if (tls) - virNetTLSContextRef(tls); + svc->tls = virObjectRef(tls); svc->nsocks = 1; if (VIR_ALLOC_N(svc->socks, svc->nsocks) < 0) @@ -265,7 +261,7 @@ void virNetServerServiceFree(virNetServerServicePtr svc) virNetSocketFree(svc->socks[i]); VIR_FREE(svc->socks); - virNetTLSContextFree(svc->tls); + virObjectUnref(svc->tls); VIR_FREE(svc); } diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c index 0b32ffe..a851dad 100644 --- a/src/rpc/virnetsocket.c +++ b/src/rpc/virnetsocket.c @@ -748,7 +748,7 @@ void virNetSocketFree(virNetSocketPtr sock) /* Make sure it can't send any more I/O during shutdown */ if (sock->tlsSession) virNetTLSSessionSetIOCallbacks(sock->tlsSession, NULL, NULL, NULL); - virNetTLSSessionFree(sock->tlsSession); + virObjectUnref(sock->tlsSession); #if HAVE_SASL virNetSASLSessionFree(sock->saslSession); #endif @@ -909,13 +909,12 @@ void virNetSocketSetTLSSession(virNetSocketPtr sock, virNetTLSSessionPtr sess) { virMutexLock(&sock->lock); - virNetTLSSessionFree(sock->tlsSession); - sock->tlsSession = sess; + virObjectUnref(sock->tlsSession); + sock->tlsSession = virObjectRef(sess); virNetTLSSessionSetIOCallbacks(sess, virNetSocketTLSSessionWrite, virNetSocketTLSSessionRead, sock); - virNetTLSSessionRef(sess); virMutexUnlock(&sock->lock); } diff --git a/src/rpc/virnettlscontext.c b/src/rpc/virnettlscontext.c index bf92088..74e13c7 100644 --- a/src/rpc/virnettlscontext.c +++ b/src/rpc/virnettlscontext.c @@ -53,8 +53,9 @@ __FUNCTION__, __LINE__, __VA_ARGS__) struct _virNetTLSContext { + virObject object; + virMutex lock; - int refs; gnutls_certificate_credentials_t x509cred; gnutls_dh_params_t dhParams; @@ -65,9 +66,9 @@ struct _virNetTLSContext { }; struct _virNetTLSSession { - virMutex lock; + virObject object; - int refs; + virMutex lock; bool handshakeComplete; @@ -79,6 +80,29 @@ struct _virNetTLSSession { void *opaque; }; +static virClassPtr virNetTLSContextClass; +static virClassPtr virNetTLSSessionClass; +static void virNetTLSContextDispose(void *obj); +static void virNetTLSSessionDispose(void *obj); + + +static int virNetTLSContextOnceInit(void) +{ + if (!(virNetTLSContextClass = virClassNew("virNetTLSContext", + sizeof(virNetTLSContext), + virNetTLSContextDispose))) + return -1; + + if (!(virNetTLSSessionClass = virClassNew("virNetTLSSession", + sizeof(virNetTLSSession), + virNetTLSSessionDispose))) + return -1; + + return 0; +} + +VIR_ONCE_GLOBAL_INIT(virNetTLSContext) + static int virNetTLSContextCheckCertFile(const char *type, const char *file, bool allowMissing) @@ -650,10 +674,11 @@ static virNetTLSContextPtr virNetTLSContextNew(const char *cacert, char *gnutlsdebug; int err; - if (VIR_ALLOC(ctxt) < 0) { - virReportOOMError(); + if (virNetTLSContextInitialize() < 0) + return NULL; + + if (!(ctxt = virObjectNew(virNetTLSContextClass))) return NULL; - } if (virMutexInit(&ctxt->lock) < 0) { virNetError(VIR_ERR_INTERNAL_ERROR, "%s", @@ -662,8 +687,6 @@ static virNetTLSContextPtr virNetTLSContextNew(const char *cacert, return NULL; } - ctxt->refs = 1; - if ((gnutlsdebug = getenv("LIBVIRT_GNUTLS_DEBUG")) != NULL) { int val; if (virStrToLong_i(gnutlsdebug, NULL, 10, &val) < 0) @@ -719,8 +742,8 @@ static virNetTLSContextPtr virNetTLSContextNew(const char *cacert, ctxt->isServer = isServer; PROBE(RPC_TLS_CONTEXT_NEW, - "ctxt=%p refs=%d cacert=%s cacrl=%s cert=%s key=%s sanityCheckCert=%d requireValidCert=%d isServer=%d", - ctxt, ctxt->refs, cacert, NULLSTR(cacrl), cert, key, sanityCheckCert, requireValidCert, isServer); + "ctxt=%p cacert=%s cacrl=%s cert=%s key=%s sanityCheckCert=%d requireValidCert=%d isServer=%d", + ctxt, cacert, NULLSTR(cacrl), cert, key, sanityCheckCert, requireValidCert, isServer); return ctxt; @@ -930,17 +953,6 @@ virNetTLSContextPtr virNetTLSContextNewClient(const char *cacert, } -void virNetTLSContextRef(virNetTLSContextPtr ctxt) -{ - virMutexLock(&ctxt->lock); - ctxt->refs++; - PROBE(RPC_TLS_CONTEXT_REF, - "ctxt=%p refs=%d", - ctxt, ctxt->refs); - virMutexUnlock(&ctxt->lock); -} - - static int virNetTLSContextValidCertificate(virNetTLSContextPtr ctxt, virNetTLSSessionPtr sess) { @@ -1109,30 +1121,16 @@ cleanup: return ret; } -void virNetTLSContextFree(virNetTLSContextPtr ctxt) +void virNetTLSContextDispose(void *obj) { - if (!ctxt) - return; - - virMutexLock(&ctxt->lock); - PROBE(RPC_TLS_CONTEXT_FREE, - "ctxt=%p refs=%d", - ctxt, ctxt->refs); - ctxt->refs--; - if (ctxt->refs > 0) { - virMutexUnlock(&ctxt->lock); - return; - } + virNetTLSContextPtr ctxt = obj; gnutls_dh_params_deinit(ctxt->dhParams); gnutls_certificate_free_credentials(ctxt->x509cred); - virMutexUnlock(&ctxt->lock); virMutexDestroy(&ctxt->lock); - VIR_FREE(ctxt); } - static ssize_t virNetTLSSessionPush(void *opaque, const void *buf, size_t len) { @@ -1170,10 +1168,8 @@ virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt, VIR_DEBUG("ctxt=%p hostname=%s isServer=%d", ctxt, NULLSTR(hostname), ctxt->isServer); - if (VIR_ALLOC(sess) < 0) { - virReportOOMError(); + if (!(sess = virObjectNew(virNetTLSSessionClass))) return NULL; - } if (virMutexInit(&sess->lock) < 0) { virNetError(VIR_ERR_INTERNAL_ERROR, "%s", @@ -1182,7 +1178,6 @@ virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt, return NULL; } - sess->refs = 1; if (hostname && !(sess->hostname = strdup(hostname))) { virReportOOMError(); @@ -1233,27 +1228,17 @@ virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt, sess->isServer = ctxt->isServer; PROBE(RPC_TLS_SESSION_NEW, - "sess=%p refs=%d ctxt=%p hostname=%s isServer=%d", - sess, sess->refs, ctxt, hostname, sess->isServer); + "sess=%p ctxt=%p hostname=%s isServer=%d", + sess, ctxt, hostname, sess->isServer); return sess; error: - virNetTLSSessionFree(sess); + virObjectUnref(sess); return NULL; } -void virNetTLSSessionRef(virNetTLSSessionPtr sess) -{ - virMutexLock(&sess->lock); - sess->refs++; - PROBE(RPC_TLS_SESSION_REF, - "sess=%p refs=%d", - sess, sess->refs); - virMutexUnlock(&sess->lock); -} - void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess, virNetTLSSessionWriteFunc writeFunc, virNetTLSSessionReadFunc readFunc, @@ -1396,26 +1381,13 @@ cleanup: } -void virNetTLSSessionFree(virNetTLSSessionPtr sess) +void virNetTLSSessionDispose(void *obj) { - if (!sess) - return; - - virMutexLock(&sess->lock); - PROBE(RPC_TLS_SESSION_FREE, - "sess=%p refs=%d", - sess, sess->refs); - sess->refs--; - if (sess->refs > 0) { - virMutexUnlock(&sess->lock); - return; - } + virNetTLSSessionPtr sess = obj; VIR_FREE(sess->hostname); gnutls_deinit(sess->session); - virMutexUnlock(&sess->lock); virMutexDestroy(&sess->lock); - VIR_FREE(sess); } /* diff --git a/src/rpc/virnettlscontext.h b/src/rpc/virnettlscontext.h index fdfce6d..4821016 100644 --- a/src/rpc/virnettlscontext.h +++ b/src/rpc/virnettlscontext.h @@ -22,6 +22,7 @@ # define __VIR_NET_TLS_CONTEXT_H__ # include "internal.h" +# include "virobject.h" typedef struct _virNetTLSContext virNetTLSContext; typedef virNetTLSContext *virNetTLSContextPtr; @@ -58,13 +59,9 @@ virNetTLSContextPtr virNetTLSContextNewClient(const char *cacert, bool sanityCheckCert, bool requireValidCert); -void virNetTLSContextRef(virNetTLSContextPtr ctxt); - int virNetTLSContextCheckCertificate(virNetTLSContextPtr ctxt, virNetTLSSessionPtr sess); -void virNetTLSContextFree(virNetTLSContextPtr ctxt); - typedef ssize_t (*virNetTLSSessionWriteFunc)(const char *buf, size_t len, void *opaque); @@ -79,8 +76,6 @@ void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess, virNetTLSSessionReadFunc readFunc, void *opaque); -void virNetTLSSessionRef(virNetTLSSessionPtr sess); - ssize_t virNetTLSSessionWrite(virNetTLSSessionPtr sess, const char *buf, size_t len); ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess, @@ -99,7 +94,4 @@ virNetTLSSessionGetHandshakeStatus(virNetTLSSessionPtr sess); int virNetTLSSessionGetKeySize(virNetTLSSessionPtr sess); -void virNetTLSSessionFree(virNetTLSSessionPtr sess); - - #endif diff --git a/tests/virnettlscontexttest.c b/tests/virnettlscontexttest.c index e745487..32e1f77 100644 --- a/tests/virnettlscontexttest.c +++ b/tests/virnettlscontexttest.c @@ -496,7 +496,7 @@ static int testTLSContextInit(const void *opaque) ret = 0; cleanup: - virNetTLSContextFree(ctxt); + virObjectUnref(ctxt); gnutls_x509_crt_deinit(data->careq.crt); gnutls_x509_crt_deinit(data->certreq.crt); data->careq.crt = data->certreq.crt = NULL; @@ -710,10 +710,10 @@ static int testTLSSessionInit(const void *opaque) ret = 0; cleanup: - virNetTLSContextFree(serverCtxt); - virNetTLSContextFree(clientCtxt); - virNetTLSSessionFree(serverSess); - virNetTLSSessionFree(clientSess); + virObjectUnref(serverCtxt); + virObjectUnref(clientCtxt); + virObjectUnref(serverSess); + virObjectUnref(clientSess); gnutls_x509_crt_deinit(data->careq.crt); if (data->othercareq.filename) gnutls_x509_crt_deinit(data->othercareq.crt); -- 1.7.10.2 -- libvir-list mailing list libvir-list@xxxxxxxxxx https://www.redhat.com/mailman/listinfo/libvir-list