[PATCH spice-streaming-agent v3 1/2] Introduce InboundMessages for the StreamPort class

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



Wraps the deserialization of the received messages in an InboundMessages
class. The class is created with the deserialized header and the raw
data of the message. A template function get_payload() returns the
struct of the concrete message. The function is specialized for each
incoming message.

While this leaves the responsibility to call the get_payload() function
with the message according to the type in the header to the caller, the
solution preserves the efficiency of the original implementation without
introducing too much complexity around the separation of the code.

Signed-off-by: Lukáš Hrázký <lhrazky@xxxxxxxxxx>
---
 src/spice-streaming-agent.cpp | 114 +++++++++-------------------------
 src/stream-port.cpp           |  81 +++++++++++++++++++++++-
 src/stream-port.hpp           |  41 +++++++++++-
 3 files changed, 148 insertions(+), 88 deletions(-)

diff --git a/src/spice-streaming-agent.cpp b/src/spice-streaming-agent.cpp
index 6b2be5b..b6e77de 100644
--- a/src/spice-streaming-agent.cpp
+++ b/src/spice-streaming-agent.cpp
@@ -77,97 +77,41 @@ static bool have_something_to_read(StreamPort &stream_port, bool blocking)
     return false;
 }
 
-static void handle_stream_start_stop(StreamPort &stream_port, uint32_t len)
-{
-    uint8_t msg[256];
-
-    if (len >= sizeof(msg)) {
-        throw std::runtime_error("msg size (" + std::to_string(len) + ") is too long "
-                                 "(longer than " + std::to_string(sizeof(msg)) + ")");
-    }
-
-    stream_port.read(msg, len);
-    streaming_requested = (msg[0] != 0); /* num_codecs */
-    syslog(LOG_INFO, "GOT START_STOP message -- request to %s streaming",
-           streaming_requested ? "START" : "STOP");
-    client_codecs.clear();
-    const int max_codecs = len - 1; /* see struct StreamMsgStartStop */
-    if (msg[0] > max_codecs) {
-        throw std::runtime_error("num_codecs=" + std::to_string(msg[0]) +
-                                 " > max_codecs=" + std::to_string(max_codecs));
-    }
-    for (int i = 1; i <= msg[0]; ++i) {
-        client_codecs.insert((SpiceVideoCodecType) msg[i]);
-    }
-}
-
-static void handle_stream_capabilities(StreamPort &stream_port, uint32_t len)
-{
-    uint8_t caps[STREAM_MSG_CAPABILITIES_MAX_BYTES];
-
-    if (len > sizeof(caps)) {
-        throw std::runtime_error("capability message too long");
-    }
-
-    stream_port.read(caps, len);
-    // we currently do not support extensions so just reply so
-    StreamDevHeader hdr = {
-        STREAM_DEVICE_PROTOCOL,
-        0,
-        STREAM_TYPE_CAPABILITIES,
-        0
-    };
-
-    stream_port.write(&hdr, sizeof(hdr));
-}
-
-static void handle_stream_error(StreamPort &stream_port, size_t len)
+static void read_command_from_device(StreamPort &stream_port)
 {
-    if (len < sizeof(StreamMsgNotifyError)) {
-        throw std::runtime_error("Received NotifyError message size " + std::to_string(len) +
-                                 " is too small (smaller than " +
-                                 std::to_string(sizeof(StreamMsgNotifyError)) + ")");
+    InboundMessage in_message = stream_port.receive();
+
+    switch (in_message.header.type) {
+    case STREAM_TYPE_CAPABILITIES: {
+        StreamDevHeader hdr = {
+            STREAM_DEVICE_PROTOCOL,
+            0,
+            STREAM_TYPE_CAPABILITIES,
+            0
+        };
+
+        std::lock_guard<std::mutex> guard(stream_port.mutex);
+        stream_port.write(&hdr, sizeof(hdr));
+        return;
     }
+    case STREAM_TYPE_NOTIFY_ERROR: {
+        NotifyErrorMessage msg = in_message.get_payload<NotifyErrorMessage>();
 
-    struct StreamMsgNotifyError1K : StreamMsgNotifyError {
-        uint8_t msg[1024];
-    } msg;
-
-    size_t len_to_read = std::min(len, sizeof(msg) - 1);
-
-    stream_port.read(&msg, len_to_read);
-    msg.msg[len_to_read - sizeof(StreamMsgNotifyError)] = '\0';
-
-    syslog(LOG_ERR, "Received NotifyError message from the server: %d - %s",
-        msg.error_code, msg.msg);
-
-    if (len_to_read < len) {
-        throw std::runtime_error("Received NotifyError message size " + std::to_string(len) +
-                                 " is too big (bigger than " + std::to_string(sizeof(msg)) + ")");
+        syslog(LOG_ERR, "Received NotifyError message from the server: %d - %s",
+               msg.error_code, msg.message);
+        return;
     }
-}
+    case STREAM_TYPE_START_STOP: {
+        StartStopMessage msg = in_message.get_payload<StartStopMessage>();
+        streaming_requested = msg.start_streaming;
+        client_codecs = msg.client_codecs;
 
-static void read_command_from_device(StreamPort &stream_port)
-{
-    StreamDevHeader hdr;
-
-    std::lock_guard<std::mutex> guard(stream_port.mutex);
-    stream_port.read(&hdr, sizeof(hdr));
+        syslog(LOG_INFO, "GOT START_STOP message -- request to %s streaming",
+               streaming_requested ? "START" : "STOP");
+        return;
+    }}
 
-    if (hdr.protocol_version != STREAM_DEVICE_PROTOCOL) {
-        throw std::runtime_error("BAD VERSION " + std::to_string(hdr.protocol_version) +
-                                 " (expected is " + std::to_string(STREAM_DEVICE_PROTOCOL) + ")");
-    }
-
-    switch (hdr.type) {
-    case STREAM_TYPE_CAPABILITIES:
-        return handle_stream_capabilities(stream_port, hdr.size);
-    case STREAM_TYPE_NOTIFY_ERROR:
-        return handle_stream_error(stream_port, hdr.size);
-    case STREAM_TYPE_START_STOP:
-        return handle_stream_start_stop(stream_port, hdr.size);
-    }
-    throw std::runtime_error("UNKNOWN msg of type " + std::to_string(hdr.type));
+    throw std::runtime_error("UNKNOWN msg of type " + std::to_string(in_message.header.type));
 }
 
 static void read_command(StreamPort &stream_port, bool blocking)
diff --git a/src/stream-port.cpp b/src/stream-port.cpp
index 5528854..afef2e9 100644
--- a/src/stream-port.cpp
+++ b/src/stream-port.cpp
@@ -19,6 +19,66 @@
 namespace spice {
 namespace streaming_agent {
 
+InboundMessage::InboundMessage(const StreamDevHeader &header, std::unique_ptr<uint8_t[]> &&data) :
+    header(header),
+    data(std::move(data))
+{}
+
+template<>
+StartStopMessage InboundMessage::get_payload<StartStopMessage>()
+{
+    StartStopMessage msg;
+
+    // data[0] is num_codecs. No codecs in the message means to stop streaming.
+    msg.start_streaming = data[0] > 0;
+
+    const size_t max_codecs = header.size - 1;
+    if (data[0] > max_codecs) {
+        throw std::runtime_error("Malformed StartStop message: num_codecs (" +
+                                 std::to_string(data[0]) + ") is greater than the message size (" +
+                                 std::to_string(max_codecs) + ")");
+    }
+
+    for (size_t i = 1; i <= data[0]; ++i) {
+        msg.client_codecs.insert((SpiceVideoCodecType) data[i]);
+    }
+
+    return msg;
+}
+
+template<>
+InCapabilitiesMessage InboundMessage::get_payload<InCapabilitiesMessage>()
+{
+    // no capabilities yet
+    return InCapabilitiesMessage();
+}
+
+template<>
+NotifyErrorMessage InboundMessage::get_payload<NotifyErrorMessage>()
+{
+    if (header.size < sizeof(StreamMsgNotifyError)) {
+        throw std::runtime_error("Received NotifyError message size " + std::to_string(header.size) +
+                                 " is too small (smaller than " +
+                                 std::to_string(sizeof(StreamMsgNotifyError)) + ")");
+    }
+
+    size_t msg_len = header.size - sizeof(StreamMsgNotifyError);
+    if (msg_len > 1024) {
+        throw std::runtime_error("Received NotifyError message is too long (" +
+                                 std::to_string(msg_len) + " > 1024)");
+    }
+
+    StreamMsgNotifyError *raw_message = reinterpret_cast<StreamMsgNotifyError*>(data.get());
+
+    NotifyErrorMessage msg;
+    msg.error_code = raw_message->error_code;
+    strncpy(msg.message, reinterpret_cast<char*>(raw_message->msg), msg_len);
+    // make sure the string is terminated
+    msg.message[msg_len] = '\0';
+
+    return msg;
+}
+
 StreamPort::StreamPort(const std::string &port_name) : fd(open(port_name.c_str(), O_RDWR | O_NONBLOCK))
 {
     if (fd < 0) {
@@ -31,9 +91,26 @@ StreamPort::~StreamPort()
     close(fd);
 }
 
-void StreamPort::read(void *buf, size_t len)
+InboundMessage StreamPort::receive()
 {
-    read_all(fd, buf, len);
+    std::lock_guard<std::mutex> stream_guard(mutex);
+
+    StreamDevHeader header;
+    read_all(fd, &header, sizeof(header));
+
+    if (header.protocol_version != STREAM_DEVICE_PROTOCOL) {
+        throw std::runtime_error("Bad protocol version: " + std::to_string(header.protocol_version) +
+                                 ", expected: " + std::to_string(STREAM_DEVICE_PROTOCOL));
+    }
+
+    if (header.size > 4 * 1024) {  // a 4kB generic limit of the message size
+        throw std::runtime_error("Inbound message too big, exceeding the 4kB limit.");
+    }
+
+    std::unique_ptr<uint8_t[]> data(new uint8_t[header.size]);
+    read_all(fd, data.get(), header.size);
+
+    return InboundMessage(header, std::move(data));
 }
 
 void StreamPort::write(const void *buf, size_t len)
diff --git a/src/stream-port.hpp b/src/stream-port.hpp
index 48f843c..136ff25 100644
--- a/src/stream-port.hpp
+++ b/src/stream-port.hpp
@@ -7,20 +7,59 @@
 #ifndef SPICE_STREAMING_AGENT_STREAM_PORT_HPP
 #define SPICE_STREAMING_AGENT_STREAM_PORT_HPP
 
+#include <spice/stream-device.h>
+#include <spice/enums.h>
+
 #include <cstddef>
 #include <string>
+#include <memory>
 #include <mutex>
+#include <set>
 
 
 namespace spice {
 namespace streaming_agent {
 
+struct StartStopMessage
+{
+    bool start_streaming = false;
+    std::set<SpiceVideoCodecType> client_codecs;
+};
+
+struct InCapabilitiesMessage {};
+
+struct NotifyErrorMessage
+{
+    uint32_t error_code;
+    char message[1025];
+};
+
+class InboundMessage
+{
+public:
+    InboundMessage(const StreamDevHeader &header, std::unique_ptr<uint8_t[]> &&data);
+
+    template<class Payload> Payload get_payload();
+
+    const StreamDevHeader header;
+private:
+    std::unique_ptr<uint8_t[]> data;
+};
+
+template<>
+StartStopMessage InboundMessage::get_payload<StartStopMessage>();
+template<>
+InCapabilitiesMessage InboundMessage::get_payload<InCapabilitiesMessage>();
+template<>
+NotifyErrorMessage InboundMessage::get_payload<NotifyErrorMessage>();
+
 class StreamPort {
 public:
     StreamPort(const std::string &port_name);
     ~StreamPort();
 
-    void read(void *buf, size_t len);
+    InboundMessage receive();
+
     void write(const void *buf, size_t len);
 
     const int fd;
-- 
2.19.1

_______________________________________________
Spice-devel mailing list
Spice-devel@xxxxxxxxxxxxxxxxxxxxx
https://lists.freedesktop.org/mailman/listinfo/spice-devel




[Index of Archives]     [Linux Virtualization]     [Linux Virtualization]     [Linux ARM Kernel]     [Linux ARM]     [Linux Omap]     [Fedora ARM]     [IETF Annouce]     [Security]     [Bugtraq]     [Linux OMAP]     [Linux MIPS]     [ECOS]     [Asterisk Internet PBX]     [Linux API]     [Monitors]