Introduces a set of generic objects which are to be used in building RPC servers/clients based on XDR. - virNetMessageHeader - standardize the XDR format for any RPC program. Copied from remote protocol for back compat - virNetMessage - Provides a buffer for (de-)serializing messages, and a copy of the decoded virNetMessageHeader. Provides APIs for encoding/decoding message headers and payloads, thus isolating all the XDR api calls in one file. Callers no longer need to use XDR themselves. - virNetSocket - a wrapper around a socket file descriptor, to simplify creation of new sockets, both for clients and services. Encapsulates all the hairy getaddrinfo code and sockaddr manipulation. Will eventually include transparent support for TLS and SASL encoding of data - virNetTLSContext - encapsulates the credentials required to setup TLS sessions. eg the set of x509 certificates and keys, optional DH parameters and x509 DName whitelist Provides APIs for easily validating certificates from a TLS session - virNetTLSSession - encapsulates the TLS session handling, so that callers no longer have a direct dependancy on gnutls. This will facilitate adding alternate TLS impls. Makes the read/write TLS functions work with same semantics as the native socket read/write functions. ie they set errno, instead of a gnutls specific error code. This code is taken from either the daemon/libvirtd.c, daemon/dispatch.c or src/remote/remote_driver.c files, which all duplicated alot of functionality. --- src/Makefile.am | 42 +++- src/rpc/virnetmessage.c | 215 +++++++++++++ src/rpc/virnetmessage.h | 31 ++ src/rpc/virnetprotocol.c | 108 +++++++ src/rpc/virnetprotocol.h | 81 +++++ src/rpc/virnetprotocol.x | 162 ++++++++++ src/rpc/virnetsocket.c | 715 ++++++++++++++++++++++++++++++++++++++++++++ src/rpc/virnetsocket.h | 97 ++++++ src/rpc/virnettlscontext.c | 611 +++++++++++++++++++++++++++++++++++++ src/rpc/virnettlscontext.h | 63 ++++ 10 files changed, 2124 insertions(+), 1 deletions(-) create mode 100644 src/rpc/virnetmessage.c create mode 100644 src/rpc/virnetmessage.h create mode 100644 src/rpc/virnetprotocol.c create mode 100644 src/rpc/virnetprotocol.h create mode 100644 src/rpc/virnetprotocol.x create mode 100644 src/rpc/virnetsocket.c create mode 100644 src/rpc/virnetsocket.h create mode 100644 src/rpc/virnettlscontext.c create mode 100644 src/rpc/virnettlscontext.h diff --git a/src/Makefile.am b/src/Makefile.am index d71c644..613ff0a 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -530,6 +530,24 @@ else mv -f rp_qemu.c-t $(srcdir)/remote/qemu_protocol.c endif +rpcgen-net: + rm -f rp_net.c-t rp_net.h-t rp_net.c-t1 rp_net.c-t2 rp_net.h-t1 + $(RPCGEN) -h -o rp_net.h-t $(srcdir)/rpc/virnetprotocol.x + $(RPCGEN) -c -o rp_net.c-t $(srcdir)/rpc/virnetprotocol.x +if HAVE_GLIBC_RPCGEN + perl -w $(srcdir)/remote/rpcgen_fix.pl rp_net.h-t > rp_net.h-t1 + perl -w $(srcdir)/remote/rpcgen_fix.pl rp_net.c-t > rp_net.c-t1 + (echo '#include <config.h>'; cat rp_net.c-t1) > rp_net.c-t2 + chmod 0444 rp_net.c-t2 rp_net.h-t1 + mv -f rp_net.h-t1 $(srcdir)/rpc/virnetprotocol.h + mv -f rp_net.c-t2 $(srcdir)/rpc/virnetprotocol.c + rm -f rp_net.c-t rp_net.h-t rp_net.c-t1 +else + chmod 0444 rp_net.c-t rp_net.h-t + mv -f rp_net.h-t $(srcdir)/rpc/virnetprotocol.h + mv -f rp_net.c-t $(srcdir)/rpc/virnetprotocol.c +endif + # # Maintainer-only target for re-generating the derived .c/.h source # files, which are actually derived from the .x file. @@ -540,7 +558,7 @@ endif # Support for non-GLIB rpcgen is here as a convenience for # non-Linux people needing to test changes during dev. # -rpcgen: rpcgen-normal rpcgen-qemu +rpcgen: rpcgen-normal rpcgen-qemu rpcgen-net endif @@ -1098,6 +1116,28 @@ libvirt_qemu_la_CFLAGS = $(AM_CFLAGS) libvirt_qemu_la_LIBADD = libvirt.la $(CYGWIN_EXTRA_LIBADD) EXTRA_DIST += $(LIBVIRT_QEMU_SYMBOL_FILE) + +noinst_LTLIBRARIES += libvirt-net-rpc.la + +libvirt_net_rpc_la_SOURCES = \ + ../daemon/event.c \ + rpc/virnetprotocol.h rpc/virnetprotocol.c \ + rpc/virnetmessage.h rpc/virnetmessage.c \ + rpc/virnettlscontext.h rpc/virnettlscontext.c \ + rpc/virnetsocket.h rpc/virnetsocket.c +libvirt_net_rpc_la_CFLAGS = \ + $(GNUTLS_CFLAGS) \ + $(SASL_CFLAGS) \ + $(AM_CFLAGS) +libvirt_net_rpc_la_LDFLAGS = \ + $(GNUTLS_LIBS) \ + $(SASL_LIBS) \ + $(AM_LDFLAGS) \ + $(CYGWIN_EXTRA_LDFLAGS) \ + $(MINGW_EXTRA_LDFLAGS)l +libvirt_net_rpc_la_LIBADD = \ + $(CYGWIN_EXTRA_LIBADD) + libexec_PROGRAMS = if WITH_STORAGE_DISK diff --git a/src/rpc/virnetmessage.c b/src/rpc/virnetmessage.c new file mode 100644 index 0000000..9bd7557 --- /dev/null +++ b/src/rpc/virnetmessage.c @@ -0,0 +1,215 @@ +#include <config.h> + +#include "virnetmessage.h" + +#include "virterror_internal.h" +#include "logging.h" + +#define virNetError(code, ...) \ + virReportErrorHelper(NULL, VIR_FROM_RPC, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + + +int virNetMessageDecodeLength(virNetMessagePtr msg) +{ + XDR xdr; + unsigned int len; + int ret = -1; + + xdrmem_create(&xdr, msg->buffer, + msg->bufferLength, XDR_DECODE); + if (!xdr_u_int(&xdr, &len)) { + virNetError(VIR_ERR_RPC, "%s", _("Unable to decode message length")); + goto cleanup; + } + msg->bufferOffset = xdr_getpos(&xdr); + + if (len < VIR_NET_MESSAGE_LEN_MAX) { + virNetError(VIR_ERR_RPC, "%s", + _("packet received from server too small")); + goto cleanup; + } + + /* Length includes length word - adjust to real length to read. */ + len -= VIR_NET_MESSAGE_LEN_MAX; + + if (len > VIR_NET_MESSAGE_MAX) { + virNetError(VIR_ERR_RPC, "%s", + _("packet received from server too large")); + goto cleanup; + } + + /* Extend our declared buffer length and carry + on reading the header + payload */ + msg->bufferLength += len; + + VIR_DEBUG("Got length, now need %d total (%d more)", msg->bufferLength, len); + + ret = 0; + +cleanup: + xdr_destroy(&xdr); + return ret; +} + + +/* + * @msg: the complete incoming message, whose header to decode + * + * Decodes the header part of the message, but does not + * validate the decoded fields in the header. It expects + * bufferLength to refer to length of the data packet. Upon + * return bufferOffset will refer to the amount of the packet + * consumed by decoding of the header. + * + * returns 0 if successfully decoded, -1 upon fatal error + */ +int virNetMessageDecodeHeader(virNetMessagePtr msg) +{ + XDR xdr; + int ret = -1; + + msg->bufferOffset = VIR_NET_MESSAGE_LEN_MAX; + + /* Parse the header. */ + xdrmem_create(&xdr, + msg->buffer + msg->bufferOffset, + msg->bufferLength - msg->bufferOffset, + XDR_DECODE); + + if (!xdr_virNetMessageHeader(&xdr, &msg->header)) { + virNetError(VIR_ERR_RPC, "%s", _("Unable to decode message header")); + goto cleanup; + } + + msg->bufferOffset += xdr_getpos(&xdr); + + ret = 0; + +cleanup: + xdr_destroy(&xdr); + return ret; +} + + +/* + * @msg: the outgoing message, whose header to encode + * + * Encodes the length word and header of the message, setting the + * message offset ready to encode the payload. Leaves space + * for the length field later. Upon return bufferLength will + * refer to the total available space for message, while + * bufferOffset will refer to current space used by header + * + * returns 0 if successfully encoded, -1 upon fatal error + */ +int virNetMessageEncodeHeader(virNetMessagePtr msg) +{ + XDR xdr; + int ret = -1; + unsigned int len = 0; + + msg->bufferLength = sizeof(msg->buffer); + msg->bufferOffset = 0; + + /* Format the header. */ + xdrmem_create(&xdr, + msg->buffer, + msg->bufferLength, + XDR_ENCODE); + + /* The real value is filled in shortly */ + if (!xdr_u_int(&xdr, &len)) { + virNetError(VIR_ERR_RPC, "%s", _("Unable to encode message length")); + goto cleanup; + } + + if (!xdr_virNetMessageHeader(&xdr, &msg->header)) { + virNetError(VIR_ERR_RPC, "%s", _("Unable to encode message header")); + goto cleanup; + } + + len = xdr_getpos(&xdr); + xdr_setpos(&xdr, 0); + + /* Fill in current length - may be re-written later + * if a payload is added + */ + if (!xdr_u_int(&xdr, &len)) { + virNetError(VIR_ERR_RPC, "%s", _("Unable to re-encode message length")); + goto cleanup; + } + + msg->bufferOffset += len; + + ret = 0; + +cleanup: + xdr_destroy(&xdr); + return ret; +} + + +int virNetMessageEncodePayload(virNetMessagePtr msg, + xdrproc_t filter, + void *data) +{ + XDR xdr; + + /* Serialise header followed by args. */ + xdrmem_create(&xdr, msg->buffer + msg->bufferOffset, + msg->bufferLength - msg->bufferOffset, XDR_ENCODE); + + if (!(*filter)(&xdr, data)) { + virNetError(VIR_ERR_RPC, "%s", _("Unable to encode message payload")); + goto error; + } + + /* Get the length stored in buffer. */ + msg->bufferOffset += xdr_getpos(&xdr); + xdr_destroy(&xdr); + + /* Re-encode the length word. */ + VIR_DEBUG("Encode length as %d", msg->bufferLength); + xdrmem_create(&xdr, msg->buffer, VIR_NET_MESSAGE_HEADER_XDR_LEN, XDR_ENCODE); + if (!xdr_u_int(&xdr, &msg->bufferOffset)) { + virNetError(VIR_ERR_RPC, "%s", _("Unable to encode message length")); + goto error; + } + xdr_destroy(&xdr); + + msg->bufferLength = msg->bufferOffset; + msg->bufferOffset = 0; + return 0; + +error: + xdr_destroy(&xdr); + return -1; +} + + +int virNetMessageDecodePayload(virNetMessagePtr msg, + xdrproc_t filter, + void *data) +{ + XDR xdr; + + /* Serialise header followed by args. */ + xdrmem_create(&xdr, msg->buffer + msg->bufferOffset, + msg->bufferLength - msg->bufferOffset, XDR_DECODE); + + if (!(*filter)(&xdr, data)) { + virNetError(VIR_ERR_RPC, "%s", _("Unable to decode message payload")); + goto error; + } + + /* Get the length stored in buffer. */ + msg->bufferLength += xdr_getpos(&xdr); + xdr_destroy(&xdr); + return 0; + +error: + xdr_destroy(&xdr); + return -1; +} + diff --git a/src/rpc/virnetmessage.h b/src/rpc/virnetmessage.h new file mode 100644 index 0000000..73d417b --- /dev/null +++ b/src/rpc/virnetmessage.h @@ -0,0 +1,31 @@ +#ifndef __VIR_NET_MESSAGE_H__ +#define __VIR_NET_MESSAGE_H__ + +#include "virnetprotocol.h" + +typedef struct virNetMessageHeader *virNetMessageHeaderPtr; + +typedef struct _virNetMessage virNetMessage; +typedef virNetMessage *virNetMessagePtr; + +struct _virNetMessage { + char buffer[VIR_NET_MESSAGE_MAX + VIR_NET_MESSAGE_LEN_MAX]; + unsigned int bufferLength; + unsigned int bufferOffset; + + virNetMessageHeader header; +}; + +int virNetMessageEncodeHeader(virNetMessagePtr msg); +int virNetMessageDecodeLength(virNetMessagePtr msg); +int virNetMessageDecodeHeader(virNetMessagePtr msg); + +int virNetMessageEncodePayload(virNetMessagePtr msg, + xdrproc_t filter, + void *data); +int virNetMessageDecodePayload(virNetMessagePtr msg, + xdrproc_t filter, + void *data); + +#endif /* __VIR_NET_MESSAGE_H__ */ + diff --git a/src/rpc/virnetprotocol.c b/src/rpc/virnetprotocol.c new file mode 100644 index 0000000..0a803ae --- /dev/null +++ b/src/rpc/virnetprotocol.c @@ -0,0 +1,108 @@ +#include <config.h> +/* + * Please do not edit this file. + * It was generated using rpcgen. + */ + +#include "./rpc/virnetprotocol.h" +#include "internal.h" +#ifdef HAVE_XDR_U_INT64_T +# define xdr_uint64_t xdr_u_int64_t +#endif +#ifndef IXDR_PUT_INT32 +# define IXDR_PUT_INT32 IXDR_PUT_LONG +#endif +#ifndef IXDR_GET_INT32 +# define IXDR_GET_INT32 IXDR_GET_LONG +#endif +#ifndef IXDR_PUT_U_INT32 +# define IXDR_PUT_U_INT32 IXDR_PUT_U_LONG +#endif +#ifndef IXDR_GET_U_INT32 +# define IXDR_GET_U_INT32 IXDR_GET_U_LONG +#endif + +bool_t +xdr_virNetMessageType (XDR *xdrs, virNetMessageType *objp) +{ + + if (!xdr_enum (xdrs, (enum_t *) objp)) + return FALSE; + return TRUE; +} + +bool_t +xdr_virNetMessageStatus (XDR *xdrs, virNetMessageStatus *objp) +{ + + if (!xdr_enum (xdrs, (enum_t *) objp)) + return FALSE; + return TRUE; +} + +bool_t +xdr_virNetMessageHeader (XDR *xdrs, virNetMessageHeader *objp) +{ + register int32_t *buf; + + + if (xdrs->x_op == XDR_ENCODE) { + buf = (int32_t*)XDR_INLINE (xdrs, 3 * BYTES_PER_XDR_UNIT); + if (buf == NULL) { + if (!xdr_u_int (xdrs, &objp->prog)) + return FALSE; + if (!xdr_u_int (xdrs, &objp->vers)) + return FALSE; + if (!xdr_int (xdrs, &objp->proc)) + return FALSE; + + } else { + (void)IXDR_PUT_U_INT32(buf, objp->prog); + (void)IXDR_PUT_U_INT32(buf, objp->vers); + (void)IXDR_PUT_INT32(buf, objp->proc); + } + if (!xdr_virNetMessageType (xdrs, &objp->type)) + return FALSE; + if (!xdr_u_int (xdrs, &objp->serial)) + return FALSE; + if (!xdr_virNetMessageStatus (xdrs, &objp->status)) + return FALSE; + return TRUE; + } else if (xdrs->x_op == XDR_DECODE) { + buf = (int32_t*)XDR_INLINE (xdrs, 3 * BYTES_PER_XDR_UNIT); + if (buf == NULL) { + if (!xdr_u_int (xdrs, &objp->prog)) + return FALSE; + if (!xdr_u_int (xdrs, &objp->vers)) + return FALSE; + if (!xdr_int (xdrs, &objp->proc)) + return FALSE; + + } else { + objp->prog = IXDR_GET_U_LONG(buf); + objp->vers = IXDR_GET_U_LONG(buf); + objp->proc = IXDR_GET_INT32(buf); + } + if (!xdr_virNetMessageType (xdrs, &objp->type)) + return FALSE; + if (!xdr_u_int (xdrs, &objp->serial)) + return FALSE; + if (!xdr_virNetMessageStatus (xdrs, &objp->status)) + return FALSE; + return TRUE; + } + + if (!xdr_u_int (xdrs, &objp->prog)) + return FALSE; + if (!xdr_u_int (xdrs, &objp->vers)) + return FALSE; + if (!xdr_int (xdrs, &objp->proc)) + return FALSE; + if (!xdr_virNetMessageType (xdrs, &objp->type)) + return FALSE; + if (!xdr_u_int (xdrs, &objp->serial)) + return FALSE; + if (!xdr_virNetMessageStatus (xdrs, &objp->status)) + return FALSE; + return TRUE; +} diff --git a/src/rpc/virnetprotocol.h b/src/rpc/virnetprotocol.h new file mode 100644 index 0000000..4ef8e20 --- /dev/null +++ b/src/rpc/virnetprotocol.h @@ -0,0 +1,81 @@ +/* + * Please do not edit this file. + * It was generated using rpcgen. + */ + +#ifndef _RP_NET_H_RPCGEN +#define _RP_NET_H_RPCGEN + +#include <rpc/rpc.h> + + +#ifdef __cplusplus +extern "C" { +#endif + +#include "internal.h" +#ifdef HAVE_XDR_U_INT64_T +# define xdr_uint64_t xdr_u_int64_t +#endif +#ifndef IXDR_PUT_INT32 +# define IXDR_PUT_INT32 IXDR_PUT_LONG +#endif +#ifndef IXDR_GET_INT32 +# define IXDR_GET_INT32 IXDR_GET_LONG +#endif +#ifndef IXDR_PUT_U_INT32 +# define IXDR_PUT_U_INT32 IXDR_PUT_U_LONG +#endif +#ifndef IXDR_GET_U_INT32 +# define IXDR_GET_U_INT32 IXDR_GET_U_LONG +#endif +#define VIR_NET_MESSAGE_MAX 262144 +#define VIR_NET_MESSAGE_HEADER_MAX 24 +#define VIR_NET_MESSAGE_PAYLOAD_MAX 262120 +#define VIR_NET_MESSAGE_LEN_MAX 4 + +enum virNetMessageType { + VIR_NET_CALL = 0, + VIR_NET_REPLY = 1, + VIR_NET_MESSAGE = 2, + VIR_NET_STREAM = 3, +}; +typedef enum virNetMessageType virNetMessageType; + +enum virNetMessageStatus { + VIR_NET_OK = 0, + VIR_NET_ERROR = 1, + VIR_NET_CONTINUE = 2, +}; +typedef enum virNetMessageStatus virNetMessageStatus; +#define VIR_NET_MESSAGE_HEADER_XDR_LEN 4 + +struct virNetMessageHeader { + u_int prog; + u_int vers; + int proc; + virNetMessageType type; + u_int serial; + virNetMessageStatus status; +}; +typedef struct virNetMessageHeader virNetMessageHeader; + +/* the xdr functions */ + +#if defined(__STDC__) || defined(__cplusplus) +extern bool_t xdr_virNetMessageType (XDR *, virNetMessageType*); +extern bool_t xdr_virNetMessageStatus (XDR *, virNetMessageStatus*); +extern bool_t xdr_virNetMessageHeader (XDR *, virNetMessageHeader*); + +#else /* K&R C */ +extern bool_t xdr_virNetMessageType (); +extern bool_t xdr_virNetMessageStatus (); +extern bool_t xdr_virNetMessageHeader (); + +#endif /* K&R C */ + +#ifdef __cplusplus +} +#endif + +#endif /* !_RP_NET_H_RPCGEN */ diff --git a/src/rpc/virnetprotocol.x b/src/rpc/virnetprotocol.x new file mode 100644 index 0000000..1d74153 --- /dev/null +++ b/src/rpc/virnetprotocol.x @@ -0,0 +1,162 @@ +/* -*- c -*- + * virnetprotocol.x: basic protocol for all RPC services. + * + * Copyright (C) 2006-2010 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Richard Jones <rjones@xxxxxxxxxx> + */ + +%#include "internal.h" + +/* cygwin's xdr implementation defines xdr_u_int64_t instead of xdr_uint64_t + * and lacks IXDR_PUT_INT32 and IXDR_GET_INT32 + */ +%#ifdef HAVE_XDR_U_INT64_T +%# define xdr_uint64_t xdr_u_int64_t +%#endif +%#ifndef IXDR_PUT_INT32 +%# define IXDR_PUT_INT32 IXDR_PUT_LONG +%#endif +%#ifndef IXDR_GET_INT32 +%# define IXDR_GET_INT32 IXDR_GET_LONG +%#endif +%#ifndef IXDR_PUT_U_INT32 +%# define IXDR_PUT_U_INT32 IXDR_PUT_U_LONG +%#endif +%#ifndef IXDR_GET_U_INT32 +%# define IXDR_GET_U_INT32 IXDR_GET_U_LONG +%#endif + +/*----- Data types. -----*/ + +/* Maximum total message size (serialised). */ +const VIR_NET_MESSAGE_MAX = 262144; + +/* Size of struct virNetMessageHeader (serialized)*/ +const VIR_NET_MESSAGE_HEADER_MAX = 24; + +/* Size of message payload */ +const VIR_NET_MESSAGE_PAYLOAD_MAX = 262120; + +/* Size of message length field. Not counted in VIR_NET_MESSAGE_MAX */ +const VIR_NET_MESSAGE_LEN_MAX = 4; + +/* + * RPC wire format + * + * Each message consists of: + * + * Name | Type | Description + * -----------+-----------------------+------------------ + * Length | int | Total number of bytes in message _including_ length. + * Header | virNetMessageHeader | Control information about procedure call + * Payload | - | Variable payload data per procedure + * + * In header, the 'serial' field varies according to: + * + * - type == VIR_NET_CALL + * * serial is set by client, incrementing by 1 each time + * + * - type == VIR_NET_REPLY + * * serial matches that from the corresponding VIR_NET_CALL + * + * - type == VIR_NET_MESSAGE + * * serial is always zero + * + * - type == VIR_NET_STREAM + * * serial matches that from the corresponding VIR_NET_CALL + * + * and the 'status' field varies according to: + * + * - type == VIR_NET_CALL + * * VIR_NET_OK always + * + * - type == VIR_NET_REPLY + * * VIR_NET_OK if RPC finished successfully + * * VIR_NET_ERROR if something failed + * + * - type == VIR_NET_MESSAGE + * * VIR_NET_OK always + * + * - type == VIR_NET_STREAM + * * VIR_NET_CONTINUE if more data is following + * * VIR_NET_OK if stream is complete + * * VIR_NET_ERROR if stream had an error + * + * Payload varies according to type and status: + * + * - type == VIR_NET_CALL + * XXX_args for procedure + * + * - type == VIR_NET_REPLY + * * status == VIR_NET_OK + * XXX_ret for procedure + * * status == VIR_NET_ERROR + * remote_error Error information + * + * - type == VIR_NET_MESSAGE + * * status == VIR_NET_OK + * XXX_args for procedure + * * status == VIR_NET_ERROR + * remote_error Error information + * + * - type == VIR_NET_STREAM + * * status == VIR_NET_CONTINUE + * byte[] raw stream data + * * status == VIR_NET_ERROR + * remote_error error information + * * status == VIR_NET_OK + * <empty> + */ +enum virNetMessageType { + /* client -> server. args from a method call */ + VIR_NET_CALL = 0, + /* server -> client. reply/error from a method call */ + VIR_NET_REPLY = 1, + /* either direction. async notification */ + VIR_NET_MESSAGE = 2, + /* either direction. stream data packet */ + VIR_NET_STREAM = 3 +}; + +enum virNetMessageStatus { + /* Status is always VIR_NET_OK for calls. + * For replies, indicates no error. + */ + VIR_NET_OK = 0, + + /* For replies, indicates that an error happened, and a struct + * remote_error follows. + */ + VIR_NET_ERROR = 1, + + /* For streams, indicates that more data is still expected + */ + VIR_NET_CONTINUE = 2 +}; + +/* 4 byte length word per header */ +const VIR_NET_MESSAGE_HEADER_XDR_LEN = 4; + +struct virNetMessageHeader { + unsigned prog; /* Unique ID for the program */ + unsigned vers; /* Program version number */ + int proc; /* Unique ID for the procedure within the program */ + virNetMessageType type; /* Type of message */ + unsigned serial; /* Serial number of message. */ + virNetMessageStatus status; +}; diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c new file mode 100644 index 0000000..de66103 --- /dev/null +++ b/src/rpc/virnetsocket.c @@ -0,0 +1,715 @@ +/* + * virnetsocket.h: generic network socket handling + * + * Copyright (C) 2006-2010 Red Hat, Inc. + * Copyright (C) 2006 Daniel P. Berrange + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@xxxxxxxxxx> + */ + +#include <config.h> + +#include <sys/stat.h> +#include <unistd.h> +#include <netinet/tcp.h> +#include <sys/wait.h> + +#include "virnetsocket.h" +#include "util.h" +#include "memory.h" +#include "virterror_internal.h" +#include "logging.h" +#include "files.h" +#include "event.h" + +#define VIR_FROM_THIS VIR_FROM_RPC + +#define virNetError(code, ...) \ + virReportErrorHelper(NULL, VIR_FROM_RPC, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + + +struct _virNetSocket { + int fd; + int watch; + pid_t pid; + int errfd; + virNetSocketIOFunc func; + void *opaque; + virSocketAddr localAddr; + virSocketAddr remoteAddr; + char *localAddrStr; + char *remoteAddrStr; +}; + + +static int virNetSocketForkDaemon(const char *binary) +{ + const char *const daemonargs[] = { binary, "--timeout=30", NULL }; + pid_t pid; + + if (virExecDaemonize(daemonargs, NULL, NULL, + &pid, -1, NULL, NULL, + VIR_EXEC_CLEAR_CAPS, + NULL, NULL, NULL) < 0) + return -1; + + return 0; +} + + +static virNetSocketPtr virNetSocketNew(virSocketAddrPtr localAddr, + virSocketAddrPtr remoteAddr, + int fd, int errfd, pid_t pid) +{ + virNetSocketPtr sock; + int no_slow_start = 1; + + if (virSetCloseExec(fd) < 0 || + virSetNonBlock(fd) < 0) + return NULL; + + if (VIR_ALLOC(sock) < 0) { + virReportOOMError(); + return NULL; + } + + sock->localAddr = *localAddr; + if (remoteAddr) + sock->remoteAddr = *remoteAddr; + sock->fd = fd; + sock->errfd = errfd; + sock->pid = pid; + + /* Disable nagle */ + if (sock->localAddr.data.sa.sa_family != AF_UNIX) + setsockopt (fd, IPPROTO_TCP, TCP_NODELAY, + (void *)&no_slow_start, + sizeof(no_slow_start)); + + + if (!(sock->localAddrStr = virSocketFormatAddrFull(localAddr, true, ";"))) { + VIR_FREE(sock); + return NULL; + } + + if (remoteAddr && + !(sock->remoteAddrStr = virSocketFormatAddrFull(remoteAddr, true, ";"))) { + VIR_FREE(sock); + return NULL; + } + + VIR_DEBUG("sock=%p", sock); + + return sock; +} + + +int virNetSocketNewListenTCP(const char *nodename, + const char *service, + virNetSocketPtr **retsocks, + size_t *nretsocks) +{ + virNetSocketPtr *socks = NULL; + size_t nsocks = 0; + struct addrinfo *ai; + struct addrinfo hints; + int fd = -1; + + *retsocks = NULL; + *nretsocks = 0; + + memset (&hints, 0, sizeof hints); + hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG; + hints.ai_socktype = SOCK_STREAM; + + int e = getaddrinfo (nodename, service, &hints, &ai); + if (e != 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to resolve address '%s' service '%s': %s"), + nodename, service, gai_strerror (e)); + goto error; + } + + struct addrinfo *runp = ai; + while (runp) { + virSocketAddr addr; + + if ((fd = socket(runp->ai_family, runp->ai_socktype, + runp->ai_protocol)) < 0) { + virReportSystemError(errno, "%s", _("Unable to create socket")); + goto error; + } + + int opt = 1; + setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof opt); + +#ifdef IPV6_V6ONLY + if (runp->ai_family == PF_INET6) { + int on = 1; + /* + * Normally on Linux an INET6 socket will bind to the INET4 + * address too. If getaddrinfo returns results with INET4 + * first though, this will result in INET6 binding failing. + * We can trivially cope with multiple server sockets, so + * we force it to only listen on IPv6 + */ + setsockopt(fd, IPPROTO_IPV6,IPV6_V6ONLY, + (void*)&on, sizeof on); + } +#endif + + if (bind(fd, runp->ai_addr, runp->ai_addrlen) < 0) { + if (errno != EADDRINUSE) { + virReportSystemError(errno, "%s", _("Unable to bind to port")); + goto error; + } + VIR_FORCE_CLOSE(fd); + continue; + } + + if (getsockname(fd, &addr.data.sa, &addr.len) < 0) { + virReportSystemError(errno, "%s", _("Unable to get local socket name")); + goto error; + } + + + if (VIR_EXPAND_N(socks, nsocks, 1) < 0) { + virReportOOMError(); + goto error; + } + + if (!(socks[nsocks-1] = virNetSocketNew(&addr, NULL, fd, -1, 0))) + goto error; + runp = runp->ai_next; + } + + freeaddrinfo (ai); + + *retsocks = socks; + *nretsocks = nsocks; + return 0; + +error: + VIR_FORCE_CLOSE(fd); + return -1; +} + +int virNetSocketNewListenUNIX(const char *path, + mode_t mask, + gid_t grp, + virNetSocketPtr *retsock) +{ + virSocketAddr addr; + mode_t oldmask; + gid_t oldgrp; + int fd; + + *retsock = NULL; + + memset(&addr, 0, sizeof(addr)); + + addr.len = sizeof(addr.data.un); + + if ((fd = socket(PF_UNIX, SOCK_STREAM, 0)) < 0) { + virReportSystemError(errno, "%s", _("Failed to create socket")); + goto error; + } + + addr.data.un.sun_family = AF_UNIX; + if (virStrcpyStatic(addr.data.un.sun_path, path) == NULL) { + virReportSystemError(ENOMEM, _("Path %s too long for unix socket"), path); + goto error; + } + if (addr.data.un.sun_path[0] == '@') + addr.data.un.sun_path[0] = '\0'; + else + unlink(addr.data.un.sun_path); + + oldgrp = getgid(); + oldmask = umask(~mask); + if (grp != 0 && setgid(grp) < 0) { + virReportSystemError(errno, + _("Failed to set group ID to %d"), grp); + goto error; + } + + if (bind(fd, &addr.data.sa, addr.len) < 0) { + virReportSystemError(errno, + _("Failed to bind socket to '%s'"), + path); + goto error; + } + umask(oldmask); + if (grp != 0 && setgid(oldgrp)) { + virReportSystemError(errno, + _("Failed to restore group ID to %d"), oldgrp); + goto error; + } + + if (!(*retsock = virNetSocketNew(&addr, NULL, fd, -1, 0))) + goto error; + + return 0; + +error: + VIR_FORCE_CLOSE(fd); + return -1; +} + + +int virNetSocketNewConnectTCP(const char *nodename, + const char *service, + virNetSocketPtr *retsock) +{ + virNetSocketPtr *socks = NULL; + size_t nsocks = 0; + struct addrinfo *ai; + struct addrinfo hints; + int fd = -1; + virSocketAddr localAddr; + virSocketAddr remoteAddr; + struct addrinfo *runp; + int savedErrno = ENOENT; + + *retsock = NULL; + + memset(&localAddr, 0, sizeof(localAddr)); + memset(&remoteAddr, 0, sizeof(remoteAddr)); + + memset (&hints, 0, sizeof hints); + hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG; + hints.ai_socktype = SOCK_STREAM; + + int e = getaddrinfo (nodename, service, &hints, &ai); + if (e != 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to resolve address '%s' service '%s': %s"), + nodename, service, gai_strerror (e)); + goto error; + } + + runp = ai; + while (runp) { + int opt = 1; + + if ((fd = socket(runp->ai_family, runp->ai_socktype, + runp->ai_protocol)) < 0) { + virReportSystemError(errno, "%s", _("Unable to create socket")); + goto error; + } + + setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof opt); + + if (connect(fd, runp->ai_addr, runp->ai_addrlen) >= 0) + break; + + savedErrno = errno; + VIR_FORCE_CLOSE(fd); + runp = runp->ai_next; + } + + if (fd == -1) { + virReportSystemError(savedErrno, + _("unable to connect to server at '%s:%s'"), + nodename, service); + return -1; + } + + freeaddrinfo (ai); + + if (getsockname(fd, &localAddr.data.sa, &localAddr.len) < 0) { + virReportSystemError(errno, "%s", _("Unable to get local socket name")); + goto error; + } + + if (getpeername(fd, &remoteAddr.data.sa, &remoteAddr.len) < 0) { + virReportSystemError(errno, "%s", _("Unable to get remote socket name")); + goto error; + } + + if (VIR_EXPAND_N(socks, nsocks, 1) < 0) { + virReportOOMError(); + goto error; + } + + if (!(*retsock = virNetSocketNew(&localAddr, &remoteAddr, fd, -1, 0))) + goto error; + + return 0; + +error: + VIR_FORCE_CLOSE(fd); + return -1; +} + + +int virNetSocketNewConnectUNIX(const char *path, + bool spawnDaemon, + const char *binary, + virNetSocketPtr *retsock) +{ + virSocketAddr localAddr; + virSocketAddr remoteAddr; + int fd; + int retries = 0; + + memset(&localAddr, 0, sizeof(localAddr)); + memset(&remoteAddr, 0, sizeof(remoteAddr)); + + remoteAddr.len = sizeof(remoteAddr.data.un); + + if ((fd = socket(PF_UNIX, SOCK_STREAM, 0)) < 0) { + virReportSystemError(errno, "%s", _("Failed to create socket")); + goto error; + } + + remoteAddr.data.un.sun_family = AF_UNIX; + if (virStrcpyStatic(remoteAddr.data.un.sun_path, path) == NULL) { + virReportSystemError(ENOMEM, _("Path %s too long for unix socket"), path); + goto error; + } + if (remoteAddr.data.un.sun_path[0] == '@') + remoteAddr.data.un.sun_path[0] = '\0'; + else + unlink(remoteAddr.data.un.sun_path); + +retry: + if (connect(fd, &remoteAddr.data.sa, remoteAddr.len) < 0) { + if (errno == ECONNREFUSED && spawnDaemon && retries < 20) { + if (retries == 0 && + virNetSocketForkDaemon(binary) < 0) + goto error; + + retries++; + usleep(1000 * 100 * retries); + goto retry; + } + + virReportSystemError(errno, + _("Failed to connect socket to '%s'"), + path); + goto error; + } + +#if 0 + /* There is no meaningful local addr for UNIX sockets, + * and getsockname() returns something with AF_INET + * in sa_family when run against AF_JUNIX sockets ! + */ + if (getsockname(fd, &localAddr.data.sa, &localAddr.len) < 0) { + virReportSystemError(errno, "%s", _("Unable to get local socket name")); + goto error; + } +#else + localAddr.data.sa.sa_family = AF_UNIX; +#endif + + VIR_WARN("%d %d", localAddr.len, localAddr.data.sa.sa_family); + + if (!(*retsock = virNetSocketNew(&localAddr, &remoteAddr, fd, -1, 0))) + goto error; + + return 0; + +error: + VIR_FORCE_CLOSE(fd); + return -1; +} + +int virNetSocketNewConnectSSH(const char *nodename, + const char *service, + const char *binary, + const char *username, + bool noTTY, + const char *netcat, + const char *path, + virNetSocketPtr *retsock) +{ + const char **cmdargv = NULL; + int ncmdargv; + + *retsock = NULL; + + ncmdargv = 6; + if (username) ncmdargv += 2; /* For -l username */ + if (noTTY) ncmdargv += 5; /* For -T -o BatchMode=yes -e none */ + if (service) ncmdargv += 2; /* For -p port */ + + /* + * Generate the final command argv[] array. + * ssh [-p $port] [-l $username] $hostname $netcat -U $sockname [NULL] + */ + if (VIR_ALLOC_N(cmdargv, ncmdargv) < 0) + goto no_memory; + + ncmdargv = 0; + cmdargv[ncmdargv++] = binary; + if (service) { + cmdargv[ncmdargv++] = "-p"; + cmdargv[ncmdargv++] = service; + } + if (username) { + cmdargv[ncmdargv++] = "-l"; + cmdargv[ncmdargv++] = username; + } + if (noTTY) { + cmdargv[ncmdargv++] = "-T"; + cmdargv[ncmdargv++] = "-o"; + cmdargv[ncmdargv++] = "BatchMode=yes"; + cmdargv[ncmdargv++] = "-e"; + cmdargv[ncmdargv++] = "none"; + } + cmdargv[ncmdargv++] = nodename; + cmdargv[ncmdargv++] = netcat ? netcat : "nc"; + cmdargv[ncmdargv++] = "-U"; + cmdargv[ncmdargv++] = path; + cmdargv[ncmdargv++] = NULL; + + return virNetSocketNewConnectCommand(cmdargv, NULL, retsock); + +no_memory: + VIR_FREE(cmdargv); + virReportOOMError(); + return -1; +} + + +int virNetSocketNewConnectCommand(const char **cmdargv, + const char **cmdenv, + virNetSocketPtr *retsock) +{ + pid_t pid = 0; + int sv[2]; + int errfd[2]; + + *retsock = NULL; + + /* Fork off the external process. Use socketpair to create a private + * (unnamed) Unix domain socket to the child process so we don't have + * to faff around with two file descriptors (a la 'pipe(2)'). + */ + if (socketpair(PF_UNIX, SOCK_STREAM, 0, sv) < 0) { + virReportSystemError(errno, "%s", + _("unable to create socket pair")); + goto error; + } + + if (pipe(errfd) < 0) { + virReportSystemError(errno, "%s", + _("unable to create socket pair")); + goto error; + } + + if (virExec(cmdargv, cmdenv, NULL, + &pid, sv[1], &(sv[1]), &(errfd[1]), + VIR_EXEC_CLEAR_CAPS) < 0) + goto error; + + /* Parent continues here. */ + VIR_FORCE_CLOSE(sv[1]); + VIR_FORCE_CLOSE(errfd[1]); + + if (!(*retsock = virNetSocketNew(NULL, NULL, sv[0], errfd[0], pid))) + goto error; + + return 0; + +error: + VIR_FORCE_CLOSE(sv[0]); + VIR_FORCE_CLOSE(sv[1]); + VIR_FORCE_CLOSE(errfd[1]); + VIR_FORCE_CLOSE(errfd[1]); + + if (pid > 0) { + pid_t reap; + do { +retry: + reap = waitpid(pid, NULL, 0); + if (reap == -1 && errno == EINTR) + goto retry; + } while (reap != -1 && reap != pid); + } + + return -1; +} + + +void virNetSocketFree(virNetSocketPtr sock) +{ + VIR_DEBUG("sock=%p", sock); + + if (!sock) + return; + + if (sock->watch) { + virEventRemoveHandle(sock->watch); + sock->watch = -1; + } + + if (sock->localAddr.data.sa.sa_family == AF_UNIX && + sock->localAddr.data.un.sun_path[0] != '\0') + unlink(sock->localAddr.data.un.sun_path); + + VIR_FORCE_CLOSE(sock->fd); + VIR_FORCE_CLOSE(sock->errfd); + + if (sock->pid > 0) { + pid_t reap; + kill(sock->pid, SIGTERM); + do { +retry: + reap = waitpid(sock->pid, NULL, 0); + if (reap == -1 && errno == EINTR) + goto retry; + } while (reap != -1 && reap != sock->pid); + } + + VIR_FREE(sock->localAddrStr); + VIR_FREE(sock->remoteAddrStr); + + VIR_FREE(sock); +} + + +int virNetSocketFD(virNetSocketPtr sock) +{ + return sock->fd; +} + + +const char *virNetSocketLocalAddrString(virNetSocketPtr sock) +{ + return sock->localAddrStr; +} + +const char *virNetSocketRemoteAddrString(virNetSocketPtr sock) +{ + return sock->remoteAddrStr; +} + +ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len) +{ + return read(sock->fd, buf, len); +} + +ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len) +{ + return write(sock->fd, buf, len); +} + + +int virNetSocketListen(virNetSocketPtr sock) +{ + if (listen(sock->fd, 30) < 0) { + virReportSystemError(errno, "%s", _("Unable to listen on socket")); + return -1; + } + return 0; +} + +int virNetSocketAccept(virNetSocketPtr sock, virNetSocketPtr *clientsock) +{ + int fd; + virSocketAddr localAddr; + virSocketAddr remoteAddr; + + *clientsock = NULL; + + memset(&localAddr, 0, sizeof(localAddr)); + memset(&remoteAddr, 0, sizeof(remoteAddr)); + + remoteAddr.len = sizeof(remoteAddr.data.stor); + if ((fd = accept(sock->fd, &remoteAddr.data.sa, &remoteAddr.len)) < 0) { + if (errno == ECONNABORTED || + errno == EAGAIN) + return 0; + + virReportSystemError(errno, "%s", + _("Unable to accept client")); + return -1; + } + + if (getsockname(fd, &localAddr.data.sa, &localAddr.len) < 0) { + virReportSystemError(errno, "%s", _("Unable to get local socket name")); + VIR_FORCE_CLOSE(fd); + return -1; + } + + + if (!(*clientsock = virNetSocketNew(&localAddr, &remoteAddr, fd, -1, 0))) + return -1; + + return 0; +} + + +static void virNetSocketEventHandle(int fd ATTRIBUTE_UNUSED, + int watch ATTRIBUTE_UNUSED, + int events, + void *opaque) +{ + virNetSocketPtr sock = opaque; + + sock->func(sock, events, sock->opaque); +} + +int virNetSocketAddIOCallback(virNetSocketPtr sock, + int events, + virNetSocketIOFunc func, + void *opaque) +{ + if (sock->watch) { + VIR_DEBUG("Watch already registered on socket %p", sock); + return -1; + } + + if ((sock->watch = virEventAddHandle(sock->fd, + events, + virNetSocketEventHandle, + sock, + NULL)) < 0) { + VIR_WARN("Failed to register watch on socket %p", sock); + return -1; + } + sock->func = func; + sock->opaque = opaque; + + return 0; +} + +void virNetSocketUpdateIOCallback(virNetSocketPtr sock, + int events) +{ + if (!sock->watch) { + VIR_DEBUG("Watch not registered on socket %p", sock); + return; + } + + virEventUpdateHandle(sock->watch, events); +} + +void virNetSocketRemoveIOCallback(virNetSocketPtr sock) +{ + if (!sock->watch) { + VIR_DEBUG("Watch not registered on socket %p", sock); + return; + } + + virEventRemoveHandle(sock->watch); +} + diff --git a/src/rpc/virnetsocket.h b/src/rpc/virnetsocket.h new file mode 100644 index 0000000..a25918d --- /dev/null +++ b/src/rpc/virnetsocket.h @@ -0,0 +1,97 @@ +/* + * virnetsocket.h: generic network socket handling + * + * Copyright (C) 2006-2010 Red Hat, Inc. + * Copyright (C) 2006 Daniel P. Berrange + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@xxxxxxxxxx> + */ + +#ifndef __VIR_NET_SOCKET_H__ +#define __VIR_NET_SOCKET_H__ + +#include "network.h" + +typedef struct _virNetSocket virNetSocket; +typedef virNetSocket *virNetSocketPtr; + + +typedef void (*virNetSocketIOFunc)(virNetSocketPtr sock, + int events, + void *opaque); + + +int virNetSocketNewListenTCP(const char *nodename, + const char *service, + virNetSocketPtr **addrs, + size_t *naddrs); + +int virNetSocketNewListenUNIX(const char *path, + mode_t mask, + gid_t grp, + virNetSocketPtr *addr); + +int virNetSocketNewConnectTCP(const char *nodename, + const char *service, + virNetSocketPtr *addr); + +int virNetSocketNewConnectUNIX(const char *path, + bool spawnDaemon, + const char *binary, + virNetSocketPtr *addr); + +int virNetSocketNewConnectSSH(const char *nodename, + const char *service, + const char *binary, + const char *username, + bool noTTY, + const char *netcat, + const char *path, + virNetSocketPtr *addr); + +int virNetSocketNewConnectCommand(const char **cmdargv, + const char **cmdenv, + virNetSocketPtr *addr); + +/* XXX bad */ +int virNetSocketFD(virNetSocketPtr sock); + +ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len); +ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len); + +void virNetSocketFree(virNetSocketPtr sock); + +const char *virNetSocketLocalAddrString(virNetSocketPtr sock); +const char *virNetSocketRemoteAddrString(virNetSocketPtr sock); + +int virNetSocketListen(virNetSocketPtr sock); +int virNetSocketAccept(virNetSocketPtr sock, + virNetSocketPtr *clientsock); + +int virNetSocketAddIOCallback(virNetSocketPtr sock, + int events, + virNetSocketIOFunc func, + void *opaque); + +void virNetSocketUpdateIOCallback(virNetSocketPtr sock, + int events); + +void virNetSocketRemoveIOCallback(virNetSocketPtr sock); + + + +#endif /* __VIR_NET_SOCKET_H__ */ diff --git a/src/rpc/virnettlscontext.c b/src/rpc/virnettlscontext.c new file mode 100644 index 0000000..9690328 --- /dev/null +++ b/src/rpc/virnettlscontext.c @@ -0,0 +1,611 @@ + + +#include <config.h> + +#include <unistd.h> +#include <fnmatch.h> +#include <stdlib.h> + +#include <gnutls/gnutls.h> +#include <gnutls/x509.h> +#include "gnutls_1_0_compat.h" + +#include "virnettlscontext.h" + +#include "memory.h" +#include "virterror_internal.h" +#include "util.h" +#include "logging.h" + +#define DH_BITS 1024 + +#define VIR_FROM_THIS VIR_FROM_RPC + +#define virNetError(code, ...) \ + virReportErrorHelper(NULL, VIR_FROM_RPC, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + +struct _virNetTLSContext { + int refs; + + gnutls_certificate_credentials_t x509cred; + gnutls_dh_params_t dhParams; + + bool isServer; + bool requireValidCert; + const char *const*x509dnWhitelist; +}; + +struct _virNetTLSSession { + int refs; + + char *hostname; + gnutls_session_t session; + virNetTLSSessionWriteFunc writeFunc; + virNetTLSSessionReadFunc readFunc; + void *opaque; +}; + + +static int +virNetTLSContextCheckCertFile(const char *type, const char *file) +{ + if (access(file, R_OK) < 0) { + virReportSystemError(errno, + _("Cannot read %s '%s'"), + type, file); + return -1; + } + return 0; +} + + +static void virNetTLSLog(int level, const char *str) { + VIR_DEBUG("%d %s", level, str); +} + +static virNetTLSContextPtr virNetTLSContextNew(const char *ca_file, + const char *crl_file, + const char *cert_file, + const char *key_file, + const char *const*x509dnWhitelist, + bool requireValidCert, + bool isServer) +{ + virNetTLSContextPtr ctxt; + char *gnutlsdebug; + int err; + + if (VIR_ALLOC(ctxt) < 0) { + virReportOOMError(); + return NULL; + } + + ctxt->refs = 1; + + /* Initialise GnuTLS. */ + gnutls_global_init(); + + if ((gnutlsdebug = getenv("LIBVIRT_GNUTLS_DEBUG")) != NULL) { + int val; + if (virStrToLong_i(gnutlsdebug, NULL, 10, &val) < 0) + val = 10; + gnutls_global_set_log_level(val); + gnutls_global_set_log_function(virNetTLSLog); + } + + + err = gnutls_certificate_allocate_credentials(&ctxt->x509cred); + if (err) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to allocate x509 credentials: %s"), + gnutls_strerror (err)); + goto error; + } + + if (ca_file && ca_file[0] != '\0') { + if (virNetTLSContextCheckCertFile("CA certificate", ca_file) < 0) + goto error; + + VIR_DEBUG("loading CA cert from %s", ca_file); + err = gnutls_certificate_set_x509_trust_file(ctxt->x509cred, + ca_file, + GNUTLS_X509_FMT_PEM); + if (err < 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to set x509 CA certificate: %s"), + gnutls_strerror (err)); + goto error; + } + } + + if (crl_file && crl_file[0] != '\0') { + if (virNetTLSContextCheckCertFile("CA revocation list", crl_file) < 0) + goto error; + + VIR_DEBUG("loading CRL from %s", crl_file); + err = gnutls_certificate_set_x509_crl_file(ctxt->x509cred, + crl_file, + GNUTLS_X509_FMT_PEM); + if (err < 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to set x509 certificate revocation list: %s"), + gnutls_strerror (err)); + goto error; + } + } + + if (cert_file && cert_file[0] != '\0' && key_file && key_file[0] != '\0') { + if (virNetTLSContextCheckCertFile("server certificate", cert_file) < 0) + goto error; + if (virNetTLSContextCheckCertFile("server key", key_file) < 0) + goto error; + VIR_DEBUG("loading cert and key from %s and %s", cert_file, key_file); + err = + gnutls_certificate_set_x509_key_file(ctxt->x509cred, + cert_file, key_file, + GNUTLS_X509_FMT_PEM); + if (err < 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to set x509 key and certificate: %s"), + gnutls_strerror (err)); + goto error; + } + } + + /* Generate Diffie Hellman parameters - for use with DHE + * kx algorithms. These should be discarded and regenerated + * once a day, once a week or once a month. Depending on the + * security requirements. + */ + if (isServer) { + err = gnutls_dh_params_init(&ctxt->dhParams); + if (err < 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to initialize diffie-hellman parameters: %s"), + gnutls_strerror (err)); + goto error; + } + err = gnutls_dh_params_generate2(ctxt->dhParams, DH_BITS); + if (err < 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to generate diffie-hellman parameters: %s"), + gnutls_strerror (err)); + goto error; + } + + gnutls_certificate_set_dh_params(ctxt->x509cred, + ctxt->dhParams); + } + + ctxt->requireValidCert = requireValidCert; + ctxt->x509dnWhitelist = x509dnWhitelist; + ctxt->isServer = isServer; + + return ctxt; + +error: + if (isServer) + gnutls_dh_params_deinit(ctxt->dhParams); + gnutls_certificate_free_credentials(ctxt->x509cred); + VIR_FREE(ctxt); + return NULL; +} + + +virNetTLSContextPtr virNetTLSContextNewServer(const char *ca_file, + const char *crl_file, + const char *cert_file, + const char *key_file, + const char *const*x509dnWhitelist, + bool requireValidCert) +{ + return virNetTLSContextNew(ca_file, crl_file, cert_file, key_file, + x509dnWhitelist, requireValidCert, true); +} + +virNetTLSContextPtr virNetTLSContextNewClient(const char *ca_file, + const char *cert_file, + const char *key_file, + bool requireValidCert) +{ + return virNetTLSContextNew(ca_file, NULL, cert_file, key_file, + NULL, requireValidCert, false); +} + + +void virNetTLSContextRef(virNetTLSContextPtr ctxt) +{ + ctxt->refs++; +} + + +/* Check DN is on tls_allowed_dn_list. */ +static int +virNetTLSContextCheckDN(virNetTLSContextPtr ctxt, + const char *dname) +{ + const char *const*wildcards; + + /* If the list is not set, allow any DN. */ + wildcards = ctxt->x509dnWhitelist; + if (!wildcards) + return 1; + + while (*wildcards) { + if (fnmatch (*wildcards, dname, 0) == 0) + return 1; + wildcards++; + } + + /* Print the client's DN. */ + DEBUG(_("Failed whitelist check for client DN '%s'"), dname); + + return 0; // Not found. +} + +static int virNetTLSContextValidCertificate(virNetTLSContextPtr ctxt, + virNetTLSSessionPtr sess) +{ + int ret; + unsigned int status; + const gnutls_datum_t *certs; + unsigned int nCerts, i; + time_t now; + char name[256]; + size_t namesize = sizeof name; + + memset(name, 0, namesize); + + if ((ret = gnutls_certificate_verify_peers2(sess->session, &status)) < 0){ + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to verify TLS peer: %s"), + gnutls_strerror(ret)); + goto authdeny; + } + + if ((now = time(NULL)) == ((time_t)-1)) { + virReportSystemError(errno, "%s", + _("cannot get current time")); + goto authfail; + } + + if (status != 0) { + const char *reason = _("Invalid certificate"); + + if (status & GNUTLS_CERT_INVALID) + reason = _("The certificate is not trusted."); + + if (status & GNUTLS_CERT_SIGNER_NOT_FOUND) + reason = _("The certificate hasn't got a known issuer."); + + if (status & GNUTLS_CERT_REVOKED) + reason = _("The certificate has been revoked."); + +#ifndef GNUTLS_1_0_COMPAT + if (status & GNUTLS_CERT_INSECURE_ALGORITHM) + reason = _("The certificate uses an insecure algorithm"); +#endif + + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Certificate failed validation: %s"), + reason); + goto authdeny; + } + + if (gnutls_certificate_type_get(sess->session) != GNUTLS_CRT_X509) { + virNetError(VIR_ERR_SYSTEM_ERROR, "%s", + _("Only x509 certificates are supported")); + goto authdeny; + } + + if (!(certs = gnutls_certificate_get_peers(sess->session, &nCerts))) { + virNetError(VIR_ERR_SYSTEM_ERROR, "%s", + _("The certificate has no peers")); + goto authdeny; + } + + for (i = 0; i < nCerts; i++) { + gnutls_x509_crt_t cert; + + if (gnutls_x509_crt_init (&cert) < 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, "%s", + _("Unable to initialize certificate")); + goto authfail; + } + + if (gnutls_x509_crt_import(cert, &certs[i], GNUTLS_X509_FMT_DER) < 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, "%s", + _("Unable to load certificate")); + gnutls_x509_crt_deinit(cert); + goto authfail; + } + + if (gnutls_x509_crt_get_expiration_time(cert) < now) { + virNetError(VIR_ERR_SYSTEM_ERROR, "%s", + _("The client certificate has expired")); + gnutls_x509_crt_deinit(cert); + goto authdeny; + } + + if (gnutls_x509_crt_get_activation_time(cert) > now) { + virNetError(VIR_ERR_SYSTEM_ERROR, "%s", + _("The client certificate is not yet active")); + gnutls_x509_crt_deinit(cert); + goto authdeny; + } + + if (i == 0) { + ret = gnutls_x509_crt_get_dn(cert, name, &namesize); + if (ret != 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Failed to get certificate distinguished name: %s"), + gnutls_strerror(ret)); + gnutls_x509_crt_deinit(cert); + goto authfail; + } + + if (!virNetTLSContextCheckDN(ctxt, name)) { + /* This is the most common error: make it informative. */ + virNetError(VIR_ERR_SYSTEM_ERROR, "%s", + _("Client's Distinguished Name is not on the list " + "of allowed clients (tls_allowed_dn_list). Use " + "'certtool -i --infile clientcert.pem' to view the" + "Distinguished Name field in the client certificate," + "or run this daemon with --verbose option.")); + gnutls_x509_crt_deinit(cert); + goto authdeny; + } + + if (sess->hostname && + !gnutls_x509_crt_check_hostname(cert, sess->hostname)) { + virNetError(VIR_ERR_RPC, + _("Certificate's owner does not match the hostname (%s)"), + sess->hostname); + gnutls_x509_crt_deinit(cert); + goto authdeny; + } + } + } + +#if 0 + PROBE(CLIENT_TLS_ALLOW, "fd=%d, name=%s", client->fd, (char *)name); +#endif + return 0; + +authdeny: +#if 0 + PROBE(CLIENT_TLS_DENY, "fd=%d, name=%s", client->fd, (char *)name); +#endif + return -1; + +authfail: +#if 0 + PROBE(CLIENT_TLS_FAIL, "fd=%d", client->fd); +#endif + return -1; +} + +int virNetTLSContextCheckCertificate(virNetTLSContextPtr ctxt, + virNetTLSSessionPtr sess) { + if (virNetTLSContextValidCertificate(ctxt, sess) < 0) { + if (ctxt->requireValidCert) { + virNetError(VIR_ERR_AUTH_FAILED, "%s", + _("Failed to verify peer's certificate")); + return -1; + } + VIR_INFO0(_("Ignoring bad certificate at user request")); + } + return 0; +} + +void virNetTLSContextFree(virNetTLSContextPtr ctxt) +{ + if (!ctxt) + return; + + ctxt->refs--; + if (ctxt->refs > 0) + return; + + gnutls_dh_params_deinit(ctxt->dhParams); + gnutls_certificate_free_credentials(ctxt->x509cred); + VIR_FREE(ctxt); +} + + + +static ssize_t +virNetTLSSessionPush(void *opaque, const void *buf, size_t len) +{ + virNetTLSSessionPtr sess = opaque; + return sess->writeFunc(buf, len, sess->opaque); +} + + +static ssize_t +virNetTLSSessionPull(void *opaque, void *buf, size_t len) +{ + virNetTLSSessionPtr sess = opaque; + return sess->readFunc(buf, len, sess->opaque); +} + + +virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt, + const char *hostname, + virNetTLSSessionWriteFunc writeFunc, + virNetTLSSessionReadFunc readFunc, + void *opaque) +{ + virNetTLSSessionPtr sess; + int err; + static const int cert_type_priority[] = { GNUTLS_CRT_X509, 0 }; + + if (VIR_ALLOC(sess) < 0) { + virReportOOMError(); + return NULL; + } + + sess->refs = 1; + sess->writeFunc = writeFunc; + sess->readFunc = readFunc; + sess->opaque = opaque; + if (!(sess->hostname = strdup(hostname))) { + virReportOOMError(); + goto error; + } + + if ((err = gnutls_init(&sess->session, + ctxt->isServer ? GNUTLS_SERVER : GNUTLS_CLIENT)) != 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Failed to initialize TLS session: %s"), + gnutls_strerror(err)); + goto error; + } + + /* avoid calling all the priority functions, since the defaults + * are adequate. + */ + if ((err = gnutls_set_default_priority(sess->session)) != 0 || + (err = gnutls_certificate_type_set_priority(sess->session, + cert_type_priority))) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Failed to set TLS session priority %s"), + gnutls_strerror(err)); + goto error; + } + + if ((err = gnutls_credentials_set(sess->session, + GNUTLS_CRD_CERTIFICATE, + ctxt->x509cred)) != 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Failed set TLS x509 credentials: %s"), + gnutls_strerror(err)); + goto error; + } + + /* request client certificate if any. + */ + if (ctxt->isServer) { + gnutls_certificate_server_set_request(sess->session, GNUTLS_CERT_REQUEST); + + gnutls_dh_set_prime_bits(sess->session, DH_BITS); + } + + gnutls_transport_set_ptr(sess->session, sess); + gnutls_transport_set_push_function(sess->session, + virNetTLSSessionPush); + gnutls_transport_set_pull_function(sess->session, + virNetTLSSessionPull); + + return sess; + +error: + virNetTLSSessionFree(sess); + return NULL; +} + + +void virNetTLSSessionRef(virNetTLSSessionPtr sess) +{ + sess->refs++; +} + + +ssize_t virNetTLSSessionWrite(virNetTLSSessionPtr sess, + const char *buf, size_t len) +{ + int ret; + ret = gnutls_record_send(sess->session, buf, len); + + switch (ret) { + case GNUTLS_E_AGAIN: + errno = EAGAIN; + break; + case GNUTLS_E_INTERRUPTED: + errno = EINTR; + break; + case 0: + break; + default: + errno = EIO; + } + + return ret >= 0 ? ret : -1; +} + +ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess, + char *buf, size_t len) +{ + int ret; + + ret = gnutls_record_recv(sess->session, buf, len); + + switch (ret) { + case GNUTLS_E_AGAIN: + errno = EAGAIN; + break; + case GNUTLS_E_INTERRUPTED: + errno = EINTR; + break; + case 0: + break; + default: + errno = EIO; + } + + return ret >= 0 ? ret : -1; +} + +int virNetTLSSessionHandshake(virNetTLSSessionPtr sess) +{ + int ret = gnutls_handshake(sess->session); + + if (ret == 0) + return 0; + if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) + return 1; + + virNetError(VIR_ERR_AUTH_FAILED, + _("TLS handshake failed %s"), + gnutls_strerror (ret)); + return -1; +} + +int virNetTLSSessionHandshakeDirection(virNetTLSSessionPtr sess) +{ + if (gnutls_record_get_direction (sess->session) == 0) + return 0; + else + return 1; +} + +int virNetTLSSessionGetKeySize(virNetTLSSessionPtr sess) +{ + gnutls_cipher_algorithm_t cipher; + int ssf; + + cipher = gnutls_cipher_get(sess->session); + if (!(ssf = gnutls_cipher_get_key_size(cipher))) { + virNetError(VIR_ERR_INTERNAL_ERROR, "%s", + _("invalid cipher size for TLS session")); + return -1; + } + + return ssf; +} + + +void virNetTLSSessionFree(virNetTLSSessionPtr sess) +{ + if (!sess) + return; + + sess->refs--; + if (sess->refs > 0) + return; + + VIR_FREE(sess->hostname); + gnutls_deinit(sess->session); + VIR_FREE(sess); +} diff --git a/src/rpc/virnettlscontext.h b/src/rpc/virnettlscontext.h new file mode 100644 index 0000000..b93b379 --- /dev/null +++ b/src/rpc/virnettlscontext.h @@ -0,0 +1,63 @@ + + +#ifndef __VIR_NET_TLS_CONTEXT_H__ +#define __VIR_NET_TLS_CONTEXT_H__ + +#include <stdbool.h> +#include <sys/types.h> + +typedef struct _virNetTLSContext virNetTLSContext; +typedef virNetTLSContext *virNetTLSContextPtr; + +typedef struct _virNetTLSSession virNetTLSSession; +typedef virNetTLSSession *virNetTLSSessionPtr; + + +virNetTLSContextPtr virNetTLSContextNewServer(const char *ca_file, + const char *crl_file, + const char *cert_file, + const char *key_file, + const char *const*x509dnWhitelist, + bool requireValidCert); + +virNetTLSContextPtr virNetTLSContextNewClient(const char *ca_file, + const char *cert_file, + const char *key_file, + bool requireValidCert); + +void virNetTLSContextRef(virNetTLSContextPtr ctxt); + +int virNetTLSContextCheckCertificate(virNetTLSContextPtr ctxt, + virNetTLSSessionPtr sess); + +void virNetTLSContextFree(virNetTLSContextPtr ctxt); + + +typedef ssize_t (*virNetTLSSessionWriteFunc)(const char *buf, size_t len, + void *opaque); +typedef ssize_t (*virNetTLSSessionReadFunc)(char *buf, size_t len, + void *opaque); + +virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt, + const char *hostname, + virNetTLSSessionWriteFunc writeFunc, + virNetTLSSessionReadFunc readFunc, + void *opaque); + +void virNetTLSSessionRef(virNetTLSSessionPtr sess); + +ssize_t virNetTLSSessionWrite(virNetTLSSessionPtr sess, + const char *buf, size_t len); +ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess, + char *buf, size_t len); + +int virNetTLSSessionHandshake(virNetTLSSessionPtr sess); + +int virNetTLSSessionHandshakeDirection(virNetTLSSessionPtr sess); + +int virNetTLSSessionGetKeySize(virNetTLSSessionPtr sess); + +void virNetTLSSessionFree(virNetTLSSessionPtr sess); + + +#endif -- 1.7.2.3 -- libvir-list mailing list libvir-list@xxxxxxxxxx https://www.redhat.com/mailman/listinfo/libvir-list