Re: [PATCH spice-streaming-agent 8/9] Encapsulate the stream port fd and locking

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



> Wrap the streaming virtio port along with the mutex to lock it in a
> class. Pass the class temporarily around to functions that need it until
> the functions too are consolidated into the class.
> 
> Signed-off-by: Lukáš Hrázký <lhrazky@xxxxxxxxxx>

The mutex is supposed to avoid situations like:

1- data header
2- cursor header
3- data payload
4- cursor payload

protecting single write instead of full messages allow these
situations.

To fix this I would either:
1- document that synchronization should be done outside the class
   (either removing the mutex from class or using it)
2- add a lock/unlock;
3- add a StreamPortLocked class and move write/read to it. The locker
   class will hold the lock and make sure read/write are done with
   the mutex held. This also allows to make the mutex field private
   (using friend class) and prevent usage or read/write without
   the mutex held.

(I prefer 3 but all are good).

Frediano

> ---
>  src/spice-streaming-agent.cpp | 104
>  +++++++++++++++++-------------------------
>  src/stream-port.cpp           |  25 ++++++++++
>  src/stream-port.hpp           |  14 ++++++
>  3 files changed, 81 insertions(+), 62 deletions(-)
> 
> diff --git a/src/spice-streaming-agent.cpp b/src/spice-streaming-agent.cpp
> index 692f067..2fdd02f 100644
> --- a/src/spice-streaming-agent.cpp
> +++ b/src/spice-streaming-agent.cpp
> @@ -32,7 +32,6 @@
>  #include <exception>
>  #include <stdexcept>
>  #include <memory>
> -#include <mutex>
>  #include <thread>
>  #include <vector>
>  #include <string>
> @@ -60,12 +59,10 @@ static bool streaming_requested = false;
>  static bool quit_requested = false;
>  static bool log_binary = false;
>  static std::set<SpiceVideoCodecType> client_codecs;
> -static int streamfd = -1;
> -static std::mutex stream_mtx;
>  
> -static int have_something_to_read(int timeout)
> +static int have_something_to_read(StreamPort &stream_port, int timeout)
>  {
> -    struct pollfd pollfd = {streamfd, POLLIN, 0};
> +    struct pollfd pollfd = {stream_port.fd, POLLIN, 0};
>  
>      if (poll(&pollfd, 1, timeout) < 0) {
>          syslog(LOG_ERR, "poll FAILED\n");
> @@ -79,7 +76,7 @@ static int have_something_to_read(int timeout)
>      return 0;
>  }
>  
> -static void handle_stream_start_stop(uint32_t len)
> +static void handle_stream_start_stop(StreamPort &stream_port, uint32_t len)
>  {
>      uint8_t msg[256];
>  
> @@ -88,7 +85,7 @@ static void handle_stream_start_stop(uint32_t len)
>                                   "(longer than " +
>                                   std::to_string(sizeof(msg)) + ")");
>      }
>  
> -    read_all(streamfd, msg, len);
> +    stream_port.read(msg, len);
>      streaming_requested = (msg[0] != 0); /* num_codecs */
>      syslog(LOG_INFO, "GOT START_STOP message -- request to %s streaming\n",
>             streaming_requested ? "START" : "STOP");
> @@ -98,7 +95,7 @@ static void handle_stream_start_stop(uint32_t len)
>      }
>  }
>  
> -static void handle_stream_capabilities(uint32_t len)
> +static void handle_stream_capabilities(StreamPort &stream_port, uint32_t
> len)
>  {
>      uint8_t caps[STREAM_MSG_CAPABILITIES_MAX_BYTES];
>  
> @@ -106,7 +103,7 @@ static void handle_stream_capabilities(uint32_t len)
>          throw std::runtime_error("capability message too long");
>      }
>  
> -    read_all(streamfd, caps, len);
> +    stream_port.read(caps, len);
>      // we currently do not support extensions so just reply so
>      StreamDevHeader hdr = {
>          STREAM_DEVICE_PROTOCOL,
> @@ -114,10 +111,10 @@ static void handle_stream_capabilities(uint32_t len)
>          STREAM_TYPE_CAPABILITIES,
>          0
>      };
> -    write_all(streamfd, &hdr, sizeof(hdr));
> +    stream_port.write(&hdr, sizeof(hdr));
>  }
>  
> -static void handle_stream_error(size_t len)
> +static void handle_stream_error(StreamPort &stream_port, size_t len)
>  {
>      if (len < sizeof(StreamMsgNotifyError)) {
>          throw std::runtime_error("Received NotifyError message size " +
>          std::to_string(len) +
> @@ -131,7 +128,7 @@ static void handle_stream_error(size_t len)
>  
>      size_t len_to_read = std::min(len, sizeof(msg) - 1);
>  
> -    read_all(streamfd, &msg, len_to_read);
> +    stream_port.read(&msg, len_to_read);
>      msg.msg[len_to_read - sizeof(StreamMsgNotifyError)] = '\0';
>  
>      syslog(LOG_ERR, "Received NotifyError message from the server: %d -
>      %s\n",
> @@ -143,13 +140,11 @@ static void handle_stream_error(size_t len)
>      }
>  }
>  
> -static void read_command_from_device(void)
> +static void read_command_from_device(StreamPort &stream_port)
>  {
>      StreamDevHeader hdr;
>  
> -    std::lock_guard<std::mutex> stream_guard(stream_mtx);
> -
> -    read_all(streamfd, &hdr, sizeof(hdr));
> +    stream_port.read(&hdr, sizeof(hdr));
>  
>      if (hdr.protocol_version != STREAM_DEVICE_PROTOCOL) {
>          throw std::runtime_error("BAD VERSION " +
>          std::to_string(hdr.protocol_version) +
> @@ -158,34 +153,34 @@ static void read_command_from_device(void)
>  
>      switch (hdr.type) {
>      case STREAM_TYPE_CAPABILITIES:
> -        return handle_stream_capabilities(hdr.size);
> +        return handle_stream_capabilities(stream_port, hdr.size);
>      case STREAM_TYPE_NOTIFY_ERROR:
> -        return handle_stream_error(hdr.size);
> +        return handle_stream_error(stream_port, hdr.size);
>      case STREAM_TYPE_START_STOP:
> -        return handle_stream_start_stop(hdr.size);
> +        return handle_stream_start_stop(stream_port, hdr.size);
>      }
>      throw std::runtime_error("UNKNOWN msg of type " +
>      std::to_string(hdr.type));
>  }
>  
> -static int read_command(bool blocking)
> +static int read_command(StreamPort &stream_port, bool blocking)
>  {
>      int timeout = blocking?-1:0;
>      while (!quit_requested) {
> -        if (!have_something_to_read(timeout)) {
> +        if (!have_something_to_read(stream_port, timeout)) {
>              if (!blocking) {
>                  return 0;
>              }
>              sleep(1);
>              continue;
>          }
> -        read_command_from_device();
> +        read_command_from_device(stream_port);
>          break;
>      }
>  
>      return 1;
>  }
>  
> -static void spice_stream_send_format(unsigned w, unsigned h, unsigned c)
> +static void spice_stream_send_format(StreamPort &stream_port, unsigned w,
> unsigned h, unsigned c)
>  {
>  
>      SpiceStreamFormatMessage msg;
> @@ -199,11 +194,10 @@ static void spice_stream_send_format(unsigned w,
> unsigned h, unsigned c)
>      msg.msg.height = h;
>      msg.msg.codec = c;
>      syslog(LOG_DEBUG, "writing format\n");
> -    std::lock_guard<std::mutex> stream_guard(stream_mtx);
> -    write_all(streamfd, &msg, msgsize);
> +    stream_port.write(&msg, msgsize);
>  }
>  
> -static void spice_stream_send_frame(const void *buf, const unsigned size)
> +static void spice_stream_send_frame(StreamPort &stream_port, const void
> *buf, const unsigned size)
>  {
>      SpiceStreamDataMessage msg;
>      const size_t msgsize = sizeof(msg);
> @@ -212,9 +206,8 @@ static void spice_stream_send_frame(const void *buf,
> const unsigned size)
>      msg.hdr.protocol_version = STREAM_DEVICE_PROTOCOL;
>      msg.hdr.type = STREAM_TYPE_DATA;
>      msg.hdr.size = size; /* includes only the body? */
> -    std::lock_guard<std::mutex> stream_guard(stream_mtx);
> -    write_all(streamfd, &msg, msgsize);
> -    write_all(streamfd, buf, size);
> +    stream_port.write(&msg, msgsize);
> +    stream_port.write(buf, size);
>  
>      syslog(LOG_DEBUG, "Sent a frame of size %u\n", size);
>  }
> @@ -264,7 +257,7 @@ static void usage(const char *progname)
>  }
>  
>  static void
> -send_cursor(unsigned width, unsigned height, int hotspot_x, int hotspot_y,
> +send_cursor(StreamPort &stream_port, unsigned width, unsigned height, int
> hotspot_x, int hotspot_y,
>              std::function<void(uint32_t *)> fill_cursor)
>  {
>      if (width >= STREAM_MSG_CURSOR_SET_MAX_WIDTH || height >=
>      STREAM_MSG_CURSOR_SET_MAX_HEIGHT) {
> @@ -294,11 +287,10 @@ send_cursor(unsigned width, unsigned height, int
> hotspot_x, int hotspot_y,
>      uint32_t *pixels = reinterpret_cast<uint32_t *>(cursor_msg.data);
>      fill_cursor(pixels);
>  
> -    std::lock_guard<std::mutex> stream_guard(stream_mtx);
> -    write_all(streamfd, msg.get(), cursor_size);
> +    stream_port.write(msg.get(), cursor_size);
>  }
>  
> -static void cursor_changes(Display *display, int event_base)
> +static void cursor_changes(StreamPort *stream_port, Display *display, int
> event_base)
>  {
>      unsigned long last_serial = 0;
>  
> @@ -323,26 +315,18 @@ static void cursor_changes(Display *display, int
> event_base)
>              for (unsigned i = 0; i < cursor->width * cursor->height; ++i)
>                  pixels[i] = cursor->pixels[i];
>          };
> -        send_cursor(cursor->width, cursor->height, cursor->xhot,
> cursor->yhot, fill_cursor);
> +        send_cursor(*stream_port, cursor->width, cursor->height,
> cursor->xhot, cursor->yhot, fill_cursor);
>      }
>  }
>  
>  static void
> -do_capture(const char *streamport, FILE *f_log)
> +do_capture(StreamPort &stream_port, FILE *f_log)
>  {
> -    streamfd = open(streamport, O_RDWR | O_NONBLOCK);
> -    if (streamfd < 0) {
> -        throw std::runtime_error("failed to open the streaming device (" +
> -                                 std::string(streamport) + "): "
> -                                 + strerror(errno));
> -    }
> -
>      unsigned int frame_count = 0;
>      while (!quit_requested) {
>          while (!quit_requested && !streaming_requested) {
> -            if (read_command(true) < 0) {
> -                syslog(LOG_ERR, "FAILED to read command\n");
> -                goto done;
> +            if (read_command(stream_port, true) < 0) {
> +                throw std::runtime_error("FAILED to read command");
>              }
>          }
>  
> @@ -385,7 +369,7 @@ do_capture(const char *streamport, FILE *f_log)
>  
>                  syslog(LOG_DEBUG, "wXh %uX%u  codec=%u\n", width, height,
>                  codec);
>  
> -                spice_stream_send_format(width, height, codec);
> +                spice_stream_send_format(stream_port, width, height, codec);
>              }
>              if (f_log) {
>                  if (log_binary) {
> @@ -398,32 +382,25 @@ do_capture(const char *streamport, FILE *f_log)
>              }
>  
>              try {
> -                spice_stream_send_frame(frame.buffer, frame.buffer_size);
> +                spice_stream_send_frame(stream_port, frame.buffer,
> frame.buffer_size);
>              } catch (const WriteError& e) {
>                  syslog(e);
>                  break;
>              }
>  
>              //usleep(1);
> -            if (read_command(false) < 0) {
> -                syslog(LOG_ERR, "FAILED to read command\n");
> -                goto done;
> +            if (read_command(stream_port, false) < 0) {
> +                throw std::runtime_error("FAILED to read command");
>              }
>          }
>      }
> -
> -done:
> -    if (streamfd >= 0) {
> -        close(streamfd);
> -        streamfd = -1;
> -    }
>  }
>  
>  #define arg_error(...) syslog(LOG_ERR, ## __VA_ARGS__);
>  
>  int main(int argc, char* argv[])
>  {
> -    const char *streamport = "/dev/virtio-ports/org.spice-space.stream.0";
> +    const char *stream_port_name =
> "/dev/virtio-ports/org.spice-space.stream.0";
>      int opt;
>      const char *log_filename = NULL;
>      int logmask = LOG_UPTO(LOG_WARNING);
> @@ -454,7 +431,7 @@ int main(int argc, char* argv[])
>              pluginsdir = optarg;
>              break;
>          case 'p':
> -            streamport = optarg;
> +            stream_port_name = optarg;
>              break;
>          case 'c': {
>              char *p = strchr(optarg, '=');
> @@ -512,12 +489,15 @@ int main(int argc, char* argv[])
>      Window rootwindow = DefaultRootWindow(display);
>      XFixesSelectCursorInput(display, rootwindow,
>      XFixesDisplayCursorNotifyMask);
>  
> -    std::thread cursor_th(cursor_changes, display, event_base);
> -    cursor_th.detach();
> -
>      int ret = EXIT_SUCCESS;
> +
>      try {
> -        do_capture(streamport, f_log);
> +        StreamPort stream_port(stream_port_name);
> +
> +        std::thread cursor_th(cursor_changes, &stream_port, display,
> event_base);
> +        cursor_th.detach();
> +
> +        do_capture(stream_port, f_log);
>      }
>      catch (std::exception &err) {
>          syslog(LOG_ERR, "%s\n", err.what());
> diff --git a/src/stream-port.cpp b/src/stream-port.cpp
> index ee85179..3cd4753 100644
> --- a/src/stream-port.cpp
> +++ b/src/stream-port.cpp
> @@ -8,6 +8,7 @@
>  #include "error.hpp"
>  
>  #include <errno.h>
> +#include <fcntl.h>
>  #include <poll.h>
>  #include <string.h>
>  #include <syslog.h>
> @@ -18,6 +19,30 @@
>  namespace spice {
>  namespace streaming_agent {
>  
> +StreamPort::StreamPort(const std::string &port_name) :
> fd(open(port_name.c_str(), O_RDWR | O_NONBLOCK))
> +{
> +    if (fd < 0) {
> +        throw IOError("Failed to open the streaming device \"" + port_name +
> "\"", errno);
> +    }
> +}
> +
> +StreamPort::~StreamPort()
> +{
> +    close(fd);
> +}
> +
> +void StreamPort::read(void *buf, size_t len)
> +{
> +    std::lock_guard<std::mutex> guard(mutex);
> +    read_all(fd, buf, len);
> +}
> +
> +void StreamPort::write(const void *buf, size_t len)
> +{
> +    std::lock_guard<std::mutex> guard(mutex);
> +    write_all(fd, buf, len);
> +}
> +
>  void read_all(int fd, void *buf, size_t len)
>  {
>      while (len > 0) {
> diff --git a/src/stream-port.hpp b/src/stream-port.hpp
> index b2d8352..9187cf5 100644
> --- a/src/stream-port.hpp
> +++ b/src/stream-port.hpp
> @@ -8,11 +8,25 @@
>  #define SPICE_STREAMING_AGENT_STREAM_PORT_HPP
>  
>  #include <cstddef>
> +#include <string>
> +#include <mutex>
>  
>  
>  namespace spice {
>  namespace streaming_agent {
>  
> +class StreamPort {
> +public:
> +    StreamPort(const std::string &port_name);
> +    ~StreamPort();
> +
> +    void read(void *buf, size_t len);
> +    void write(const void *buf, size_t len);
> +
> +    int fd;
> +    std::mutex mutex;
> +};
> +
>  void read_all(int fd, void *buf, size_t len);
>  void write_all(int fd, const void *buf, size_t len);
>  
_______________________________________________
Spice-devel mailing list
Spice-devel@xxxxxxxxxxxxxxxxxxxxx
https://lists.freedesktop.org/mailman/listinfo/spice-devel




[Index of Archives]     [Linux ARM Kernel]     [Linux ARM]     [Linux Omap]     [Fedora ARM]     [IETF Annouce]     [Security]     [Bugtraq]     [Linux]     [Linux OMAP]     [Linux MIPS]     [ECOS]     [Asterisk Internet PBX]     [Linux API]     [Monitors]