This extends the basic virNetSocket APIs to allow them to have a handle to the TLS/SASL session objects, once established. This ensures that any data reads/writes are automagically passed through the TLS/SASL encryption layers if required. * src/rpc/virnetsocket.c, src/rpc/virnetsocket.h: Wire up SASL/TLS encryption --- src/rpc/virnetsocket.c | 274 +++++++++++++++++++++++++++++++++++++++++++++++- src/rpc/virnetsocket.h | 11 ++ 2 files changed, 282 insertions(+), 3 deletions(-) diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c index 6855397..daa40f4 100644 --- a/src/rpc/virnetsocket.c +++ b/src/rpc/virnetsocket.c @@ -27,6 +27,9 @@ #include <sys/socket.h> #include <unistd.h> #include <sys/wait.h> +#ifdef HAVE_NETINET_TCP_H +# include <netinet/tcp.h> +#endif #ifdef HAVE_NETINET_TCP_H # include <netinet/tcp.h> @@ -59,6 +62,19 @@ struct _virNetSocket { virSocketAddr remoteAddr; char *localAddrStr; char *remoteAddrStr; + + virNetTLSSessionPtr tlsSession; +#if HAVE_SASL + virNetSASLSessionPtr saslSession; + + const char *saslDecoded; + size_t saslDecodedLength; + size_t saslDecodedOffset; + + const char *saslEncoded; + size_t saslEncodedLength; + size_t saslEncodedOffset; +#endif }; @@ -416,7 +432,7 @@ error: } -#if HAVE_SYS_UN_H +#ifdef HAVE_SYS_UN_H int virNetSocketNewConnectUNIX(const char *path, bool spawnDaemon, const char *binary, @@ -632,6 +648,14 @@ void virNetSocketFree(virNetSocketPtr sock) unlink(sock->localAddr.data.un.sun_path); #endif + /* Make sure it can't send any more I/O during shutdown */ + if (sock->tlsSession) + virNetTLSSessionSetIOCallbacks(sock->tlsSession, NULL, NULL, NULL); + virNetTLSSessionFree(sock->tlsSession); +#if HAVE_SASL + virNetSASLSessionFree(sock->saslSession); +#endif + VIR_FORCE_CLOSE(sock->fd); VIR_FORCE_CLOSE(sock->errfd); @@ -717,14 +741,258 @@ const char *virNetSocketRemoteAddrString(virNetSocketPtr sock) return sock->remoteAddrStr; } -ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len) + +static ssize_t virNetSocketTLSSessionWrite(const char *buf, + size_t len, + void *opaque) { + virNetSocketPtr sock = opaque; + return write(sock->fd, buf, len); +} + + +static ssize_t virNetSocketTLSSessionRead(char *buf, + size_t len, + void *opaque) +{ + virNetSocketPtr sock = opaque; return read(sock->fd, buf, len); } + +void virNetSocketSetTLSSession(virNetSocketPtr sock, + virNetTLSSessionPtr sess) +{ + virNetTLSSessionFree(sock->tlsSession); + sock->tlsSession = sess; + virNetTLSSessionSetIOCallbacks(sess, + virNetSocketTLSSessionWrite, + virNetSocketTLSSessionRead, + sock); + virNetTLSSessionRef(sess); +} + + +#if HAVE_SASL +void virNetSocketSetSASLSession(virNetSocketPtr sock, + virNetSASLSessionPtr sess) +{ + virNetSASLSessionFree(sock->saslSession); + sock->saslSession = sess; + virNetSASLSessionRef(sess); +} +#endif + + +bool virNetSocketHasCachedData(virNetSocketPtr sock ATTRIBUTE_UNUSED) +{ +#if HAVE_SASL + if (sock->saslDecoded) + return true; +#endif + return false; +} + + +static ssize_t virNetSocketReadWire(virNetSocketPtr sock, char *buf, size_t len) +{ + char *errout = NULL; + ssize_t ret; +reread: + if (sock->tlsSession && + virNetTLSSessionGetHandshakeStatus(sock->tlsSession) == + VIR_NET_TLS_HANDSHAKE_COMPLETE) { + ret = virNetTLSSessionRead(sock->tlsSession, buf, len); + } else { + ret = read(sock->fd, buf, len); + } + + if ((ret < 0) && (errno == EINTR)) + goto reread; + if ((ret < 0) && (errno == EAGAIN)) + return 0; + + if (ret <= 0 && + sock->errfd != -1 && + virFileReadLimFD(sock->errfd, 1024, &errout) >= 0 && + errout != NULL) { + size_t elen = strlen(errout); + if (elen && errout[elen-1] == '\n') + errout[elen-1] = '\0'; + } + + if (ret < 0) { + if (errout) + virReportSystemError(errno, + _("Cannot recv data: %s"), errout); + else + virReportSystemError(errno, "%s", + _("Cannot recv data")); + ret = -1; + } else if (ret == 0) { + if (errout) + virReportSystemError(EIO, + _("End of file while reading data: %s"), errout); + else + virReportSystemError(EIO, "%s", + _("End of file while reading data")); + ret = -1; + } + + VIR_FREE(errout); + return ret; +} + +static ssize_t virNetSocketWriteWire(virNetSocketPtr sock, const char *buf, size_t len) +{ + ssize_t ret; +rewrite: + if (sock->tlsSession && + virNetTLSSessionGetHandshakeStatus(sock->tlsSession) == + VIR_NET_TLS_HANDSHAKE_COMPLETE) { + ret = virNetTLSSessionWrite(sock->tlsSession, buf, len); + } else { + ret = write(sock->fd, buf, len); + } + + if (ret < 0) { + if (errno == EINTR) + goto rewrite; + if (errno == EAGAIN) + return 0; + + virReportSystemError(errno, "%s", + _("Cannot write data")); + return -1; + } + if (ret == 0) { + virReportSystemError(EIO, "%s", + _("End of file while writing data")); + return -1; + } + + return ret; +} + + +#if HAVE_SASL +static ssize_t virNetSocketReadSASL(virNetSocketPtr sock, char *buf, size_t len) +{ + ssize_t got; + + /* Need to read some more data off the wire */ + if (sock->saslDecoded == NULL) { + ssize_t encodedLen = virNetSASLSessionGetMaxBufSize(sock->saslSession); + char *encoded; + if (VIR_ALLOC_N(encoded, encodedLen) < 0) { + virReportOOMError(); + return -1; + } + encodedLen = virNetSocketReadWire(sock, encoded, encodedLen); + + if (encodedLen <= 0) { + VIR_FREE(encoded); + return encodedLen; + } + + if (virNetSASLSessionDecode(sock->saslSession, + encoded, encodedLen, + &sock->saslDecoded, &sock->saslDecodedLength) < 0) { + VIR_FREE(encoded); + return -1; + } + VIR_FREE(encoded); + + sock->saslDecodedOffset = 0; + } + + /* Some buffered decoded data to return now */ + got = sock->saslDecodedLength - sock->saslDecodedOffset; + + if (len > got) + len = got; + + memcpy(buf, sock->saslDecoded + sock->saslDecodedOffset, len); + sock->saslDecodedOffset += len; + + if (sock->saslDecodedOffset == sock->saslDecodedLength) { + sock->saslDecoded = NULL; + sock->saslDecodedOffset = sock->saslDecodedLength = 0; + } + + return len; +} + + +static ssize_t virNetSocketWriteSASL(virNetSocketPtr sock, const char *buf, size_t len) +{ + int ret; + size_t tosend = virNetSASLSessionGetMaxBufSize(sock->saslSession); + + /* SASL doesn't neccessarily let us send the whole + buffer at once */ + if (tosend > len) + tosend = len; + + /* Not got any pending encoded data, so we need to encode raw stuff */ + if (sock->saslEncoded == NULL) { + if (virNetSASLSessionEncode(sock->saslSession, + buf, tosend, + &sock->saslEncoded, + &sock->saslEncodedLength) < 0) + return -1; + + sock->saslEncodedOffset = 0; + } + + /* Send some of the encoded stuff out on the wire */ + ret = virNetSocketWriteWire(sock, + sock->saslEncoded + sock->saslEncodedOffset, + sock->saslEncodedLength - sock->saslEncodedOffset); + + if (ret <= 0) + return ret; /* -1 error, 0 == egain */ + + /* Note how much we sent */ + sock->saslEncodedOffset += ret; + + /* Sent all encoded, so update raw buffer to indicate completion */ + if (sock->saslEncodedOffset == sock->saslEncodedLength) { + sock->saslEncoded = NULL; + sock->saslEncodedOffset = sock->saslEncodedLength = 0; + + /* Mark as complete, so caller detects completion */ + return tosend; + } else { + /* Still have stuff pending in saslEncoded buffer. + * Pretend to caller that we didn't send any yet. + * The caller will then retry with same buffer + * shortly, which lets us finish saslEncoded. + */ + return 0; + } +} +#endif + + +ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len) +{ +#if HAVE_SASL + if (sock->saslSession) + return virNetSocketReadSASL(sock, buf, len); + else +#endif + return virNetSocketReadWire(sock, buf, len); +} + ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len) { - return write(sock->fd, buf, len); +#if HAVE_SASL + if (sock->saslSession) + return virNetSocketWriteSASL(sock, buf, len); + else +#endif + return virNetSocketWriteWire(sock, buf, len); } diff --git a/src/rpc/virnetsocket.h b/src/rpc/virnetsocket.h index 218fe8f..59ff288 100644 --- a/src/rpc/virnetsocket.h +++ b/src/rpc/virnetsocket.h @@ -26,6 +26,10 @@ # include "network.h" # include "command.h" +# include "virnettlscontext.h" +# ifdef HAVE_SASL +# include "virnetsaslcontext.h" +# endif typedef struct _virNetSocket virNetSocket; typedef virNetSocket *virNetSocketPtr; @@ -83,6 +87,13 @@ int virNetSocketSetBlocking(virNetSocketPtr sock, ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len); ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len); +void virNetSocketSetTLSSession(virNetSocketPtr sock, + virNetTLSSessionPtr sess); +# ifdef HAVE_SASL +void virNetSocketSetSASLSession(virNetSocketPtr sock, + virNetSASLSessionPtr sess); +# endif +bool virNetSocketHasCachedData(virNetSocketPtr sock); void virNetSocketFree(virNetSocketPtr sock); const char *virNetSocketLocalAddrString(virNetSocketPtr sock); -- 1.7.4 -- libvir-list mailing list libvir-list@xxxxxxxxxx https://www.redhat.com/mailman/listinfo/libvir-list