> > Wrap the streaming virtio port along with the mutex to lock it in a > class. Pass the class temporarily around to functions that need it until > the functions too are consolidated into the class. > > The locking needs to be outside the class for now to prevent situations > like: > 1 - data header > 2 - cursor header > 3 - data message > 4 - cursor message > > Signed-off-by: Lukáš Hrázký <lhrazky@xxxxxxxxxx> > --- > src/spice-streaming-agent.cpp | 111 > +++++++++++++++++++----------------------- > src/stream-port.cpp | 23 +++++++++ > src/stream-port.hpp | 14 ++++++ > 3 files changed, 86 insertions(+), 62 deletions(-) > > diff --git a/src/spice-streaming-agent.cpp b/src/spice-streaming-agent.cpp > index 692f067..a9e5bf5 100644 > --- a/src/spice-streaming-agent.cpp > +++ b/src/spice-streaming-agent.cpp > @@ -32,7 +32,6 @@ > #include <exception> > #include <stdexcept> > #include <memory> > -#include <mutex> > #include <thread> > #include <vector> > #include <string> > @@ -60,12 +59,10 @@ static bool streaming_requested = false; > static bool quit_requested = false; > static bool log_binary = false; > static std::set<SpiceVideoCodecType> client_codecs; > -static int streamfd = -1; > -static std::mutex stream_mtx; > > -static int have_something_to_read(int timeout) > +static int have_something_to_read(StreamPort &stream_port, int timeout) > { > - struct pollfd pollfd = {streamfd, POLLIN, 0}; > + struct pollfd pollfd = {stream_port.fd, POLLIN, 0}; > > if (poll(&pollfd, 1, timeout) < 0) { > syslog(LOG_ERR, "poll FAILED\n"); > @@ -79,7 +76,7 @@ static int have_something_to_read(int timeout) > return 0; > } > > -static void handle_stream_start_stop(uint32_t len) > +static void handle_stream_start_stop(StreamPort &stream_port, uint32_t len) > { > uint8_t msg[256]; > > @@ -88,7 +85,7 @@ static void handle_stream_start_stop(uint32_t len) > "(longer than " + > std::to_string(sizeof(msg)) + ")"); > } > > - read_all(streamfd, msg, len); > + stream_port.read(msg, len); > streaming_requested = (msg[0] != 0); /* num_codecs */ > syslog(LOG_INFO, "GOT START_STOP message -- request to %s streaming\n", > streaming_requested ? "START" : "STOP"); > @@ -98,7 +95,7 @@ static void handle_stream_start_stop(uint32_t len) > } > } > > -static void handle_stream_capabilities(uint32_t len) > +static void handle_stream_capabilities(StreamPort &stream_port, uint32_t > len) > { > uint8_t caps[STREAM_MSG_CAPABILITIES_MAX_BYTES]; > > @@ -106,7 +103,7 @@ static void handle_stream_capabilities(uint32_t len) > throw std::runtime_error("capability message too long"); > } > > - read_all(streamfd, caps, len); > + stream_port.read(caps, len); > // we currently do not support extensions so just reply so > StreamDevHeader hdr = { > STREAM_DEVICE_PROTOCOL, > @@ -114,10 +111,11 @@ static void handle_stream_capabilities(uint32_t len) > STREAM_TYPE_CAPABILITIES, > 0 > }; > - write_all(streamfd, &hdr, sizeof(hdr)); > + > + stream_port.write(&hdr, sizeof(hdr)); > } > > -static void handle_stream_error(size_t len) > +static void handle_stream_error(StreamPort &stream_port, size_t len) > { > if (len < sizeof(StreamMsgNotifyError)) { > throw std::runtime_error("Received NotifyError message size " + > std::to_string(len) + > @@ -131,7 +129,7 @@ static void handle_stream_error(size_t len) > > size_t len_to_read = std::min(len, sizeof(msg) - 1); > > - read_all(streamfd, &msg, len_to_read); > + stream_port.read(&msg, len_to_read); > msg.msg[len_to_read - sizeof(StreamMsgNotifyError)] = '\0'; > > syslog(LOG_ERR, "Received NotifyError message from the server: %d - > %s\n", > @@ -143,13 +141,12 @@ static void handle_stream_error(size_t len) > } > } > > -static void read_command_from_device(void) > +static void read_command_from_device(StreamPort &stream_port) > { > StreamDevHeader hdr; > > - std::lock_guard<std::mutex> stream_guard(stream_mtx); > - > - read_all(streamfd, &hdr, sizeof(hdr)); > + std::lock_guard<std::mutex> guard(stream_port.mutex); > + stream_port.read(&hdr, sizeof(hdr)); > > if (hdr.protocol_version != STREAM_DEVICE_PROTOCOL) { > throw std::runtime_error("BAD VERSION " + > std::to_string(hdr.protocol_version) + > @@ -158,34 +155,34 @@ static void read_command_from_device(void) > > switch (hdr.type) { > case STREAM_TYPE_CAPABILITIES: > - return handle_stream_capabilities(hdr.size); > + return handle_stream_capabilities(stream_port, hdr.size); > case STREAM_TYPE_NOTIFY_ERROR: > - return handle_stream_error(hdr.size); > + return handle_stream_error(stream_port, hdr.size); > case STREAM_TYPE_START_STOP: > - return handle_stream_start_stop(hdr.size); > + return handle_stream_start_stop(stream_port, hdr.size); > } > throw std::runtime_error("UNKNOWN msg of type " + > std::to_string(hdr.type)); > } > > -static int read_command(bool blocking) > +static int read_command(StreamPort &stream_port, bool blocking) > { > int timeout = blocking?-1:0; > while (!quit_requested) { > - if (!have_something_to_read(timeout)) { > + if (!have_something_to_read(stream_port, timeout)) { > if (!blocking) { > return 0; > } > sleep(1); > continue; > } > - read_command_from_device(); > + read_command_from_device(stream_port); > break; > } > > return 1; > } > > -static void spice_stream_send_format(unsigned w, unsigned h, unsigned c) > +static void spice_stream_send_format(StreamPort &stream_port, unsigned w, > unsigned h, unsigned c) > { > > SpiceStreamFormatMessage msg; > @@ -198,12 +195,13 @@ static void spice_stream_send_format(unsigned w, > unsigned h, unsigned c) > msg.msg.width = w; > msg.msg.height = h; > msg.msg.codec = c; > + > syslog(LOG_DEBUG, "writing format\n"); > - std::lock_guard<std::mutex> stream_guard(stream_mtx); > - write_all(streamfd, &msg, msgsize); > + std::lock_guard<std::mutex> guard(stream_port.mutex); > + stream_port.write(&msg, msgsize); > } > > -static void spice_stream_send_frame(const void *buf, const unsigned size) > +static void spice_stream_send_frame(StreamPort &stream_port, const void > *buf, const unsigned size) > { > SpiceStreamDataMessage msg; > const size_t msgsize = sizeof(msg); > @@ -212,9 +210,10 @@ static void spice_stream_send_frame(const void *buf, > const unsigned size) > msg.hdr.protocol_version = STREAM_DEVICE_PROTOCOL; > msg.hdr.type = STREAM_TYPE_DATA; > msg.hdr.size = size; /* includes only the body? */ > - std::lock_guard<std::mutex> stream_guard(stream_mtx); > - write_all(streamfd, &msg, msgsize); > - write_all(streamfd, buf, size); > + > + std::lock_guard<std::mutex> guard(stream_port.mutex); > + stream_port.write(&msg, msgsize); > + stream_port.write(buf, size); > > syslog(LOG_DEBUG, "Sent a frame of size %u\n", size); > } > @@ -264,7 +263,7 @@ static void usage(const char *progname) > } > > static void > -send_cursor(unsigned width, unsigned height, int hotspot_x, int hotspot_y, > +send_cursor(StreamPort &stream_port, unsigned width, unsigned height, int > hotspot_x, int hotspot_y, > std::function<void(uint32_t *)> fill_cursor) > { > if (width >= STREAM_MSG_CURSOR_SET_MAX_WIDTH || height >= > STREAM_MSG_CURSOR_SET_MAX_HEIGHT) { > @@ -294,11 +293,11 @@ send_cursor(unsigned width, unsigned height, int > hotspot_x, int hotspot_y, > uint32_t *pixels = reinterpret_cast<uint32_t *>(cursor_msg.data); > fill_cursor(pixels); > > - std::lock_guard<std::mutex> stream_guard(stream_mtx); > - write_all(streamfd, msg.get(), cursor_size); > + std::lock_guard<std::mutex> guard(stream_port.mutex); > + stream_port.write(msg.get(), cursor_size); > } > > -static void cursor_changes(Display *display, int event_base) > +static void cursor_changes(StreamPort *stream_port, Display *display, int > event_base) > { > unsigned long last_serial = 0; > > @@ -323,26 +322,18 @@ static void cursor_changes(Display *display, int > event_base) > for (unsigned i = 0; i < cursor->width * cursor->height; ++i) > pixels[i] = cursor->pixels[i]; > }; > - send_cursor(cursor->width, cursor->height, cursor->xhot, > cursor->yhot, fill_cursor); > + send_cursor(*stream_port, cursor->width, cursor->height, > cursor->xhot, cursor->yhot, fill_cursor); > } > } > > static void > -do_capture(const char *streamport, FILE *f_log) > +do_capture(StreamPort &stream_port, FILE *f_log) > { > - streamfd = open(streamport, O_RDWR | O_NONBLOCK); > - if (streamfd < 0) { > - throw std::runtime_error("failed to open the streaming device (" + > - std::string(streamport) + "): " > - + strerror(errno)); > - } > - > unsigned int frame_count = 0; > while (!quit_requested) { > while (!quit_requested && !streaming_requested) { > - if (read_command(true) < 0) { > - syslog(LOG_ERR, "FAILED to read command\n"); > - goto done; > + if (read_command(stream_port, true) < 0) { > + throw std::runtime_error("FAILED to read command"); > } > } > > @@ -385,7 +376,7 @@ do_capture(const char *streamport, FILE *f_log) > > syslog(LOG_DEBUG, "wXh %uX%u codec=%u\n", width, height, > codec); > > - spice_stream_send_format(width, height, codec); > + spice_stream_send_format(stream_port, width, height, codec); > } > if (f_log) { > if (log_binary) { > @@ -398,32 +389,25 @@ do_capture(const char *streamport, FILE *f_log) > } > > try { > - spice_stream_send_frame(frame.buffer, frame.buffer_size); > + spice_stream_send_frame(stream_port, frame.buffer, > frame.buffer_size); > } catch (const WriteError& e) { > syslog(e); > break; > } > > //usleep(1); > - if (read_command(false) < 0) { > - syslog(LOG_ERR, "FAILED to read command\n"); > - goto done; > + if (read_command(stream_port, false) < 0) { > + throw std::runtime_error("FAILED to read command"); > } > } > } > - > -done: > - if (streamfd >= 0) { > - close(streamfd); > - streamfd = -1; > - } > } > > #define arg_error(...) syslog(LOG_ERR, ## __VA_ARGS__); > > int main(int argc, char* argv[]) > { > - const char *streamport = "/dev/virtio-ports/org.spice-space.stream.0"; > + const char *stream_port_name = > "/dev/virtio-ports/org.spice-space.stream.0"; > int opt; > const char *log_filename = NULL; > int logmask = LOG_UPTO(LOG_WARNING); > @@ -454,7 +438,7 @@ int main(int argc, char* argv[]) > pluginsdir = optarg; > break; > case 'p': > - streamport = optarg; > + stream_port_name = optarg; > break; > case 'c': { > char *p = strchr(optarg, '='); > @@ -512,12 +496,15 @@ int main(int argc, char* argv[]) > Window rootwindow = DefaultRootWindow(display); > XFixesSelectCursorInput(display, rootwindow, > XFixesDisplayCursorNotifyMask); > > - std::thread cursor_th(cursor_changes, display, event_base); > - cursor_th.detach(); > - > int ret = EXIT_SUCCESS; > + > try { > - do_capture(streamport, f_log); > + StreamPort stream_port(stream_port_name); > + > + std::thread cursor_th(cursor_changes, &stream_port, display, > event_base); > + cursor_th.detach(); > + > + do_capture(stream_port, f_log); > } > catch (std::exception &err) { > syslog(LOG_ERR, "%s\n", err.what()); > diff --git a/src/stream-port.cpp b/src/stream-port.cpp > index 72364bd..5528854 100644 > --- a/src/stream-port.cpp > +++ b/src/stream-port.cpp > @@ -8,6 +8,7 @@ > #include "error.hpp" > > #include <errno.h> > +#include <fcntl.h> > #include <poll.h> > #include <string.h> > #include <syslog.h> > @@ -18,6 +19,28 @@ > namespace spice { > namespace streaming_agent { > > +StreamPort::StreamPort(const std::string &port_name) : > fd(open(port_name.c_str(), O_RDWR | O_NONBLOCK)) > +{ > + if (fd < 0) { > + throw IOError("Failed to open the streaming device \"" + port_name + > "\"", errno); > + } > +} > + > +StreamPort::~StreamPort() > +{ > + close(fd); > +} > + > +void StreamPort::read(void *buf, size_t len) > +{ > + read_all(fd, buf, len); > +} > + > +void StreamPort::write(const void *buf, size_t len) > +{ > + write_all(fd, buf, len); > +} > + > void read_all(int fd, void *buf, size_t len) > { > while (len > 0) { > diff --git a/src/stream-port.hpp b/src/stream-port.hpp > index b2d8352..9187cf5 100644 > --- a/src/stream-port.hpp > +++ b/src/stream-port.hpp > @@ -8,11 +8,25 @@ > #define SPICE_STREAMING_AGENT_STREAM_PORT_HPP > > #include <cstddef> > +#include <string> > +#include <mutex> > > > namespace spice { > namespace streaming_agent { > > +class StreamPort { > +public: > + StreamPort(const std::string &port_name); > + ~StreamPort(); > + > + void read(void *buf, size_t len); > + void write(const void *buf, size_t len); > + > + int fd; > + std::mutex mutex; > +}; > + > void read_all(int fd, void *buf, size_t len); > void write_all(int fd, const void *buf, size_t len); > Acked-by: Frediano Ziglio <fziglio@xxxxxxxxxx> Frediano _______________________________________________ Spice-devel mailing list Spice-devel@xxxxxxxxxxxxxxxxxxxxx https://lists.freedesktop.org/mailman/listinfo/spice-devel