[PATCH spice-streaming-agent v2 8/9] Encapsulate the stream port fd and locking

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



Wrap the streaming virtio port along with the mutex to lock it in a
class. Pass the class temporarily around to functions that need it until
the functions too are consolidated into the class.

The locking needs to be outside the class for now to prevent situations
like:
1 - data header
2 - cursor header
3 - data message
4 - cursor message

Signed-off-by: Lukáš Hrázký <lhrazky@xxxxxxxxxx>
---
 src/spice-streaming-agent.cpp | 111 +++++++++++++++++++-----------------------
 src/stream-port.cpp           |  23 +++++++++
 src/stream-port.hpp           |  14 ++++++
 3 files changed, 86 insertions(+), 62 deletions(-)

diff --git a/src/spice-streaming-agent.cpp b/src/spice-streaming-agent.cpp
index 692f067..a9e5bf5 100644
--- a/src/spice-streaming-agent.cpp
+++ b/src/spice-streaming-agent.cpp
@@ -32,7 +32,6 @@
 #include <exception>
 #include <stdexcept>
 #include <memory>
-#include <mutex>
 #include <thread>
 #include <vector>
 #include <string>
@@ -60,12 +59,10 @@ static bool streaming_requested = false;
 static bool quit_requested = false;
 static bool log_binary = false;
 static std::set<SpiceVideoCodecType> client_codecs;
-static int streamfd = -1;
-static std::mutex stream_mtx;
 
-static int have_something_to_read(int timeout)
+static int have_something_to_read(StreamPort &stream_port, int timeout)
 {
-    struct pollfd pollfd = {streamfd, POLLIN, 0};
+    struct pollfd pollfd = {stream_port.fd, POLLIN, 0};
 
     if (poll(&pollfd, 1, timeout) < 0) {
         syslog(LOG_ERR, "poll FAILED\n");
@@ -79,7 +76,7 @@ static int have_something_to_read(int timeout)
     return 0;
 }
 
-static void handle_stream_start_stop(uint32_t len)
+static void handle_stream_start_stop(StreamPort &stream_port, uint32_t len)
 {
     uint8_t msg[256];
 
@@ -88,7 +85,7 @@ static void handle_stream_start_stop(uint32_t len)
                                  "(longer than " + std::to_string(sizeof(msg)) + ")");
     }
 
-    read_all(streamfd, msg, len);
+    stream_port.read(msg, len);
     streaming_requested = (msg[0] != 0); /* num_codecs */
     syslog(LOG_INFO, "GOT START_STOP message -- request to %s streaming\n",
            streaming_requested ? "START" : "STOP");
@@ -98,7 +95,7 @@ static void handle_stream_start_stop(uint32_t len)
     }
 }
 
-static void handle_stream_capabilities(uint32_t len)
+static void handle_stream_capabilities(StreamPort &stream_port, uint32_t len)
 {
     uint8_t caps[STREAM_MSG_CAPABILITIES_MAX_BYTES];
 
@@ -106,7 +103,7 @@ static void handle_stream_capabilities(uint32_t len)
         throw std::runtime_error("capability message too long");
     }
 
-    read_all(streamfd, caps, len);
+    stream_port.read(caps, len);
     // we currently do not support extensions so just reply so
     StreamDevHeader hdr = {
         STREAM_DEVICE_PROTOCOL,
@@ -114,10 +111,11 @@ static void handle_stream_capabilities(uint32_t len)
         STREAM_TYPE_CAPABILITIES,
         0
     };
-    write_all(streamfd, &hdr, sizeof(hdr));
+
+    stream_port.write(&hdr, sizeof(hdr));
 }
 
-static void handle_stream_error(size_t len)
+static void handle_stream_error(StreamPort &stream_port, size_t len)
 {
     if (len < sizeof(StreamMsgNotifyError)) {
         throw std::runtime_error("Received NotifyError message size " + std::to_string(len) +
@@ -131,7 +129,7 @@ static void handle_stream_error(size_t len)
 
     size_t len_to_read = std::min(len, sizeof(msg) - 1);
 
-    read_all(streamfd, &msg, len_to_read);
+    stream_port.read(&msg, len_to_read);
     msg.msg[len_to_read - sizeof(StreamMsgNotifyError)] = '\0';
 
     syslog(LOG_ERR, "Received NotifyError message from the server: %d - %s\n",
@@ -143,13 +141,12 @@ static void handle_stream_error(size_t len)
     }
 }
 
-static void read_command_from_device(void)
+static void read_command_from_device(StreamPort &stream_port)
 {
     StreamDevHeader hdr;
 
-    std::lock_guard<std::mutex> stream_guard(stream_mtx);
-
-    read_all(streamfd, &hdr, sizeof(hdr));
+    std::lock_guard<std::mutex> guard(stream_port.mutex);
+    stream_port.read(&hdr, sizeof(hdr));
 
     if (hdr.protocol_version != STREAM_DEVICE_PROTOCOL) {
         throw std::runtime_error("BAD VERSION " + std::to_string(hdr.protocol_version) +
@@ -158,34 +155,34 @@ static void read_command_from_device(void)
 
     switch (hdr.type) {
     case STREAM_TYPE_CAPABILITIES:
-        return handle_stream_capabilities(hdr.size);
+        return handle_stream_capabilities(stream_port, hdr.size);
     case STREAM_TYPE_NOTIFY_ERROR:
-        return handle_stream_error(hdr.size);
+        return handle_stream_error(stream_port, hdr.size);
     case STREAM_TYPE_START_STOP:
-        return handle_stream_start_stop(hdr.size);
+        return handle_stream_start_stop(stream_port, hdr.size);
     }
     throw std::runtime_error("UNKNOWN msg of type " + std::to_string(hdr.type));
 }
 
-static int read_command(bool blocking)
+static int read_command(StreamPort &stream_port, bool blocking)
 {
     int timeout = blocking?-1:0;
     while (!quit_requested) {
-        if (!have_something_to_read(timeout)) {
+        if (!have_something_to_read(stream_port, timeout)) {
             if (!blocking) {
                 return 0;
             }
             sleep(1);
             continue;
         }
-        read_command_from_device();
+        read_command_from_device(stream_port);
         break;
     }
 
     return 1;
 }
 
-static void spice_stream_send_format(unsigned w, unsigned h, unsigned c)
+static void spice_stream_send_format(StreamPort &stream_port, unsigned w, unsigned h, unsigned c)
 {
 
     SpiceStreamFormatMessage msg;
@@ -198,12 +195,13 @@ static void spice_stream_send_format(unsigned w, unsigned h, unsigned c)
     msg.msg.width = w;
     msg.msg.height = h;
     msg.msg.codec = c;
+
     syslog(LOG_DEBUG, "writing format\n");
-    std::lock_guard<std::mutex> stream_guard(stream_mtx);
-    write_all(streamfd, &msg, msgsize);
+    std::lock_guard<std::mutex> guard(stream_port.mutex);
+    stream_port.write(&msg, msgsize);
 }
 
-static void spice_stream_send_frame(const void *buf, const unsigned size)
+static void spice_stream_send_frame(StreamPort &stream_port, const void *buf, const unsigned size)
 {
     SpiceStreamDataMessage msg;
     const size_t msgsize = sizeof(msg);
@@ -212,9 +210,10 @@ static void spice_stream_send_frame(const void *buf, const unsigned size)
     msg.hdr.protocol_version = STREAM_DEVICE_PROTOCOL;
     msg.hdr.type = STREAM_TYPE_DATA;
     msg.hdr.size = size; /* includes only the body? */
-    std::lock_guard<std::mutex> stream_guard(stream_mtx);
-    write_all(streamfd, &msg, msgsize);
-    write_all(streamfd, buf, size);
+
+    std::lock_guard<std::mutex> guard(stream_port.mutex);
+    stream_port.write(&msg, msgsize);
+    stream_port.write(buf, size);
 
     syslog(LOG_DEBUG, "Sent a frame of size %u\n", size);
 }
@@ -264,7 +263,7 @@ static void usage(const char *progname)
 }
 
 static void
-send_cursor(unsigned width, unsigned height, int hotspot_x, int hotspot_y,
+send_cursor(StreamPort &stream_port, unsigned width, unsigned height, int hotspot_x, int hotspot_y,
             std::function<void(uint32_t *)> fill_cursor)
 {
     if (width >= STREAM_MSG_CURSOR_SET_MAX_WIDTH || height >= STREAM_MSG_CURSOR_SET_MAX_HEIGHT) {
@@ -294,11 +293,11 @@ send_cursor(unsigned width, unsigned height, int hotspot_x, int hotspot_y,
     uint32_t *pixels = reinterpret_cast<uint32_t *>(cursor_msg.data);
     fill_cursor(pixels);
 
-    std::lock_guard<std::mutex> stream_guard(stream_mtx);
-    write_all(streamfd, msg.get(), cursor_size);
+    std::lock_guard<std::mutex> guard(stream_port.mutex);
+    stream_port.write(msg.get(), cursor_size);
 }
 
-static void cursor_changes(Display *display, int event_base)
+static void cursor_changes(StreamPort *stream_port, Display *display, int event_base)
 {
     unsigned long last_serial = 0;
 
@@ -323,26 +322,18 @@ static void cursor_changes(Display *display, int event_base)
             for (unsigned i = 0; i < cursor->width * cursor->height; ++i)
                 pixels[i] = cursor->pixels[i];
         };
-        send_cursor(cursor->width, cursor->height, cursor->xhot, cursor->yhot, fill_cursor);
+        send_cursor(*stream_port, cursor->width, cursor->height, cursor->xhot, cursor->yhot, fill_cursor);
     }
 }
 
 static void
-do_capture(const char *streamport, FILE *f_log)
+do_capture(StreamPort &stream_port, FILE *f_log)
 {
-    streamfd = open(streamport, O_RDWR | O_NONBLOCK);
-    if (streamfd < 0) {
-        throw std::runtime_error("failed to open the streaming device (" +
-                                 std::string(streamport) + "): "
-                                 + strerror(errno));
-    }
-
     unsigned int frame_count = 0;
     while (!quit_requested) {
         while (!quit_requested && !streaming_requested) {
-            if (read_command(true) < 0) {
-                syslog(LOG_ERR, "FAILED to read command\n");
-                goto done;
+            if (read_command(stream_port, true) < 0) {
+                throw std::runtime_error("FAILED to read command");
             }
         }
 
@@ -385,7 +376,7 @@ do_capture(const char *streamport, FILE *f_log)
 
                 syslog(LOG_DEBUG, "wXh %uX%u  codec=%u\n", width, height, codec);
 
-                spice_stream_send_format(width, height, codec);
+                spice_stream_send_format(stream_port, width, height, codec);
             }
             if (f_log) {
                 if (log_binary) {
@@ -398,32 +389,25 @@ do_capture(const char *streamport, FILE *f_log)
             }
 
             try {
-                spice_stream_send_frame(frame.buffer, frame.buffer_size);
+                spice_stream_send_frame(stream_port, frame.buffer, frame.buffer_size);
             } catch (const WriteError& e) {
                 syslog(e);
                 break;
             }
 
             //usleep(1);
-            if (read_command(false) < 0) {
-                syslog(LOG_ERR, "FAILED to read command\n");
-                goto done;
+            if (read_command(stream_port, false) < 0) {
+                throw std::runtime_error("FAILED to read command");
             }
         }
     }
-
-done:
-    if (streamfd >= 0) {
-        close(streamfd);
-        streamfd = -1;
-    }
 }
 
 #define arg_error(...) syslog(LOG_ERR, ## __VA_ARGS__);
 
 int main(int argc, char* argv[])
 {
-    const char *streamport = "/dev/virtio-ports/org.spice-space.stream.0";
+    const char *stream_port_name = "/dev/virtio-ports/org.spice-space.stream.0";
     int opt;
     const char *log_filename = NULL;
     int logmask = LOG_UPTO(LOG_WARNING);
@@ -454,7 +438,7 @@ int main(int argc, char* argv[])
             pluginsdir = optarg;
             break;
         case 'p':
-            streamport = optarg;
+            stream_port_name = optarg;
             break;
         case 'c': {
             char *p = strchr(optarg, '=');
@@ -512,12 +496,15 @@ int main(int argc, char* argv[])
     Window rootwindow = DefaultRootWindow(display);
     XFixesSelectCursorInput(display, rootwindow, XFixesDisplayCursorNotifyMask);
 
-    std::thread cursor_th(cursor_changes, display, event_base);
-    cursor_th.detach();
-
     int ret = EXIT_SUCCESS;
+
     try {
-        do_capture(streamport, f_log);
+        StreamPort stream_port(stream_port_name);
+
+        std::thread cursor_th(cursor_changes, &stream_port, display, event_base);
+        cursor_th.detach();
+
+        do_capture(stream_port, f_log);
     }
     catch (std::exception &err) {
         syslog(LOG_ERR, "%s\n", err.what());
diff --git a/src/stream-port.cpp b/src/stream-port.cpp
index 72364bd..5528854 100644
--- a/src/stream-port.cpp
+++ b/src/stream-port.cpp
@@ -8,6 +8,7 @@
 #include "error.hpp"
 
 #include <errno.h>
+#include <fcntl.h>
 #include <poll.h>
 #include <string.h>
 #include <syslog.h>
@@ -18,6 +19,28 @@
 namespace spice {
 namespace streaming_agent {
 
+StreamPort::StreamPort(const std::string &port_name) : fd(open(port_name.c_str(), O_RDWR | O_NONBLOCK))
+{
+    if (fd < 0) {
+        throw IOError("Failed to open the streaming device \"" + port_name + "\"", errno);
+    }
+}
+
+StreamPort::~StreamPort()
+{
+    close(fd);
+}
+
+void StreamPort::read(void *buf, size_t len)
+{
+    read_all(fd, buf, len);
+}
+
+void StreamPort::write(const void *buf, size_t len)
+{
+    write_all(fd, buf, len);
+}
+
 void read_all(int fd, void *buf, size_t len)
 {
     while (len > 0) {
diff --git a/src/stream-port.hpp b/src/stream-port.hpp
index b2d8352..9187cf5 100644
--- a/src/stream-port.hpp
+++ b/src/stream-port.hpp
@@ -8,11 +8,25 @@
 #define SPICE_STREAMING_AGENT_STREAM_PORT_HPP
 
 #include <cstddef>
+#include <string>
+#include <mutex>
 
 
 namespace spice {
 namespace streaming_agent {
 
+class StreamPort {
+public:
+    StreamPort(const std::string &port_name);
+    ~StreamPort();
+
+    void read(void *buf, size_t len);
+    void write(const void *buf, size_t len);
+
+    int fd;
+    std::mutex mutex;
+};
+
 void read_all(int fd, void *buf, size_t len);
 void write_all(int fd, const void *buf, size_t len);
 
-- 
2.16.2

_______________________________________________
Spice-devel mailing list
Spice-devel@xxxxxxxxxxxxxxxxxxxxx
https://lists.freedesktop.org/mailman/listinfo/spice-devel




[Index of Archives]     [Linux ARM Kernel]     [Linux ARM]     [Linux Omap]     [Fedora ARM]     [IETF Annouce]     [Security]     [Bugtraq]     [Linux]     [Linux OMAP]     [Linux MIPS]     [ECOS]     [Asterisk Internet PBX]     [Linux API]     [Monitors]