Some TLX interacting functions can be called within two or more threads with the same pointer. Therefore we need to protect virNetTLSSessionPtr with mutex to avoid non-consistent states. --- src/rpc/virnettlscontext.c | 41 +++++++++++++++++++++++++++++++++++++++-- 1 files changed, 39 insertions(+), 2 deletions(-) diff --git a/src/rpc/virnettlscontext.c b/src/rpc/virnettlscontext.c index bde4e7a..a0f7a3f 100644 --- a/src/rpc/virnettlscontext.c +++ b/src/rpc/virnettlscontext.c @@ -35,6 +35,7 @@ #include "util.h" #include "logging.h" #include "configmake.h" +#include "threads.h" #define DH_BITS 1024 @@ -63,6 +64,7 @@ struct _virNetTLSContext { }; struct _virNetTLSSession { + virMutex lock; int refs; bool handshakeComplete; @@ -1083,6 +1085,16 @@ void virNetTLSContextFree(virNetTLSContextPtr ctxt) +static void virNetTLSSessionLock(virNetTLSSessionPtr session) +{ + virMutexLock(&session->lock); +} + +static void virNetTLSSessionUnlock(virNetTLSSessionPtr session) +{ + virMutexUnlock(&session->lock); +} + static ssize_t virNetTLSSessionPush(void *opaque, const void *buf, size_t len) { @@ -1124,6 +1136,9 @@ virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt, return NULL; } + if (virMutexInit(&sess->lock) < 0) + goto error; + sess->refs = 1; if (hostname && !(sess->hostname = strdup(hostname))) { @@ -1184,7 +1199,9 @@ error: void virNetTLSSessionRef(virNetTLSSessionPtr sess) { + virNetTLSSessionLock(sess); sess->refs++; + virNetTLSSessionUnlock(sess); } void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess, @@ -1192,9 +1209,11 @@ void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess, virNetTLSSessionReadFunc readFunc, void *opaque) { + virNetTLSSessionLock(sess); sess->writeFunc = writeFunc; sess->readFunc = readFunc; sess->opaque = opaque; + virNetTLSSessionUnlock(sess); } @@ -1202,7 +1221,10 @@ ssize_t virNetTLSSessionWrite(virNetTLSSessionPtr sess, const char *buf, size_t len) { ssize_t ret; + + virNetTLSSessionLock(sess); ret = gnutls_record_send(sess->session, buf, len); + virNetTLSSessionUnlock(sess); if (ret >= 0) return ret; @@ -1230,7 +1252,9 @@ ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess, { ssize_t ret; + virNetTLSSessionLock(sess); ret = gnutls_record_recv(sess->session, buf, len); + virNetTLSSessionUnlock(sess); if (ret >= 0) return ret; @@ -1253,15 +1277,19 @@ ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess, int virNetTLSSessionHandshake(virNetTLSSessionPtr sess) { VIR_DEBUG("sess=%p", sess); + virNetTLSSessionLock(sess); int ret = gnutls_handshake(sess->session); VIR_DEBUG("Ret=%d", ret); if (ret == 0) { sess->handshakeComplete = true; VIR_DEBUG("Handshake is complete"); + virNetTLSSessionUnlock(sess); return 0; } - if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) + if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) { + virNetTLSSessionUnlock(sess); return 1; + } #if 0 PROBE(CLIENT_TLS_FAIL, "fd=%d", @@ -1271,6 +1299,7 @@ int virNetTLSSessionHandshake(virNetTLSSessionPtr sess) virNetError(VIR_ERR_AUTH_FAILED, _("TLS handshake failed %s"), gnutls_strerror(ret)); + virNetTLSSessionUnlock(sess); return -1; } @@ -1290,12 +1319,15 @@ int virNetTLSSessionGetKeySize(virNetTLSSessionPtr sess) gnutls_cipher_algorithm_t cipher; int ssf; + virNetTLSSessionLock(sess); cipher = gnutls_cipher_get(sess->session); if (!(ssf = gnutls_cipher_get_key_size(cipher))) { virNetError(VIR_ERR_INTERNAL_ERROR, "%s", _("invalid cipher size for TLS session")); + virNetTLSSessionUnlock(sess); return -1; } + virNetTLSSessionUnlock(sess); return ssf; } @@ -1306,11 +1338,16 @@ void virNetTLSSessionFree(virNetTLSSessionPtr sess) if (!sess) return; + virNetTLSSessionLock(sess); sess->refs--; - if (sess->refs > 0) + if (sess->refs > 0) { + virNetTLSSessionUnlock(sess); return; + } VIR_FREE(sess->hostname); gnutls_deinit(sess->session); + virNetTLSSessionUnlock(sess); + virMutexDestroy(&sess->lock); VIR_FREE(sess); } -- 1.7.5.rc3 -- libvir-list mailing list libvir-list@xxxxxxxxxx https://www.redhat.com/mailman/listinfo/libvir-list