Wraps the deserialization of the received messages in an InboundMessages class. The class is created with the deserialized header and the raw data of the message. A template function get_payload() returns the struct of the concrete message. The function is specialized for each incoming message. While this leaves the responsibility to call the get_payload() function with the message according to the type in the header to the caller, the solution preserves the efficiency of the original implementation without introducing too much complexity around the separation of the code. Signed-off-by: Lukáš Hrázký <lhrazky@xxxxxxxxxx> --- src/spice-streaming-agent.cpp | 115 +++++++++------------------------- src/stream-port.cpp | 70 ++++++++++++++++++++- src/stream-port.hpp | 41 +++++++++++- 3 files changed, 139 insertions(+), 87 deletions(-) diff --git a/src/spice-streaming-agent.cpp b/src/spice-streaming-agent.cpp index a9baf4d..a89ba3f 100644 --- a/src/spice-streaming-agent.cpp +++ b/src/spice-streaming-agent.cpp @@ -77,92 +77,39 @@ static bool have_something_to_read(StreamPort &stream_port, bool blocking) return false; } -static void handle_stream_start_stop(StreamPort &stream_port, uint32_t len) -{ - uint8_t msg[256]; - - if (len >= sizeof(msg)) { - throw std::runtime_error("msg size (" + std::to_string(len) + ") is too long " - "(longer than " + std::to_string(sizeof(msg)) + ")"); - } - - stream_port.read(msg, len); - streaming_requested = (msg[0] != 0); /* num_codecs */ - syslog(LOG_INFO, "GOT START_STOP message -- request to %s streaming", - streaming_requested ? "START" : "STOP"); - client_codecs.clear(); - for (int i = 1; i <= msg[0]; ++i) { - client_codecs.insert((SpiceVideoCodecType) msg[i]); - } -} - -static void handle_stream_capabilities(StreamPort &stream_port, uint32_t len) -{ - uint8_t caps[STREAM_MSG_CAPABILITIES_MAX_BYTES]; - - if (len > sizeof(caps)) { - throw std::runtime_error("capability message too long"); - } - - stream_port.read(caps, len); - // we currently do not support extensions so just reply so - StreamDevHeader hdr = { - STREAM_DEVICE_PROTOCOL, - 0, - STREAM_TYPE_CAPABILITIES, - 0 - }; - - stream_port.write(&hdr, sizeof(hdr)); -} - -static void handle_stream_error(StreamPort &stream_port, size_t len) -{ - if (len < sizeof(StreamMsgNotifyError)) { - throw std::runtime_error("Received NotifyError message size " + std::to_string(len) + - " is too small (smaller than " + - std::to_string(sizeof(StreamMsgNotifyError)) + ")"); - } - - struct StreamMsgNotifyError1K : StreamMsgNotifyError { - uint8_t msg[1024]; - } msg; - - size_t len_to_read = std::min(len, sizeof(msg) - 1); - - stream_port.read(&msg, len_to_read); - msg.msg[len_to_read - sizeof(StreamMsgNotifyError)] = '\0'; - - syslog(LOG_ERR, "Received NotifyError message from the server: %d - %s", - msg.error_code, msg.msg); - - if (len_to_read < len) { - throw std::runtime_error("Received NotifyError message size " + std::to_string(len) + - " is too big (bigger than " + std::to_string(sizeof(msg)) + ")"); - } -} - static void read_command_from_device(StreamPort &stream_port) { - StreamDevHeader hdr; - - std::lock_guard<std::mutex> guard(stream_port.mutex); - stream_port.read(&hdr, sizeof(hdr)); - - if (hdr.protocol_version != STREAM_DEVICE_PROTOCOL) { - throw std::runtime_error("BAD VERSION " + std::to_string(hdr.protocol_version) + - " (expected is " + std::to_string(STREAM_DEVICE_PROTOCOL) + ")"); - } - - switch (hdr.type) { - case STREAM_TYPE_CAPABILITIES: - return handle_stream_capabilities(stream_port, hdr.size); - case STREAM_TYPE_NOTIFY_ERROR: - return handle_stream_error(stream_port, hdr.size); - case STREAM_TYPE_START_STOP: - return handle_stream_start_stop(stream_port, hdr.size); - } - throw std::runtime_error("UNKNOWN msg of type " + std::to_string(hdr.type)); + InboundMessage in_message = stream_port.receive(); + + switch (in_message.header.type) { + case STREAM_TYPE_CAPABILITIES: { + StreamDevHeader hdr = { + STREAM_DEVICE_PROTOCOL, + 0, + STREAM_TYPE_CAPABILITIES, + 0 + }; + + std::lock_guard<std::mutex> guard(stream_port.mutex); + stream_port.write(&hdr, sizeof(hdr)); + return; + } case STREAM_TYPE_NOTIFY_ERROR: { + NotifyErrorMessage msg = in_message.get_payload<NotifyErrorMessage>(); + + syslog(LOG_ERR, "Received NotifyError message from the server: %d - %s", + msg.error_code, msg.message); + return; + } case STREAM_TYPE_START_STOP: { + StartStopMessage msg = in_message.get_payload<StartStopMessage>(); + streaming_requested = msg.start_streaming; + client_codecs = msg.client_codecs; + + syslog(LOG_INFO, "GOT START_STOP message -- request to %s streaming", + streaming_requested ? "START" : "STOP"); + return; + }} + + throw std::runtime_error("UNKNOWN msg of type " + std::to_string(in_message.header.type)); } static void read_command(StreamPort &stream_port, bool blocking) diff --git a/src/stream-port.cpp b/src/stream-port.cpp index 5528854..56747fd 100644 --- a/src/stream-port.cpp +++ b/src/stream-port.cpp @@ -19,6 +19,58 @@ namespace spice { namespace streaming_agent { +InboundMessage::InboundMessage(const StreamDevHeader &header, std::unique_ptr<uint8_t[]> &&data) : + header(header), + data(std::move(data)) +{} + +template<> +StartStopMessage InboundMessage::get_payload<StartStopMessage>() +{ + StartStopMessage msg; + + msg.start_streaming = data[0]; // num_codecs + + for (size_t i = 1; i <= data[0]; ++i) { + msg.client_codecs.insert((SpiceVideoCodecType) data[i]); + } + + return msg; +} + +template<> +InCapabilitiesMessage InboundMessage::get_payload<InCapabilitiesMessage>() +{ + // no capabilities yet + return InCapabilitiesMessage(); +} + +template<> +NotifyErrorMessage InboundMessage::get_payload<NotifyErrorMessage>() +{ + if (header.size < sizeof(StreamMsgNotifyError)) { + throw std::runtime_error("Received NotifyError message size " + std::to_string(header.size) + + " is too small (smaller than " + + std::to_string(sizeof(StreamMsgNotifyError)) + ")"); + } + + size_t msg_len = header.size - sizeof(StreamMsgNotifyError); + if (msg_len > 1024) { + throw std::runtime_error("Received NotifyError message is too long (" + + std::to_string(msg_len) + " > 1024)"); + } + + StreamMsgNotifyError *raw_message = reinterpret_cast<StreamMsgNotifyError*>(data.get()); + + NotifyErrorMessage msg; + msg.error_code = raw_message->error_code; + strncpy(msg.message, reinterpret_cast<char*>(raw_message->msg), msg_len); + // make sure the string is terminated + msg.message[msg_len] = '\0'; + + return msg; +} + StreamPort::StreamPort(const std::string &port_name) : fd(open(port_name.c_str(), O_RDWR | O_NONBLOCK)) { if (fd < 0) { @@ -31,9 +83,23 @@ StreamPort::~StreamPort() close(fd); } -void StreamPort::read(void *buf, size_t len) +InboundMessage StreamPort::receive() { - read_all(fd, buf, len); + std::lock_guard<std::mutex> stream_guard(mutex); + + StreamDevHeader header; + read_all(fd, &header, sizeof(header)); + + if (header.protocol_version != STREAM_DEVICE_PROTOCOL) { + throw std::runtime_error("Bad protocol version: " + std::to_string(header.protocol_version) + + ", expected: " + std::to_string(STREAM_DEVICE_PROTOCOL)); + } + + // TODO should we limit the maximum message size? + std::unique_ptr<uint8_t[]> data(new uint8_t[header.size]); + read_all(fd, data.get(), header.size); + + return InboundMessage(header, std::move(data)); } void StreamPort::write(const void *buf, size_t len) diff --git a/src/stream-port.hpp b/src/stream-port.hpp index 9187cf5..090930b 100644 --- a/src/stream-port.hpp +++ b/src/stream-port.hpp @@ -7,20 +7,59 @@ #ifndef SPICE_STREAMING_AGENT_STREAM_PORT_HPP #define SPICE_STREAMING_AGENT_STREAM_PORT_HPP +#include <spice/stream-device.h> +#include <spice/enums.h> + #include <cstddef> #include <string> +#include <memory> #include <mutex> +#include <set> namespace spice { namespace streaming_agent { +struct StartStopMessage +{ + bool start_streaming = false; + std::set<SpiceVideoCodecType> client_codecs; +}; + +struct InCapabilitiesMessage {}; + +struct NotifyErrorMessage +{ + uint32_t error_code; + char message[1025]; +}; + +class InboundMessage +{ +public: + InboundMessage(const StreamDevHeader &header, std::unique_ptr<uint8_t[]> &&data); + + template<class Payload> Payload get_payload(); + + const StreamDevHeader header; +private: + std::unique_ptr<uint8_t[]> data; +}; + +template<> +StartStopMessage InboundMessage::get_payload<StartStopMessage>(); +template<> +InCapabilitiesMessage InboundMessage::get_payload<InCapabilitiesMessage>(); +template<> +NotifyErrorMessage InboundMessage::get_payload<NotifyErrorMessage>(); + class StreamPort { public: StreamPort(const std::string &port_name); ~StreamPort(); - void read(void *buf, size_t len); + InboundMessage receive(); + void write(const void *buf, size_t len); int fd; -- 2.17.1 _______________________________________________ Spice-devel mailing list Spice-devel@xxxxxxxxxxxxxxxxxxxxx https://lists.freedesktop.org/mailman/listinfo/spice-devel