To facilitate creation of new daemons providing XDR RPC services, pull alot of the libvirtd daemon code into a set of reusable objects. * virNetServer: A server contains one or more services which accept incoming clients. It maintains the list of active clients. It has a list of RPC programs which can be used by clients. When clients produce a complete RPC message, the server passes this onto the corresponding program for handling, and queues any response back with the client. * virNetServerClient: Encapsulates a single client connection. All I/O for the client is handled, reading & writing RPC messages. Also contains the SASL/TLS code, but this will eventually move into the virNetSocket object * virNetServerProgram: Handles processing and dispatch of RPC method calls for a single RPC (program,version). Multiple programs can be registered with the server. * virNetServerService: Encapsulates socket(s) listening for new connections. Each service listens on a single host/port, but may have multiple sockets if on a dual IPv4/6 host. Each new daemon now merely has to define the list of RPC procedures & their handlers. It does not need to deal with any network related functionality at all. --- src/Makefile.am | 18 +- src/rpc/virnetserver.c | 654 +++++++++++++++++++++++++++ src/rpc/virnetserver.h | 74 ++++ src/rpc/virnetserverclient.c | 974 +++++++++++++++++++++++++++++++++++++++++ src/rpc/virnetserverclient.h | 40 ++ src/rpc/virnetservermessage.h | 20 + src/rpc/virnetserverprogram.c | 437 ++++++++++++++++++ src/rpc/virnetserverprogram.h | 76 ++++ src/rpc/virnetserverservice.c | 208 +++++++++ src/rpc/virnetserverservice.h | 32 ++ 10 files changed, 2532 insertions(+), 1 deletions(-) create mode 100644 src/rpc/virnetserver.c create mode 100644 src/rpc/virnetserver.h create mode 100644 src/rpc/virnetserverclient.c create mode 100644 src/rpc/virnetserverclient.h create mode 100644 src/rpc/virnetservermessage.h create mode 100644 src/rpc/virnetserverprogram.c create mode 100644 src/rpc/virnetserverprogram.h create mode 100644 src/rpc/virnetserverservice.c create mode 100644 src/rpc/virnetserverservice.h diff --git a/src/Makefile.am b/src/Makefile.am index 613ff0a..e78a0af 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -1117,7 +1117,7 @@ libvirt_qemu_la_LIBADD = libvirt.la $(CYGWIN_EXTRA_LIBADD) EXTRA_DIST += $(LIBVIRT_QEMU_SYMBOL_FILE) -noinst_LTLIBRARIES += libvirt-net-rpc.la +noinst_LTLIBRARIES += libvirt-net-rpc.la libvirt-net-rpc-server.la libvirt_net_rpc_la_SOURCES = \ ../daemon/event.c \ @@ -1138,6 +1138,22 @@ libvirt_net_rpc_la_LDFLAGS = \ libvirt_net_rpc_la_LIBADD = \ $(CYGWIN_EXTRA_LIBADD) +libvirt_net_server_la_SOURCES = \ + rpc/virnetservermessage.h \ + rpc/virnetserverprogram.h rpc/virnetserverprogram.c \ + rpc/virnetserverservice.h rpc/virnetserverservice.c \ + rpc/virnetserverclient.h rpc/virnetserverclient.c \ + rpc/virnetserver.h rpc/virnetserver.c +libvirt_net_server_la_CFLAGS = \ + $(AM_CFLAGS) +libvirt_net_server_la_LDFLAGS = \ + $(AM_LDFLAGS) \ + $(CYGWIN_EXTRA_LDFLAGS) \ + $(MINGW_EXTRA_LDFLAGS)l +libvirt_net_server_la_LIBADD = \ + $(CYGWIN_EXTRA_LIBADD) + + libexec_PROGRAMS = if WITH_STORAGE_DISK diff --git a/src/rpc/virnetserver.c b/src/rpc/virnetserver.c new file mode 100644 index 0000000..0384bb9 --- /dev/null +++ b/src/rpc/virnetserver.c @@ -0,0 +1,654 @@ +/* + * libvirtd.h: daemon data structure definitions + * + * Copyright (C) 2006-2010 Red Hat, Inc. + * Copyright (C) 2006 Daniel P. Berrange + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@xxxxxxxxxx> + */ + +#include <config.h> + +#include <unistd.h> +#include <string.h> + +#include "virnetserver.h" +#include "logging.h" +#include "memory.h" +#include "virterror_internal.h" +#include "threads.h" +#include "threadpool.h" +#include "util.h" +#include "files.h" +#include "event.h" +#include "../daemon/event.h" + +#define VIR_FROM_THIS VIR_FROM_RPC + +#define virNetError(code, ...) \ + virReportErrorHelper(NULL, VIR_FROM_RPC, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + +typedef struct _virNetServerSignal virNetServerSignal; +typedef virNetServerSignal *virNetServerSignalPtr; + +struct _virNetServerSignal { + struct sigaction oldaction; + int signum; + virNetServerSignalFunc func; + void *opaque; +}; + +typedef struct _virNetServerJob virNetServerJob; +typedef virNetServerJob *virNetServerJobPtr; + +struct _virNetServerJob { + virNetServerClientPtr client; + virNetServerMessagePtr msg; +}; + +struct _virNetServer { + int refs; + + virMutex lock; + + virThreadPoolPtr workers; + + size_t nsignals; + virNetServerSignalPtr *signals; + int sigread; + int sigwrite; + int sigwatch; + + size_t nservices; + virNetServerServicePtr *services; + + size_t nprograms; + virNetServerProgramPtr *programs; + + size_t nclients; + size_t nclients_alloc; + size_t nclients_max; + virNetServerClientPtr *clients; + + unsigned int quit :1; + + virNetTLSContextPtr tls; + + unsigned int autoShutdownTimeout; + virNetServerAutoShutdownFunc autoShutdownFunc; + void *autoShutdownOpaque; +}; + + +static void virNetServerLock(virNetServerPtr srv) +{ + virMutexLock(&srv->lock); +} + +static void virNetServerUnlock(virNetServerPtr srv) +{ + virMutexUnlock(&srv->lock); +} + + +static void virNetServerHandleJob(void *jobOpaque, void *opaque) +{ + virNetServerPtr srv = opaque; + virNetServerJobPtr job = jobOpaque; + virNetServerProgramPtr prog = NULL; + int i; + + virNetServerLock(srv); + VIR_DEBUG("server=%p client=%p message=%p", + srv, job->client, job->msg); + + for (i = 0 ; i < srv->nprograms ; i++) { + if (virNetServerProgramMatches(srv->programs[i], job->msg)) { + prog = srv->programs[i]; + break; + } + } + + if (!prog) { + VIR_DEBUG("Cannot find program %d version %d", + job->msg->msg.header.prog, + job->msg->msg.header.vers); + goto error; + } + + if (virNetServerProgramDispatch(prog, + srv, + job->client, + job->msg) < 0) { + job->msg = NULL; + goto error; + } + + VIR_FREE(job); + virNetServerUnlock(srv); + return; + +error: + if (job->msg) + virNetServerClientFinishMessage(job->client, job->msg); + virNetServerClientClose(job->client); + VIR_FREE(job); + virNetServerUnlock(srv); +} + + +static int virNetServerDispatchNewMessage(virNetServerClientPtr client, + virNetServerMessagePtr msg, + void *opaque) +{ + virNetServerPtr srv = opaque; + virNetServerJobPtr job; + + VIR_DEBUG("server=%p client=%p message=%p", + srv, client, msg); + + if (VIR_ALLOC(job) < 0) { + virNetServerClientFinishMessage(client, msg); + virNetServerClientClose(client); + virReportOOMError(); + return -1; + } + + job->client = client; + job->msg = msg; + + virNetServerLock(srv); + virThreadPoolSendJob(srv->workers, job); + virNetServerUnlock(srv); + + return 0; +} + + +static int virNetServerDispatchNewClient(virNetServerServicePtr svc ATTRIBUTE_UNUSED, + virNetServerClientPtr client, + void *opaque) +{ + virNetServerPtr srv = opaque; + + virNetServerLock(srv); + + if (srv->nclients >= srv->nclients_max) { + virNetError(VIR_ERR_RPC, + _("Too many active clients (%d), dropping connection from %s"), + (int)srv->nclients_max, virNetServerClientAddrString(client)); + goto error; + } + + if (virNetServerClientInit(client) < 0) + goto error; + + if (VIR_RESIZE_N(srv->clients, srv->nclients_alloc, + srv->nclients, 1) < 0) { + virReportOOMError(); + goto error; + } + srv->clients[srv->nclients++] = client; + virNetServerClientRef(client); + + virNetServerClientSetDispatcher(client, + virNetServerDispatchNewMessage, + srv); + + virNetServerUnlock(srv); + return 0; + +error: + virNetServerUnlock(srv); + return -1; +} + + +virNetServerPtr virNetServerNew(size_t min_workers, + size_t max_workers, + size_t max_clients) +{ + virNetServerPtr srv; + struct sigaction sig_action; + + if (VIR_ALLOC(srv) < 0) { + virReportOOMError(); + return NULL; + } + + srv->refs = 1; + + if (!(srv->workers = virThreadPoolNew(min_workers, max_workers, + virNetServerHandleJob, + srv))) + goto error; + + srv->nclients_max = max_clients; + srv->sigwrite = srv->sigread = -1; + + if (virMutexInit(&srv->lock) < 0) { + virNetError(VIR_ERR_INTERNAL_ERROR, "%s", + _("cannot initialize mutex")); + goto error; + } + + if (virEventInit() < 0) { + virNetError(VIR_ERR_INTERNAL_ERROR, "%s", + _("Failed to initialize event system")); + goto error; + } + + virEventRegisterImpl(virEventAddHandleImpl, + virEventUpdateHandleImpl, + virEventRemoveHandleImpl, + virEventAddTimeoutImpl, + virEventUpdateTimeoutImpl, + virEventRemoveTimeoutImpl); + + memset(&sig_action, 0, sizeof(sig_action)); + sig_action.sa_handler = SIG_IGN; + sigaction(SIGPIPE, &sig_action, NULL); + + return srv; + +error: + virNetServerFree(srv); + return NULL; +} + + +void virNetServerRef(virNetServerPtr srv) +{ + virNetServerLock(srv); + srv->refs++; + virNetServerUnlock(srv); +} + + +void virNetServerAutoShutdown(virNetServerPtr srv, + unsigned int timeout, + virNetServerAutoShutdownFunc func, + void *opaque) +{ + virNetServerLock(srv); + + srv->autoShutdownTimeout = timeout; + srv->autoShutdownFunc = func; + srv->autoShutdownOpaque = opaque; + + virNetServerUnlock(srv); +} + + +static sig_atomic_t sigErrors = 0; +static int sigLastErrno = 0; +static int sigWrite = -1; + +static void virNetServerSignalHandler(int sig, siginfo_t * siginfo, + void* context ATTRIBUTE_UNUSED) +{ + int origerrno; + int r; + + /* set the sig num in the struct */ + siginfo->si_signo = sig; + + origerrno = errno; + r = safewrite(sigWrite, siginfo, sizeof(*siginfo)); + if (r == -1) { + sigErrors++; + sigLastErrno = errno; + } + errno = origerrno; +} + +static void +virNetServerSignalEvent(int watch, + int fd ATTRIBUTE_UNUSED, + int events ATTRIBUTE_UNUSED, + void *opaque) { + virNetServerPtr srv = opaque; + siginfo_t siginfo; + int i; + + virNetServerLock(srv); + + if (saferead(srv->sigread, &siginfo, sizeof(siginfo)) != sizeof(siginfo)) { + virReportSystemError(errno, "%s", + _("Failed to read from signal pipe")); + virEventRemoveHandle(watch); + srv->sigwatch = -1; + goto cleanup; + } + + for (i = 0 ; i < srv->nsignals ; i++) { + if (siginfo.si_signo == srv->signals[i]->signum) { + virNetServerSignalFunc func = srv->signals[i]->func; + void *funcopaque = srv->signals[i]->opaque; + virNetServerUnlock(srv); + func(srv, &siginfo, funcopaque); + return; + } + } + + virNetError(VIR_ERR_INTERNAL_ERROR, + _("Unexpected signal received: %d"), siginfo.si_signo); + +cleanup: + virNetServerUnlock(srv); +} + +static int virNetServerSignalSetup(virNetServerPtr srv) +{ + int fds[2]; + + if (srv->sigwrite != -1) + return 0; + + if (pipe(fds) < 0) { + virReportSystemError(errno, "%s", + _("Unable to create signal pipe")); + return -1; + } + + if (virSetNonBlock(fds[0]) < 0 || + virSetNonBlock(fds[1]) < 0 || + virSetCloseExec(fds[0]) < 0 || + virSetCloseExec(fds[1]) < 0) { + virReportSystemError(errno, "%s", + _("Failed to setup pipe flags")); + goto error; + } + + if ((srv->sigwatch = virEventAddHandle(fds[0], + VIR_EVENT_HANDLE_READABLE, + virNetServerSignalEvent, + srv, NULL)) < 0) { + virNetError(VIR_ERR_INTERNAL_ERROR, "%s", + _("Failed to add signal handle watch")); + goto error; + } + + srv->sigread = fds[0]; + srv->sigwrite = fds[1]; + sigWrite = fds[1]; + + return 0; + +error: + VIR_FORCE_CLOSE(fds[0]); + VIR_FORCE_CLOSE(fds[1]); + return -1; +} + +int virNetServerAddSignalHandler(virNetServerPtr srv, + int signum, + virNetServerSignalFunc func, + void *opaque) +{ + virNetServerSignalPtr sigdata; + struct sigaction sig_action; + + virNetServerLock(srv); + + if (virNetServerSignalSetup(srv) < 0) + goto error; + + if (VIR_EXPAND_N(srv->signals, srv->nsignals, 1) < 0) + goto no_memory; + + if (VIR_ALLOC(sigdata) < 0) + goto no_memory; + + sigdata->signum = signum; + sigdata->func = func; + sigdata->opaque = opaque; + + memset(&sig_action, 0, sizeof(sig_action)); + sig_action.sa_sigaction = virNetServerSignalHandler; + sig_action.sa_flags = SA_SIGINFO; + sigemptyset(&sig_action.sa_mask); + + sigaction(signum, &sig_action, &sigdata->oldaction); + + srv->signals[srv->nsignals-1] = sigdata; + + virNetServerUnlock(srv); + return 0; + +no_memory: + virReportOOMError(); +error: + VIR_FREE(sigdata); + virNetServerUnlock(srv); + return -1; +} + + + +int virNetServerAddService(virNetServerPtr srv, + virNetServerServicePtr svc) +{ + virNetServerLock(srv); + + if (VIR_EXPAND_N(srv->services, srv->nservices, 1) < 0) + goto no_memory; + + srv->services[srv->nservices-1] = svc; + virNetServerServiceRef(svc); + + virNetServerServiceSetDispatcher(svc, + virNetServerDispatchNewClient, + srv); + + virNetServerUnlock(srv); + return 0; + +no_memory: + virReportOOMError(); + virNetServerUnlock(srv); + return -1; +} + +int virNetServerAddProgram(virNetServerPtr srv, + virNetServerProgramPtr prog) +{ + virNetServerLock(srv); + + if (VIR_EXPAND_N(srv->programs, srv->nprograms, 1) < 0) + goto no_memory; + + srv->programs[srv->nprograms-1] = prog; + virNetServerProgramRef(prog); + + virNetServerUnlock(srv); + return 0; + +no_memory: + virReportOOMError(); + virNetServerUnlock(srv); + return -1; +} + +int virNetServerSetTLSContext(virNetServerPtr srv, + virNetTLSContextPtr tls) +{ + srv->tls = tls; + virNetTLSContextRef(tls); + return 0; +} + + +static void virNetServerAutoShutdownTimer(int timerid ATTRIBUTE_UNUSED, + void *opaque) { + virNetServerPtr srv = opaque; + + virNetServerLock(srv); + + if (srv->autoShutdownFunc(srv, srv->autoShutdownOpaque)) { + VIR_DEBUG0("Automatic shutdown triggered"); + srv->quit = 1; + } + + virNetServerUnlock(srv); +} + + +void virNetServerUpdateServices(virNetServerPtr srv, + bool enabled) +{ + int i; + + virNetServerLock(srv); + for (i = 0 ; i < srv->nservices ; i++) + virNetServerServiceToggle(srv->services[i], enabled); + + virNetServerUnlock(srv); +} + + +void virNetServerRun(virNetServerPtr srv) +{ + int timerid = -1; + int timerActive = 0; + + virNetServerLock(srv); + + if (srv->autoShutdownTimeout && + (timerid = virEventAddTimeout(-1, + virNetServerAutoShutdownTimer, + srv, NULL)) < 0) { + virNetError(VIR_ERR_INTERNAL_ERROR, "%s", + _("Failed to register shutdown timeout")); + goto cleanup; + } + + while (!srv->quit) { + /* A shutdown timeout is specified, so check + * if any drivers have active state, if not + * shutdown after timeout seconds + */ + if (srv->autoShutdownTimeout) { + if (timerActive) { + if (srv->clients) { + DEBUG("Deactivating shutdown timer %d", timerid); + virEventUpdateTimeout(timerid, -1); + timerActive = 0; + } + } else { + if (!srv->clients) { + DEBUG("Activating shutdown timer %d", timerid); + virEventUpdateTimeout(timerid, + srv->autoShutdownTimeout * 1000); + timerActive = 1; + } + } + } + + virNetServerUnlock(srv); + if (virEventRunOnce() < 0) { + virNetServerLock(srv); + DEBUG0("Loop iteration error, exiting"); + break; + } + virNetServerLock(srv); + +#if 0 + reprocess: + for (i = 0 ; i < srv->nclients ; i++) { + int inactive; + virMutexLock(&srv->clients[i]->lock); + inactive = srv->clients[i]->fd == -1 + && srv->clients[i]->refs == 0; + virMutexUnlock(&srv->clients[i]->lock); + if (inactive) { + qemudFreeClient(srv->clients[i]); + srv->nclients--; + if (i < srv->nclients) + memmove(srv->clients + i, + srv->clients + i + 1, + sizeof (*srv->clients) * (srv->nclients - i)); + + VIR_SHRINK_N(srv->clients, srv->nclients, 0); + goto reprocess; + } + } +#endif + + } + +cleanup: + virNetServerUnlock(srv); +} + + +void virNetServerQuit(virNetServerPtr srv) +{ + virNetServerLock(srv); + + srv->quit = 1; + + virNetServerUnlock(srv); +} + +void virNetServerFree(virNetServerPtr srv) +{ + int i; + + if (!srv) + return; + + virNetServerLock(srv); + srv->refs--; + if (srv->refs > 0) { + virNetServerUnlock(srv); + return; + } + + for (i = 0 ; i < srv->nservices ; i++) + virNetServerServiceToggle(srv->services[i], false); + + virThreadPoolFree(srv->workers); + + for (i = 0 ; i < srv->nsignals ; i++) { + sigaction(srv->signals[i]->signum, &srv->signals[i]->oldaction, NULL); + VIR_FREE(srv->signals[i]); + } + VIR_FREE(srv->signals); + VIR_FORCE_CLOSE(srv->sigread); + VIR_FORCE_CLOSE(srv->sigwrite); + if (srv->sigwatch > 0) + virEventRemoveHandle(srv->sigwatch); + + for (i = 0 ; i < srv->nservices ; i++) + virNetServerServiceFree(srv->services[i]); + VIR_FREE(srv->services); + + for (i = 0 ; i < srv->nprograms ; i++) + virNetServerProgramFree(srv->programs[i]); + VIR_FREE(srv->programs); + + for (i = 0 ; i < srv->nclients ; i++) + virNetServerClientFree(srv->clients[i]); + VIR_FREE(srv->clients); + + virNetServerUnlock(srv); + virMutexDestroy(&srv->lock); + VIR_FREE(srv); +} + diff --git a/src/rpc/virnetserver.h b/src/rpc/virnetserver.h new file mode 100644 index 0000000..df9e714 --- /dev/null +++ b/src/rpc/virnetserver.h @@ -0,0 +1,74 @@ +/* + * virnetserver.h: generic network RPC server + * + * Copyright (C) 2006-2010 Red Hat, Inc. + * Copyright (C) 2006 Daniel P. Berrange + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@xxxxxxxxxx> + */ + +#ifndef __VIR_NET_SERVER_H__ +#define __VIR_NET_SERVER_H__ + +#include <stdbool.h> +#include <signal.h> + +#include "virnettlscontext.h" +#include "virnetserverprogram.h" +#include "virnetserverclient.h" +#include "virnetserverservice.h" + +virNetServerPtr virNetServerNew(size_t min_workers, + size_t max_workers, + size_t max_clients); + +typedef int (*virNetServerAutoShutdownFunc)(virNetServerPtr srv, void *opaque); + +void virNetServerRef(virNetServerPtr srv); + +void virNetServerAutoShutdown(virNetServerPtr srv, + unsigned int timeout, + virNetServerAutoShutdownFunc func, + void *opaque); + +typedef void (*virNetServerSignalFunc)(virNetServerPtr srv, siginfo_t *info, void *opaque); + +int virNetServerAddSignalHandler(virNetServerPtr srv, + int signum, + virNetServerSignalFunc func, + void *opaque); + +int virNetServerAddService(virNetServerPtr srv, + virNetServerServicePtr svc); + +int virNetServerAddProgram(virNetServerPtr srv, + virNetServerProgramPtr prog); + +int virNetServerSetTLSContext(virNetServerPtr srv, + virNetTLSContextPtr tls); + +void virNetServerUpdateServices(virNetServerPtr srv, + bool enabled); + +void virNetServerRun(virNetServerPtr srv); + +void virNetServerQuit(virNetServerPtr srv); + +void virNetServerFree(virNetServerPtr srv); + + +#endif diff --git a/src/rpc/virnetserverclient.c b/src/rpc/virnetserverclient.c new file mode 100644 index 0000000..76d4b33 --- /dev/null +++ b/src/rpc/virnetserverclient.c @@ -0,0 +1,974 @@ + +#include <config.h> + +# if HAVE_SASL +# include <sasl/sasl.h> +# endif + +#include "virnetserverclient.h" + +#include "logging.h" +#include "virterror_internal.h" +#include "memory.h" +#include "threads.h" + +#define VIR_FROM_THIS VIR_FROM_RPC + +#define virNetError(code, ...) \ + virReportErrorHelper(NULL, VIR_FROM_RPC, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + +/* Allow for filtering of incoming messages to a custom + * dispatch processing queue, instead of client->dx. + */ + +typedef struct _virNetServerClientFilter virNetServerClientFilter; +typedef virNetServerClientFilter *virNetServerClientFilterPtr; + +typedef int (*virNetServerClientFilterFunc)(virNetServerClientPtr client, + virNetServerMessagePtr msg, + void *opaque); + +struct _virNetServerClientFilter { + virNetServerClientFilterFunc func; + void *opaque; + + virNetServerClientFilterPtr next; +}; + + +typedef struct _virNetServerClientStream virNetServerClientStream; +typedef virNetServerClientStream *virNetServerClientStreamPtr; + +struct _virNetServerClientStream { + void *opaque; + + int procedure; + int serial; + + unsigned int recvEOF : 1; + unsigned int closed : 1; + + virNetServerClientFilter filter; + + virNetServerMessagePtr rx; + int tx; + + virNetServerClientStreamPtr next; +}; + +#if HAVE_SASL +/* Whether we're passing reads & writes through a sasl SSF */ +enum virNetServerClientSSF { + VIR_NET_CLIENT_SSF_NONE = 0, + VIR_NET_CLIENT_SSF_READ = 1, + VIR_NET_CLIENT_SSF_WRITE = 2, +}; +#endif + +struct _virNetServerClient +{ + int refs; + + virMutex lock; + virNetSocketPtr sock; + int auth; + virNetTLSContextPtr tlsCtxt; + virNetTLSSessionPtr tls; + unsigned int handshake : 1; + unsigned int closing : 1; + +# if HAVE_SASL + sasl_conn_t *saslconn; + int saslSSF; + const char *saslDecoded; + unsigned int saslDecodedLength; + unsigned int saslDecodedOffset; + const char *saslEncoded; + unsigned int saslEncodedLength; + unsigned int saslEncodedOffset; + char *saslUsername; +# endif + + /* Count of messages in the 'tx' queue, + * and the server worker pool queue + * ie RPC calls in progress. Does not count + * async events which are not used for + * throttling calculations */ + size_t nrequests; + size_t nrequests_max; + /* Zero or one messages being received. Zero if + * nrequests >= max_clients and throttling */ + virNetServerMessagePtr rx; + /* Zero or many messages waiting for transmit + * back to client, including async events */ + virNetServerMessagePtr tx; + + /* Filters to capture messages that would otherwise + * end up on the 'dx' queue */ + virNetServerClientFilterPtr filters; + + /* Data streams */ + virNetServerClientStreamPtr streams; + + virNetServerClientDispatchFunc dispatchFunc; + void *dispatchOpaque; +}; + + +static void virNetServerClientDispatchEvent(virNetSocketPtr sock, int events, void *opaque); +static void virNetServerClientFinishMessageLocked(virNetServerClientPtr client, + virNetServerMessagePtr msg); + +static void virNetServerClientLock(virNetServerClientPtr client) +{ + virMutexLock(&client->lock); +} + +static void virNetServerClientUnlock(virNetServerClientPtr client) +{ + virMutexUnlock(&client->lock); +} + +/* + * @client: a locked client object + */ +static int +virNetServerClientCalculateHandleMode(virNetServerClientPtr client) { + int mode = 0; + + if (client->handshake) { + if (virNetTLSSessionHandshakeDirection(client->tls) == 0) + mode |= VIR_EVENT_HANDLE_READABLE; + else + mode |= VIR_EVENT_HANDLE_WRITABLE; + } else { + /* If there is a message on the rx queue then + * we're wanting more input */ + if (client->rx) + mode |= VIR_EVENT_HANDLE_READABLE; + + /* If there are one or more messages to send back to client, + then monitor for writability on socket */ + if (client->tx) + mode |= VIR_EVENT_HANDLE_WRITABLE; + } + + return mode; +} + +/* + * @server: a locked or unlocked server object + * @client: a locked client object + */ +static int virNetServerClientRegisterEvent(virNetServerClientPtr client) +{ + int mode = virNetServerClientCalculateHandleMode(client); + + VIR_DEBUG("Registering client event callback %d", mode); + if (virNetSocketAddIOCallback(client->sock, + mode, + virNetServerClientDispatchEvent, + client) < 0) + return -1; + + return 0; +} + +/* + * @client: a locked client object + */ +static void virNetServerClientUpdateEvent(virNetServerClientPtr client) +{ + int mode; + + if (!client->sock) + return; + + mode = virNetServerClientCalculateHandleMode(client); + + virNetSocketUpdateIOCallback(client->sock, mode); +} + + +static void virNetServerClientMessageQueuePush(virNetServerMessagePtr *queue, + virNetServerMessagePtr msg) +{ + virNetServerMessagePtr tmp = *queue; + + if (tmp) { + while (tmp->next) + tmp = tmp->next; + tmp->next = msg; + } else { + *queue = msg; + } +} + +static virNetServerMessagePtr +virNetServerClientMessageQueueServe(virNetServerMessagePtr *queue) +{ + virNetServerMessagePtr tmp = *queue; + + if (tmp) { + *queue = tmp->next; + tmp->next = NULL; + } + + return tmp; +} + + +static ssize_t virNetServerClientTLSWriteFunc(const char *buf, size_t len, + void *opaque) +{ + virNetServerClientPtr client = opaque; + + return virNetSocketWrite(client->sock, buf, len); +} + +static ssize_t virNetServerClientTLSReadFunc(char *buf, size_t len, + void *opaque) +{ + virNetServerClientPtr client = opaque; + + return virNetSocketRead(client->sock, buf, len); +} + + +/* Check the client's access. */ +static int +virNetServerClientCheckAccess(virNetServerClientPtr client) +{ + virNetServerMessagePtr confirm; + + /* Verify client certificate. */ + if (virNetTLSContextCheckCertificate(client->tlsCtxt, client->tls) < 0) + return -1; + + if (client->tx) { + VIR_INFO0(_("client had unexpected data pending tx after access check")); + return -1; + } + + if (VIR_ALLOC(confirm) < 0) { + virReportOOMError(); + return -1; + } + + /* Checks have succeeded. Write a '\1' byte back to the client to + * indicate this (otherwise the socket is abruptly closed). + * (NB. The '\1' byte is sent in an encrypted record). + */ + confirm->async = 1; + confirm->msg.bufferLength = 1; + confirm->msg.bufferOffset = 0; + confirm->msg.buffer[0] = '\1'; + + client->tx = confirm; + + return 0; +} + + +virNetServerClientPtr virNetServerClientNew(virNetSocketPtr sock, + int auth, + virNetTLSContextPtr tls) +{ + virNetServerClientPtr client; + + VIR_DEBUG("sock=%p auth=%d tls=%p", sock, auth, tls); + + if (VIR_ALLOC(client) < 0) { + virReportOOMError(); + return NULL; + } + + if (virMutexInit(&client->lock) < 0) + goto error; + + client->refs = 1; + client->sock = sock; + client->auth = auth; + client->tlsCtxt = tls; + client->nrequests_max = 10; /* XXX */ + + virNetTLSContextRef(tls); + + /* Prepare one for packet receive */ + if (VIR_ALLOC(client->rx) < 0) + goto error; + client->rx->msg.bufferLength = VIR_NET_MESSAGE_LEN_MAX; + + VIR_DEBUG("client=%p", client); + + return client; + +error: + /* XXX ref counting is better than this */ + client->sock = NULL; /* Caller owns 'sock' upon failure */ + virNetServerClientFree(client); + return NULL; +} + +void virNetServerClientRef(virNetServerClientPtr client) +{ + virNetServerClientLock(client); + client->refs++; + virNetServerClientUnlock(client); +} + + +void virNetServerClientSetDispatcher(virNetServerClientPtr client, + virNetServerClientDispatchFunc func, + void *opaque) +{ + virNetServerClientLock(client); + client->dispatchFunc = func; + client->dispatchOpaque = opaque; + virNetServerClientUnlock(client); +} + + +const char *virNetServerClientAddrString(virNetServerClientPtr client) +{ + return virNetSocketRemoteAddrString(client->sock); +} + + +void virNetServerClientFree(virNetServerClientPtr client) +{ + VIR_DEBUG("client=%p", client); + + if (!client) + return; + + virNetServerClientLock(client); + + client->refs--; + if (client->refs > 0) { + virNetServerClientUnlock(client); + return; + } + + while (client->rx) { + virNetServerMessagePtr msg + = virNetServerClientMessageQueueServe(&client->rx); + VIR_FREE(msg); + } + while (client->tx) { + virNetServerMessagePtr msg + = virNetServerClientMessageQueueServe(&client->tx); + VIR_FREE(msg); + } + + virNetTLSSessionFree(client->tls); + virNetTLSContextFree(client->tlsCtxt); + virNetSocketFree(client->sock); + virNetServerClientUnlock(client); + virMutexDestroy(&client->lock); + VIR_FREE(client); +} + + +/* + * You must hold lock for the client + * + * We don't free stuff here, merely disconnect the client's + * network socket & resources. + * + * Full free of the client is done later in a safe point + * where it can be guaranteed it is no longer in use + */ +static void virNetServerClientCloseLocked(virNetServerClientPtr client) +{ + /* Do now, even though we don't close the socket + * until end, to ensure we don't get invoked + * again due to tls shutdown */ + if (client->sock) + virNetSocketRemoveIOCallback(client->sock); + +#if HAVE_SASL + if (client->saslconn) { + sasl_dispose(&client->saslconn); + client->saslconn = NULL; + } + VIR_FREE(client->saslUsername); +#endif + if (client->tls) { + virNetTLSSessionFree(client->tls); + client->tls = NULL; + } + if (client->sock) { + virNetSocketFree(client->sock); + client->sock = NULL; + } + +} + + +/* Client must be unlocked */ +void virNetServerClientClose(virNetServerClientPtr client) +{ + virNetServerClientLock(client); + virNetServerClientCloseLocked(client); + virNetServerClientUnlock(client); +} + + +int virNetServerClientInit(virNetServerClientPtr client) +{ + virNetServerClientLock(client); + + if (!client->tlsCtxt) { + /* Plain socket, so prepare to read first message */ + if (virNetServerClientRegisterEvent(client) < 0) + goto error; + } else { + int ret; + + if (!(client->tls = virNetTLSSessionNew(client->tlsCtxt, + NULL, + virNetServerClientTLSWriteFunc, + virNetServerClientTLSReadFunc, + client))) + goto error; + + /* Begin the TLS handshake. */ + ret = virNetTLSSessionHandshake(client->tls); + if (ret == 0) { + client->handshake = 0; + /* Unlikely, but ... Next step is to check the certificate. */ + if (virNetServerClientCheckAccess(client) < 0) + goto error; + + /* Handshake & cert check OK, so prepare to read first message */ + if (virNetServerClientRegisterEvent(client) < 0) + goto error; + } else if (ret > 0) { + /* Most likely, need to do more handshake data */ + client->handshake = 1; + + if (virNetServerClientRegisterEvent(client) < 0) + goto error; + } else { +#if 0 + PROBE(CLIENT_TLS_FAIL, "fd=%d", client->fd); +#endif + goto error; + } + } + + virNetServerClientUnlock(client); + return 0; + +error: + virNetServerClientUnlock(client); + return -1; +} + + + +/* + * Read data into buffer using wire decoding (plain or TLS) + * + * Returns: + * -1 on error or EOF + * 0 on EAGAIN + * n number of bytes + */ +static ssize_t virNetServerClientReadBuf(virNetServerClientPtr client, + char *data, ssize_t len) +{ + ssize_t ret; + + if (len < 0) { + virNetError(VIR_ERR_RPC, + _("unexpected negative length request %lld"), + (long long int) len); + virNetServerClientCloseLocked(client); + return -1; + } + + /*virNetServerClientDebug ("virNetServerClientRead: len = %d", len);*/ + + if (client->tls) + ret = virNetTLSSessionRead(client->tls, data, len); + else + ret = virNetSocketRead(client->sock, data, len); + + if (ret == -1 && (errno == EAGAIN || + errno == EINTR)) + return 0; + if (ret <= 0) { + if (ret != 0) + virReportSystemError(errno, "%s", + _("Unable to read from client")); + else + VIR_DEBUG0("EOF from client connection"); + virNetServerClientCloseLocked(client); + return -1; + } + + return ret; +} + +/* + * Read data into buffer without decoding + * + * Returns: + * -1 on error or EOF + * 0 on EAGAIN + * n number of bytes + */ +static ssize_t virNetServerClientReadPlain(virNetServerClientPtr client) +{ + ssize_t ret; + ret = virNetServerClientReadBuf(client, + client->rx->msg.buffer + client->rx->msg.bufferOffset, + client->rx->msg.bufferLength - client->rx->msg.bufferOffset); + if (ret <= 0) + return ret; /* -1 error, 0 eagain */ + + client->rx->msg.bufferOffset += ret; + return ret; +} + +#if HAVE_SASL +/* + * Read data into buffer decoding with SASL + * + * Returns: + * -1 on error or EOF + * 0 on EAGAIN + * n number of bytes + */ +static ssize_t virNetServerClientReadSASL(virNetServerClientPtr client) +{ + ssize_t got, want; + + /* We're doing a SSF data read, so now its times to ensure + * future writes are under SSF too. + * + * cf remoteSASLCheckSSF in remote.c + */ + client->saslSSF |= VIR_NET_CLIENT_SSF_WRITE; + + /* Need to read some more data off the wire */ + if (client->saslDecoded == NULL) { + int ret; + char encoded[8192]; + ssize_t encodedLen = sizeof(encoded); + encodedLen = virNetServerClientReadBuf(client, encoded, encodedLen); + + if (encodedLen <= 0) + return encodedLen; + + ret = sasl_decode(client->saslconn, encoded, encodedLen, + &client->saslDecoded, &client->saslDecodedLength); + if (ret != SASL_OK) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("failed to decode SASL data %s"), + sasl_errstring(ret, NULL, NULL)); + virNetServerClientCloseLocked(client); + return -1; + } + + client->saslDecodedOffset = 0; + } + + /* Some buffered decoded data to return now */ + got = client->saslDecodedLength - client->saslDecodedOffset; + want = client->rx->msg.bufferLength - client->rx->msg.bufferOffset; + + if (want > got) + want = got; + + memcpy(client->rx->msg.buffer + client->rx->msg.bufferOffset, + client->saslDecoded + client->saslDecodedOffset, want); + client->saslDecodedOffset += want; + client->rx->msg.bufferOffset += want; + + if (client->saslDecodedOffset == client->saslDecodedLength) { + client->saslDecoded = NULL; + client->saslDecodedOffset = client->saslDecodedLength = 0; + } + + return want; +} +#endif + +/* + * Read as much data off wire as possible till we fill our + * buffer, or would block on I/O + */ +static ssize_t virNetServerClientRead(virNetServerClientPtr client) +{ +#if HAVE_SASL + if (client->saslSSF & VIR_NET_CLIENT_SSF_READ) + return virNetServerClientReadSASL(client); + else +#endif + return virNetServerClientReadPlain(client); +} + + +/* + * Read data until we get a complete message to process + */ +static void virNetServerClientDispatchRead(virNetServerClientPtr client) +{ +readmore: + if (virNetServerClientRead(client) < 0) + return; /* Error */ + + if (client->rx->msg.bufferOffset < client->rx->msg.bufferLength) + return; /* Still not read enough */ + + /* Either done with length word header */ + if (client->rx->msg.bufferLength == VIR_NET_MESSAGE_LEN_MAX) { + if (virNetMessageDecodeLength(&client->rx->msg) < 0) + return; + + virNetServerClientUpdateEvent(client); + + /* Try and read payload immediately instead of going back + into poll() because chances are the data is already + waiting for us */ + goto readmore; + } else { + /* Grab the completed message */ + virNetServerMessagePtr msg = virNetServerClientMessageQueueServe(&client->rx); + virNetServerClientFilterPtr filter; + + /* Decode the header so we can use it for routing decisions */ + if (virNetMessageDecodeHeader(&msg->msg) < 0) { + VIR_FREE(msg); + virNetServerClientCloseLocked(client); + } + + /* Check if any filters match this message */ + filter = client->filters; + while (filter) { + int ret; + ret = (filter->func)(client, msg, filter->opaque); + if (ret == 1) { + msg = NULL; + break; + } else if (ret == -1) { + VIR_FREE(msg); + virNetServerClientCloseLocked(client); + return; + } + filter = filter->next; + } + + client->nrequests++; + + /* Possibly need to create another receive buffer */ + if ((client->nrequests < client->nrequests_max) && + VIR_ALLOC(client->rx) < 0) { + virNetServerClientCloseLocked(client); + } else { + if (client->rx) + client->rx->msg.bufferLength = VIR_NET_MESSAGE_LEN_MAX; + + virNetServerClientUpdateEvent(client); + + } + + /* Send it off for processing */ + if (msg) { + if (client->dispatchFunc) + client->dispatchFunc(client, msg, client->dispatchOpaque); + else + virNetServerClientFinishMessageLocked(client, msg); + } + } +} + + +/* + * Send a chunk of data using wire encoding (plain or TLS) + * + * Returns: + * -1 on error + * 0 on EAGAIN + * n number of bytes + */ +static ssize_t virNetServerClientWriteBuf(virNetServerClientPtr client, + const char *data, ssize_t len) +{ + ssize_t ret; + + if (len < 0) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("unexpected negative length request %lld"), + (long long int) len); + virNetServerClientCloseLocked(client); + return -1; + } + + if (client->tls) + ret = virNetTLSSessionWrite(client->tls, data, len); + else + ret = virNetSocketWrite(client->sock, data, len); + + if (ret == -1 && (errno == EAGAIN || + errno == EINTR)) + return 0; + if (ret == -1) { + virReportSystemError(errno, "%s", + _("Unable to write to client")); + virNetServerClientCloseLocked(client); + return -1; + } + return ret; +} + + +/* + * Send client->tx using no encoding + * + * Returns: + * -1 on error or EOF + * 0 on EAGAIN + * n number of bytes + */ +static int virNetServerClientWritePlain(virNetServerClientPtr client) +{ + int ret = virNetServerClientWriteBuf(client, + client->tx->msg.buffer + client->tx->msg.bufferOffset, + client->tx->msg.bufferLength - client->tx->msg.bufferOffset); + if (ret <= 0) + return ret; /* -1 error, 0 = egain */ + client->tx->msg.bufferOffset += ret; + return ret; +} + + +#if HAVE_SASL +/* + * Send client->tx using SASL encoding + * + * Returns: + * -1 on error + * 0 on EAGAIN + * n number of bytes + */ +static int virNetServerClientWriteSASL(virNetServerClientPtr client) +{ + int ret; + + /* Not got any pending encoded data, so we need to encode raw stuff */ + if (client->saslEncoded == NULL) { + ret = sasl_encode(client->saslconn, + client->tx->msg.buffer + client->tx->msg.bufferOffset, + client->tx->msg.bufferLength - client->tx->msg.bufferOffset, + &client->saslEncoded, + &client->saslEncodedLength); + + if (ret != SASL_OK) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("failed to encode SASL data %s"), + sasl_errstring(ret, NULL, NULL)); + virNetServerClientCloseLocked(client); + return -1; + } + + client->saslEncodedOffset = 0; + } + + /* Send some of the encoded stuff out on the wire */ + ret = virNetServerClientWriteBuf(client, + client->saslEncoded + client->saslEncodedOffset, + client->saslEncodedLength - client->saslEncodedOffset); + + if (ret <= 0) + return ret; /* -1 error, 0 == egain */ + + /* Note how much we sent */ + client->saslEncodedOffset += ret; + + /* Sent all encoded, so update raw buffer to indicate completion */ + if (client->saslEncodedOffset == client->saslEncodedLength) { + client->saslEncoded = NULL; + client->saslEncodedOffset = client->saslEncodedLength = 0; + + /* Mark as complete, so caller detects completion */ + client->tx->msg.bufferOffset = client->tx->msg.bufferLength; + } + + return ret; +} +#endif + +/* + * Send as much data in the client->tx as possible + * + * Returns: + * -1 on error or EOF + * 0 on EAGAIN + * n number of bytes + */ +static ssize_t virNetServerClientWrite(virNetServerClientPtr client) +{ +#if HAVE_SASL + if (client->saslSSF & VIR_NET_CLIENT_SSF_WRITE) + return virNetServerClientWriteSASL(client); + else +#endif + return virNetServerClientWritePlain(client); +} + + +/* + * Process all queued client->tx messages until + * we would block on I/O + */ +static void +virNetServerClientDispatchWrite(virNetServerClientPtr client) +{ + while (client->tx) { + ssize_t ret; + + ret = virNetServerClientWrite(client); + if (ret < 0) { + virNetServerClientCloseLocked(client); + return; + } + if (ret == 0) + return; /* Would block on write EAGAIN */ + + if (client->tx->msg.bufferOffset == client->tx->msg.bufferLength) { + virNetServerMessagePtr reply; + + /* Get finished reply from head of tx queue */ + reply = virNetServerClientMessageQueueServe(&client->tx); + + virNetServerClientFinishMessageLocked(client, reply); + + if (client->closing) + virNetServerClientCloseLocked(client); + } + } +} + +static void +virNetServerClientDispatchHandshake(virNetServerClientPtr client) +{ + int ret; + /* Continue the handshake. */ + ret = virNetTLSSessionHandshake(client->tls); + if (ret == 0) { + client->handshake = 0; + + /* Finished. Next step is to check the certificate. */ + if (virNetServerClientCheckAccess(client) < 0) + virNetServerClientCloseLocked(client); + else + virNetServerClientUpdateEvent(client); + } else if (ret > 0) { + /* Carry on waiting for more handshake. Update + the events just in case handshake data flow + direction has changed */ + virNetServerClientUpdateEvent (client); + } else { +#if 0 + PROBE(CLIENT_TLS_FAIL, "fd=%d", client->fd); +#endif + /* Fatal error in handshake */ + virNetServerClientCloseLocked(client); + } +} + +static void +virNetServerClientDispatchEvent(virNetSocketPtr sock, int events, void *opaque) +{ + virNetServerClientPtr client = opaque; + + virNetServerClientLock(client); + + if (client->sock != sock) { + virNetSocketRemoveIOCallback(sock); + virNetServerClientUnlock(client); + return; + } + + if (events & (VIR_EVENT_HANDLE_WRITABLE | + VIR_EVENT_HANDLE_READABLE)) { + if (client->handshake) { + virNetServerClientDispatchHandshake(client); + } else { + if (events & VIR_EVENT_HANDLE_WRITABLE) + virNetServerClientDispatchWrite(client); + if (events & VIR_EVENT_HANDLE_READABLE) + virNetServerClientDispatchRead(client); + } + } + + /* NB, will get HANGUP + READABLE at same time upon + * disconnect */ + if (events & (VIR_EVENT_HANDLE_ERROR | + VIR_EVENT_HANDLE_HANGUP)) + virNetServerClientCloseLocked(client); + + virNetServerClientUnlock(client); +} + + +void virNetServerClientSendMessage(virNetServerClientPtr client, + virNetServerMessagePtr msg) +{ + virNetServerClientLock(client); + + virNetServerClientMessageQueuePush(&client->tx, msg); + + virNetServerClientUpdateEvent(client); + + virNetServerClientUnlock(client); +} + +static void virNetServerClientFinishMessageLocked(virNetServerClientPtr client, + virNetServerMessagePtr msg) +{ + if (msg->streamTX) { +#if 0 + XXX + remoteStreamMessageFinished(client, msg); +#endif + } else if (!msg->async) + client->nrequests--; + + /* See if the recv queue is currently throttled */ + if (!client->rx && + client->nrequests < client->nrequests_max) { + /* Reset message record for next RX attempt */ + memset(msg, 0, sizeof(*msg)); + client->rx = msg; + /* Get ready to receive next message */ + client->rx->msg.bufferLength = VIR_NET_MESSAGE_LEN_MAX; + } else { + VIR_FREE(msg); + } + + virNetServerClientUpdateEvent(client); +} + +void virNetServerClientFinishMessage(virNetServerClientPtr client, + virNetServerMessagePtr msg) +{ + virNetServerClientLock(client); + virNetServerClientFinishMessageLocked(client, msg); + virNetServerClientUnlock(client); +} + +bool virNetServerClientNeedAuth(virNetServerClientPtr client) +{ + bool need = false; + virNetServerClientLock(client); + if (client->auth) + need = true; + virNetServerClientUnlock(client); + return need; +} diff --git a/src/rpc/virnetserverclient.h b/src/rpc/virnetserverclient.h new file mode 100644 index 0000000..5fe239c --- /dev/null +++ b/src/rpc/virnetserverclient.h @@ -0,0 +1,40 @@ + + +#ifndef __VIR_NET_SERVER_CLIENT_H__ +#define __VIR_NET_SERVER_CLIENT_H__ + +#include "virnetsocket.h" +#include "virnetserverprogram.h" +#include "virnettlscontext.h" + +typedef int (*virNetServerClientDispatchFunc)(virNetServerClientPtr client, + virNetServerMessagePtr msg, + void *opaque); + +virNetServerClientPtr virNetServerClientNew(virNetSocketPtr sock, + int auth, + virNetTLSContextPtr tls); + +void virNetServerClientRef(virNetServerClientPtr client); + +void virNetServerClientSetDispatcher(virNetServerClientPtr client, + virNetServerClientDispatchFunc func, + void *opaque); +void virNetServerClientClose(virNetServerClientPtr client); + +int virNetServerClientInit(virNetServerClientPtr client); + +const char *virNetServerClientAddrString(virNetServerClientPtr client); + +void virNetServerClientSendMessage(virNetServerClientPtr client, + virNetServerMessagePtr msg); + +void virNetServerClientFinishMessage(virNetServerClientPtr client, + virNetServerMessagePtr msg); + +bool virNetServerClientNeedAuth(virNetServerClientPtr client); + +void virNetServerClientFree(virNetServerClientPtr client); + + +#endif /* __VIR_NET_SERVER_CLIENT_H__ */ diff --git a/src/rpc/virnetservermessage.h b/src/rpc/virnetservermessage.h new file mode 100644 index 0000000..88624fe --- /dev/null +++ b/src/rpc/virnetservermessage.h @@ -0,0 +1,20 @@ + +#ifndef __VIR_NET_SERVER_MESSAGE_H__ +#define __VIR_NET_SERVER_MESSAGE_H__ + +#include "virnetmessage.h" + +typedef struct _virNetServerMessage virNetServerMessage; +typedef virNetServerMessage *virNetServerMessagePtr; + +struct _virNetServerMessage { + virNetMessage msg; + + unsigned int async : 1; + unsigned int streamTX : 1; + + virNetServerMessagePtr next; +}; + +#endif /* __VIR_NET_SERVER_MESSAGE_H__ */ + diff --git a/src/rpc/virnetserverprogram.c b/src/rpc/virnetserverprogram.c new file mode 100644 index 0000000..59328ac --- /dev/null +++ b/src/rpc/virnetserverprogram.c @@ -0,0 +1,437 @@ + + +#include <config.h> + +#include "virnetserverprogram.h" +#include "virnetserverclient.h" + +#include "memory.h" +#include "virterror_internal.h" +#include "logging.h" + +#define VIR_FROM_THIS VIR_FROM_RPC + +struct _virNetServerProgram { + int refs; + + unsigned program; + unsigned version; + virNetServerProgramProcPtr procs; + size_t nprocs; + virNetServerProgramErrorHanderPtr err; +}; + +virNetServerProgramPtr virNetServerProgramNew(unsigned program, + unsigned version, + virNetServerProgramProcPtr procs, + size_t nprocs, + virNetServerProgramErrorHanderPtr err) +{ + virNetServerProgramPtr prog; + + if (VIR_ALLOC(prog) < 0) { + virReportOOMError(); + return NULL; + } + + prog->refs = 1; + prog->program = program; + prog->version = version; + prog->procs = procs; + prog->nprocs = nprocs; + prog->err = err; + + return prog; +} + + +void virNetServerProgramRef(virNetServerProgramPtr prog) +{ + prog->refs++; +} + + +int virNetServerProgramMatches(virNetServerProgramPtr prog, + virNetServerMessagePtr msg) +{ + if (prog->program == msg->msg.header.prog && + prog->version == msg->msg.header.vers) + return 1; + return 0; +} + + +static virNetServerProgramProcPtr virNetServerProgramGetProc(virNetServerProgramPtr prog, + int procedure) +{ + if (procedure < 0) + return NULL; + if (procedure >= prog->nprocs) + return NULL; + + return &prog->procs[procedure]; +} + +static void +virNetServerProgramStringError(virNetServerProgramPtr prog, + void *rerr, + int code, + const char *str) +{ + prog->err->func(prog, rerr, code, str); +} + +static void +virNetServerProgramFormatError(virNetServerProgramPtr prog, + void *rerr, + int code, + const char *fmt, + ...) +{ + va_list args; + char msgbuf[1024]; + char *msg = msgbuf; + + va_start(args, fmt); + vsnprintf(msgbuf, sizeof msgbuf, fmt, args); + va_end(args); + + virNetServerProgramStringError(prog, rerr, code, msg); +} + + +#if 0 +static void +virNetServerProgramGenericError(virNetServerProgramPtr prog, + void *rerr) +{ + virNetServerProgramStringError(prog, rerr, + VIR_ERR_INTERNAL_ERROR, + _("function returned error code but did not set an error message")); +} +#endif + + +static void +virNetServerProgramOOMError(virNetServerProgramPtr prog, + void *rerr) +{ + virNetServerProgramStringError(prog, rerr, + VIR_ERR_NO_MEMORY, + _("out of memory")); +} + + +static int +remoteSerializeError(virNetServerProgramPtr prog, + virNetServerClientPtr client, + void *rerr, + int program, + int version, + int procedure, + int type, + int serial) +{ + virNetServerMessagePtr msg = NULL; + + DEBUG("prog=%d ver=%d proc=%d type=%d serial=%d, msg=%p", + program, version, procedure, type, serial, rerr); + + if (VIR_ALLOC(msg) < 0) { + virReportOOMError(); + return -1; + } + + + /* Return header. */ + msg->msg.header.prog = program; + msg->msg.header.vers = version; + msg->msg.header.proc = procedure; + msg->msg.header.type = type; + msg->msg.header.serial = serial; + msg->msg.header.status = VIR_NET_ERROR; + + if (virNetMessageEncodeHeader(&msg->msg) < 0) + goto error; + + if (virNetMessageEncodePayload(&msg->msg, prog->err->filter, rerr) < 0) + goto error; + + /* Put reply on end of tx queue to send out */ + virNetServerClientSendMessage(client, msg); + xdr_free(prog->err->filter, rerr); + + return 0; + +error: + VIR_WARN("Failed to serialize remote error '%p'", rerr); + VIR_FREE(msg); + xdr_free(prog->err->filter, rerr); + return -1; +} + + +/* + * @client: the client to send the error to + * @rerr: the error object to send + * @req: the message this error is in reply to + * + * Send an error message to the client + * + * Returns 0 if the error was sent, -1 upon fatal error + */ +static int +virNetServerProgramSerializeReplyError(virNetServerProgramPtr prog, + virNetServerClientPtr client, + void *rerr, + virNetMessageHeaderPtr req) +{ + /* + * For data streams, errors are sent back as data streams + * For method calls, errors are sent back as method replies + */ + return remoteSerializeError(prog, + client, + rerr, + req->prog, + req->vers, + req->proc, + req->type == VIR_NET_STREAM ? VIR_NET_STREAM : VIR_NET_REPLY, + req->serial); +} + +static int +virNetServerProgramDispatchCall(virNetServerProgramPtr prog, + void *rerr, + virNetServerPtr server, + virNetServerClientPtr client, + virNetServerMessagePtr msg); + +/* + * @server: the unlocked server object + * @client: the locked client object + * @msg: the complete incoming message packet, with header already decoded + * + * This function gets called from qemud when it pulls a incoming + * remote protocol message off the dispatch queue for processing. + * + * The @msg parameter must have had its header decoded already by + * calling remoteDecodeClientMessageHeader + * + * Returns 0 if the message was dispatched, -1 upon fatal error + */ +int virNetServerProgramDispatch(virNetServerProgramPtr prog, + virNetServerPtr server, + virNetServerClientPtr client, + virNetServerMessagePtr msg) +{ + int ret = -1; + char *rerr; + + if (VIR_ALLOC_N(rerr, prog->err->len) < 0) { + virReportOOMError(); + return -1; + } + + DEBUG("prog=%d ver=%d type=%d status=%d serial=%d proc=%d", + msg->msg.header.prog, msg->msg.header.vers, msg->msg.header.type, + msg->msg.header.status, msg->msg.header.serial, msg->msg.header.proc); + + /* Check version, etc. */ + if (msg->msg.header.prog != prog->program) { + virNetServerProgramFormatError(prog, rerr, VIR_ERR_RPC, + _("program mismatch (actual %x, expected %x)"), + msg->msg.header.prog, prog->program); + goto error; + } + + if (msg->msg.header.vers != prog->version) { + virNetServerProgramFormatError(prog, rerr, VIR_ERR_RPC, + _("version mismatch (actual %x, expected %x)"), + msg->msg.header.vers, prog->version); + goto error; + } + + switch (msg->msg.header.type) { + case VIR_NET_CALL: + ret = virNetServerProgramDispatchCall(prog, rerr, server, client, msg); + break; + + case VIR_NET_STREAM: + /* Since stream data is non-acked, async, we may continue to received + * stream packets after we closed down a stream. Just drop & ignore + * these. + */ + VIR_INFO("Ignoring unexpected stream data serial=%d proc=%d status=%d", + msg->msg.header.serial, msg->msg.header.proc, msg->msg.header.status); + virNetServerClientFinishMessage(client, msg); + ret = 0; + break; + + default: + virNetServerProgramFormatError(prog, rerr, VIR_ERR_RPC, + _("Unexpected message type %d"), + (int)msg->msg.header.type); + goto error; + } + + VIR_FREE(rerr); + + return ret; + +error: + ret = virNetServerProgramSerializeReplyError(prog, client, rerr, &msg->msg.header); + + if (ret >= 0) + VIR_FREE(msg); + + VIR_FREE(rerr); + + return ret; +} + + +/* + * @server: the unlocked server object + * @client: the unlocked client object + * @msg: the complete incoming method call, with header already decoded + * + * This method is used to dispatch an message representing an + * incoming method call from a client. It decodes the payload + * to obtain method call arguments, invokves the method and + * then sends a reply packet with the return values + * + * Returns 0 if the reply was sent, or -1 upon fatal error + */ +static int +virNetServerProgramDispatchCall(virNetServerProgramPtr prog, + void *rerr, + virNetServerPtr server, + virNetServerClientPtr client, + virNetServerMessagePtr msg) +{ + char *arg = NULL; + char *ret = NULL; + int rv = -1; + unsigned int len; + virNetServerProgramProcPtr dispatcher; + + memset(rerr, 0, sizeof rerr); + + if (msg->msg.header.status != VIR_NET_OK) { + virNetServerProgramFormatError(prog, rerr, VIR_ERR_RPC, + _("Unexpected message status %d"), + (int)msg->msg.header.status); + goto error; + } + + dispatcher = virNetServerProgramGetProc(prog, msg->msg.header.proc); + + if (!dispatcher) { + virNetServerProgramFormatError(prog, rerr, VIR_ERR_RPC, + _("unknown procedure: %d"), + msg->msg.header.proc); + goto error; + } + + /* If client is marked as needing auth, don't allow any RPC ops + * which are except for authentication ones + */ + if (virNetServerClientNeedAuth(client) && + dispatcher->needAuth) { + /* Explicitly *NOT* calling remoteDispatchAuthError() because + we want back-compatability with libvirt clients which don't + support the VIR_ERR_AUTH_FAILED error code */ + virNetServerProgramFormatError(prog, rerr, VIR_ERR_RPC, + "%s", _("authentication required")); + goto error; + } + + if (VIR_ALLOC_N(arg, dispatcher->arg_len) < 0) { + virNetServerProgramOOMError(prog, rerr); + goto error; + } + if (VIR_ALLOC_N(ret, dispatcher->ret_len) < 0) { + virNetServerProgramOOMError(prog, rerr); + goto error; + } + + if (virNetMessageDecodePayload(&msg->msg, dispatcher->arg_filter, arg) < 0) + goto error; + + /* + * When the RPC handler is called: + * + * - Server object is unlocked + * - Client object is unlocked + * + * Without locking, it is safe to use: + * + * 'rerr', 'args and 'ret' + */ + rv = (dispatcher->func)(server, client, &msg->msg.header, rerr, arg, ret); + + xdr_free(dispatcher->arg_filter, arg); + + if (rv < 0) + goto error; + + /* Return header. We're re-using same message object, so + * only need to tweak type/status fields */ + /*msg->msg.header.prog = msg->msg.header.prog;*/ + /*msg->msg.header.vers = msg->msg.header.vers;*/ + /*msg->msg.header.proc = msg->msg.header.proc;*/ + msg->msg.header.type = VIR_NET_REPLY; + /*msg->msg.header.serial = msg->msg.header.serial;*/ + msg->msg.header.status = VIR_NET_OK; + + if (virNetMessageEncodeHeader(&msg->msg) < 0) { + xdr_free(dispatcher->ret_filter, ret); + goto error; + } + + if (virNetMessageEncodePayload(&msg->msg, dispatcher->ret_filter, ret) < 0) { + xdr_free(dispatcher->ret_filter, ret); + goto error; + } + + /* Reset ready for I/O */ + msg->msg.bufferLength = len; + msg->msg.bufferOffset = 0; + + VIR_FREE(arg); + VIR_FREE(ret); + + /* Put reply on end of tx queue to send out */ + virNetServerClientSendMessage(client, msg); + + return 0; + +error: + /* Bad stuff (de-)serializing message, but we have an + * RPC error message we can send back to the client */ + rv = virNetServerProgramSerializeReplyError(prog, client, rerr, &msg->msg.header); + + if (rv >= 0) + VIR_FREE(msg); + + VIR_FREE(arg); + VIR_FREE(ret); + + return rv; +} + + +void virNetServerProgramFree(virNetServerProgramPtr prog) +{ + if (!prog) + return; + + prog->refs--; + if (prog->refs > 0) + return; + + VIR_FREE(prog); +} + + diff --git a/src/rpc/virnetserverprogram.h b/src/rpc/virnetserverprogram.h new file mode 100644 index 0000000..d116867 --- /dev/null +++ b/src/rpc/virnetserverprogram.h @@ -0,0 +1,76 @@ + +#ifndef __VIR_NET_PROGRAM_H__ +#define __VIR_NET_PROGRAM_H__ + +#include <stdbool.h> + +#include "virnetservermessage.h" + +typedef struct _virNetServer virNetServer; +typedef virNetServer *virNetServerPtr; + +typedef struct _virNetServerClient virNetServerClient; +typedef virNetServerClient *virNetServerClientPtr; + +typedef struct _virNetServerService virNetServerService; +typedef virNetServerService *virNetServerServicePtr; + +typedef struct _virNetServerProgram virNetServerProgram; +typedef virNetServerProgram *virNetServerProgramPtr; + +typedef struct _virNetServerProgramProc virNetServerProgramProc; +typedef virNetServerProgramProc *virNetServerProgramProcPtr; + +typedef struct _virNetServerProgramErrorHandler virNetServerProgramErrorHander; +typedef virNetServerProgramErrorHander *virNetServerProgramErrorHanderPtr; + +typedef int (*virNetServerProgramErrorFunc)(virNetServerProgramPtr prog, + void *rerr, + int code, + const char *msg); + +struct _virNetServerProgramErrorHandler { + virNetServerProgramErrorFunc func; + size_t len; + xdrproc_t filter; +}; + + +typedef int (*virNetServerProgramDispatchFunc)(virNetServerPtr server, + virNetServerClientPtr client, + virNetMessageHeader *hdr, + void *err, + void *args, + void *ret); + +struct _virNetServerProgramProc { + virNetServerProgramDispatchFunc func; + size_t arg_len; + xdrproc_t arg_filter; + size_t ret_len; + xdrproc_t ret_filter; + bool needAuth; +}; + +virNetServerProgramPtr virNetServerProgramNew(unsigned program, + unsigned version, + virNetServerProgramProcPtr procs, + size_t nprocs, + virNetServerProgramErrorHanderPtr err); + +void virNetServerProgramRef(virNetServerProgramPtr prog); + +int virNetServerProgramMatches(virNetServerProgramPtr prog, + virNetServerMessagePtr msg); + +int virNetServerProgramDispatch(virNetServerProgramPtr prog, + virNetServerPtr server, + virNetServerClientPtr client, + virNetServerMessagePtr msg); + +void virNetServerProgramFree(virNetServerProgramPtr prog); + + + + +#endif /* __VIR_NET_SERVER_PROGRAM_H__ */ diff --git a/src/rpc/virnetserverservice.c b/src/rpc/virnetserverservice.c new file mode 100644 index 0000000..f985603 --- /dev/null +++ b/src/rpc/virnetserverservice.c @@ -0,0 +1,208 @@ + +#include <config.h> + + +#include "virnetserverservice.h" + +#include "memory.h" +#include "virterror_internal.h" + + +#define VIR_FROM_THIS VIR_FROM_RPC + +struct _virNetServerService { + int refs; + + size_t nsocks; + virNetSocketPtr *socks; + + int auth; + + virNetTLSContextPtr tls; + + virNetServerServiceDispatchFunc dispatchFunc; + void *dispatchOpaque; +}; + + + +static void virNetServerServiceAccept(virNetSocketPtr sock, + int events ATTRIBUTE_UNUSED, + void *opaque) +{ + virNetServerServicePtr svc = opaque; + virNetServerClientPtr client = NULL; + virNetSocketPtr clientsock = NULL; + + if (virNetSocketAccept(sock, &clientsock) < 0) + goto error; + + if (!clientsock) /* Connection already went away */ + goto cleanup; + + if (!(client = virNetServerClientNew(clientsock, + svc->auth, + svc->tls))) + goto error; + + if (!svc->dispatchFunc) + goto error; + + svc->dispatchFunc(svc, client, svc->dispatchOpaque); + + virNetServerClientFree(client); + +cleanup: + return; + +error: + virNetSocketFree(clientsock); +} + + +virNetServerServicePtr virNetServerServiceNewTCP(const char *nodename, + const char *service, + int auth, + virNetTLSContextPtr tls) +{ + virNetServerServicePtr svc; + int i; + + if (VIR_ALLOC(svc) < 0) + goto no_memory; + + svc->refs = 1; + svc->auth = auth; + svc->tls = tls; + if (tls) + virNetTLSContextRef(tls); + + if (virNetSocketNewListenTCP(nodename, + service, + &svc->socks, + &svc->nsocks) < 0) + goto error; + + for (i = 0 ; i < svc->nsocks ; i++) { + if (virNetSocketListen(svc->socks[i]) < 0) + goto error; + + /* IO callback is initially disabled, until we're ready + * to deal with incoming clients */ + if (virNetSocketAddIOCallback(svc->socks[i], + 0, + virNetServerServiceAccept, + svc) < 0) + goto error; + } + + + return svc; + +no_memory: + virReportOOMError(); +error: + virNetServerServiceFree(svc); + return NULL; +} + + +virNetServerServicePtr virNetServerServiceNewUNIX(const char *path, + mode_t mask, + gid_t grp, + int auth, + virNetTLSContextPtr tls) +{ + virNetServerServicePtr svc; + int i; + + if (VIR_ALLOC(svc) < 0) + goto no_memory; + + svc->refs = 1; + svc->auth = auth; + svc->tls = tls; + if (tls) + virNetTLSContextRef(tls); + + svc->nsocks = 1; + if (VIR_ALLOC_N(svc->socks, svc->nsocks) < 0) + goto no_memory; + + if (virNetSocketNewListenUNIX(path, + mask, + grp, + &svc->socks[0]) < 0) + goto error; + + for (i = 0 ; i < svc->nsocks ; i++) { + if (virNetSocketListen(svc->socks[i]) < 0) + goto error; + + /* IO callback is initially disabled, until we're ready + * to deal with incoming clients */ + if (virNetSocketAddIOCallback(svc->socks[i], + 0, + virNetServerServiceAccept, + svc) < 0) + goto error; + } + + + return svc; + +no_memory: + virReportOOMError(); +error: + virNetServerServiceFree(svc); + return NULL; +} + + +void virNetServerServiceRef(virNetServerServicePtr svc) +{ + svc->refs++; +} + + +void virNetServerServiceSetDispatcher(virNetServerServicePtr svc, + virNetServerServiceDispatchFunc func, + void *opaque) +{ + svc->dispatchFunc = func; + svc->dispatchOpaque = opaque; +} + + +void virNetServerServiceFree(virNetServerServicePtr svc) +{ + int i; + + if (!svc) + return; + + svc->refs--; + if (svc->refs > 0) + return; + + for (i = 0 ; i < svc->nsocks ; i++) + virNetSocketFree(svc->socks[i]); + VIR_FREE(svc->socks); + + virNetTLSContextFree(svc->tls); + + VIR_FREE(svc); +} + +void virNetServerServiceToggle(virNetServerServicePtr svc, + bool enabled) +{ + int i; + + for (i = 0 ; i < svc->nsocks ; i++) + virNetSocketUpdateIOCallback(svc->socks[i], + enabled ? + VIR_EVENT_HANDLE_READABLE : + 0); +} + diff --git a/src/rpc/virnetserverservice.h b/src/rpc/virnetserverservice.h new file mode 100644 index 0000000..4bedde7 --- /dev/null +++ b/src/rpc/virnetserverservice.h @@ -0,0 +1,32 @@ + +#ifndef __VIR_NET_SERVER_SERVICE_H__ +#define __VIR_NET_SERVER_SERVICE_H__ + +#include "virnetserverclient.h" + +typedef int (*virNetServerServiceDispatchFunc)(virNetServerServicePtr svc, + virNetServerClientPtr client, + void *opaque); + +virNetServerServicePtr virNetServerServiceNewTCP(const char *nodename, + const char *service, + int auth, + virNetTLSContextPtr tls); +virNetServerServicePtr virNetServerServiceNewUNIX(const char *path, + mode_t mask, + gid_t grp, + int auth, + virNetTLSContextPtr tls); + +void virNetServerServiceRef(virNetServerServicePtr svc); + +void virNetServerServiceSetDispatcher(virNetServerServicePtr svc, + virNetServerServiceDispatchFunc func, + void *opaque); + +void virNetServerServiceFree(virNetServerServicePtr svc); + +void virNetServerServiceToggle(virNetServerServicePtr svc, + bool enabled); + +#endif -- 1.7.2.3 -- libvir-list mailing list libvir-list@xxxxxxxxxx https://www.redhat.com/mailman/listinfo/libvir-list